feat(auth): wire auth end-to-end (middleware + frontend replacement)

Backend:
- Port auth_middleware, csrf_middleware, langgraph_auth, routers/auth
- Port authz decorator (owner_filter_key defaults to 'owner_id')
- Merge app.py: register AuthMiddleware + CSRFMiddleware + CORS, add
  _ensure_admin_user lifespan hook, _migrate_orphaned_threads helper,
  register auth router
- Merge deps.py: add get_local_provider, get_current_user_from_request,
  get_optional_user_from_request; keep get_current_user as thin str|None
  adapter for feedback router
- langgraph.json: add auth path pointing to langgraph_auth.py:auth
- Rename metadata['user_id'] -> metadata['owner_id'] in langgraph_auth
  (both metadata write and LangGraph filter dict) + test fixtures

Frontend:
- Delete better-auth library and api catch-all route
- Remove better-auth npm dependency and env vars (BETTER_AUTH_SECRET,
  BETTER_AUTH_GITHUB_*) from env.js
- Port frontend/src/core/auth/* (AuthProvider, gateway-config,
  proxy-policy, server-side getServerSideUser, types)
- Port frontend/src/core/api/fetcher.ts
- Port (auth)/layout, (auth)/login, (auth)/setup pages
- Rewrite workspace/layout.tsx as server component that calls
  getServerSideUser and wraps in AuthProvider
- Port workspace/workspace-content.tsx for the client-side sidebar logic

Tests:
- Port 5 auth test files (test_auth, test_auth_middleware,
  test_auth_type_system, test_ensure_admin, test_langgraph_auth)
- 176 auth tests PASS

After this commit: login/logout/registration flow works, but persistence
layer does not yet filter by owner_id. Commit 4 closes that gap.
This commit is contained in:
greatmengqi 2026-04-08 09:41:56 +08:00
parent 03c3b18565
commit f942e4e597
32 changed files with 3859 additions and 268 deletions

View File

@ -1,15 +1,21 @@
import logging
import os
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from datetime import UTC
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.gateway.auth_middleware import AuthMiddleware
from app.gateway.config import get_gateway_config
from app.gateway.csrf_middleware import CSRFMiddleware
from app.gateway.deps import langgraph_runtime
from app.gateway.routers import (
agents,
artifacts,
assistants_compat,
auth,
channels,
feedback,
mcp,
@ -34,6 +40,92 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
async def _ensure_admin_user(app: FastAPI) -> None:
"""Auto-create the admin user on first boot if no users exist.
Prints the generated password to stdout so the operator can log in.
On subsequent boots, warns if any user still needs setup.
Multi-worker safe: relies on SQLite UNIQUE constraint to resolve races.
Only the worker that successfully creates/updates the admin prints the
password; losers silently skip.
"""
import secrets
from app.gateway.deps import get_local_provider
provider = get_local_provider()
user_count = await provider.count_users()
if user_count == 0:
password = secrets.token_urlsafe(16)
try:
admin = await provider.create_user(email="admin@deerflow.dev", password=password, system_role="admin", needs_setup=True)
except ValueError:
return # Another worker already created the admin.
# Migrate orphaned threads (no owner_id) to this admin
store = getattr(app.state, "store", None)
if store is not None:
await _migrate_orphaned_threads(store, str(admin.id))
logger.info("=" * 60)
logger.info(" Admin account created on first boot")
logger.info(" Email: %s", admin.email)
logger.info(" Password: %s", password)
logger.info(" Change it after login: Settings -> Account")
logger.info("=" * 60)
return
# Admin exists but setup never completed — reset password so operator
# can always find it in the console without needing the CLI.
# Multi-worker guard: if admin was created less than 30s ago, another
# worker just created it and will print the password — skip reset.
admin = await provider.get_user_by_email("admin@deerflow.dev")
if admin and admin.needs_setup:
import time
age = time.time() - admin.created_at.replace(tzinfo=UTC).timestamp()
if age < 30:
return # Just created by another worker in this startup; its password is still valid.
from app.gateway.auth.password import hash_password_async
password = secrets.token_urlsafe(16)
admin.password_hash = await hash_password_async(password)
admin.token_version += 1
await provider.update_user(admin)
logger.info("=" * 60)
logger.info(" Admin account setup incomplete — password reset")
logger.info(" Email: %s", admin.email)
logger.info(" Password: %s", password)
logger.info(" Change it after login: Settings -> Account")
logger.info("=" * 60)
async def _migrate_orphaned_threads(store, admin_user_id: str) -> None:
"""Migrate threads with no owner_id to the given admin.
NOTE: This is the initial port. Commit 5 will replace the hardcoded
limit=1000 with cursor pagination and extend to SQL persistence tables.
"""
try:
migrated = 0
results = await store.asearch(("threads",), limit=1000)
for item in results:
metadata = item.value.get("metadata", {})
if not metadata.get("owner_id"):
metadata["owner_id"] = admin_user_id
item.value["metadata"] = metadata
await store.aput(("threads",), item.key, item.value)
migrated += 1
if migrated:
logger.info("Migrated %d orphaned thread(s) to admin", migrated)
except Exception:
logger.exception("Thread migration failed (non-fatal)")
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Application lifespan handler."""
@ -53,6 +145,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
async with langgraph_runtime(app):
logger.info("LangGraph runtime initialised")
# Ensure admin user exists (auto-create on first boot)
# Must run AFTER langgraph_runtime so app.state.store is available for thread migration
await _ensure_admin_user(app)
# Start IM channel service if any channels are configured
try:
from app.channels.service import start_channel_service
@ -164,7 +260,35 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
],
)
# CORS is handled by nginx - no need for FastAPI middleware
# Auth: reject unauthenticated requests to non-public paths (fail-closed safety net)
app.add_middleware(AuthMiddleware)
# CSRF: Double Submit Cookie pattern for state-changing requests
app.add_middleware(CSRFMiddleware)
# CORS: when GATEWAY_CORS_ORIGINS is set (dev without nginx), add CORS middleware.
# In production, nginx handles CORS and no middleware is needed.
cors_origins_env = os.environ.get("GATEWAY_CORS_ORIGINS", "")
if cors_origins_env:
cors_origins = [o.strip() for o in cors_origins_env.split(",") if o.strip()]
# Validate: wildcard origin with credentials is a security misconfiguration
for origin in cors_origins:
if origin == "*":
logger.error(
"GATEWAY_CORS_ORIGINS contains wildcard '*' with allow_credentials=True. "
"This is a security misconfiguration — browsers will reject the response. "
"Use explicit scheme://host:port origins instead."
)
cors_origins = [o for o in cors_origins if o != "*"]
break
if cors_origins:
app.add_middleware(
CORSMiddleware,
allow_origins=cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Include routers
# Models API is mounted at /api/models
@ -200,6 +324,9 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
# Assistants compatibility API (LangGraph Platform stub)
app.include_router(assistants_compat.router)
# Auth API is mounted at /api/v1/auth
app.include_router(auth.router)
# Feedback API is mounted at /api/threads/{thread_id}/runs/{run_id}/feedback
app.include_router(feedback.router)

View File

@ -0,0 +1,71 @@
"""Global authentication middleware — fail-closed safety net.
Rejects unauthenticated requests to non-public paths with 401.
Fine-grained permission checks remain in authz.py decorators.
"""
from collections.abc import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from starlette.types import ASGIApp
from app.gateway.auth.errors import AuthErrorCode
# Paths that never require authentication.
_PUBLIC_PATH_PREFIXES: tuple[str, ...] = (
"/health",
"/docs",
"/redoc",
"/openapi.json",
)
# Exact auth paths that are public (login/register/status check).
# /api/v1/auth/me, /api/v1/auth/change-password etc. are NOT public.
_PUBLIC_EXACT_PATHS: frozenset[str] = frozenset(
{
"/api/v1/auth/login/local",
"/api/v1/auth/register",
"/api/v1/auth/logout",
"/api/v1/auth/setup-status",
}
)
def _is_public(path: str) -> bool:
stripped = path.rstrip("/")
if stripped in _PUBLIC_EXACT_PATHS:
return True
return any(path.startswith(prefix) for prefix in _PUBLIC_PATH_PREFIXES)
class AuthMiddleware(BaseHTTPMiddleware):
"""Coarse-grained auth gate: reject requests without a valid session cookie.
This does NOT verify JWT signature or user existence that is the job of
``get_current_user_from_request`` in deps.py (called by ``@require_auth``).
The middleware only checks *presence* of the cookie so that new endpoints
that forget ``@require_auth`` are not completely exposed.
"""
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
async def dispatch(self, request: Request, call_next: Callable) -> Response:
if _is_public(request.url.path):
return await call_next(request)
# Non-public path: require session cookie
if not request.cookies.get("access_token"):
return JSONResponse(
status_code=401,
content={
"detail": {
"code": AuthErrorCode.NOT_AUTHENTICATED,
"message": "Authentication required",
}
},
)
return await call_next(request)

View File

@ -0,0 +1,261 @@
"""Authorization decorators and context for DeerFlow.
Inspired by LangGraph Auth system: https://github.com/langchain-ai/langgraph/blob/main/libs/sdk-py/langgraph_sdk/auth/__init__.py
**Usage:**
1. Use ``@require_auth`` on routes that need authentication
2. Use ``@require_permission("resource", "action", filter_key=...)`` for permission checks
3. The decorator chain processes from bottom to top
**Example:**
@router.get("/{thread_id}")
@require_auth
@require_permission("threads", "read", owner_check=True)
async def get_thread(thread_id: str, request: Request):
# User is authenticated and has threads:read permission
...
**Permission Model:**
- threads:read - View thread
- threads:write - Create/update thread
- threads:delete - Delete thread
- runs:create - Run agent
- runs:read - View run
- runs:cancel - Cancel run
"""
from __future__ import annotations
import functools
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
from fastapi import HTTPException, Request
if TYPE_CHECKING:
from app.gateway.auth.models import User
P = ParamSpec("P")
T = TypeVar("T")
# Permission constants
class Permissions:
"""Permission constants for resource:action format."""
# Threads
THREADS_READ = "threads:read"
THREADS_WRITE = "threads:write"
THREADS_DELETE = "threads:delete"
# Runs
RUNS_CREATE = "runs:create"
RUNS_READ = "runs:read"
RUNS_CANCEL = "runs:cancel"
class AuthContext:
"""Authentication context for the current request.
Stored in request.state.auth after require_auth decoration.
Attributes:
user: The authenticated user, or None if anonymous
permissions: List of permission strings (e.g., "threads:read")
"""
__slots__ = ("user", "permissions")
def __init__(self, user: User | None = None, permissions: list[str] | None = None):
self.user = user
self.permissions = permissions or []
@property
def is_authenticated(self) -> bool:
"""Check if user is authenticated."""
return self.user is not None
def has_permission(self, resource: str, action: str) -> bool:
"""Check if context has permission for resource:action.
Args:
resource: Resource name (e.g., "threads")
action: Action name (e.g., "read")
Returns:
True if user has permission
"""
permission = f"{resource}:{action}"
return permission in self.permissions
def require_user(self) -> User:
"""Get user or raise 401.
Raises:
HTTPException 401 if not authenticated
"""
if not self.user:
raise HTTPException(status_code=401, detail="Authentication required")
return self.user
def get_auth_context(request: Request) -> AuthContext | None:
"""Get AuthContext from request state."""
return getattr(request.state, "auth", None)
_ALL_PERMISSIONS: list[str] = [
Permissions.THREADS_READ,
Permissions.THREADS_WRITE,
Permissions.THREADS_DELETE,
Permissions.RUNS_CREATE,
Permissions.RUNS_READ,
Permissions.RUNS_CANCEL,
]
async def _authenticate(request: Request) -> AuthContext:
"""Authenticate request and return AuthContext.
Delegates to deps.get_optional_user_from_request() for the JWTUser pipeline.
Returns AuthContext with user=None for anonymous requests.
"""
from app.gateway.deps import get_optional_user_from_request
user = await get_optional_user_from_request(request)
if user is None:
return AuthContext(user=None, permissions=[])
# In future, permissions could be stored in user record
return AuthContext(user=user, permissions=_ALL_PERMISSIONS)
def require_auth[**P, T](func: Callable[P, T]) -> Callable[P, T]:
"""Decorator that authenticates the request and sets AuthContext.
Must be placed ABOVE other decorators (executes after them).
Usage:
@router.get("/{thread_id}")
@require_auth # Bottom decorator (executes first after permission check)
@require_permission("threads", "read")
async def get_thread(thread_id: str, request: Request):
auth: AuthContext = request.state.auth
...
Raises:
ValueError: If 'request' parameter is missing
"""
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
request = kwargs.get("request")
if request is None:
raise ValueError("require_auth decorator requires 'request' parameter")
# Authenticate and set context
auth_context = await _authenticate(request)
request.state.auth = auth_context
return await func(*args, **kwargs)
return wrapper
def require_permission(
resource: str,
action: str,
owner_check: bool = False,
owner_filter_key: str = "owner_id",
inject_record: bool = False,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Decorator that checks permission for resource:action.
Must be used AFTER @require_auth.
Args:
resource: Resource name (e.g., "threads", "runs")
action: Action name (e.g., "read", "write", "delete")
owner_check: If True, validates that the current user owns the resource.
Requires 'thread_id' path parameter and performs ownership check.
owner_filter_key: Field name for ownership filter (default: "owner_id")
inject_record: If True and owner_check is True, injects the thread record
into kwargs['thread_record'] for use in the handler.
Usage:
# Simple permission check
@require_permission("threads", "read")
async def get_thread(thread_id: str, request: Request):
...
# With ownership check (for /threads/{thread_id} endpoints)
@require_permission("threads", "delete", owner_check=True)
async def delete_thread(thread_id: str, request: Request):
...
# With ownership check and record injection
@require_permission("threads", "delete", owner_check=True, inject_record=True)
async def delete_thread(thread_id: str, request: Request, thread_record: dict = None):
# thread_record is injected if found
...
Raises:
HTTPException 401: If authentication required but user is anonymous
HTTPException 403: If user lacks permission
HTTPException 404: If owner_check=True but user doesn't own the thread
ValueError: If owner_check=True but 'thread_id' parameter is missing
"""
def decorator(func: Callable[P, T]) -> Callable[P, T]:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
request = kwargs.get("request")
if request is None:
raise ValueError("require_permission decorator requires 'request' parameter")
auth: AuthContext = getattr(request.state, "auth", None)
if auth is None:
auth = await _authenticate(request)
request.state.auth = auth
if not auth.is_authenticated:
raise HTTPException(status_code=401, detail="Authentication required")
# Check permission
if not auth.has_permission(resource, action):
raise HTTPException(
status_code=403,
detail=f"Permission denied: {resource}:{action}",
)
# Owner check for thread-specific resources
if owner_check:
thread_id = kwargs.get("thread_id")
if thread_id is None:
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
# Get thread and verify ownership
from app.gateway.routers.threads import _store_get, get_store
store = get_store(request)
if store is not None:
record = await _store_get(store, thread_id)
if record:
owner_id = record.get("metadata", {}).get(owner_filter_key)
if owner_id and owner_id != str(auth.user.id):
raise HTTPException(
status_code=404,
detail=f"Thread {thread_id} not found",
)
# Inject record if requested
if inject_record:
kwargs["thread_record"] = record
return await func(*args, **kwargs)
return wrapper
return decorator

View File

@ -0,0 +1,112 @@
"""CSRF protection middleware for FastAPI.
Per RFC-001:
State-changing operations require CSRF protection.
"""
import secrets
from collections.abc import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from starlette.types import ASGIApp
CSRF_COOKIE_NAME = "csrf_token"
CSRF_HEADER_NAME = "X-CSRF-Token"
CSRF_TOKEN_LENGTH = 64 # bytes
def is_secure_request(request: Request) -> bool:
"""Detect whether the original client request was made over HTTPS."""
return request.headers.get("x-forwarded-proto", request.url.scheme) == "https"
def generate_csrf_token() -> str:
"""Generate a secure random CSRF token."""
return secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
def should_check_csrf(request: Request) -> bool:
"""Determine if a request needs CSRF validation.
CSRF is checked for state-changing methods (POST, PUT, DELETE, PATCH).
GET, HEAD, OPTIONS, and TRACE are exempt per RFC 7231.
"""
if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
return False
path = request.url.path.rstrip("/")
# Exempt /api/v1/auth/me endpoint
if path == "/api/v1/auth/me":
return False
return True
_AUTH_EXEMPT_PATHS: frozenset[str] = frozenset(
{
"/api/v1/auth/login/local",
"/api/v1/auth/logout",
"/api/v1/auth/register",
}
)
def is_auth_endpoint(request: Request) -> bool:
"""Check if the request is to an auth endpoint.
Auth endpoints don't need CSRF validation on first call (no token).
"""
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
class CSRFMiddleware(BaseHTTPMiddleware):
"""Middleware that implements CSRF protection using Double Submit Cookie pattern."""
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
async def dispatch(self, request: Request, call_next: Callable) -> Response:
_is_auth = is_auth_endpoint(request)
if should_check_csrf(request) and not _is_auth:
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
header_token = request.headers.get(CSRF_HEADER_NAME)
if not cookie_token or not header_token:
return JSONResponse(
status_code=403,
content={"detail": "CSRF token missing. Include X-CSRF-Token header."},
)
if not secrets.compare_digest(cookie_token, header_token):
return JSONResponse(
status_code=403,
content={"detail": "CSRF token mismatch."},
)
response = await call_next(request)
# For auth endpoints that set up session, also set CSRF cookie
if _is_auth and request.method == "POST":
# Generate a new CSRF token for the session
csrf_token = generate_csrf_token()
is_https = is_secure_request(request)
response.set_cookie(
key=CSRF_COOKIE_NAME,
value=csrf_token,
httponly=False, # Must be JS-readable for Double Submit Cookie pattern
secure=is_https,
samesite="strict",
)
return response
def get_csrf_token(request: Request) -> str | None:
"""Get the CSRF token from the current request's cookies.
This is useful for server-side rendering where you need to embed
token in forms or headers.
"""
return request.cookies.get(CSRF_COOKIE_NAME)

View File

@ -11,11 +11,16 @@ from __future__ import annotations
from collections.abc import AsyncGenerator
from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING
from fastapi import FastAPI, HTTPException, Request
from deerflow.runtime import RunContext, RunManager
if TYPE_CHECKING:
from app.gateway.auth.local_provider import LocalAuthProvider
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
@asynccontextmanager
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
@ -127,10 +132,86 @@ def get_run_context(request: Request) -> RunContext:
)
async def get_current_user(request: Request) -> str | None:
"""Extract user identity from request.
# ---------------------------------------------------------------------------
# Auth helpers (used by authz.py and auth middleware)
# ---------------------------------------------------------------------------
Phase 2: always returns None (no authentication).
Phase 3: extract user_id from JWT / session / API key header.
# Cached singletons to avoid repeated instantiation per request
_cached_local_provider: LocalAuthProvider | None = None
_cached_repo: SQLiteUserRepository | None = None
def get_local_provider() -> LocalAuthProvider:
"""Get or create the cached LocalAuthProvider singleton."""
global _cached_local_provider, _cached_repo
if _cached_repo is None:
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
_cached_repo = SQLiteUserRepository()
if _cached_local_provider is None:
from app.gateway.auth.local_provider import LocalAuthProvider
_cached_local_provider = LocalAuthProvider(repository=_cached_repo)
return _cached_local_provider
async def get_current_user_from_request(request: Request):
"""Get the current authenticated user from the request cookie.
Raises HTTPException 401 if not authenticated.
"""
return None
from app.gateway.auth import decode_token
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
access_token = request.cookies.get("access_token")
if not access_token:
raise HTTPException(
status_code=401,
detail=AuthErrorResponse(code=AuthErrorCode.NOT_AUTHENTICATED, message="Not authenticated").model_dump(),
)
payload = decode_token(access_token)
if isinstance(payload, TokenError):
raise HTTPException(
status_code=401,
detail=AuthErrorResponse(code=token_error_to_code(payload), message=f"Token error: {payload.value}").model_dump(),
)
provider = get_local_provider()
user = await provider.get_user(payload.sub)
if user is None:
raise HTTPException(
status_code=401,
detail=AuthErrorResponse(code=AuthErrorCode.USER_NOT_FOUND, message="User not found").model_dump(),
)
# Token version mismatch → password was changed, token is stale
if user.token_version != payload.ver:
raise HTTPException(
status_code=401,
detail=AuthErrorResponse(code=AuthErrorCode.TOKEN_INVALID, message="Token revoked (password changed)").model_dump(),
)
return user
async def get_optional_user_from_request(request: Request):
"""Get optional authenticated user from request.
Returns None if not authenticated.
"""
try:
return await get_current_user_from_request(request)
except HTTPException:
return None
async def get_current_user(request: Request) -> str | None:
"""Extract user_id from request cookie, or None if not authenticated.
Thin adapter that returns the string id for callers that only need
identification (e.g., ``feedback.py``). Full-user callers should use
``get_current_user_from_request`` or ``get_optional_user_from_request``.
"""
user = await get_optional_user_from_request(request)
return str(user.id) if user else None

View File

@ -0,0 +1,106 @@
"""LangGraph Server auth handler — shares JWT logic with Gateway.
Loaded by LangGraph Server via langgraph.json ``auth.path``.
Reuses the same ``decode_token`` / ``get_auth_config`` as Gateway,
so both modes validate tokens with the same secret and rules.
Two layers:
1. @auth.authenticate validates JWT cookie, extracts user_id,
and enforces CSRF on state-changing methods (POST/PUT/DELETE/PATCH)
2. @auth.on returns metadata filter so each user only sees own threads
"""
import secrets
from langgraph_sdk import Auth
from app.gateway.auth.errors import TokenError
from app.gateway.auth.jwt import decode_token
from app.gateway.deps import get_local_provider
auth = Auth()
# Methods that require CSRF validation (state-changing per RFC 7231).
_CSRF_METHODS = frozenset({"POST", "PUT", "DELETE", "PATCH"})
def _check_csrf(request) -> None:
"""Enforce Double Submit Cookie CSRF check for state-changing requests.
Mirrors Gateway's CSRFMiddleware logic so that LangGraph routes
proxied directly by nginx have the same CSRF protection.
"""
method = getattr(request, "method", "") or ""
if method.upper() not in _CSRF_METHODS:
return
cookie_token = request.cookies.get("csrf_token")
header_token = request.headers.get("x-csrf-token")
if not cookie_token or not header_token:
raise Auth.exceptions.HTTPException(
status_code=403,
detail="CSRF token missing. Include X-CSRF-Token header.",
)
if not secrets.compare_digest(cookie_token, header_token):
raise Auth.exceptions.HTTPException(
status_code=403,
detail="CSRF token mismatch.",
)
@auth.authenticate
async def authenticate(request):
"""Validate the session cookie, decode JWT, and check token_version.
Same validation chain as Gateway's get_current_user_from_request:
cookie decode JWT DB lookup token_version match
Also enforces CSRF on state-changing methods.
"""
# CSRF check before authentication so forged cross-site requests
# are rejected early, even if the cookie carries a valid JWT.
_check_csrf(request)
token = request.cookies.get("access_token")
if not token:
raise Auth.exceptions.HTTPException(
status_code=401,
detail="Not authenticated",
)
payload = decode_token(token)
if isinstance(payload, TokenError):
raise Auth.exceptions.HTTPException(
status_code=401,
detail=f"Token error: {payload.value}",
)
user = await get_local_provider().get_user(payload.sub)
if user is None:
raise Auth.exceptions.HTTPException(
status_code=401,
detail="User not found",
)
if user.token_version != payload.ver:
raise Auth.exceptions.HTTPException(
status_code=401,
detail="Token revoked (password changed)",
)
return payload.sub
@auth.on
async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict):
"""Inject owner_id metadata on writes; filter by owner_id on reads.
Gateway stores thread ownership as ``metadata.owner_id``.
This handler ensures LangGraph Server enforces the same isolation.
"""
# On create/update: stamp owner_id into metadata
metadata = value.setdefault("metadata", {})
metadata["owner_id"] = ctx.user.identity
# Return filter dict — LangGraph applies it to search/read/delete
return {"owner_id": ctx.user.identity}

View File

@ -0,0 +1,303 @@
"""Authentication endpoints."""
import logging
import time
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel, EmailStr, Field
from app.gateway.auth import (
UserResponse,
create_access_token,
)
from app.gateway.auth.config import get_auth_config
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
from app.gateway.csrf_middleware import is_secure_request
from app.gateway.deps import get_current_user_from_request, get_local_provider
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
# ── Request/Response Models ──────────────────────────────────────────────
class LoginResponse(BaseModel):
"""Response model for login — token only lives in HttpOnly cookie."""
expires_in: int # seconds
needs_setup: bool = False
class RegisterRequest(BaseModel):
"""Request model for user registration."""
email: EmailStr
password: str = Field(..., min_length=8)
class ChangePasswordRequest(BaseModel):
"""Request model for password change (also handles setup flow)."""
current_password: str
new_password: str = Field(..., min_length=8)
new_email: EmailStr | None = None
class MessageResponse(BaseModel):
"""Generic message response."""
message: str
# ── Helpers ───────────────────────────────────────────────────────────────
def _set_session_cookie(response: Response, token: str, request: Request) -> None:
"""Set the access_token HttpOnly cookie on the response."""
config = get_auth_config()
is_https = is_secure_request(request)
response.set_cookie(
key="access_token",
value=token,
httponly=True,
secure=is_https,
samesite="lax",
max_age=config.token_expiry_days * 24 * 3600 if is_https else None,
)
# ── Rate Limiting ────────────────────────────────────────────────────────
# In-process dict — not shared across workers. Sufficient for single-worker deployments.
_MAX_LOGIN_ATTEMPTS = 5
_LOCKOUT_SECONDS = 300 # 5 minutes
# ip → (fail_count, lock_until_timestamp)
_login_attempts: dict[str, tuple[int, float]] = {}
def _get_client_ip(request: Request) -> str:
"""Extract the real client IP for rate limiting.
Uses ``X-Real-IP`` header set by nginx (``proxy_set_header X-Real-IP
$remote_addr``). Nginx unconditionally overwrites any client-supplied
``X-Real-IP``, so the value seen by Gateway is always the TCP peer IP
that nginx observed it cannot be spoofed by the client.
``request.client.host`` is NOT reliable because uvicorn's default
``proxy_headers=True`` replaces it with the *first* entry from
``X-Forwarded-For``, which IS client-spoofable.
``X-Forwarded-For`` is intentionally NOT used for the same reason.
"""
real_ip = request.headers.get("x-real-ip", "").strip()
if real_ip:
return real_ip
# Fallback: direct connection without nginx (e.g. unit tests, dev).
return request.client.host if request.client else "unknown"
def _check_rate_limit(ip: str) -> None:
"""Raise 429 if the IP is currently locked out."""
record = _login_attempts.get(ip)
if record is None:
return
fail_count, lock_until = record
if fail_count >= _MAX_LOGIN_ATTEMPTS:
if time.time() < lock_until:
raise HTTPException(
status_code=429,
detail="Too many login attempts. Try again later.",
)
del _login_attempts[ip]
_MAX_TRACKED_IPS = 10000
def _record_login_failure(ip: str) -> None:
"""Record a failed login attempt for the given IP."""
# Evict expired lockouts when dict grows too large
if len(_login_attempts) >= _MAX_TRACKED_IPS:
now = time.time()
expired = [k for k, (c, t) in _login_attempts.items() if c >= _MAX_LOGIN_ATTEMPTS and now >= t]
for k in expired:
del _login_attempts[k]
# If still too large, evict cheapest-to-lose half: below-threshold
# IPs (lock_until=0.0) sort first, then earliest-expiring lockouts.
if len(_login_attempts) >= _MAX_TRACKED_IPS:
by_time = sorted(_login_attempts.items(), key=lambda kv: kv[1][1])
for k, _ in by_time[: len(by_time) // 2]:
del _login_attempts[k]
record = _login_attempts.get(ip)
if record is None:
_login_attempts[ip] = (1, 0.0)
else:
new_count = record[0] + 1
lock_until = time.time() + _LOCKOUT_SECONDS if new_count >= _MAX_LOGIN_ATTEMPTS else 0.0
_login_attempts[ip] = (new_count, lock_until)
def _record_login_success(ip: str) -> None:
"""Clear failure counter for the given IP on successful login."""
_login_attempts.pop(ip, None)
# ── Endpoints ─────────────────────────────────────────────────────────────
@router.post("/login/local", response_model=LoginResponse)
async def login_local(
request: Request,
response: Response,
form_data: OAuth2PasswordRequestForm = Depends(),
):
"""Local email/password login."""
client_ip = _get_client_ip(request)
_check_rate_limit(client_ip)
user = await get_local_provider().authenticate({"email": form_data.username, "password": form_data.password})
if user is None:
_record_login_failure(client_ip)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Incorrect email or password").model_dump(),
)
_record_login_success(client_ip)
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return LoginResponse(
expires_in=get_auth_config().token_expiry_days * 24 * 3600,
needs_setup=user.needs_setup,
)
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def register(request: Request, response: Response, body: RegisterRequest):
"""Register a new user account (always 'user' role).
Admin is auto-created on first boot. This endpoint creates regular users.
Auto-login by setting the session cookie.
"""
try:
user = await get_local_provider().create_user(email=body.email, password=body.password, system_role="user")
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already registered").model_dump(),
)
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
@router.post("/logout", response_model=MessageResponse)
async def logout(request: Request, response: Response):
"""Logout current user by clearing the cookie."""
response.delete_cookie(key="access_token", secure=is_secure_request(request), samesite="lax")
return MessageResponse(message="Successfully logged out")
@router.post("/change-password", response_model=MessageResponse)
async def change_password(request: Request, response: Response, body: ChangePasswordRequest):
"""Change password for the currently authenticated user.
Also handles the first-boot setup flow:
- If new_email is provided, updates email (checks uniqueness)
- If user.needs_setup is True and new_email is given, clears needs_setup
- Always increments token_version to invalidate old sessions
- Re-issues session cookie with new token_version
"""
from app.gateway.auth.password import hash_password_async, verify_password_async
user = await get_current_user_from_request(request)
if user.password_hash is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump())
if not await verify_password_async(body.current_password, user.password_hash):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Current password is incorrect").model_dump())
provider = get_local_provider()
# Update email if provided
if body.new_email is not None:
existing = await provider.get_user_by_email(body.new_email)
if existing and str(existing.id) != str(user.id):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already in use").model_dump())
user.email = body.new_email
# Update password + bump version
user.password_hash = await hash_password_async(body.new_password)
user.token_version += 1
# Clear setup flag if this is the setup flow
if user.needs_setup and body.new_email is not None:
user.needs_setup = False
await provider.update_user(user)
# Re-issue cookie with new token_version
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return MessageResponse(message="Password changed successfully")
@router.get("/me", response_model=UserResponse)
async def get_me(request: Request):
"""Get current authenticated user info."""
user = await get_current_user_from_request(request)
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role, needs_setup=user.needs_setup)
@router.get("/setup-status")
async def setup_status():
"""Check if admin account exists. Always False after first boot."""
user_count = await get_local_provider().count_users()
return {"needs_setup": user_count == 0}
# ── OAuth Endpoints (Future/Placeholder) ─────────────────────────────────
@router.get("/oauth/{provider}")
async def oauth_login(provider: str):
"""Initiate OAuth login flow.
Redirects to the OAuth provider's authorization URL.
Currently a placeholder - requires OAuth provider implementation.
"""
if provider not in ["github", "google"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported OAuth provider: {provider}",
)
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="OAuth login not yet implemented",
)
@router.get("/callback/{provider}")
async def oauth_callback(provider: str, code: str, state: str):
"""OAuth callback endpoint.
Handles the OAuth provider's callback after user authorization.
Currently a placeholder.
"""
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="OAuth callback not yet implemented",
)

View File

@ -8,6 +8,9 @@
"graphs": {
"lead_agent": "deerflow.agents:make_lead_agent"
},
"auth": {
"path": "./app/gateway/langgraph_auth.py:auth"
},
"checkpointer": {
"path": "./packages/harness/deerflow/agents/checkpointer/async_provider.py:make_checkpointer"
}

506
backend/tests/test_auth.py Normal file
View File

@ -0,0 +1,506 @@
"""Tests for authentication module: JWT, password hashing, AuthContext, and authz decorators."""
from datetime import timedelta
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
import pytest
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from app.gateway.auth import create_access_token, decode_token, hash_password, verify_password
from app.gateway.auth.models import User
from app.gateway.authz import (
AuthContext,
Permissions,
get_auth_context,
require_auth,
require_permission,
)
# ── Password Hashing ────────────────────────────────────────────────────────
def test_hash_password_and_verify():
"""Hashing and verification round-trip."""
password = "s3cr3tP@ssw0rd!"
hashed = hash_password(password)
assert hashed != password
assert verify_password(password, hashed) is True
assert verify_password("wrongpassword", hashed) is False
def test_hash_password_different_each_time():
"""bcrypt generates unique salts, so same password has different hashes."""
password = "testpassword"
h1 = hash_password(password)
h2 = hash_password(password)
assert h1 != h2 # Different salts
# But both verify correctly
assert verify_password(password, h1) is True
assert verify_password(password, h2) is True
def test_verify_password_rejects_empty():
"""Empty password should not verify."""
hashed = hash_password("nonempty")
assert verify_password("", hashed) is False
# ── JWT ─────────────────────────────────────────────────────────────────────
def test_create_and_decode_token():
"""JWT creation and decoding round-trip."""
user_id = str(uuid4())
# Set a valid JWT secret for this test
import os
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
token = create_access_token(user_id)
assert isinstance(token, str)
payload = decode_token(token)
assert payload is not None
assert payload.sub == user_id
def test_decode_token_expired():
"""Expired token returns TokenError.EXPIRED."""
from app.gateway.auth.errors import TokenError
user_id = str(uuid4())
# Create token that expires immediately
token = create_access_token(user_id, expires_delta=timedelta(seconds=-1))
payload = decode_token(token)
assert payload == TokenError.EXPIRED
def test_decode_token_invalid():
"""Invalid token returns TokenError."""
from app.gateway.auth.errors import TokenError
assert isinstance(decode_token("not.a.valid.token"), TokenError)
assert isinstance(decode_token(""), TokenError)
assert isinstance(decode_token("completely-wrong"), TokenError)
def test_create_token_custom_expiry():
"""Custom expiry is respected."""
user_id = str(uuid4())
token = create_access_token(user_id, expires_delta=timedelta(hours=1))
payload = decode_token(token)
assert payload is not None
assert payload.sub == user_id
# ── AuthContext ────────────────────────────────────────────────────────────
def test_auth_context_unauthenticated():
"""AuthContext with no user."""
ctx = AuthContext(user=None, permissions=[])
assert ctx.is_authenticated is False
assert ctx.has_permission("threads", "read") is False
def test_auth_context_authenticated_no_perms():
"""AuthContext with user but no permissions."""
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
ctx = AuthContext(user=user, permissions=[])
assert ctx.is_authenticated is True
assert ctx.has_permission("threads", "read") is False
def test_auth_context_has_permission():
"""AuthContext permission checking."""
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
perms = [Permissions.THREADS_READ, Permissions.THREADS_WRITE]
ctx = AuthContext(user=user, permissions=perms)
assert ctx.has_permission("threads", "read") is True
assert ctx.has_permission("threads", "write") is True
assert ctx.has_permission("threads", "delete") is False
assert ctx.has_permission("runs", "read") is False
def test_auth_context_require_user_raises():
"""require_user raises 401 when not authenticated."""
ctx = AuthContext(user=None, permissions=[])
with pytest.raises(HTTPException) as exc_info:
ctx.require_user()
assert exc_info.value.status_code == 401
def test_auth_context_require_user_returns_user():
"""require_user returns user when authenticated."""
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
ctx = AuthContext(user=user, permissions=[])
returned = ctx.require_user()
assert returned == user
# ── get_auth_context helper ─────────────────────────────────────────────────
def test_get_auth_context_not_set():
"""get_auth_context returns None when auth not set on request."""
mock_request = MagicMock()
# Make getattr return None (simulating attribute not set)
mock_request.state = MagicMock()
del mock_request.state.auth
assert get_auth_context(mock_request) is None
def test_get_auth_context_set():
"""get_auth_context returns the AuthContext from request."""
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
ctx = AuthContext(user=user, permissions=[Permissions.THREADS_READ])
mock_request = MagicMock()
mock_request.state.auth = ctx
assert get_auth_context(mock_request) == ctx
# ── require_auth decorator ──────────────────────────────────────────────────
def test_require_auth_sets_auth_context():
"""require_auth sets auth context on request from cookie."""
from fastapi import Request
app = FastAPI()
@app.get("/test")
@require_auth
async def endpoint(request: Request):
ctx = get_auth_context(request)
return {"authenticated": ctx.is_authenticated}
with TestClient(app) as client:
# No cookie → anonymous
response = client.get("/test")
assert response.status_code == 200
assert response.json()["authenticated"] is False
def test_require_auth_requires_request_param():
"""require_auth raises ValueError if request parameter is missing."""
import asyncio
@require_auth
async def bad_endpoint(): # Missing `request` parameter
pass
with pytest.raises(ValueError, match="require_auth decorator requires 'request' parameter"):
asyncio.run(bad_endpoint())
# ── require_permission decorator ─────────────────────────────────────────────
def test_require_permission_requires_auth():
"""require_permission raises 401 when not authenticated."""
from fastapi import Request
app = FastAPI()
@app.get("/test")
@require_permission("threads", "read")
async def endpoint(request: Request):
return {"ok": True}
with TestClient(app) as client:
response = client.get("/test")
assert response.status_code == 401
assert "Authentication required" in response.json()["detail"]
def test_require_permission_denies_wrong_permission():
"""User without required permission gets 403."""
from fastapi import Request
app = FastAPI()
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
@app.get("/test")
@require_permission("threads", "delete")
async def endpoint(request: Request):
return {"ok": True}
mock_auth = AuthContext(user=user, permissions=[Permissions.THREADS_READ])
with patch("app.gateway.authz._authenticate", return_value=mock_auth):
with TestClient(app) as client:
response = client.get("/test")
assert response.status_code == 403
assert "Permission denied" in response.json()["detail"]
# ── Weak JWT secret warning ──────────────────────────────────────────────────
# ── User Model Fields ──────────────────────────────────────────────────────
def test_user_model_has_needs_setup_default_false():
"""New users default to needs_setup=False."""
user = User(email="test@example.com", password_hash="hash")
assert user.needs_setup is False
def test_user_model_has_token_version_default_zero():
"""New users default to token_version=0."""
user = User(email="test@example.com", password_hash="hash")
assert user.token_version == 0
def test_user_model_needs_setup_true():
"""Auto-created admin has needs_setup=True."""
user = User(email="admin@example.com", password_hash="hash", needs_setup=True)
assert user.needs_setup is True
def test_sqlite_round_trip_new_fields():
"""needs_setup and token_version survive create → read round-trip."""
import asyncio
import os
import tempfile
from pathlib import Path
from app.gateway.auth.repositories import sqlite as sqlite_mod
with tempfile.TemporaryDirectory() as tmpdir:
db_path = os.path.join(tmpdir, "test_users.db")
old_path = sqlite_mod._resolved_db_path
old_init = sqlite_mod._table_initialized
sqlite_mod._resolved_db_path = Path(db_path)
sqlite_mod._table_initialized = False
try:
repo = sqlite_mod.SQLiteUserRepository()
user = User(
email="setup@test.com",
password_hash="fakehash",
system_role="admin",
needs_setup=True,
token_version=3,
)
created = asyncio.run(repo.create_user(user))
assert created.needs_setup is True
assert created.token_version == 3
fetched = asyncio.run(repo.get_user_by_email("setup@test.com"))
assert fetched is not None
assert fetched.needs_setup is True
assert fetched.token_version == 3
fetched.needs_setup = False
fetched.token_version = 4
asyncio.run(repo.update_user(fetched))
refetched = asyncio.run(repo.get_user_by_id(str(fetched.id)))
assert refetched.needs_setup is False
assert refetched.token_version == 4
finally:
sqlite_mod._resolved_db_path = old_path
sqlite_mod._table_initialized = old_init
# ── Token Versioning ───────────────────────────────────────────────────────
def test_jwt_encodes_ver():
"""JWT payload includes ver field."""
import os
from app.gateway.auth.errors import TokenError
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
token = create_access_token(str(uuid4()), token_version=3)
payload = decode_token(token)
assert not isinstance(payload, TokenError)
assert payload.ver == 3
def test_jwt_default_ver_zero():
"""JWT ver defaults to 0."""
import os
from app.gateway.auth.errors import TokenError
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
token = create_access_token(str(uuid4()))
payload = decode_token(token)
assert not isinstance(payload, TokenError)
assert payload.ver == 0
def test_token_version_mismatch_rejects():
"""Token with stale ver is rejected by get_current_user_from_request."""
import asyncio
import os
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
user_id = str(uuid4())
token = create_access_token(user_id, token_version=0)
mock_user = User(id=user_id, email="test@example.com", password_hash="hash", token_version=1)
mock_request = MagicMock()
mock_request.cookies = {"access_token": token}
with patch("app.gateway.deps.get_local_provider") as mock_provider_fn:
mock_provider = MagicMock()
mock_provider.get_user = AsyncMock(return_value=mock_user)
mock_provider_fn.return_value = mock_provider
from app.gateway.deps import get_current_user_from_request
with pytest.raises(HTTPException) as exc_info:
asyncio.run(get_current_user_from_request(mock_request))
assert exc_info.value.status_code == 401
assert "revoked" in str(exc_info.value.detail).lower()
# ── change-password extension ──────────────────────────────────────────────
def test_change_password_request_accepts_new_email():
"""ChangePasswordRequest model accepts optional new_email."""
from app.gateway.routers.auth import ChangePasswordRequest
req = ChangePasswordRequest(
current_password="old",
new_password="newpassword",
new_email="new@example.com",
)
assert req.new_email == "new@example.com"
def test_change_password_request_new_email_optional():
"""ChangePasswordRequest model works without new_email."""
from app.gateway.routers.auth import ChangePasswordRequest
req = ChangePasswordRequest(current_password="old", new_password="newpassword")
assert req.new_email is None
def test_login_response_includes_needs_setup():
"""LoginResponse includes needs_setup field."""
from app.gateway.routers.auth import LoginResponse
resp = LoginResponse(expires_in=3600, needs_setup=True)
assert resp.needs_setup is True
resp2 = LoginResponse(expires_in=3600)
assert resp2.needs_setup is False
# ── Rate Limiting ──────────────────────────────────────────────────────────
def test_rate_limiter_allows_under_limit():
"""Requests under the limit are allowed."""
from app.gateway.routers.auth import _check_rate_limit, _login_attempts
_login_attempts.clear()
_check_rate_limit("192.168.1.1") # Should not raise
def test_rate_limiter_blocks_after_max_failures():
"""IP is blocked after 5 consecutive failures."""
from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure
_login_attempts.clear()
ip = "10.0.0.1"
for _ in range(5):
_record_login_failure(ip)
with pytest.raises(HTTPException) as exc_info:
_check_rate_limit(ip)
assert exc_info.value.status_code == 429
def test_rate_limiter_resets_on_success():
"""Successful login clears the failure counter."""
from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure, _record_login_success
_login_attempts.clear()
ip = "10.0.0.2"
for _ in range(4):
_record_login_failure(ip)
_record_login_success(ip)
_check_rate_limit(ip) # Should not raise
# ── Client IP extraction ─────────────────────────────────────────────────
def test_get_client_ip_direct_connection():
"""Without nginx (no X-Real-IP), falls back to request.client.host."""
from app.gateway.routers.auth import _get_client_ip
req = MagicMock()
req.client.host = "203.0.113.42"
req.headers = {}
assert _get_client_ip(req) == "203.0.113.42"
def test_get_client_ip_uses_x_real_ip():
"""X-Real-IP (set by nginx) is used when present."""
from app.gateway.routers.auth import _get_client_ip
req = MagicMock()
req.client.host = "10.0.0.1" # uvicorn may have replaced this with XFF[0]
req.headers = {"x-real-ip": "203.0.113.42"}
assert _get_client_ip(req) == "203.0.113.42"
def test_get_client_ip_xff_ignored():
"""X-Forwarded-For is never used; only X-Real-IP matters."""
from app.gateway.routers.auth import _get_client_ip
req = MagicMock()
req.client.host = "10.0.0.1"
req.headers = {"x-forwarded-for": "10.0.0.1, 198.51.100.5", "x-real-ip": "198.51.100.5"}
assert _get_client_ip(req) == "198.51.100.5"
def test_get_client_ip_no_real_ip_fallback():
"""No X-Real-IP → falls back to client.host (direct connection)."""
from app.gateway.routers.auth import _get_client_ip
req = MagicMock()
req.client.host = "127.0.0.1"
req.headers = {}
assert _get_client_ip(req) == "127.0.0.1"
def test_get_client_ip_x_real_ip_always_preferred():
"""X-Real-IP is always preferred over client.host regardless of IP."""
from app.gateway.routers.auth import _get_client_ip
req = MagicMock()
req.client.host = "203.0.113.99"
req.headers = {"x-real-ip": "198.51.100.7"}
assert _get_client_ip(req) == "198.51.100.7"
# ── Weak JWT secret warning ──────────────────────────────────────────────────
def test_missing_jwt_secret_generates_ephemeral(monkeypatch, caplog):
"""get_auth_config() auto-generates an ephemeral secret when AUTH_JWT_SECRET is unset."""
import logging
import app.gateway.auth.config as config_module
config_module._auth_config = None
monkeypatch.delenv("AUTH_JWT_SECRET", raising=False)
with caplog.at_level(logging.WARNING):
config = config_module.get_auth_config()
assert config.jwt_secret # non-empty ephemeral secret
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
# Cleanup
config_module._auth_config = None

View File

@ -0,0 +1,216 @@
"""Tests for the global AuthMiddleware (fail-closed safety net)."""
import pytest
from starlette.testclient import TestClient
from app.gateway.auth_middleware import AuthMiddleware, _is_public
# ── _is_public unit tests ─────────────────────────────────────────────────
@pytest.mark.parametrize(
"path",
[
"/health",
"/health/",
"/docs",
"/docs/",
"/redoc",
"/openapi.json",
"/api/v1/auth/login/local",
"/api/v1/auth/register",
"/api/v1/auth/logout",
"/api/v1/auth/setup-status",
],
)
def test_public_paths(path: str):
assert _is_public(path) is True
@pytest.mark.parametrize(
"path",
[
"/api/models",
"/api/mcp/config",
"/api/memory",
"/api/skills",
"/api/threads/123",
"/api/threads/123/uploads",
"/api/agents",
"/api/channels",
"/api/runs/stream",
"/api/threads/123/runs",
"/api/v1/auth/me",
"/api/v1/auth/change-password",
],
)
def test_protected_paths(path: str):
assert _is_public(path) is False
# ── Trailing slash / normalization edge cases ─────────────────────────────
@pytest.mark.parametrize(
"path",
[
"/api/v1/auth/login/local/",
"/api/v1/auth/register/",
"/api/v1/auth/logout/",
"/api/v1/auth/setup-status/",
],
)
def test_public_auth_paths_with_trailing_slash(path: str):
assert _is_public(path) is True
@pytest.mark.parametrize(
"path",
[
"/api/models/",
"/api/v1/auth/me/",
"/api/v1/auth/change-password/",
],
)
def test_protected_paths_with_trailing_slash(path: str):
assert _is_public(path) is False
def test_unknown_api_path_is_protected():
"""Fail-closed: any new /api/* path is protected by default."""
assert _is_public("/api/new-feature") is False
assert _is_public("/api/v2/something") is False
assert _is_public("/api/v1/auth/new-endpoint") is False
# ── Middleware integration tests ──────────────────────────────────────────
def _make_app():
"""Create a minimal FastAPI app with AuthMiddleware for testing."""
from fastapi import FastAPI
app = FastAPI()
app.add_middleware(AuthMiddleware)
@app.get("/health")
async def health():
return {"status": "ok"}
@app.get("/api/v1/auth/me")
async def auth_me():
return {"id": "1", "email": "test@test.com"}
@app.get("/api/v1/auth/setup-status")
async def setup_status():
return {"needs_setup": False}
@app.get("/api/models")
async def models_get():
return {"models": []}
@app.put("/api/mcp/config")
async def mcp_put():
return {"ok": True}
@app.delete("/api/threads/abc")
async def thread_delete():
return {"ok": True}
@app.patch("/api/threads/abc")
async def thread_patch():
return {"ok": True}
@app.post("/api/threads/abc/runs/stream")
async def stream():
return {"ok": True}
@app.get("/api/future-endpoint")
async def future():
return {"ok": True}
return app
@pytest.fixture
def client():
return TestClient(_make_app())
def test_public_path_no_cookie(client):
res = client.get("/health")
assert res.status_code == 200
def test_public_auth_path_no_cookie(client):
"""Public auth endpoints (login/register) pass without cookie."""
res = client.get("/api/v1/auth/setup-status")
assert res.status_code == 200
def test_protected_auth_path_no_cookie(client):
"""/auth/me requires cookie even though it's under /api/v1/auth/."""
res = client.get("/api/v1/auth/me")
assert res.status_code == 401
def test_protected_path_no_cookie_returns_401(client):
res = client.get("/api/models")
assert res.status_code == 401
body = res.json()
assert body["detail"]["code"] == "not_authenticated"
def test_protected_path_with_cookie_passes(client):
res = client.get("/api/models", cookies={"access_token": "some-token"})
assert res.status_code == 200
def test_protected_post_no_cookie_returns_401(client):
res = client.post("/api/threads/abc/runs/stream")
assert res.status_code == 401
# ── Method matrix: PUT/DELETE/PATCH also protected ────────────────────────
def test_protected_put_no_cookie(client):
res = client.put("/api/mcp/config")
assert res.status_code == 401
def test_protected_delete_no_cookie(client):
res = client.delete("/api/threads/abc")
assert res.status_code == 401
def test_protected_patch_no_cookie(client):
res = client.patch("/api/threads/abc")
assert res.status_code == 401
def test_put_with_cookie_passes(client):
client.cookies.set("access_token", "tok")
res = client.put("/api/mcp/config")
assert res.status_code == 200
def test_delete_with_cookie_passes(client):
client.cookies.set("access_token", "tok")
res = client.delete("/api/threads/abc")
assert res.status_code == 200
# ── Fail-closed: unknown future endpoints ─────────────────────────────────
def test_unknown_endpoint_no_cookie_returns_401(client):
"""Any new /api/* endpoint is blocked by default without cookie."""
res = client.get("/api/future-endpoint")
assert res.status_code == 401
def test_unknown_endpoint_with_cookie_passes(client):
client.cookies.set("access_token", "tok")
res = client.get("/api/future-endpoint")
assert res.status_code == 200

View File

@ -0,0 +1,675 @@
"""Tests for auth type system hardening.
Covers structured error responses, typed decode_token callers,
CSRF middleware path matching, config-driven cookie security,
and unhappy paths / edge cases for all auth boundaries.
"""
import os
import secrets
from datetime import UTC, datetime, timedelta
from unittest.mock import patch
import jwt as pyjwt
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from pydantic import ValidationError
from app.gateway.auth.config import AuthConfig, set_auth_config
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
from app.gateway.auth.jwt import decode_token
from app.gateway.csrf_middleware import (
CSRF_COOKIE_NAME,
CSRF_HEADER_NAME,
CSRFMiddleware,
is_auth_endpoint,
should_check_csrf,
)
# ── Setup ────────────────────────────────────────────────────────────
_TEST_SECRET = "test-secret-for-auth-type-system-tests-min32"
def _setup_config():
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
# ── CSRF Middleware Path Matching ────────────────────────────────────
class _FakeRequest:
"""Minimal request mock for CSRF path matching tests."""
def __init__(self, path: str, method: str = "POST"):
self.method = method
class _URL:
def __init__(self, p):
self.path = p
self.url = _URL(path)
self.cookies = {}
self.headers = {}
def test_csrf_exempts_login_local():
"""login/local (actual route) should be exempt from CSRF."""
req = _FakeRequest("/api/v1/auth/login/local")
assert is_auth_endpoint(req) is True
def test_csrf_exempts_login_local_trailing_slash():
"""Trailing slash should also be exempt."""
req = _FakeRequest("/api/v1/auth/login/local/")
assert is_auth_endpoint(req) is True
def test_csrf_exempts_logout():
req = _FakeRequest("/api/v1/auth/logout")
assert is_auth_endpoint(req) is True
def test_csrf_exempts_register():
req = _FakeRequest("/api/v1/auth/register")
assert is_auth_endpoint(req) is True
def test_csrf_does_not_exempt_old_login_path():
"""Old /api/v1/auth/login (without /local) should NOT be exempt."""
req = _FakeRequest("/api/v1/auth/login")
assert is_auth_endpoint(req) is False
def test_csrf_does_not_exempt_me():
req = _FakeRequest("/api/v1/auth/me")
assert is_auth_endpoint(req) is False
def test_csrf_skips_get_requests():
req = _FakeRequest("/api/v1/auth/me", method="GET")
assert should_check_csrf(req) is False
def test_csrf_checks_post_to_protected():
req = _FakeRequest("/api/v1/some/endpoint", method="POST")
assert should_check_csrf(req) is True
# ── Structured Error Response Format ────────────────────────────────
def test_auth_error_response_has_code_and_message():
"""All auth errors should have structured {code, message} format."""
err = AuthErrorResponse(
code=AuthErrorCode.INVALID_CREDENTIALS,
message="Wrong password",
)
d = err.model_dump()
assert "code" in d
assert "message" in d
assert d["code"] == "invalid_credentials"
def test_auth_error_response_all_codes_serializable():
"""Every AuthErrorCode should be serializable in AuthErrorResponse."""
for code in AuthErrorCode:
err = AuthErrorResponse(code=code, message=f"Test {code.value}")
d = err.model_dump()
assert d["code"] == code.value
# ── decode_token Caller Pattern ──────────────────────────────────────
def test_decode_token_expired_maps_to_token_expired_code():
"""TokenError.EXPIRED should map to AuthErrorCode.TOKEN_EXPIRED."""
_setup_config()
from datetime import UTC, datetime, timedelta
import jwt as pyjwt
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
result = decode_token(token)
assert result == TokenError.EXPIRED
# Verify the mapping pattern used in route handlers
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
assert code == AuthErrorCode.TOKEN_EXPIRED
def test_decode_token_invalid_sig_maps_to_token_invalid_code():
"""TokenError.INVALID_SIGNATURE should map to AuthErrorCode.TOKEN_INVALID."""
_setup_config()
from datetime import UTC, datetime, timedelta
import jwt as pyjwt
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
token = pyjwt.encode(payload, "wrong-key", algorithm="HS256")
result = decode_token(token)
assert result == TokenError.INVALID_SIGNATURE
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
assert code == AuthErrorCode.TOKEN_INVALID
def test_decode_token_malformed_maps_to_token_invalid_code():
"""TokenError.MALFORMED should map to AuthErrorCode.TOKEN_INVALID."""
_setup_config()
result = decode_token("garbage")
assert result == TokenError.MALFORMED
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
assert code == AuthErrorCode.TOKEN_INVALID
# ── Login Response Format ────────────────────────────────────────────
def test_login_response_model_has_no_access_token():
"""LoginResponse should NOT contain access_token field (RFC-001)."""
from app.gateway.routers.auth import LoginResponse
resp = LoginResponse(expires_in=604800)
d = resp.model_dump()
assert "access_token" not in d
assert "expires_in" in d
assert d["expires_in"] == 604800
def test_login_response_model_fields():
"""LoginResponse has expires_in and needs_setup."""
from app.gateway.routers.auth import LoginResponse
fields = set(LoginResponse.model_fields.keys())
assert fields == {"expires_in", "needs_setup"}
# ── AuthConfig in Route ──────────────────────────────────────────────
def test_auth_config_token_expiry_used_in_login_response():
"""LoginResponse.expires_in should come from config.token_expiry_days."""
from app.gateway.routers.auth import LoginResponse
expected_seconds = 14 * 24 * 3600
resp = LoginResponse(expires_in=expected_seconds)
assert resp.expires_in == expected_seconds
# ── UserResponse Type Preservation ───────────────────────────────────
def test_user_response_system_role_literal():
"""UserResponse.system_role should only accept 'admin' or 'user'."""
from app.gateway.auth.models import UserResponse
# Valid roles
resp = UserResponse(id="1", email="a@b.com", system_role="admin")
assert resp.system_role == "admin"
resp = UserResponse(id="1", email="a@b.com", system_role="user")
assert resp.system_role == "user"
def test_user_response_rejects_invalid_role():
"""UserResponse should reject invalid system_role values."""
from app.gateway.auth.models import UserResponse
with pytest.raises(ValidationError):
UserResponse(id="1", email="a@b.com", system_role="superadmin")
# ══════════════════════════════════════════════════════════════════════
# UNHAPPY PATHS / EDGE CASES
# ══════════════════════════════════════════════════════════════════════
# ── get_current_user structured 401 responses ────────────────────────
def test_get_current_user_no_cookie_returns_not_authenticated():
"""No cookie → 401 with code=not_authenticated."""
import asyncio
from fastapi import HTTPException
from app.gateway.deps import get_current_user_from_request
mock_request = type("MockRequest", (), {"cookies": {}})()
with pytest.raises(HTTPException) as exc_info:
asyncio.run(get_current_user_from_request(mock_request))
assert exc_info.value.status_code == 401
detail = exc_info.value.detail
assert detail["code"] == "not_authenticated"
def test_get_current_user_expired_token_returns_token_expired():
"""Expired token → 401 with code=token_expired."""
import asyncio
from fastapi import HTTPException
from app.gateway.deps import get_current_user_from_request
_setup_config()
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
mock_request = type("MockRequest", (), {"cookies": {"access_token": token}})()
with pytest.raises(HTTPException) as exc_info:
asyncio.run(get_current_user_from_request(mock_request))
assert exc_info.value.status_code == 401
detail = exc_info.value.detail
assert detail["code"] == "token_expired"
def test_get_current_user_invalid_token_returns_token_invalid():
"""Bad signature → 401 with code=token_invalid."""
import asyncio
from fastapi import HTTPException
from app.gateway.deps import get_current_user_from_request
_setup_config()
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256")
mock_request = type("MockRequest", (), {"cookies": {"access_token": token}})()
with pytest.raises(HTTPException) as exc_info:
asyncio.run(get_current_user_from_request(mock_request))
assert exc_info.value.status_code == 401
detail = exc_info.value.detail
assert detail["code"] == "token_invalid"
def test_get_current_user_malformed_token_returns_token_invalid():
"""Garbage token → 401 with code=token_invalid."""
import asyncio
from fastapi import HTTPException
from app.gateway.deps import get_current_user_from_request
_setup_config()
mock_request = type("MockRequest", (), {"cookies": {"access_token": "not-a-jwt"}})()
with pytest.raises(HTTPException) as exc_info:
asyncio.run(get_current_user_from_request(mock_request))
assert exc_info.value.status_code == 401
detail = exc_info.value.detail
assert detail["code"] == "token_invalid"
# ── decode_token edge cases ──────────────────────────────────────────
def test_decode_token_empty_string_returns_malformed():
_setup_config()
result = decode_token("")
assert result == TokenError.MALFORMED
def test_decode_token_whitespace_returns_malformed():
_setup_config()
result = decode_token(" ")
assert result == TokenError.MALFORMED
# ── AuthConfig validation edge cases ─────────────────────────────────
def test_auth_config_missing_jwt_secret_raises():
"""AuthConfig requires jwt_secret — no default allowed."""
with pytest.raises(ValidationError):
AuthConfig()
def test_auth_config_token_expiry_zero_raises():
"""token_expiry_days must be >= 1."""
with pytest.raises(ValidationError):
AuthConfig(jwt_secret="secret", token_expiry_days=0)
def test_auth_config_token_expiry_31_raises():
"""token_expiry_days must be <= 30."""
with pytest.raises(ValidationError):
AuthConfig(jwt_secret="secret", token_expiry_days=31)
def test_auth_config_token_expiry_boundary_1_ok():
config = AuthConfig(jwt_secret="secret", token_expiry_days=1)
assert config.token_expiry_days == 1
def test_auth_config_token_expiry_boundary_30_ok():
config = AuthConfig(jwt_secret="secret", token_expiry_days=30)
assert config.token_expiry_days == 30
def test_get_auth_config_missing_env_var_generates_ephemeral(caplog):
"""get_auth_config() auto-generates ephemeral secret when AUTH_JWT_SECRET is unset."""
import logging
import app.gateway.auth.config as cfg
old = cfg._auth_config
cfg._auth_config = None
try:
with patch.dict(os.environ, {}, clear=True):
os.environ.pop("AUTH_JWT_SECRET", None)
with caplog.at_level(logging.WARNING):
config = cfg.get_auth_config()
assert config.jwt_secret
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
finally:
cfg._auth_config = old
# ── CSRF middleware integration (unhappy paths) ──────────────────────
def _make_csrf_app():
"""Create a minimal FastAPI app with CSRFMiddleware for testing."""
from fastapi import HTTPException as _HTTPException
from fastapi.responses import JSONResponse as _JSONResponse
app = FastAPI()
@app.exception_handler(_HTTPException)
async def _http_exc_handler(request, exc):
return _JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
app.add_middleware(CSRFMiddleware)
@app.post("/api/v1/test/protected")
async def protected():
return {"ok": True}
@app.post("/api/v1/auth/login/local")
async def login():
return {"ok": True}
@app.get("/api/v1/test/read")
async def read_endpoint():
return {"ok": True}
return app
def test_csrf_middleware_blocks_post_without_token():
"""POST to protected endpoint without CSRF token → 403 with structured detail."""
client = TestClient(_make_csrf_app())
resp = client.post("/api/v1/test/protected")
assert resp.status_code == 403
assert "CSRF" in resp.json()["detail"]
assert "missing" in resp.json()["detail"].lower()
def test_csrf_middleware_blocks_post_with_mismatched_token():
"""POST with mismatched CSRF cookie/header → 403 with mismatch detail."""
client = TestClient(_make_csrf_app())
client.cookies.set(CSRF_COOKIE_NAME, "token-a")
resp = client.post(
"/api/v1/test/protected",
headers={CSRF_HEADER_NAME: "token-b"},
)
assert resp.status_code == 403
assert "mismatch" in resp.json()["detail"].lower()
def test_csrf_middleware_allows_post_with_matching_token():
"""POST with matching CSRF cookie/header → 200."""
client = TestClient(_make_csrf_app())
token = secrets.token_urlsafe(64)
client.cookies.set(CSRF_COOKIE_NAME, token)
resp = client.post(
"/api/v1/test/protected",
headers={CSRF_HEADER_NAME: token},
)
assert resp.status_code == 200
def test_csrf_middleware_allows_get_without_token():
"""GET requests bypass CSRF check."""
client = TestClient(_make_csrf_app())
resp = client.get("/api/v1/test/read")
assert resp.status_code == 200
def test_csrf_middleware_exempts_login_local():
"""POST to login/local is exempt from CSRF (no token yet)."""
client = TestClient(_make_csrf_app())
resp = client.post("/api/v1/auth/login/local")
assert resp.status_code == 200
def test_csrf_middleware_sets_cookie_on_auth_endpoint():
"""Auth endpoints should receive a CSRF cookie in response."""
client = TestClient(_make_csrf_app())
resp = client.post("/api/v1/auth/login/local")
assert CSRF_COOKIE_NAME in resp.cookies
# ── UserResponse edge cases ──────────────────────────────────────────
def test_user_response_missing_required_fields():
"""UserResponse with missing fields → ValidationError."""
from app.gateway.auth.models import UserResponse
with pytest.raises(ValidationError):
UserResponse(id="1") # missing email, system_role
with pytest.raises(ValidationError):
UserResponse(id="1", email="a@b.com") # missing system_role
def test_user_response_empty_string_role_rejected():
"""Empty string is not a valid role."""
from app.gateway.auth.models import UserResponse
with pytest.raises(ValidationError):
UserResponse(id="1", email="a@b.com", system_role="")
# ══════════════════════════════════════════════════════════════════════
# HTTP-LEVEL API CONTRACT TESTS
# ══════════════════════════════════════════════════════════════════════
def _make_auth_app():
"""Create FastAPI app with auth routes for contract testing."""
from app.gateway.app import create_app
return create_app()
def _get_auth_client():
"""Get TestClient for auth API contract tests."""
return TestClient(_make_auth_app())
def test_api_auth_me_no_cookie_returns_structured_401():
"""/api/v1/auth/me without cookie → 401 with {code: 'not_authenticated'}."""
_setup_config()
client = _get_auth_client()
resp = client.get("/api/v1/auth/me")
assert resp.status_code == 401
body = resp.json()
assert body["detail"]["code"] == "not_authenticated"
assert "message" in body["detail"]
def test_api_auth_me_expired_token_returns_structured_401():
"""/api/v1/auth/me with expired token → 401 with {code: 'token_expired'}."""
_setup_config()
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
client = _get_auth_client()
client.cookies.set("access_token", token)
resp = client.get("/api/v1/auth/me")
assert resp.status_code == 401
body = resp.json()
assert body["detail"]["code"] == "token_expired"
def test_api_auth_me_invalid_sig_returns_structured_401():
"""/api/v1/auth/me with bad signature → 401 with {code: 'token_invalid'}."""
_setup_config()
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
token = pyjwt.encode(payload, "wrong-key", algorithm="HS256")
client = _get_auth_client()
client.cookies.set("access_token", token)
resp = client.get("/api/v1/auth/me")
assert resp.status_code == 401
body = resp.json()
assert body["detail"]["code"] == "token_invalid"
def test_api_login_bad_credentials_returns_structured_401():
"""Login with wrong password → 401 with {code: 'invalid_credentials'}."""
_setup_config()
client = _get_auth_client()
resp = client.post(
"/api/v1/auth/login/local",
data={"username": "nonexistent@test.com", "password": "wrongpassword"},
)
assert resp.status_code == 401
body = resp.json()
assert body["detail"]["code"] == "invalid_credentials"
def test_api_login_success_no_token_in_body():
"""Successful login → response body has expires_in but NOT access_token."""
_setup_config()
client = _get_auth_client()
# Register first
client.post(
"/api/v1/auth/register",
json={"email": "contract-test@test.com", "password": "securepassword123"},
)
# Login
resp = client.post(
"/api/v1/auth/login/local",
data={"username": "contract-test@test.com", "password": "securepassword123"},
)
assert resp.status_code == 200
body = resp.json()
assert "expires_in" in body
assert "access_token" not in body
# Token should be in cookie, not body
assert "access_token" in resp.cookies
def test_api_register_duplicate_returns_structured_400():
"""Register with duplicate email → 400 with {code: 'email_already_exists'}."""
_setup_config()
client = _get_auth_client()
email = "dup-contract-test@test.com"
# First register
client.post("/api/v1/auth/register", json={"email": email, "password": "password123"})
# Duplicate
resp = client.post("/api/v1/auth/register", json={"email": email, "password": "password456"})
assert resp.status_code == 400
body = resp.json()
assert body["detail"]["code"] == "email_already_exists"
# ── Cookie security: HTTP vs HTTPS ────────────────────────────────────
def _unique_email(prefix: str) -> str:
return f"{prefix}-{secrets.token_hex(4)}@test.com"
def _get_set_cookie_headers(resp) -> list[str]:
"""Extract all set-cookie header values from a TestClient response."""
return [v for k, v in resp.headers.multi_items() if k.lower() == "set-cookie"]
def test_register_http_cookie_httponly_true_secure_false():
"""HTTP register → access_token cookie is httponly=True, secure=False, no max_age."""
_setup_config()
client = _get_auth_client()
resp = client.post(
"/api/v1/auth/register",
json={"email": _unique_email("http-cookie"), "password": "password123"},
)
assert resp.status_code == 201
cookie_header = resp.headers.get("set-cookie", "")
assert "access_token=" in cookie_header
assert "httponly" in cookie_header.lower()
assert "secure" not in cookie_header.lower().replace("samesite", "")
def test_register_https_cookie_httponly_true_secure_true():
"""HTTPS register (x-forwarded-proto) → access_token cookie is httponly=True, secure=True, has max_age."""
_setup_config()
client = _get_auth_client()
resp = client.post(
"/api/v1/auth/register",
json={"email": _unique_email("https-cookie"), "password": "password123"},
headers={"x-forwarded-proto": "https"},
)
assert resp.status_code == 201
cookie_header = resp.headers.get("set-cookie", "")
assert "access_token=" in cookie_header
assert "httponly" in cookie_header.lower()
assert "secure" in cookie_header.lower()
assert "max-age" in cookie_header.lower()
def test_login_https_sets_secure_cookie():
"""HTTPS login → access_token cookie has secure flag."""
_setup_config()
client = _get_auth_client()
email = _unique_email("https-login")
client.post("/api/v1/auth/register", json={"email": email, "password": "password123"})
resp = client.post(
"/api/v1/auth/login/local",
data={"username": email, "password": "password123"},
headers={"x-forwarded-proto": "https"},
)
assert resp.status_code == 200
cookie_header = resp.headers.get("set-cookie", "")
assert "access_token=" in cookie_header
assert "httponly" in cookie_header.lower()
assert "secure" in cookie_header.lower()
def test_csrf_cookie_secure_on_https():
"""HTTPS register → csrf_token cookie has secure flag but NOT httponly."""
_setup_config()
client = _get_auth_client()
resp = client.post(
"/api/v1/auth/register",
json={"email": _unique_email("csrf-https"), "password": "password123"},
headers={"x-forwarded-proto": "https"},
)
assert resp.status_code == 201
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
assert csrf_cookies, "csrf_token cookie not set on HTTPS register"
csrf_header = csrf_cookies[0]
assert "secure" in csrf_header.lower()
assert "httponly" not in csrf_header.lower()
def test_csrf_cookie_not_secure_on_http():
"""HTTP register → csrf_token cookie does NOT have secure flag."""
_setup_config()
client = _get_auth_client()
resp = client.post(
"/api/v1/auth/register",
json={"email": _unique_email("csrf-http"), "password": "password123"},
)
assert resp.status_code == 201
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
assert csrf_cookies, "csrf_token cookie not set on HTTP register"
csrf_header = csrf_cookies[0]
assert "secure" not in csrf_header.lower().replace("samesite", "")

View File

@ -0,0 +1,214 @@
"""Tests for _ensure_admin_user() in app.py.
Covers: first-boot admin creation, auto-reset on needs_setup=True,
no-op on needs_setup=False, migration, and edge cases.
"""
import asyncio
import os
from datetime import UTC, datetime, timedelta
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
import pytest
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-ensure-admin-testing-min-32")
from app.gateway.auth.config import AuthConfig, set_auth_config
from app.gateway.auth.models import User
_JWT_SECRET = "test-secret-key-ensure-admin-testing-min-32"
@pytest.fixture(autouse=True)
def _setup_auth_config():
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
yield
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
def _make_app_stub(store=None):
"""Minimal app-like object with state.store."""
app = SimpleNamespace()
app.state = SimpleNamespace()
app.state.store = store
return app
def _make_provider(user_count=0, admin_user=None):
p = AsyncMock()
p.count_users = AsyncMock(return_value=user_count)
p.create_user = AsyncMock(
side_effect=lambda **kw: User(
email=kw["email"],
password_hash="hashed",
system_role=kw.get("system_role", "user"),
needs_setup=kw.get("needs_setup", False),
)
)
p.get_user_by_email = AsyncMock(return_value=admin_user)
p.update_user = AsyncMock(side_effect=lambda u: u)
return p
# ── First boot: no users ─────────────────────────────────────────────────
def test_first_boot_creates_admin():
"""count_users==0 → create admin with needs_setup=True."""
provider = _make_provider(user_count=0)
app = _make_app_stub()
with patch("app.gateway.deps.get_local_provider", return_value=provider):
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
from app.gateway.app import _ensure_admin_user
asyncio.run(_ensure_admin_user(app))
provider.create_user.assert_called_once()
call_kwargs = provider.create_user.call_args[1]
assert call_kwargs["email"] == "admin@deerflow.dev"
assert call_kwargs["system_role"] == "admin"
assert call_kwargs["needs_setup"] is True
assert len(call_kwargs["password"]) > 10 # random password generated
def test_first_boot_triggers_migration_if_store_present():
"""First boot with store → _migrate_orphaned_threads called."""
provider = _make_provider(user_count=0)
store = AsyncMock()
store.asearch = AsyncMock(return_value=[])
app = _make_app_stub(store=store)
with patch("app.gateway.deps.get_local_provider", return_value=provider):
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
from app.gateway.app import _ensure_admin_user
asyncio.run(_ensure_admin_user(app))
store.asearch.assert_called_once()
def test_first_boot_no_store_skips_migration():
"""First boot without store → no crash, migration skipped."""
provider = _make_provider(user_count=0)
app = _make_app_stub(store=None)
with patch("app.gateway.deps.get_local_provider", return_value=provider):
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
from app.gateway.app import _ensure_admin_user
asyncio.run(_ensure_admin_user(app))
provider.create_user.assert_called_once()
# ── Subsequent boot: needs_setup=True → auto-reset ───────────────────────
def test_needs_setup_true_resets_password():
"""Existing admin with needs_setup=True → password reset + token_version bumped."""
admin = User(
email="admin@deerflow.dev",
password_hash="old-hash",
system_role="admin",
needs_setup=True,
token_version=0,
created_at=datetime.now(UTC) - timedelta(seconds=30),
)
provider = _make_provider(user_count=1, admin_user=admin)
app = _make_app_stub()
with patch("app.gateway.deps.get_local_provider", return_value=provider):
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="new-hash"):
from app.gateway.app import _ensure_admin_user
asyncio.run(_ensure_admin_user(app))
# Password was reset
provider.update_user.assert_called_once()
updated = provider.update_user.call_args[0][0]
assert updated.password_hash == "new-hash"
assert updated.token_version == 1
def test_needs_setup_true_consecutive_resets_increment_version():
"""Two boots with needs_setup=True → token_version increments each time."""
admin = User(
email="admin@deerflow.dev",
password_hash="hash",
system_role="admin",
needs_setup=True,
token_version=3,
created_at=datetime.now(UTC) - timedelta(seconds=30),
)
provider = _make_provider(user_count=1, admin_user=admin)
app = _make_app_stub()
with patch("app.gateway.deps.get_local_provider", return_value=provider):
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="new-hash"):
from app.gateway.app import _ensure_admin_user
asyncio.run(_ensure_admin_user(app))
updated = provider.update_user.call_args[0][0]
assert updated.token_version == 4
# ── Subsequent boot: needs_setup=False → no-op ──────────────────────────
def test_needs_setup_false_no_reset():
"""Admin with needs_setup=False → no password reset, no update."""
admin = User(
email="admin@deerflow.dev",
password_hash="stable-hash",
system_role="admin",
needs_setup=False,
token_version=2,
)
provider = _make_provider(user_count=1, admin_user=admin)
app = _make_app_stub()
with patch("app.gateway.deps.get_local_provider", return_value=provider):
from app.gateway.app import _ensure_admin_user
asyncio.run(_ensure_admin_user(app))
provider.update_user.assert_not_called()
assert admin.password_hash == "stable-hash"
assert admin.token_version == 2
# ── Edge cases ───────────────────────────────────────────────────────────
def test_no_admin_email_found_no_crash():
"""Users exist but no admin@deerflow.dev → no crash, no reset."""
provider = _make_provider(user_count=3, admin_user=None)
app = _make_app_stub()
with patch("app.gateway.deps.get_local_provider", return_value=provider):
from app.gateway.app import _ensure_admin_user
asyncio.run(_ensure_admin_user(app))
provider.update_user.assert_not_called()
provider.create_user.assert_not_called()
def test_migration_failure_is_non_fatal():
"""_migrate_orphaned_threads exception is caught and logged."""
provider = _make_provider(user_count=0)
store = AsyncMock()
store.asearch = AsyncMock(side_effect=RuntimeError("store crashed"))
app = _make_app_stub(store=store)
with patch("app.gateway.deps.get_local_provider", return_value=provider):
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
from app.gateway.app import _ensure_admin_user
# Should not raise
asyncio.run(_ensure_admin_user(app))
provider.create_user.assert_called_once()

View File

@ -0,0 +1,312 @@
"""Tests for LangGraph Server auth handler (langgraph_auth.py).
Validates that the LangGraph auth layer enforces the same rules as Gateway:
cookie JWT decode DB lookup token_version check owner filter
"""
import asyncio
import os
from datetime import timedelta
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
from uuid import uuid4
import pytest
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-for-langgraph-auth-testing-min-32")
from langgraph_sdk import Auth
from app.gateway.auth.config import AuthConfig, set_auth_config
from app.gateway.auth.jwt import create_access_token, decode_token
from app.gateway.auth.models import User
from app.gateway.langgraph_auth import add_owner_filter, authenticate
# ── Helpers ───────────────────────────────────────────────────────────────
_JWT_SECRET = "test-secret-key-for-langgraph-auth-testing-min-32"
@pytest.fixture(autouse=True)
def _setup_auth_config():
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
yield
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
def _req(cookies=None, method="GET", headers=None):
return SimpleNamespace(cookies=cookies or {}, method=method, headers=headers or {})
def _user(user_id=None, token_version=0):
return User(email="test@example.com", password_hash="fakehash", system_role="user", id=user_id or uuid4(), token_version=token_version)
def _mock_provider(user=None):
p = AsyncMock()
p.get_user = AsyncMock(return_value=user)
return p
# ── @auth.authenticate ───────────────────────────────────────────────────
def test_no_cookie_raises_401():
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req()))
assert exc.value.status_code == 401
assert "Not authenticated" in str(exc.value.detail)
def test_invalid_jwt_raises_401():
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req({"access_token": "garbage"})))
assert exc.value.status_code == 401
assert "Token error" in str(exc.value.detail)
def test_expired_jwt_raises_401():
token = create_access_token("user-1", expires_delta=timedelta(seconds=-1))
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req({"access_token": token})))
assert exc.value.status_code == 401
def test_user_not_found_raises_401():
token = create_access_token("ghost")
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(None)):
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req({"access_token": token})))
assert exc.value.status_code == 401
assert "User not found" in str(exc.value.detail)
def test_token_version_mismatch_raises_401():
user = _user(token_version=2)
token = create_access_token(str(user.id), token_version=1)
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req({"access_token": token})))
assert exc.value.status_code == 401
assert "revoked" in str(exc.value.detail).lower()
def test_valid_token_returns_user_id():
user = _user(token_version=0)
token = create_access_token(str(user.id), token_version=0)
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
result = asyncio.run(authenticate(_req({"access_token": token})))
assert result == str(user.id)
def test_valid_token_matching_version():
user = _user(token_version=5)
token = create_access_token(str(user.id), token_version=5)
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
result = asyncio.run(authenticate(_req({"access_token": token})))
assert result == str(user.id)
# ── @auth.authenticate edge cases ────────────────────────────────────────
def test_provider_exception_propagates():
"""Provider raises → should not be swallowed silently."""
token = create_access_token("user-1")
p = AsyncMock()
p.get_user = AsyncMock(side_effect=RuntimeError("DB down"))
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=p):
with pytest.raises(RuntimeError, match="DB down"):
asyncio.run(authenticate(_req({"access_token": token})))
def test_jwt_missing_ver_defaults_to_zero():
"""JWT without 'ver' claim → decoded as ver=0, matches user with token_version=0."""
import jwt as pyjwt
uid = str(uuid4())
raw = pyjwt.encode({"sub": uid, "exp": 9999999999, "iat": 1000000000}, _JWT_SECRET, algorithm="HS256")
user = _user(user_id=uid, token_version=0)
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
result = asyncio.run(authenticate(_req({"access_token": raw})))
assert result == uid
def test_jwt_missing_ver_rejected_when_user_version_nonzero():
"""JWT without 'ver' (defaults 0) vs user with token_version=1 → 401."""
import jwt as pyjwt
uid = str(uuid4())
raw = pyjwt.encode({"sub": uid, "exp": 9999999999, "iat": 1000000000}, _JWT_SECRET, algorithm="HS256")
user = _user(user_id=uid, token_version=1)
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req({"access_token": raw})))
assert exc.value.status_code == 401
def test_wrong_secret_raises_401():
"""Token signed with different secret → 401."""
import jwt as pyjwt
raw = pyjwt.encode({"sub": "user-1", "exp": 9999999999, "ver": 0}, "wrong-secret-that-is-long-enough-32chars!", algorithm="HS256")
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req({"access_token": raw})))
assert exc.value.status_code == 401
# ── @auth.on (owner filter) ──────────────────────────────────────────────
class _FakeUser:
"""Minimal BaseUser-compatible object without langgraph_api.config dependency."""
def __init__(self, identity: str):
self.identity = identity
self.is_authenticated = True
self.display_name = identity
def _make_ctx(user_id):
return Auth.types.AuthContext(resource="threads", action="create", user=_FakeUser(user_id), permissions=[])
def test_filter_injects_user_id():
value = {}
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
assert value["metadata"]["owner_id"] == "user-a"
def test_filter_preserves_existing_metadata():
value = {"metadata": {"title": "hello"}}
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
assert value["metadata"]["owner_id"] == "user-a"
assert value["metadata"]["title"] == "hello"
def test_filter_returns_user_id_dict():
result = asyncio.run(add_owner_filter(_make_ctx("user-x"), {}))
assert result == {"owner_id": "user-x"}
def test_filter_read_write_consistency():
value = {}
filter_dict = asyncio.run(add_owner_filter(_make_ctx("user-1"), value))
assert value["metadata"]["owner_id"] == filter_dict["owner_id"]
def test_different_users_different_filters():
f_a = asyncio.run(add_owner_filter(_make_ctx("a"), {}))
f_b = asyncio.run(add_owner_filter(_make_ctx("b"), {}))
assert f_a["owner_id"] != f_b["owner_id"]
def test_filter_overrides_conflicting_user_id():
"""If value already has a different user_id in metadata, it gets overwritten."""
value = {"metadata": {"owner_id": "attacker"}}
asyncio.run(add_owner_filter(_make_ctx("real-owner"), value))
assert value["metadata"]["owner_id"] == "real-owner"
def test_filter_with_empty_metadata():
"""Explicit empty metadata dict is fine."""
value = {"metadata": {}}
result = asyncio.run(add_owner_filter(_make_ctx("user-z"), value))
assert value["metadata"]["owner_id"] == "user-z"
assert result == {"owner_id": "user-z"}
# ── Gateway parity ───────────────────────────────────────────────────────
def test_shared_jwt_secret():
token = create_access_token("user-1", token_version=3)
payload = decode_token(token)
from app.gateway.auth.errors import TokenError
assert not isinstance(payload, TokenError)
assert payload.sub == "user-1"
assert payload.ver == 3
def test_langgraph_json_has_auth_path():
import json
config = json.loads((Path(__file__).parent.parent / "langgraph.json").read_text())
assert "auth" in config
assert "langgraph_auth" in config["auth"]["path"]
def test_auth_handler_has_both_layers():
from app.gateway.langgraph_auth import auth
assert auth._authenticate_handler is not None
assert len(auth._global_handlers) == 1
# ── CSRF in LangGraph auth ──────────────────────────────────────────────
def test_csrf_get_no_check():
"""GET requests skip CSRF — should proceed to JWT validation."""
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req(method="GET")))
# Rejected by missing cookie, NOT by CSRF
assert exc.value.status_code == 401
assert "Not authenticated" in str(exc.value.detail)
def test_csrf_post_missing_token():
"""POST without CSRF token → 403."""
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req(method="POST", cookies={"access_token": "some-jwt"})))
assert exc.value.status_code == 403
assert "CSRF token missing" in str(exc.value.detail)
def test_csrf_post_mismatched_token():
"""POST with mismatched CSRF tokens → 403."""
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(
authenticate(
_req(
method="POST",
cookies={"access_token": "some-jwt", "csrf_token": "real-token"},
headers={"x-csrf-token": "wrong-token"},
)
)
)
assert exc.value.status_code == 403
assert "mismatch" in str(exc.value.detail)
def test_csrf_post_matching_token_proceeds_to_jwt():
"""POST with matching CSRF tokens passes CSRF check, then fails on JWT."""
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(
authenticate(
_req(
method="POST",
cookies={"access_token": "garbage", "csrf_token": "same-token"},
headers={"x-csrf-token": "same-token"},
)
)
)
# Past CSRF, rejected by JWT decode
assert exc.value.status_code == 401
assert "Token error" in str(exc.value.detail)
def test_csrf_put_requires_token():
"""PUT also requires CSRF."""
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req(method="PUT", cookies={"access_token": "jwt"})))
assert exc.value.status_code == 403
def test_csrf_delete_requires_token():
"""DELETE also requires CSRF."""
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req(method="DELETE", cookies={"access_token": "jwt"})))
assert exc.value.status_code == 403

View File

@ -52,7 +52,6 @@
"@xyflow/react": "^12.10.0",
"ai": "^6.0.33",
"best-effort-json-parser": "^1.2.1",
"better-auth": "^1.3",
"canvas-confetti": "^1.9.4",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",

183
frontend/pnpm-lock.yaml generated
View File

@ -113,9 +113,6 @@ importers:
best-effort-json-parser:
specifier: ^1.2.1
version: 1.2.1
better-auth:
specifier: ^1.3
version: 1.4.18(next@16.1.7(@opentelemetry/api@1.9.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(vue@3.5.28(typescript@5.9.3))
canvas-confetti:
specifier: ^1.9.4
version: 1.9.4
@ -317,27 +314,6 @@ packages:
resolution: {integrity: sha512-LwdZHpScM4Qz8Xw2iKSzS+cfglZzJGvofQICy7W7v4caru4EaAmyUuO6BGrbyQ2mYV11W0U8j5mBhd14dd3B0A==}
engines: {node: '>=6.9.0'}
'@better-auth/core@1.4.18':
resolution: {integrity: sha512-q+awYgC7nkLEBdx2sW0iJjkzgSHlIxGnOpsN1r/O1+a4m7osJNHtfK2mKJSL1I+GfNyIlxJF8WvD/NLuYMpmcg==}
peerDependencies:
'@better-auth/utils': 0.3.0
'@better-fetch/fetch': 1.1.21
better-call: 1.1.8
jose: ^6.1.0
kysely: ^0.28.5
nanostores: ^1.0.1
'@better-auth/telemetry@1.4.18':
resolution: {integrity: sha512-e5rDF8S4j3Um/0LIVATL2in9dL4lfO2fr2v1Wio4qTMRbfxqnUDTa+6SZtwdeJrbc4O+a3c+IyIpjG9Q/6GpfQ==}
peerDependencies:
'@better-auth/core': 1.4.18
'@better-auth/utils@0.3.0':
resolution: {integrity: sha512-W+Adw6ZA6mgvnSnhOki270rwJ42t4XzSK6YWGF//BbVXL6SwCLWfyzBc1lN2m/4RM28KubdBKQ4X5VMoLRNPQw==}
'@better-fetch/fetch@1.1.21':
resolution: {integrity: sha512-/ImESw0sskqlVR94jB+5+Pxjf+xBwDZF/N5+y2/q4EqD7IARUTSpPfIo8uf39SYpCxyOCtbyYpUrZ3F/k0zT4A==}
'@braintree/sanitize-url@7.1.2':
resolution: {integrity: sha512-jigsZK+sMF/cuiB7sERuo9V7N9jx+dhmHHnQyDSVdpZwVutaBu7WvNYqMDLSgFgfB30n452TP3vjDAvFC973mA==}
@ -1116,14 +1092,6 @@ packages:
cpu: [x64]
os: [win32]
'@noble/ciphers@2.1.1':
resolution: {integrity: sha512-bysYuiVfhxNJuldNXlFEitTVdNnYUc+XNJZd7Qm2a5j1vZHgY+fazadNFWFaMK/2vye0JVlxV3gHmC0WDfAOQw==}
engines: {node: '>= 20.19.0'}
'@noble/hashes@2.0.1':
resolution: {integrity: sha512-XlOlEbQcE9fmuXxrVTXCTlG2nlRXa9Rj3rr5Ue/+tX+nmkgbX720YHh0VR3hBF9xDvwnb8D2shVGOwNx+ulArw==}
engines: {node: '>= 20.19.0'}
'@nodelib/fs.scandir@2.1.5':
resolution: {integrity: sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==}
engines: {node: '>= 8'}
@ -2696,76 +2664,6 @@ packages:
best-effort-json-parser@1.2.1:
resolution: {integrity: sha512-UICSLibQdzS1f+PBsi3u2YE3SsdXcWicHUg3IMvfuaePS2AYnZJdJeKhGv5OM8/mqJwPt79aDrEJ1oa84tELvw==}
better-auth@1.4.18:
resolution: {integrity: sha512-bnyifLWBPcYVltH3RhS7CM62MoelEqC6Q+GnZwfiDWNfepXoQZBjEvn4urcERC7NTKgKq5zNBM8rvPvRBa6xcg==}
peerDependencies:
'@lynx-js/react': '*'
'@prisma/client': ^5.0.0 || ^6.0.0 || ^7.0.0
'@sveltejs/kit': ^2.0.0
'@tanstack/react-start': ^1.0.0
'@tanstack/solid-start': ^1.0.0
better-sqlite3: ^12.0.0
drizzle-kit: '>=0.31.4'
drizzle-orm: '>=0.41.0'
mongodb: ^6.0.0 || ^7.0.0
mysql2: ^3.0.0
next: ^14.0.0 || ^15.0.0 || ^16.0.0
pg: ^8.0.0
prisma: ^5.0.0 || ^6.0.0 || ^7.0.0
react: ^18.0.0 || ^19.0.0
react-dom: ^18.0.0 || ^19.0.0
solid-js: ^1.0.0
svelte: ^4.0.0 || ^5.0.0
vitest: ^2.0.0 || ^3.0.0 || ^4.0.0
vue: ^3.0.0
peerDependenciesMeta:
'@lynx-js/react':
optional: true
'@prisma/client':
optional: true
'@sveltejs/kit':
optional: true
'@tanstack/react-start':
optional: true
'@tanstack/solid-start':
optional: true
better-sqlite3:
optional: true
drizzle-kit:
optional: true
drizzle-orm:
optional: true
mongodb:
optional: true
mysql2:
optional: true
next:
optional: true
pg:
optional: true
prisma:
optional: true
react:
optional: true
react-dom:
optional: true
solid-js:
optional: true
svelte:
optional: true
vitest:
optional: true
vue:
optional: true
better-call@1.1.8:
resolution: {integrity: sha512-XMQ2rs6FNXasGNfMjzbyroSwKwYbZ/T3IxruSS6U2MJRsSYh3wYtG3o6H00ZlKZ/C/UPOAD97tqgQJNsxyeTXw==}
peerDependencies:
zod: ^4.0.0
peerDependenciesMeta:
zod:
optional: true
better-react-mathjax@2.3.0:
resolution: {integrity: sha512-K0ceQC+jQmB+NLDogO5HCpqmYf18AU2FxDbLdduYgkHYWZApFggkHE4dIaXCV1NqeoscESYXXo1GSkY6fA295w==}
peerDependencies:
@ -3973,9 +3871,6 @@ packages:
resolution: {integrity: sha512-ekilCSN1jwRvIbgeg/57YFh8qQDNbwDb9xT/qu2DAHbFFZUicIl4ygVaAvzveMhMVr3LnpSKTNnwt8PoOfmKhQ==}
hasBin: true
jose@6.1.3:
resolution: {integrity: sha512-0TpaTfihd4QMNwrz/ob2Bp7X04yuxJkjRGi4aKmOqwhov54i6u79oCv7T+C7lo70MKH6BesI3vscD1yb/yzKXQ==}
js-tiktoken@1.0.21:
resolution: {integrity: sha512-biOj/6M5qdgx5TKjDnFT1ymSpM5tbd3ylwDtrQvFQSu0Z7bBYko2dF+W/aUkXUPuk6IVpRxk/3Q2sHOzGlS36g==}
@ -4026,10 +3921,6 @@ packages:
knitwork@1.3.0:
resolution: {integrity: sha512-4LqMNoONzR43B1W0ek0fhXMsDNW/zxa1NdFAVMY+k28pgZLovR4G3PB5MrpTxCy1QaZCqNoiaKPr5w5qZHfSNw==}
kysely@0.28.11:
resolution: {integrity: sha512-zpGIFg0HuoC893rIjYX1BETkVWdDnzTzF5e0kWXJFg5lE0k1/LfNWBejrcnOFu8Q2Rfq/hTDTU7XLUM8QOrpzg==}
engines: {node: '>=20.0.0'}
langium@3.3.1:
resolution: {integrity: sha512-QJv/h939gDpvT+9SiLVlY7tZC3xB2qK57v0J04Sh9wpMb6MP1q8gB21L3WIo8T5P1MSMg3Ep14L7KkDCFG3y4w==}
engines: {node: '>=16.0.0'}
@ -4458,10 +4349,6 @@ packages:
engines: {node: ^18 || >=20}
hasBin: true
nanostores@1.1.0:
resolution: {integrity: sha512-yJBmDJr18xy47dbNVlHcgdPrulSn1nhSE6Ns9vTG+Nx9VPT6iV1MD6aQFp/t52zpf82FhLLTXAXr30NuCnxvwA==}
engines: {node: ^20.0.0 || >=22.0.0}
napi-postinstall@0.3.4:
resolution: {integrity: sha512-PHI5f1O0EP5xJ9gQmFGMS6IZcrVvTjpXjz7Na41gTE7eE2hK11lg04CECCYEEjdc17EV4DO+fkGEtt7TpTaTiQ==}
engines: {node: ^12.20.0 || ^14.18.0 || >=16.0.0}
@ -5050,9 +4937,6 @@ packages:
engines: {node: '>=18.0.0', npm: '>=8.0.0'}
hasBin: true
rou3@0.7.12:
resolution: {integrity: sha512-iFE4hLDuloSWcD7mjdCDhx2bKcIsYbtOTpfH5MHHLSKMOUyjqQXTeZVa289uuwEGEKFoE/BAPbhaU4B774nceg==}
roughjs@4.6.6:
resolution: {integrity: sha512-ZUz/69+SYpFN/g/lUlo2FXcIjRkSu3nDarreVdGGndHEBJ6cXPdKguS8JGxwj5HA5xIbVKSmLgr5b3AWxtRfvQ==}
@ -5105,9 +4989,6 @@ packages:
server-only@0.0.1:
resolution: {integrity: sha512-qepMx2JxAa5jjfzxG79yPPq+8BuFToHd1hm7kI+Z4zAq1ftQiP7HcxMhDDItrbtwVeLg/cY2JnKnrcFkmiswNA==}
set-cookie-parser@2.7.2:
resolution: {integrity: sha512-oeM1lpU/UvhTxw+g3cIfxXHyJRc/uidd3yK1P242gzHds0udQBYzs3y8j4gCCW+ZJ7ad0yctld8RYO+bdurlvw==}
set-function-length@1.2.2:
resolution: {integrity: sha512-pgRc4hJ4/sNjWCSS9AmnS40x3bNMDTknHgL5UaMBTMyJnU90EgWh1Rz+MC9eFu4BuN/UwZjKQuY/1v3rM7HMfg==}
engines: {node: '>= 0.4'}
@ -5802,27 +5683,6 @@ snapshots:
'@babel/helper-string-parser': 7.27.1
'@babel/helper-validator-identifier': 7.28.5
'@better-auth/core@1.4.18(@better-auth/utils@0.3.0)(@better-fetch/fetch@1.1.21)(better-call@1.1.8(zod@3.25.76))(jose@6.1.3)(kysely@0.28.11)(nanostores@1.1.0)':
dependencies:
'@better-auth/utils': 0.3.0
'@better-fetch/fetch': 1.1.21
'@standard-schema/spec': 1.1.0
better-call: 1.1.8(zod@4.3.6)
jose: 6.1.3
kysely: 0.28.11
nanostores: 1.1.0
zod: 4.3.6
'@better-auth/telemetry@1.4.18(@better-auth/core@1.4.18(@better-auth/utils@0.3.0)(@better-fetch/fetch@1.1.21)(better-call@1.1.8(zod@3.25.76))(jose@6.1.3)(kysely@0.28.11)(nanostores@1.1.0))':
dependencies:
'@better-auth/core': 1.4.18(@better-auth/utils@0.3.0)(@better-fetch/fetch@1.1.21)(better-call@1.1.8(zod@3.25.76))(jose@6.1.3)(kysely@0.28.11)(nanostores@1.1.0)
'@better-auth/utils': 0.3.0
'@better-fetch/fetch': 1.1.21
'@better-auth/utils@0.3.0': {}
'@better-fetch/fetch@1.1.21': {}
'@braintree/sanitize-url@7.1.2': {}
'@cfworker/json-schema@4.1.1': {}
@ -6671,10 +6531,6 @@ snapshots:
'@next/swc-win32-x64-msvc@16.1.7':
optional: true
'@noble/ciphers@2.1.1': {}
'@noble/hashes@2.0.1': {}
'@nodelib/fs.scandir@2.1.5':
dependencies:
'@nodelib/fs.stat': 2.0.5
@ -8242,35 +8098,6 @@ snapshots:
best-effort-json-parser@1.2.1: {}
better-auth@1.4.18(next@16.1.7(@opentelemetry/api@1.9.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(vue@3.5.28(typescript@5.9.3)):
dependencies:
'@better-auth/core': 1.4.18(@better-auth/utils@0.3.0)(@better-fetch/fetch@1.1.21)(better-call@1.1.8(zod@3.25.76))(jose@6.1.3)(kysely@0.28.11)(nanostores@1.1.0)
'@better-auth/telemetry': 1.4.18(@better-auth/core@1.4.18(@better-auth/utils@0.3.0)(@better-fetch/fetch@1.1.21)(better-call@1.1.8(zod@3.25.76))(jose@6.1.3)(kysely@0.28.11)(nanostores@1.1.0))
'@better-auth/utils': 0.3.0
'@better-fetch/fetch': 1.1.21
'@noble/ciphers': 2.1.1
'@noble/hashes': 2.0.1
better-call: 1.1.8(zod@4.3.6)
defu: 6.1.4
jose: 6.1.3
kysely: 0.28.11
nanostores: 1.1.0
zod: 4.3.6
optionalDependencies:
next: 16.1.7(@opentelemetry/api@1.9.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
react: 19.2.4
react-dom: 19.2.4(react@19.2.4)
vue: 3.5.28(typescript@5.9.3)
better-call@1.1.8(zod@4.3.6):
dependencies:
'@better-auth/utils': 0.3.0
'@better-fetch/fetch': 1.1.21
rou3: 0.7.12
set-cookie-parser: 2.7.2
optionalDependencies:
zod: 4.3.6
better-react-mathjax@2.3.0(react@19.2.4):
dependencies:
mathjax-full: 3.2.2
@ -9786,8 +9613,6 @@ snapshots:
jiti@2.6.1: {}
jose@6.1.3: {}
js-tiktoken@1.0.21:
dependencies:
base64-js: 1.5.1
@ -9833,8 +9658,6 @@ snapshots:
knitwork@1.3.0: {}
kysely@0.28.11: {}
langium@3.3.1:
dependencies:
chevrotain: 11.0.3
@ -10529,8 +10352,6 @@ snapshots:
nanoid@5.1.6: {}
nanostores@1.1.0: {}
napi-postinstall@0.3.4: {}
natural-compare@1.4.0: {}
@ -11305,8 +11126,6 @@ snapshots:
'@rollup/rollup-win32-x64-msvc': 4.60.0
fsevents: 2.3.3
rou3@0.7.12: {}
roughjs@4.6.6:
dependencies:
hachure-fill: 0.5.2
@ -11373,8 +11192,6 @@ snapshots:
server-only@0.0.1: {}
set-cookie-parser@2.7.2: {}
set-function-length@1.2.2:
dependencies:
define-data-property: 1.1.4

View File

@ -0,0 +1,45 @@
import Link from "next/link";
import { redirect } from "next/navigation";
import { type ReactNode } from "react";
import { AuthProvider } from "@/core/auth/AuthProvider";
import { getServerSideUser } from "@/core/auth/server";
import { assertNever } from "@/core/auth/types";
export const dynamic = "force-dynamic";
export default async function AuthLayout({
children,
}: {
children: ReactNode;
}) {
const result = await getServerSideUser();
switch (result.tag) {
case "authenticated":
redirect("/workspace");
case "needs_setup":
// Allow access to setup page
return <AuthProvider initialUser={result.user}>{children}</AuthProvider>;
case "unauthenticated":
return <AuthProvider initialUser={null}>{children}</AuthProvider>;
case "gateway_unavailable":
return (
<div className="flex h-screen flex-col items-center justify-center gap-4">
<p className="text-muted-foreground">
Service temporarily unavailable.
</p>
<Link
href="/login"
className="bg-primary text-primary-foreground hover:bg-primary/90 rounded-md px-4 py-2 text-sm"
>
Retry
</Link>
</div>
);
case "config_error":
throw new Error(result.message);
default:
assertNever(result);
}
}

View File

@ -0,0 +1,183 @@
"use client";
import Link from "next/link";
import { useRouter, useSearchParams } from "next/navigation";
import { useEffect, useState } from "react";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { useAuth } from "@/core/auth/AuthProvider";
import { parseAuthError } from "@/core/auth/types";
/**
* Validate next parameter
* Prevent open redirect attacks
* Per RFC-001: Only allow relative paths starting with /
*/
function validateNextParam(next: string | null): string | null {
if (!next) {
return null;
}
// Need start with / (relative path)
if (!next.startsWith("/")) {
return null;
}
// Disallow protocol-relative URLs
if (
next.startsWith("//") ||
next.startsWith("http://") ||
next.startsWith("https://")
) {
return null;
}
// Disallow URLs with different protocols (e.g., javascript:, data:, etc)
if (next.includes(":") && !next.startsWith("/")) {
return null;
}
// Valid relative path
return next;
}
export default function LoginPage() {
const router = useRouter();
const searchParams = useSearchParams();
const { isAuthenticated } = useAuth();
const [email, setEmail] = useState("");
const [password, setPassword] = useState("");
const [isLogin, setIsLogin] = useState(true);
const [error, setError] = useState("");
const [loading, setLoading] = useState(false);
// Get next parameter for validated redirect
const nextParam = searchParams.get("next");
const redirectPath = validateNextParam(nextParam) ?? "/workspace";
// Redirect if already authenticated (client-side, post-login)
useEffect(() => {
if (isAuthenticated) {
router.push(redirectPath);
}
}, [isAuthenticated, redirectPath, router]);
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault();
setError("");
setLoading(true);
try {
const endpoint = isLogin
? "/api/v1/auth/login/local"
: "/api/v1/auth/register";
const body = isLogin
? `username=${encodeURIComponent(email)}&password=${encodeURIComponent(password)}`
: JSON.stringify({ email, password });
const headers: HeadersInit = isLogin
? { "Content-Type": "application/x-www-form-urlencoded" }
: { "Content-Type": "application/json" };
const res = await fetch(endpoint, {
method: "POST",
headers,
body,
credentials: "include", // Important: include HttpOnly cookie
});
if (!res.ok) {
const data = await res.json();
const authError = parseAuthError(data);
setError(authError.message);
return;
}
// Both login and register set a cookie — redirect to workspace
router.push(redirectPath);
} catch (_err) {
setError("Network error. Please try again.");
} finally {
setLoading(false);
}
};
return (
<div className="flex min-h-screen items-center justify-center bg-[#0a0a0a]">
<div className="border-border/20 w-full max-w-md space-y-6 rounded-lg border bg-black/50 p-8 backdrop-blur-sm">
<div className="text-center">
<h1 className="font-serif text-3xl">DeerFlow</h1>
<p className="text-muted-foreground mt-2">
{isLogin ? "Sign in to your account" : "Create a new account"}
</p>
</div>
<form onSubmit={handleSubmit} className="space-y-4">
<div>
<label htmlFor="email" className="text-sm font-medium">
Email
</label>
<Input
id="email"
type="email"
value={email}
onChange={(e) => setEmail(e.target.value)}
placeholder="you@example.com"
required
className="mt-1 bg-white text-black"
/>
</div>
<div>
<label htmlFor="password" className="text-sm font-medium">
Password
</label>
<Input
id="password"
type="password"
value={password}
onChange={(e) => setPassword(e.target.value)}
placeholder="•••••••"
required
minLength={isLogin ? 6 : 8}
className="mt-1 bg-white text-black"
/>
</div>
{error && <p className="text-sm text-red-500">{error}</p>}
<Button type="submit" className="w-full" disabled={loading}>
{loading
? "Please wait..."
: isLogin
? "Sign In"
: "Create Account"}
</Button>
</form>
<div className="text-center text-sm">
<button
type="button"
onClick={() => {
setIsLogin(!isLogin);
setError("");
}}
className="text-blue-500 hover:underline"
>
{isLogin
? "Don't have an account? Sign up"
: "Already have an account? Sign in"}
</button>
</div>
<div className="text-muted-foreground text-center text-xs">
<Link href="/" className="hover:underline">
Back to home
</Link>
</div>
</div>
</div>
);
}

View File

@ -0,0 +1,115 @@
"use client";
import { useRouter } from "next/navigation";
import { useState } from "react";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { getCsrfHeaders } from "@/core/api/fetcher";
import { parseAuthError } from "@/core/auth/types";
export default function SetupPage() {
const router = useRouter();
const [email, setEmail] = useState("");
const [newPassword, setNewPassword] = useState("");
const [confirmPassword, setConfirmPassword] = useState("");
const [currentPassword, setCurrentPassword] = useState("");
const [error, setError] = useState("");
const [loading, setLoading] = useState(false);
const handleSetup = async (e: React.FormEvent) => {
e.preventDefault();
setError("");
if (newPassword !== confirmPassword) {
setError("Passwords do not match");
return;
}
if (newPassword.length < 8) {
setError("Password must be at least 8 characters");
return;
}
setLoading(true);
try {
const res = await fetch("/api/v1/auth/change-password", {
method: "POST",
headers: {
"Content-Type": "application/json",
...getCsrfHeaders(),
},
credentials: "include",
body: JSON.stringify({
current_password: currentPassword,
new_password: newPassword,
new_email: email || undefined,
}),
});
if (!res.ok) {
const data = await res.json();
const authError = parseAuthError(data);
setError(authError.message);
return;
}
router.push("/workspace");
} catch {
setError("Network error. Please try again.");
} finally {
setLoading(false);
}
};
return (
<div className="flex min-h-screen items-center justify-center">
<div className="w-full max-w-sm space-y-6 p-6">
<div className="text-center">
<h1 className="font-serif text-3xl">DeerFlow</h1>
<p className="text-muted-foreground mt-2">
Complete admin account setup
</p>
<p className="text-muted-foreground mt-1 text-xs">
Set your real email and a new password.
</p>
</div>
<form onSubmit={handleSetup} className="space-y-4">
<Input
type="email"
placeholder="Your email"
value={email}
onChange={(e) => setEmail(e.target.value)}
required
/>
<Input
type="password"
placeholder="Current password (from console log)"
value={currentPassword}
onChange={(e) => setCurrentPassword(e.target.value)}
required
/>
<Input
type="password"
placeholder="New password"
value={newPassword}
onChange={(e) => setNewPassword(e.target.value)}
required
minLength={8}
/>
<Input
type="password"
placeholder="Confirm new password"
value={confirmPassword}
onChange={(e) => setConfirmPassword(e.target.value)}
required
minLength={8}
/>
{error && <p className="text-sm text-red-500">{error}</p>}
<Button type="submit" className="w-full" disabled={loading}>
{loading ? "Setting up..." : "Complete Setup"}
</Button>
</form>
</div>
</div>
);
}

View File

@ -1,5 +0,0 @@
import { toNextJsHandler } from "better-auth/next-js";
import { auth } from "@/server/better-auth";
export const { GET, POST } = toNextJsHandler(auth.handler);

View File

@ -1,47 +1,58 @@
"use client";
import Link from "next/link";
import { redirect } from "next/navigation";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import { useCallback, useEffect, useLayoutEffect, useState } from "react";
import { Toaster } from "sonner";
import { AuthProvider } from "@/core/auth/AuthProvider";
import { getServerSideUser } from "@/core/auth/server";
import { assertNever } from "@/core/auth/types";
import { SidebarInset, SidebarProvider } from "@/components/ui/sidebar";
import { CommandPalette } from "@/components/workspace/command-palette";
import { WorkspaceSidebar } from "@/components/workspace/workspace-sidebar";
import { getLocalSettings, useLocalSettings } from "@/core/settings";
import { WorkspaceContent } from "./workspace-content";
const queryClient = new QueryClient();
export const dynamic = "force-dynamic";
export default function WorkspaceLayout({
export default async function WorkspaceLayout({
children,
}: Readonly<{ children: React.ReactNode }>) {
const [settings, setSettings] = useLocalSettings();
const [open, setOpen] = useState(false); // SSR default: open (matches server render)
useLayoutEffect(() => {
// Runs synchronously before first paint on the client — no visual flash
setOpen(!getLocalSettings().layout.sidebar_collapsed);
}, []);
useEffect(() => {
setOpen(!settings.layout.sidebar_collapsed);
}, [settings.layout.sidebar_collapsed]);
const handleOpenChange = useCallback(
(open: boolean) => {
setOpen(open);
setSettings("layout", { sidebar_collapsed: !open });
},
[setSettings],
);
return (
<QueryClientProvider client={queryClient}>
<SidebarProvider
className="h-screen"
open={open}
onOpenChange={handleOpenChange}
>
<WorkspaceSidebar />
<SidebarInset className="min-w-0">{children}</SidebarInset>
</SidebarProvider>
<CommandPalette />
<Toaster position="top-center" />
</QueryClientProvider>
);
const result = await getServerSideUser();
switch (result.tag) {
case "authenticated":
return (
<AuthProvider initialUser={result.user}>
<WorkspaceContent>{children}</WorkspaceContent>
</AuthProvider>
);
case "needs_setup":
redirect("/setup");
case "unauthenticated":
redirect("/login");
case "gateway_unavailable":
return (
<div className="flex h-screen flex-col items-center justify-center gap-4">
<p className="text-muted-foreground">
Service temporarily unavailable.
</p>
<p className="text-muted-foreground text-xs">
The backend may be restarting. Please wait a moment and try again.
</p>
<div className="flex gap-3">
<Link
href="/workspace"
className="bg-primary text-primary-foreground hover:bg-primary/90 rounded-md px-4 py-2 text-sm"
>
Retry
</Link>
<Link
href="/api/v1/auth/logout"
className="text-muted-foreground hover:bg-muted rounded-md border px-4 py-2 text-sm"
>
Logout &amp; Reset
</Link>
</div>
</div>
);
case "config_error":
throw new Error(result.message);
default:
assertNever(result);
}
}

View File

@ -0,0 +1,50 @@
"use client";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import { useCallback, useEffect, useLayoutEffect, useState } from "react";
import { Toaster } from "sonner";
import { SidebarInset, SidebarProvider } from "@/components/ui/sidebar";
import { CommandPalette } from "@/components/workspace/command-palette";
import { WorkspaceSidebar } from "@/components/workspace/workspace-sidebar";
import { getLocalSettings, useLocalSettings } from "@/core/settings";
export function WorkspaceContent({
children,
}: Readonly<{ children: React.ReactNode }>) {
const [queryClient] = useState(() => new QueryClient());
const [settings, setSettings] = useLocalSettings();
const [open, setOpen] = useState(false); // SSR default: open (matches server render)
useLayoutEffect(() => {
// Runs synchronously before first paint on the client — no visual flash
setOpen(!getLocalSettings().layout.sidebar_collapsed);
}, []);
useEffect(() => {
setOpen(!settings.layout.sidebar_collapsed);
}, [settings.layout.sidebar_collapsed]);
const handleOpenChange = useCallback(
(open: boolean) => {
setOpen(open);
setSettings("layout", { sidebar_collapsed: !open });
},
[setSettings],
);
return (
<QueryClientProvider client={queryClient}>
<SidebarProvider
className="h-screen"
open={open}
onOpenChange={handleOpenChange}
>
<WorkspaceSidebar />
<SidebarInset className="min-w-0">{children}</SidebarInset>
</SidebarProvider>
<CommandPalette />
<Toaster position="top-center" />
</QueryClientProvider>
);
}

View File

@ -0,0 +1,39 @@
import { buildLoginUrl } from "@/core/auth/types";
/**
* Fetch with credentials. Automatically redirects to login on 401.
*/
export async function fetchWithAuth(
input: RequestInfo | string,
init?: RequestInit,
): Promise<Response> {
const url = typeof input === "string" ? input : input.url;
const res = await fetch(url, {
...init,
credentials: "include",
});
if (res.status === 401) {
window.location.href = buildLoginUrl(window.location.pathname);
throw new Error("Unauthorized");
}
return res;
}
/**
* Build headers for CSRF-protected requests
* Per RFC-001: Double Submit Cookie pattern
*/
export function getCsrfHeaders(): HeadersInit {
const token = getCsrfToken();
return token ? { "X-CSRF-Token": token } : {};
}
/**
* Get CSRF token from cookie
*/
function getCsrfToken(): string | null {
const match = /csrf_token=([^;]+)/.exec(document.cookie);
return match?.[1] ?? null;
}

View File

@ -0,0 +1,165 @@
"use client";
import { useRouter, usePathname } from "next/navigation";
import React, {
createContext,
useContext,
useState,
useCallback,
useEffect,
type ReactNode,
} from "react";
import { type User, buildLoginUrl } from "./types";
// Re-export for consumers
export type { User };
/**
* Authentication context provided to consuming components
*/
interface AuthContextType {
user: User | null;
isAuthenticated: boolean;
isLoading: boolean;
logout: () => Promise<void>;
refreshUser: () => Promise<void>;
}
const AuthContext = createContext<AuthContextType | undefined>(undefined);
interface AuthProviderProps {
children: ReactNode;
initialUser: User | null;
}
/**
* AuthProvider - Unified authentication context for the application
*
* Per RFC-001:
* - Only holds display information (user), never JWT or tokens
* - initialUser comes from server-side guard, avoiding client flicker
* - Provides logout and refresh capabilities
*/
export function AuthProvider({ children, initialUser }: AuthProviderProps) {
const [user, setUser] = useState<User | null>(initialUser);
const [isLoading, setIsLoading] = useState(false);
const router = useRouter();
const pathname = usePathname();
const isAuthenticated = user !== null;
/**
* Fetch current user from FastAPI
* Used when initialUser might be stale (e.g., after tab was inactive)
*/
const refreshUser = useCallback(async () => {
try {
setIsLoading(true);
const res = await fetch("/api/v1/auth/me", {
credentials: "include",
});
if (res.ok) {
const data = await res.json();
setUser(data);
} else if (res.status === 401) {
// Session expired or invalid
setUser(null);
// Redirect to login if on a protected route
if (pathname?.startsWith("/workspace")) {
router.push(buildLoginUrl(pathname));
}
}
} catch (err) {
console.error("Failed to refresh user:", err);
setUser(null);
} finally {
setIsLoading(false);
}
}, [pathname, router]);
/**
* Logout - call FastAPI logout endpoint and clear local state
* Per RFC-001: Immediately clear local state, don't wait for server confirmation
*/
const logout = useCallback(async () => {
// Immediately clear local state to prevent UI flicker
setUser(null);
try {
await fetch("/api/v1/auth/logout", {
method: "POST",
credentials: "include",
});
} catch (err) {
console.error("Logout request failed:", err);
// Still redirect even if logout request fails
}
// Redirect to home page
router.push("/");
}, [router]);
/**
* Handle visibility change - refresh user when tab becomes visible again.
* Throttled to at most once per 60 s to avoid spamming the backend on rapid tab switches.
*/
const lastCheckRef = React.useRef(0);
useEffect(() => {
const handleVisibilityChange = () => {
if (document.visibilityState !== "visible" || user === null) return;
const now = Date.now();
if (now - lastCheckRef.current < 60_000) return;
lastCheckRef.current = now;
void refreshUser();
};
document.addEventListener("visibilitychange", handleVisibilityChange);
return () => {
document.removeEventListener("visibilitychange", handleVisibilityChange);
};
}, [user, refreshUser]);
const value: AuthContextType = {
user,
isAuthenticated,
isLoading,
logout,
refreshUser,
};
return <AuthContext.Provider value={value}>{children}</AuthContext.Provider>;
}
/**
* Hook to access authentication context
* Throws if used outside AuthProvider - this is intentional for proper usage
*/
export function useAuth(): AuthContextType {
const context = useContext(AuthContext);
if (context === undefined) {
throw new Error("useAuth must be used within an AuthProvider");
}
return context;
}
/**
* Hook to require authentication - redirects to login if not authenticated
* Useful for client-side checks in addition to server-side guards
*/
export function useRequireAuth(): AuthContextType {
const auth = useAuth();
const router = useRouter();
const pathname = usePathname();
useEffect(() => {
// Only redirect if we're sure user is not authenticated (not just loading)
if (!auth.isLoading && !auth.isAuthenticated) {
router.push(buildLoginUrl(pathname || "/workspace"));
}
}, [auth.isAuthenticated, auth.isLoading, router, pathname]);
return auth;
}

View File

@ -0,0 +1,34 @@
import { z } from "zod";
const gatewayConfigSchema = z.object({
internalGatewayUrl: z.string().url(),
trustedOrigins: z.array(z.string()).min(1),
});
export type GatewayConfig = z.infer<typeof gatewayConfigSchema>;
let _cached: GatewayConfig | null = null;
export function getGatewayConfig(): GatewayConfig {
if (_cached) return _cached;
const isDev = process.env.NODE_ENV === "development";
const rawUrl = process.env.DEER_FLOW_INTERNAL_GATEWAY_BASE_URL?.trim();
const internalGatewayUrl =
rawUrl?.replace(/\/+$/, "") ??
(isDev ? "http://localhost:8001" : undefined);
const rawOrigins = process.env.DEER_FLOW_TRUSTED_ORIGINS?.trim();
const trustedOrigins = rawOrigins
? rawOrigins
.split(",")
.map((s) => s.trim())
.filter(Boolean)
: isDev
? ["http://localhost:3000"]
: undefined;
_cached = gatewayConfigSchema.parse({ internalGatewayUrl, trustedOrigins });
return _cached;
}

View File

@ -0,0 +1,55 @@
export interface ProxyPolicy {
/** Allowed upstream path prefixes */
readonly allowedPaths: readonly string[];
/** Request headers to strip before forwarding */
readonly strippedRequestHeaders: ReadonlySet<string>;
/** Response headers to strip before returning */
readonly strippedResponseHeaders: ReadonlySet<string>;
/** Credential mode: which cookie to forward */
readonly credential: { readonly type: "cookie"; readonly name: string };
/** Timeout in ms */
readonly timeoutMs: number;
/** CSRF: required for non-GET/HEAD */
readonly csrf: boolean;
}
export const LANGGRAPH_COMPAT_POLICY: ProxyPolicy = {
allowedPaths: [
"threads",
"runs",
"assistants",
"store",
"models",
"mcp",
"skills",
"memory",
],
strippedRequestHeaders: new Set([
"host",
"connection",
"keep-alive",
"transfer-encoding",
"te",
"trailer",
"upgrade",
"authorization",
"x-api-key",
"origin",
"referer",
"proxy-authorization",
"proxy-authenticate",
]),
strippedResponseHeaders: new Set([
"connection",
"keep-alive",
"transfer-encoding",
"te",
"trailer",
"upgrade",
"content-length",
"set-cookie",
]),
credential: { type: "cookie", name: "access_token" },
timeoutMs: 120_000,
csrf: true,
};

View File

@ -0,0 +1,57 @@
import { cookies } from "next/headers";
import { getGatewayConfig } from "./gateway-config";
import { type AuthResult, userSchema } from "./types";
const SSR_AUTH_TIMEOUT_MS = 5_000;
/**
* Fetch the authenticated user from the gateway using the request's cookies.
* Returns a tagged AuthResult callers use exhaustive switch, no try/catch.
*/
export async function getServerSideUser(): Promise<AuthResult> {
const cookieStore = await cookies();
const sessionCookie = cookieStore.get("access_token");
let internalGatewayUrl: string;
try {
internalGatewayUrl = getGatewayConfig().internalGatewayUrl;
} catch (err) {
return { tag: "config_error", message: String(err) };
}
if (!sessionCookie) return { tag: "unauthenticated" };
const controller = new AbortController();
const timeout = setTimeout(() => controller.abort(), SSR_AUTH_TIMEOUT_MS);
try {
const res = await fetch(`${internalGatewayUrl}/api/v1/auth/me`, {
headers: { Cookie: `access_token=${sessionCookie.value}` },
cache: "no-store",
signal: controller.signal,
});
clearTimeout(timeout); // Clear immediately — covers all response branches
if (res.ok) {
const parsed = userSchema.safeParse(await res.json());
if (!parsed.success) {
console.error("[SSR auth] Malformed /auth/me response:", parsed.error);
return { tag: "gateway_unavailable" };
}
if (parsed.data.needs_setup) {
return { tag: "needs_setup", user: parsed.data };
}
return { tag: "authenticated", user: parsed.data };
}
if (res.status === 401 || res.status === 403) {
return { tag: "unauthenticated" };
}
console.error(`[SSR auth] /api/v1/auth/me responded ${res.status}`);
return { tag: "gateway_unavailable" };
} catch (err) {
clearTimeout(timeout);
console.error("[SSR auth] Failed to reach gateway:", err);
return { tag: "gateway_unavailable" };
}
}

View File

@ -0,0 +1,72 @@
import { z } from "zod";
// ── User schema (single source of truth) ──────────────────────────
export const userSchema = z.object({
id: z.string(),
email: z.string().email(),
system_role: z.enum(["admin", "user"]),
needs_setup: z.boolean().optional().default(false),
});
export type User = z.infer<typeof userSchema>;
// ── SSR auth result (tagged union) ────────────────────────────────
export type AuthResult =
| { tag: "authenticated"; user: User }
| { tag: "needs_setup"; user: User }
| { tag: "unauthenticated" }
| { tag: "gateway_unavailable" }
| { tag: "config_error"; message: string };
export function assertNever(x: never): never {
throw new Error(`Unexpected auth result: ${JSON.stringify(x)}`);
}
export function buildLoginUrl(returnPath: string): string {
return `/login?next=${encodeURIComponent(returnPath)}`;
}
// ── Backend error response parsing ────────────────────────────────
const AUTH_ERROR_CODES = [
"invalid_credentials",
"token_expired",
"token_invalid",
"user_not_found",
"email_already_exists",
"provider_not_found",
"not_authenticated",
] as const;
export type AuthErrorCode = (typeof AUTH_ERROR_CODES)[number];
export interface AuthErrorResponse {
code: AuthErrorCode;
message: string;
}
const authErrorSchema = z.object({
code: z.enum(AUTH_ERROR_CODES),
message: z.string(),
});
export function parseAuthError(data: unknown): AuthErrorResponse {
// Try top-level {code, message} first
const parsed = authErrorSchema.safeParse(data);
if (parsed.success) return parsed.data;
// Unwrap FastAPI's {detail: {code, message}} envelope
if (typeof data === "object" && data !== null && "detail" in data) {
const detail = (data as Record<string, unknown>).detail;
const nested = authErrorSchema.safeParse(detail);
if (nested.success) return nested.data;
// Legacy string-detail responses
if (typeof detail === "string") {
return { code: "invalid_credentials", message: detail };
}
}
return { code: "invalid_credentials", message: "Authentication failed" };
}

View File

@ -7,12 +7,6 @@ export const env = createEnv({
* isn't built with invalid env vars.
*/
server: {
BETTER_AUTH_SECRET:
process.env.NODE_ENV === "production"
? z.string()
: z.string().optional(),
BETTER_AUTH_GITHUB_CLIENT_ID: z.string().optional(),
BETTER_AUTH_GITHUB_CLIENT_SECRET: z.string().optional(),
GITHUB_OAUTH_TOKEN: z.string().optional(),
NODE_ENV: z
.enum(["development", "test", "production"])
@ -35,10 +29,6 @@ export const env = createEnv({
* middlewares) or client-side so we need to destruct manually.
*/
runtimeEnv: {
BETTER_AUTH_SECRET: process.env.BETTER_AUTH_SECRET,
BETTER_AUTH_GITHUB_CLIENT_ID: process.env.BETTER_AUTH_GITHUB_CLIENT_ID,
BETTER_AUTH_GITHUB_CLIENT_SECRET:
process.env.BETTER_AUTH_GITHUB_CLIENT_SECRET,
NODE_ENV: process.env.NODE_ENV,
NEXT_PUBLIC_BACKEND_BASE_URL: process.env.NEXT_PUBLIC_BACKEND_BASE_URL,

View File

@ -1,5 +0,0 @@
import { createAuthClient } from "better-auth/react";
export const authClient = createAuthClient();
export type Session = typeof authClient.$Infer.Session;

View File

@ -1,9 +0,0 @@
import { betterAuth } from "better-auth";
export const auth = betterAuth({
emailAndPassword: {
enabled: true,
},
});
export type Session = typeof auth.$Infer.Session;

View File

@ -1 +0,0 @@
export { auth } from "./config";

View File

@ -1,8 +0,0 @@
import { headers } from "next/headers";
import { cache } from "react";
import { auth } from ".";
export const getSession = cache(async () =>
auth.api.getSession({ headers: await headers() }),
);