mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
feat(auth): wire auth end-to-end (middleware + frontend replacement)
Backend: - Port auth_middleware, csrf_middleware, langgraph_auth, routers/auth - Port authz decorator (owner_filter_key defaults to 'owner_id') - Merge app.py: register AuthMiddleware + CSRFMiddleware + CORS, add _ensure_admin_user lifespan hook, _migrate_orphaned_threads helper, register auth router - Merge deps.py: add get_local_provider, get_current_user_from_request, get_optional_user_from_request; keep get_current_user as thin str|None adapter for feedback router - langgraph.json: add auth path pointing to langgraph_auth.py:auth - Rename metadata['user_id'] -> metadata['owner_id'] in langgraph_auth (both metadata write and LangGraph filter dict) + test fixtures Frontend: - Delete better-auth library and api catch-all route - Remove better-auth npm dependency and env vars (BETTER_AUTH_SECRET, BETTER_AUTH_GITHUB_*) from env.js - Port frontend/src/core/auth/* (AuthProvider, gateway-config, proxy-policy, server-side getServerSideUser, types) - Port frontend/src/core/api/fetcher.ts - Port (auth)/layout, (auth)/login, (auth)/setup pages - Rewrite workspace/layout.tsx as server component that calls getServerSideUser and wraps in AuthProvider - Port workspace/workspace-content.tsx for the client-side sidebar logic Tests: - Port 5 auth test files (test_auth, test_auth_middleware, test_auth_type_system, test_ensure_admin, test_langgraph_auth) - 176 auth tests PASS After this commit: login/logout/registration flow works, but persistence layer does not yet filter by owner_id. Commit 4 closes that gap.
This commit is contained in:
parent
03c3b18565
commit
f942e4e597
@ -1,15 +1,21 @@
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.gateway.auth_middleware import AuthMiddleware
|
||||
from app.gateway.config import get_gateway_config
|
||||
from app.gateway.csrf_middleware import CSRFMiddleware
|
||||
from app.gateway.deps import langgraph_runtime
|
||||
from app.gateway.routers import (
|
||||
agents,
|
||||
artifacts,
|
||||
assistants_compat,
|
||||
auth,
|
||||
channels,
|
||||
feedback,
|
||||
mcp,
|
||||
@ -34,6 +40,92 @@ logging.basicConfig(
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _ensure_admin_user(app: FastAPI) -> None:
|
||||
"""Auto-create the admin user on first boot if no users exist.
|
||||
|
||||
Prints the generated password to stdout so the operator can log in.
|
||||
On subsequent boots, warns if any user still needs setup.
|
||||
|
||||
Multi-worker safe: relies on SQLite UNIQUE constraint to resolve races.
|
||||
Only the worker that successfully creates/updates the admin prints the
|
||||
password; losers silently skip.
|
||||
"""
|
||||
import secrets
|
||||
|
||||
from app.gateway.deps import get_local_provider
|
||||
|
||||
provider = get_local_provider()
|
||||
user_count = await provider.count_users()
|
||||
|
||||
if user_count == 0:
|
||||
password = secrets.token_urlsafe(16)
|
||||
try:
|
||||
admin = await provider.create_user(email="admin@deerflow.dev", password=password, system_role="admin", needs_setup=True)
|
||||
except ValueError:
|
||||
return # Another worker already created the admin.
|
||||
|
||||
# Migrate orphaned threads (no owner_id) to this admin
|
||||
store = getattr(app.state, "store", None)
|
||||
if store is not None:
|
||||
await _migrate_orphaned_threads(store, str(admin.id))
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info(" Admin account created on first boot")
|
||||
logger.info(" Email: %s", admin.email)
|
||||
logger.info(" Password: %s", password)
|
||||
logger.info(" Change it after login: Settings -> Account")
|
||||
logger.info("=" * 60)
|
||||
return
|
||||
|
||||
# Admin exists but setup never completed — reset password so operator
|
||||
# can always find it in the console without needing the CLI.
|
||||
# Multi-worker guard: if admin was created less than 30s ago, another
|
||||
# worker just created it and will print the password — skip reset.
|
||||
admin = await provider.get_user_by_email("admin@deerflow.dev")
|
||||
if admin and admin.needs_setup:
|
||||
import time
|
||||
|
||||
age = time.time() - admin.created_at.replace(tzinfo=UTC).timestamp()
|
||||
if age < 30:
|
||||
return # Just created by another worker in this startup; its password is still valid.
|
||||
|
||||
from app.gateway.auth.password import hash_password_async
|
||||
|
||||
password = secrets.token_urlsafe(16)
|
||||
admin.password_hash = await hash_password_async(password)
|
||||
admin.token_version += 1
|
||||
await provider.update_user(admin)
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info(" Admin account setup incomplete — password reset")
|
||||
logger.info(" Email: %s", admin.email)
|
||||
logger.info(" Password: %s", password)
|
||||
logger.info(" Change it after login: Settings -> Account")
|
||||
logger.info("=" * 60)
|
||||
|
||||
|
||||
async def _migrate_orphaned_threads(store, admin_user_id: str) -> None:
|
||||
"""Migrate threads with no owner_id to the given admin.
|
||||
|
||||
NOTE: This is the initial port. Commit 5 will replace the hardcoded
|
||||
limit=1000 with cursor pagination and extend to SQL persistence tables.
|
||||
"""
|
||||
try:
|
||||
migrated = 0
|
||||
results = await store.asearch(("threads",), limit=1000)
|
||||
for item in results:
|
||||
metadata = item.value.get("metadata", {})
|
||||
if not metadata.get("owner_id"):
|
||||
metadata["owner_id"] = admin_user_id
|
||||
item.value["metadata"] = metadata
|
||||
await store.aput(("threads",), item.key, item.value)
|
||||
migrated += 1
|
||||
if migrated:
|
||||
logger.info("Migrated %d orphaned thread(s) to admin", migrated)
|
||||
except Exception:
|
||||
logger.exception("Thread migration failed (non-fatal)")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Application lifespan handler."""
|
||||
@ -53,6 +145,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
async with langgraph_runtime(app):
|
||||
logger.info("LangGraph runtime initialised")
|
||||
|
||||
# Ensure admin user exists (auto-create on first boot)
|
||||
# Must run AFTER langgraph_runtime so app.state.store is available for thread migration
|
||||
await _ensure_admin_user(app)
|
||||
|
||||
# Start IM channel service if any channels are configured
|
||||
try:
|
||||
from app.channels.service import start_channel_service
|
||||
@ -164,7 +260,35 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
|
||||
],
|
||||
)
|
||||
|
||||
# CORS is handled by nginx - no need for FastAPI middleware
|
||||
# Auth: reject unauthenticated requests to non-public paths (fail-closed safety net)
|
||||
app.add_middleware(AuthMiddleware)
|
||||
|
||||
# CSRF: Double Submit Cookie pattern for state-changing requests
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
# CORS: when GATEWAY_CORS_ORIGINS is set (dev without nginx), add CORS middleware.
|
||||
# In production, nginx handles CORS and no middleware is needed.
|
||||
cors_origins_env = os.environ.get("GATEWAY_CORS_ORIGINS", "")
|
||||
if cors_origins_env:
|
||||
cors_origins = [o.strip() for o in cors_origins_env.split(",") if o.strip()]
|
||||
# Validate: wildcard origin with credentials is a security misconfiguration
|
||||
for origin in cors_origins:
|
||||
if origin == "*":
|
||||
logger.error(
|
||||
"GATEWAY_CORS_ORIGINS contains wildcard '*' with allow_credentials=True. "
|
||||
"This is a security misconfiguration — browsers will reject the response. "
|
||||
"Use explicit scheme://host:port origins instead."
|
||||
)
|
||||
cors_origins = [o for o in cors_origins if o != "*"]
|
||||
break
|
||||
if cors_origins:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include routers
|
||||
# Models API is mounted at /api/models
|
||||
@ -200,6 +324,9 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
|
||||
# Assistants compatibility API (LangGraph Platform stub)
|
||||
app.include_router(assistants_compat.router)
|
||||
|
||||
# Auth API is mounted at /api/v1/auth
|
||||
app.include_router(auth.router)
|
||||
|
||||
# Feedback API is mounted at /api/threads/{thread_id}/runs/{run_id}/feedback
|
||||
app.include_router(feedback.router)
|
||||
|
||||
|
||||
71
backend/app/gateway/auth_middleware.py
Normal file
71
backend/app/gateway/auth_middleware.py
Normal file
@ -0,0 +1,71 @@
|
||||
"""Global authentication middleware — fail-closed safety net.
|
||||
|
||||
Rejects unauthenticated requests to non-public paths with 401.
|
||||
Fine-grained permission checks remain in authz.py decorators.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from app.gateway.auth.errors import AuthErrorCode
|
||||
|
||||
# Paths that never require authentication.
|
||||
_PUBLIC_PATH_PREFIXES: tuple[str, ...] = (
|
||||
"/health",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
)
|
||||
|
||||
# Exact auth paths that are public (login/register/status check).
|
||||
# /api/v1/auth/me, /api/v1/auth/change-password etc. are NOT public.
|
||||
_PUBLIC_EXACT_PATHS: frozenset[str] = frozenset(
|
||||
{
|
||||
"/api/v1/auth/login/local",
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/logout",
|
||||
"/api/v1/auth/setup-status",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _is_public(path: str) -> bool:
|
||||
stripped = path.rstrip("/")
|
||||
if stripped in _PUBLIC_EXACT_PATHS:
|
||||
return True
|
||||
return any(path.startswith(prefix) for prefix in _PUBLIC_PATH_PREFIXES)
|
||||
|
||||
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
"""Coarse-grained auth gate: reject requests without a valid session cookie.
|
||||
|
||||
This does NOT verify JWT signature or user existence — that is the job of
|
||||
``get_current_user_from_request`` in deps.py (called by ``@require_auth``).
|
||||
The middleware only checks *presence* of the cookie so that new endpoints
|
||||
that forget ``@require_auth`` are not completely exposed.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(app)
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
if _is_public(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Non-public path: require session cookie
|
||||
if not request.cookies.get("access_token"):
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"detail": {
|
||||
"code": AuthErrorCode.NOT_AUTHENTICATED,
|
||||
"message": "Authentication required",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
261
backend/app/gateway/authz.py
Normal file
261
backend/app/gateway/authz.py
Normal file
@ -0,0 +1,261 @@
|
||||
"""Authorization decorators and context for DeerFlow.
|
||||
|
||||
Inspired by LangGraph Auth system: https://github.com/langchain-ai/langgraph/blob/main/libs/sdk-py/langgraph_sdk/auth/__init__.py
|
||||
|
||||
**Usage:**
|
||||
|
||||
1. Use ``@require_auth`` on routes that need authentication
|
||||
2. Use ``@require_permission("resource", "action", filter_key=...)`` for permission checks
|
||||
3. The decorator chain processes from bottom to top
|
||||
|
||||
**Example:**
|
||||
|
||||
@router.get("/{thread_id}")
|
||||
@require_auth
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def get_thread(thread_id: str, request: Request):
|
||||
# User is authenticated and has threads:read permission
|
||||
...
|
||||
|
||||
**Permission Model:**
|
||||
|
||||
- threads:read - View thread
|
||||
- threads:write - Create/update thread
|
||||
- threads:delete - Delete thread
|
||||
- runs:create - Run agent
|
||||
- runs:read - View run
|
||||
- runs:cancel - Cancel run
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.gateway.auth.models import User
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
# Permission constants
|
||||
class Permissions:
|
||||
"""Permission constants for resource:action format."""
|
||||
|
||||
# Threads
|
||||
THREADS_READ = "threads:read"
|
||||
THREADS_WRITE = "threads:write"
|
||||
THREADS_DELETE = "threads:delete"
|
||||
|
||||
# Runs
|
||||
RUNS_CREATE = "runs:create"
|
||||
RUNS_READ = "runs:read"
|
||||
RUNS_CANCEL = "runs:cancel"
|
||||
|
||||
|
||||
class AuthContext:
|
||||
"""Authentication context for the current request.
|
||||
|
||||
Stored in request.state.auth after require_auth decoration.
|
||||
|
||||
Attributes:
|
||||
user: The authenticated user, or None if anonymous
|
||||
permissions: List of permission strings (e.g., "threads:read")
|
||||
"""
|
||||
|
||||
__slots__ = ("user", "permissions")
|
||||
|
||||
def __init__(self, user: User | None = None, permissions: list[str] | None = None):
|
||||
self.user = user
|
||||
self.permissions = permissions or []
|
||||
|
||||
@property
|
||||
def is_authenticated(self) -> bool:
|
||||
"""Check if user is authenticated."""
|
||||
return self.user is not None
|
||||
|
||||
def has_permission(self, resource: str, action: str) -> bool:
|
||||
"""Check if context has permission for resource:action.
|
||||
|
||||
Args:
|
||||
resource: Resource name (e.g., "threads")
|
||||
action: Action name (e.g., "read")
|
||||
|
||||
Returns:
|
||||
True if user has permission
|
||||
"""
|
||||
permission = f"{resource}:{action}"
|
||||
return permission in self.permissions
|
||||
|
||||
def require_user(self) -> User:
|
||||
"""Get user or raise 401.
|
||||
|
||||
Raises:
|
||||
HTTPException 401 if not authenticated
|
||||
"""
|
||||
if not self.user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
return self.user
|
||||
|
||||
|
||||
def get_auth_context(request: Request) -> AuthContext | None:
|
||||
"""Get AuthContext from request state."""
|
||||
return getattr(request.state, "auth", None)
|
||||
|
||||
|
||||
_ALL_PERMISSIONS: list[str] = [
|
||||
Permissions.THREADS_READ,
|
||||
Permissions.THREADS_WRITE,
|
||||
Permissions.THREADS_DELETE,
|
||||
Permissions.RUNS_CREATE,
|
||||
Permissions.RUNS_READ,
|
||||
Permissions.RUNS_CANCEL,
|
||||
]
|
||||
|
||||
|
||||
async def _authenticate(request: Request) -> AuthContext:
|
||||
"""Authenticate request and return AuthContext.
|
||||
|
||||
Delegates to deps.get_optional_user_from_request() for the JWT→User pipeline.
|
||||
Returns AuthContext with user=None for anonymous requests.
|
||||
"""
|
||||
from app.gateway.deps import get_optional_user_from_request
|
||||
|
||||
user = await get_optional_user_from_request(request)
|
||||
if user is None:
|
||||
return AuthContext(user=None, permissions=[])
|
||||
|
||||
# In future, permissions could be stored in user record
|
||||
return AuthContext(user=user, permissions=_ALL_PERMISSIONS)
|
||||
|
||||
|
||||
def require_auth[**P, T](func: Callable[P, T]) -> Callable[P, T]:
|
||||
"""Decorator that authenticates the request and sets AuthContext.
|
||||
|
||||
Must be placed ABOVE other decorators (executes after them).
|
||||
|
||||
Usage:
|
||||
@router.get("/{thread_id}")
|
||||
@require_auth # Bottom decorator (executes first after permission check)
|
||||
@require_permission("threads", "read")
|
||||
async def get_thread(thread_id: str, request: Request):
|
||||
auth: AuthContext = request.state.auth
|
||||
...
|
||||
|
||||
Raises:
|
||||
ValueError: If 'request' parameter is missing
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
request = kwargs.get("request")
|
||||
if request is None:
|
||||
raise ValueError("require_auth decorator requires 'request' parameter")
|
||||
|
||||
# Authenticate and set context
|
||||
auth_context = await _authenticate(request)
|
||||
request.state.auth = auth_context
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_permission(
|
||||
resource: str,
|
||||
action: str,
|
||||
owner_check: bool = False,
|
||||
owner_filter_key: str = "owner_id",
|
||||
inject_record: bool = False,
|
||||
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
||||
"""Decorator that checks permission for resource:action.
|
||||
|
||||
Must be used AFTER @require_auth.
|
||||
|
||||
Args:
|
||||
resource: Resource name (e.g., "threads", "runs")
|
||||
action: Action name (e.g., "read", "write", "delete")
|
||||
owner_check: If True, validates that the current user owns the resource.
|
||||
Requires 'thread_id' path parameter and performs ownership check.
|
||||
owner_filter_key: Field name for ownership filter (default: "owner_id")
|
||||
inject_record: If True and owner_check is True, injects the thread record
|
||||
into kwargs['thread_record'] for use in the handler.
|
||||
|
||||
Usage:
|
||||
# Simple permission check
|
||||
@require_permission("threads", "read")
|
||||
async def get_thread(thread_id: str, request: Request):
|
||||
...
|
||||
|
||||
# With ownership check (for /threads/{thread_id} endpoints)
|
||||
@require_permission("threads", "delete", owner_check=True)
|
||||
async def delete_thread(thread_id: str, request: Request):
|
||||
...
|
||||
|
||||
# With ownership check and record injection
|
||||
@require_permission("threads", "delete", owner_check=True, inject_record=True)
|
||||
async def delete_thread(thread_id: str, request: Request, thread_record: dict = None):
|
||||
# thread_record is injected if found
|
||||
...
|
||||
|
||||
Raises:
|
||||
HTTPException 401: If authentication required but user is anonymous
|
||||
HTTPException 403: If user lacks permission
|
||||
HTTPException 404: If owner_check=True but user doesn't own the thread
|
||||
ValueError: If owner_check=True but 'thread_id' parameter is missing
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[P, T]) -> Callable[P, T]:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
request = kwargs.get("request")
|
||||
if request is None:
|
||||
raise ValueError("require_permission decorator requires 'request' parameter")
|
||||
|
||||
auth: AuthContext = getattr(request.state, "auth", None)
|
||||
if auth is None:
|
||||
auth = await _authenticate(request)
|
||||
request.state.auth = auth
|
||||
|
||||
if not auth.is_authenticated:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
# Check permission
|
||||
if not auth.has_permission(resource, action):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Permission denied: {resource}:{action}",
|
||||
)
|
||||
|
||||
# Owner check for thread-specific resources
|
||||
if owner_check:
|
||||
thread_id = kwargs.get("thread_id")
|
||||
if thread_id is None:
|
||||
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
||||
|
||||
# Get thread and verify ownership
|
||||
from app.gateway.routers.threads import _store_get, get_store
|
||||
|
||||
store = get_store(request)
|
||||
if store is not None:
|
||||
record = await _store_get(store, thread_id)
|
||||
if record:
|
||||
owner_id = record.get("metadata", {}).get(owner_filter_key)
|
||||
if owner_id and owner_id != str(auth.user.id):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Thread {thread_id} not found",
|
||||
)
|
||||
# Inject record if requested
|
||||
if inject_record:
|
||||
kwargs["thread_record"] = record
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
112
backend/app/gateway/csrf_middleware.py
Normal file
112
backend/app/gateway/csrf_middleware.py
Normal file
@ -0,0 +1,112 @@
|
||||
"""CSRF protection middleware for FastAPI.
|
||||
|
||||
Per RFC-001:
|
||||
State-changing operations require CSRF protection.
|
||||
"""
|
||||
|
||||
import secrets
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
CSRF_COOKIE_NAME = "csrf_token"
|
||||
CSRF_HEADER_NAME = "X-CSRF-Token"
|
||||
CSRF_TOKEN_LENGTH = 64 # bytes
|
||||
|
||||
|
||||
def is_secure_request(request: Request) -> bool:
|
||||
"""Detect whether the original client request was made over HTTPS."""
|
||||
return request.headers.get("x-forwarded-proto", request.url.scheme) == "https"
|
||||
|
||||
|
||||
def generate_csrf_token() -> str:
|
||||
"""Generate a secure random CSRF token."""
|
||||
return secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
|
||||
|
||||
|
||||
def should_check_csrf(request: Request) -> bool:
|
||||
"""Determine if a request needs CSRF validation.
|
||||
|
||||
CSRF is checked for state-changing methods (POST, PUT, DELETE, PATCH).
|
||||
GET, HEAD, OPTIONS, and TRACE are exempt per RFC 7231.
|
||||
"""
|
||||
if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
|
||||
return False
|
||||
|
||||
path = request.url.path.rstrip("/")
|
||||
# Exempt /api/v1/auth/me endpoint
|
||||
if path == "/api/v1/auth/me":
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
_AUTH_EXEMPT_PATHS: frozenset[str] = frozenset(
|
||||
{
|
||||
"/api/v1/auth/login/local",
|
||||
"/api/v1/auth/logout",
|
||||
"/api/v1/auth/register",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def is_auth_endpoint(request: Request) -> bool:
|
||||
"""Check if the request is to an auth endpoint.
|
||||
|
||||
Auth endpoints don't need CSRF validation on first call (no token).
|
||||
"""
|
||||
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
|
||||
|
||||
|
||||
class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware that implements CSRF protection using Double Submit Cookie pattern."""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(app)
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
_is_auth = is_auth_endpoint(request)
|
||||
|
||||
if should_check_csrf(request) and not _is_auth:
|
||||
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
|
||||
header_token = request.headers.get(CSRF_HEADER_NAME)
|
||||
|
||||
if not cookie_token or not header_token:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "CSRF token missing. Include X-CSRF-Token header."},
|
||||
)
|
||||
|
||||
if not secrets.compare_digest(cookie_token, header_token):
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "CSRF token mismatch."},
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
# For auth endpoints that set up session, also set CSRF cookie
|
||||
if _is_auth and request.method == "POST":
|
||||
# Generate a new CSRF token for the session
|
||||
csrf_token = generate_csrf_token()
|
||||
is_https = is_secure_request(request)
|
||||
response.set_cookie(
|
||||
key=CSRF_COOKIE_NAME,
|
||||
value=csrf_token,
|
||||
httponly=False, # Must be JS-readable for Double Submit Cookie pattern
|
||||
secure=is_https,
|
||||
samesite="strict",
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def get_csrf_token(request: Request) -> str | None:
|
||||
"""Get the CSRF token from the current request's cookies.
|
||||
|
||||
This is useful for server-side rendering where you need to embed
|
||||
token in forms or headers.
|
||||
"""
|
||||
return request.cookies.get(CSRF_COOKIE_NAME)
|
||||
@ -11,11 +11,16 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
|
||||
from deerflow.runtime import RunContext, RunManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
@ -127,10 +132,86 @@ def get_run_context(request: Request) -> RunContext:
|
||||
)
|
||||
|
||||
|
||||
async def get_current_user(request: Request) -> str | None:
|
||||
"""Extract user identity from request.
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth helpers (used by authz.py and auth middleware)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Phase 2: always returns None (no authentication).
|
||||
Phase 3: extract user_id from JWT / session / API key header.
|
||||
# Cached singletons to avoid repeated instantiation per request
|
||||
_cached_local_provider: LocalAuthProvider | None = None
|
||||
_cached_repo: SQLiteUserRepository | None = None
|
||||
|
||||
|
||||
def get_local_provider() -> LocalAuthProvider:
|
||||
"""Get or create the cached LocalAuthProvider singleton."""
|
||||
global _cached_local_provider, _cached_repo
|
||||
if _cached_repo is None:
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
|
||||
_cached_repo = SQLiteUserRepository()
|
||||
if _cached_local_provider is None:
|
||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||
|
||||
_cached_local_provider = LocalAuthProvider(repository=_cached_repo)
|
||||
return _cached_local_provider
|
||||
|
||||
|
||||
async def get_current_user_from_request(request: Request):
|
||||
"""Get the current authenticated user from the request cookie.
|
||||
|
||||
Raises HTTPException 401 if not authenticated.
|
||||
"""
|
||||
return None
|
||||
from app.gateway.auth import decode_token
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
|
||||
|
||||
access_token = request.cookies.get("access_token")
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.NOT_AUTHENTICATED, message="Not authenticated").model_dump(),
|
||||
)
|
||||
|
||||
payload = decode_token(access_token)
|
||||
if isinstance(payload, TokenError):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=AuthErrorResponse(code=token_error_to_code(payload), message=f"Token error: {payload.value}").model_dump(),
|
||||
)
|
||||
|
||||
provider = get_local_provider()
|
||||
user = await provider.get_user(payload.sub)
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.USER_NOT_FOUND, message="User not found").model_dump(),
|
||||
)
|
||||
|
||||
# Token version mismatch → password was changed, token is stale
|
||||
if user.token_version != payload.ver:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.TOKEN_INVALID, message="Token revoked (password changed)").model_dump(),
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_optional_user_from_request(request: Request):
|
||||
"""Get optional authenticated user from request.
|
||||
|
||||
Returns None if not authenticated.
|
||||
"""
|
||||
try:
|
||||
return await get_current_user_from_request(request)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
|
||||
async def get_current_user(request: Request) -> str | None:
|
||||
"""Extract user_id from request cookie, or None if not authenticated.
|
||||
|
||||
Thin adapter that returns the string id for callers that only need
|
||||
identification (e.g., ``feedback.py``). Full-user callers should use
|
||||
``get_current_user_from_request`` or ``get_optional_user_from_request``.
|
||||
"""
|
||||
user = await get_optional_user_from_request(request)
|
||||
return str(user.id) if user else None
|
||||
|
||||
106
backend/app/gateway/langgraph_auth.py
Normal file
106
backend/app/gateway/langgraph_auth.py
Normal file
@ -0,0 +1,106 @@
|
||||
"""LangGraph Server auth handler — shares JWT logic with Gateway.
|
||||
|
||||
Loaded by LangGraph Server via langgraph.json ``auth.path``.
|
||||
Reuses the same ``decode_token`` / ``get_auth_config`` as Gateway,
|
||||
so both modes validate tokens with the same secret and rules.
|
||||
|
||||
Two layers:
|
||||
1. @auth.authenticate — validates JWT cookie, extracts user_id,
|
||||
and enforces CSRF on state-changing methods (POST/PUT/DELETE/PATCH)
|
||||
2. @auth.on — returns metadata filter so each user only sees own threads
|
||||
"""
|
||||
|
||||
import secrets
|
||||
|
||||
from langgraph_sdk import Auth
|
||||
|
||||
from app.gateway.auth.errors import TokenError
|
||||
from app.gateway.auth.jwt import decode_token
|
||||
from app.gateway.deps import get_local_provider
|
||||
|
||||
auth = Auth()
|
||||
|
||||
# Methods that require CSRF validation (state-changing per RFC 7231).
|
||||
_CSRF_METHODS = frozenset({"POST", "PUT", "DELETE", "PATCH"})
|
||||
|
||||
|
||||
def _check_csrf(request) -> None:
|
||||
"""Enforce Double Submit Cookie CSRF check for state-changing requests.
|
||||
|
||||
Mirrors Gateway's CSRFMiddleware logic so that LangGraph routes
|
||||
proxied directly by nginx have the same CSRF protection.
|
||||
"""
|
||||
method = getattr(request, "method", "") or ""
|
||||
if method.upper() not in _CSRF_METHODS:
|
||||
return
|
||||
|
||||
cookie_token = request.cookies.get("csrf_token")
|
||||
header_token = request.headers.get("x-csrf-token")
|
||||
|
||||
if not cookie_token or not header_token:
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=403,
|
||||
detail="CSRF token missing. Include X-CSRF-Token header.",
|
||||
)
|
||||
|
||||
if not secrets.compare_digest(cookie_token, header_token):
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=403,
|
||||
detail="CSRF token mismatch.",
|
||||
)
|
||||
|
||||
|
||||
@auth.authenticate
|
||||
async def authenticate(request):
|
||||
"""Validate the session cookie, decode JWT, and check token_version.
|
||||
|
||||
Same validation chain as Gateway's get_current_user_from_request:
|
||||
cookie → decode JWT → DB lookup → token_version match
|
||||
Also enforces CSRF on state-changing methods.
|
||||
"""
|
||||
# CSRF check before authentication so forged cross-site requests
|
||||
# are rejected early, even if the cookie carries a valid JWT.
|
||||
_check_csrf(request)
|
||||
|
||||
token = request.cookies.get("access_token")
|
||||
if not token:
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=401,
|
||||
detail="Not authenticated",
|
||||
)
|
||||
|
||||
payload = decode_token(token)
|
||||
if isinstance(payload, TokenError):
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Token error: {payload.value}",
|
||||
)
|
||||
|
||||
user = await get_local_provider().get_user(payload.sub)
|
||||
if user is None:
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=401,
|
||||
detail="User not found",
|
||||
)
|
||||
if user.token_version != payload.ver:
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=401,
|
||||
detail="Token revoked (password changed)",
|
||||
)
|
||||
|
||||
return payload.sub
|
||||
|
||||
|
||||
@auth.on
|
||||
async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict):
|
||||
"""Inject owner_id metadata on writes; filter by owner_id on reads.
|
||||
|
||||
Gateway stores thread ownership as ``metadata.owner_id``.
|
||||
This handler ensures LangGraph Server enforces the same isolation.
|
||||
"""
|
||||
# On create/update: stamp owner_id into metadata
|
||||
metadata = value.setdefault("metadata", {})
|
||||
metadata["owner_id"] = ctx.user.identity
|
||||
|
||||
# Return filter dict — LangGraph applies it to search/read/delete
|
||||
return {"owner_id": ctx.user.identity}
|
||||
303
backend/app/gateway/routers/auth.py
Normal file
303
backend/app/gateway/routers/auth.py
Normal file
@ -0,0 +1,303 @@
|
||||
"""Authentication endpoints."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
from app.gateway.auth import (
|
||||
UserResponse,
|
||||
create_access_token,
|
||||
)
|
||||
from app.gateway.auth.config import get_auth_config
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
|
||||
from app.gateway.csrf_middleware import is_secure_request
|
||||
from app.gateway.deps import get_current_user_from_request, get_local_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
||||
|
||||
|
||||
# ── Request/Response Models ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
"""Response model for login — token only lives in HttpOnly cookie."""
|
||||
|
||||
expires_in: int # seconds
|
||||
needs_setup: bool = False
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
"""Request model for user registration."""
|
||||
|
||||
email: EmailStr
|
||||
password: str = Field(..., min_length=8)
|
||||
|
||||
|
||||
class ChangePasswordRequest(BaseModel):
|
||||
"""Request model for password change (also handles setup flow)."""
|
||||
|
||||
current_password: str
|
||||
new_password: str = Field(..., min_length=8)
|
||||
new_email: EmailStr | None = None
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
"""Generic message response."""
|
||||
|
||||
message: str
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _set_session_cookie(response: Response, token: str, request: Request) -> None:
|
||||
"""Set the access_token HttpOnly cookie on the response."""
|
||||
config = get_auth_config()
|
||||
is_https = is_secure_request(request)
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=is_https,
|
||||
samesite="lax",
|
||||
max_age=config.token_expiry_days * 24 * 3600 if is_https else None,
|
||||
)
|
||||
|
||||
|
||||
# ── Rate Limiting ────────────────────────────────────────────────────────
|
||||
# In-process dict — not shared across workers. Sufficient for single-worker deployments.
|
||||
|
||||
_MAX_LOGIN_ATTEMPTS = 5
|
||||
_LOCKOUT_SECONDS = 300 # 5 minutes
|
||||
|
||||
# ip → (fail_count, lock_until_timestamp)
|
||||
_login_attempts: dict[str, tuple[int, float]] = {}
|
||||
|
||||
|
||||
def _get_client_ip(request: Request) -> str:
|
||||
"""Extract the real client IP for rate limiting.
|
||||
|
||||
Uses ``X-Real-IP`` header set by nginx (``proxy_set_header X-Real-IP
|
||||
$remote_addr``). Nginx unconditionally overwrites any client-supplied
|
||||
``X-Real-IP``, so the value seen by Gateway is always the TCP peer IP
|
||||
that nginx observed — it cannot be spoofed by the client.
|
||||
|
||||
``request.client.host`` is NOT reliable because uvicorn's default
|
||||
``proxy_headers=True`` replaces it with the *first* entry from
|
||||
``X-Forwarded-For``, which IS client-spoofable.
|
||||
|
||||
``X-Forwarded-For`` is intentionally NOT used for the same reason.
|
||||
"""
|
||||
real_ip = request.headers.get("x-real-ip", "").strip()
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
# Fallback: direct connection without nginx (e.g. unit tests, dev).
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
|
||||
def _check_rate_limit(ip: str) -> None:
|
||||
"""Raise 429 if the IP is currently locked out."""
|
||||
record = _login_attempts.get(ip)
|
||||
if record is None:
|
||||
return
|
||||
fail_count, lock_until = record
|
||||
if fail_count >= _MAX_LOGIN_ATTEMPTS:
|
||||
if time.time() < lock_until:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="Too many login attempts. Try again later.",
|
||||
)
|
||||
del _login_attempts[ip]
|
||||
|
||||
|
||||
_MAX_TRACKED_IPS = 10000
|
||||
|
||||
|
||||
def _record_login_failure(ip: str) -> None:
|
||||
"""Record a failed login attempt for the given IP."""
|
||||
# Evict expired lockouts when dict grows too large
|
||||
if len(_login_attempts) >= _MAX_TRACKED_IPS:
|
||||
now = time.time()
|
||||
expired = [k for k, (c, t) in _login_attempts.items() if c >= _MAX_LOGIN_ATTEMPTS and now >= t]
|
||||
for k in expired:
|
||||
del _login_attempts[k]
|
||||
# If still too large, evict cheapest-to-lose half: below-threshold
|
||||
# IPs (lock_until=0.0) sort first, then earliest-expiring lockouts.
|
||||
if len(_login_attempts) >= _MAX_TRACKED_IPS:
|
||||
by_time = sorted(_login_attempts.items(), key=lambda kv: kv[1][1])
|
||||
for k, _ in by_time[: len(by_time) // 2]:
|
||||
del _login_attempts[k]
|
||||
|
||||
record = _login_attempts.get(ip)
|
||||
if record is None:
|
||||
_login_attempts[ip] = (1, 0.0)
|
||||
else:
|
||||
new_count = record[0] + 1
|
||||
lock_until = time.time() + _LOCKOUT_SECONDS if new_count >= _MAX_LOGIN_ATTEMPTS else 0.0
|
||||
_login_attempts[ip] = (new_count, lock_until)
|
||||
|
||||
|
||||
def _record_login_success(ip: str) -> None:
|
||||
"""Clear failure counter for the given IP on successful login."""
|
||||
_login_attempts.pop(ip, None)
|
||||
|
||||
|
||||
# ── Endpoints ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/login/local", response_model=LoginResponse)
|
||||
async def login_local(
|
||||
request: Request,
|
||||
response: Response,
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
):
|
||||
"""Local email/password login."""
|
||||
client_ip = _get_client_ip(request)
|
||||
_check_rate_limit(client_ip)
|
||||
|
||||
user = await get_local_provider().authenticate({"email": form_data.username, "password": form_data.password})
|
||||
|
||||
if user is None:
|
||||
_record_login_failure(client_ip)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Incorrect email or password").model_dump(),
|
||||
)
|
||||
|
||||
_record_login_success(client_ip)
|
||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||
_set_session_cookie(response, token, request)
|
||||
|
||||
return LoginResponse(
|
||||
expires_in=get_auth_config().token_expiry_days * 24 * 3600,
|
||||
needs_setup=user.needs_setup,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def register(request: Request, response: Response, body: RegisterRequest):
|
||||
"""Register a new user account (always 'user' role).
|
||||
|
||||
Admin is auto-created on first boot. This endpoint creates regular users.
|
||||
Auto-login by setting the session cookie.
|
||||
"""
|
||||
try:
|
||||
user = await get_local_provider().create_user(email=body.email, password=body.password, system_role="user")
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already registered").model_dump(),
|
||||
)
|
||||
|
||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||
_set_session_cookie(response, token, request)
|
||||
|
||||
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
|
||||
|
||||
|
||||
@router.post("/logout", response_model=MessageResponse)
|
||||
async def logout(request: Request, response: Response):
|
||||
"""Logout current user by clearing the cookie."""
|
||||
response.delete_cookie(key="access_token", secure=is_secure_request(request), samesite="lax")
|
||||
return MessageResponse(message="Successfully logged out")
|
||||
|
||||
|
||||
@router.post("/change-password", response_model=MessageResponse)
|
||||
async def change_password(request: Request, response: Response, body: ChangePasswordRequest):
|
||||
"""Change password for the currently authenticated user.
|
||||
|
||||
Also handles the first-boot setup flow:
|
||||
- If new_email is provided, updates email (checks uniqueness)
|
||||
- If user.needs_setup is True and new_email is given, clears needs_setup
|
||||
- Always increments token_version to invalidate old sessions
|
||||
- Re-issues session cookie with new token_version
|
||||
"""
|
||||
from app.gateway.auth.password import hash_password_async, verify_password_async
|
||||
|
||||
user = await get_current_user_from_request(request)
|
||||
|
||||
if user.password_hash is None:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump())
|
||||
|
||||
if not await verify_password_async(body.current_password, user.password_hash):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Current password is incorrect").model_dump())
|
||||
|
||||
provider = get_local_provider()
|
||||
|
||||
# Update email if provided
|
||||
if body.new_email is not None:
|
||||
existing = await provider.get_user_by_email(body.new_email)
|
||||
if existing and str(existing.id) != str(user.id):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already in use").model_dump())
|
||||
user.email = body.new_email
|
||||
|
||||
# Update password + bump version
|
||||
user.password_hash = await hash_password_async(body.new_password)
|
||||
user.token_version += 1
|
||||
|
||||
# Clear setup flag if this is the setup flow
|
||||
if user.needs_setup and body.new_email is not None:
|
||||
user.needs_setup = False
|
||||
|
||||
await provider.update_user(user)
|
||||
|
||||
# Re-issue cookie with new token_version
|
||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||
_set_session_cookie(response, token, request)
|
||||
|
||||
return MessageResponse(message="Password changed successfully")
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_me(request: Request):
|
||||
"""Get current authenticated user info."""
|
||||
user = await get_current_user_from_request(request)
|
||||
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role, needs_setup=user.needs_setup)
|
||||
|
||||
|
||||
@router.get("/setup-status")
|
||||
async def setup_status():
|
||||
"""Check if admin account exists. Always False after first boot."""
|
||||
user_count = await get_local_provider().count_users()
|
||||
return {"needs_setup": user_count == 0}
|
||||
|
||||
|
||||
# ── OAuth Endpoints (Future/Placeholder) ─────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/oauth/{provider}")
|
||||
async def oauth_login(provider: str):
|
||||
"""Initiate OAuth login flow.
|
||||
|
||||
Redirects to the OAuth provider's authorization URL.
|
||||
Currently a placeholder - requires OAuth provider implementation.
|
||||
"""
|
||||
if provider not in ["github", "google"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unsupported OAuth provider: {provider}",
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="OAuth login not yet implemented",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/callback/{provider}")
|
||||
async def oauth_callback(provider: str, code: str, state: str):
|
||||
"""OAuth callback endpoint.
|
||||
|
||||
Handles the OAuth provider's callback after user authorization.
|
||||
Currently a placeholder.
|
||||
"""
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="OAuth callback not yet implemented",
|
||||
)
|
||||
@ -8,6 +8,9 @@
|
||||
"graphs": {
|
||||
"lead_agent": "deerflow.agents:make_lead_agent"
|
||||
},
|
||||
"auth": {
|
||||
"path": "./app/gateway/langgraph_auth.py:auth"
|
||||
},
|
||||
"checkpointer": {
|
||||
"path": "./packages/harness/deerflow/agents/checkpointer/async_provider.py:make_checkpointer"
|
||||
}
|
||||
|
||||
506
backend/tests/test_auth.py
Normal file
506
backend/tests/test_auth.py
Normal file
@ -0,0 +1,506 @@
|
||||
"""Tests for authentication module: JWT, password hashing, AuthContext, and authz decorators."""
|
||||
|
||||
from datetime import timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.auth import create_access_token, decode_token, hash_password, verify_password
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.authz import (
|
||||
AuthContext,
|
||||
Permissions,
|
||||
get_auth_context,
|
||||
require_auth,
|
||||
require_permission,
|
||||
)
|
||||
|
||||
# ── Password Hashing ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_hash_password_and_verify():
|
||||
"""Hashing and verification round-trip."""
|
||||
password = "s3cr3tP@ssw0rd!"
|
||||
hashed = hash_password(password)
|
||||
assert hashed != password
|
||||
assert verify_password(password, hashed) is True
|
||||
assert verify_password("wrongpassword", hashed) is False
|
||||
|
||||
|
||||
def test_hash_password_different_each_time():
|
||||
"""bcrypt generates unique salts, so same password has different hashes."""
|
||||
password = "testpassword"
|
||||
h1 = hash_password(password)
|
||||
h2 = hash_password(password)
|
||||
assert h1 != h2 # Different salts
|
||||
# But both verify correctly
|
||||
assert verify_password(password, h1) is True
|
||||
assert verify_password(password, h2) is True
|
||||
|
||||
|
||||
def test_verify_password_rejects_empty():
|
||||
"""Empty password should not verify."""
|
||||
hashed = hash_password("nonempty")
|
||||
assert verify_password("", hashed) is False
|
||||
|
||||
|
||||
# ── JWT ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_create_and_decode_token():
|
||||
"""JWT creation and decoding round-trip."""
|
||||
user_id = str(uuid4())
|
||||
# Set a valid JWT secret for this test
|
||||
import os
|
||||
|
||||
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||
token = create_access_token(user_id)
|
||||
assert isinstance(token, str)
|
||||
|
||||
payload = decode_token(token)
|
||||
assert payload is not None
|
||||
assert payload.sub == user_id
|
||||
|
||||
|
||||
def test_decode_token_expired():
|
||||
"""Expired token returns TokenError.EXPIRED."""
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
user_id = str(uuid4())
|
||||
# Create token that expires immediately
|
||||
token = create_access_token(user_id, expires_delta=timedelta(seconds=-1))
|
||||
payload = decode_token(token)
|
||||
assert payload == TokenError.EXPIRED
|
||||
|
||||
|
||||
def test_decode_token_invalid():
|
||||
"""Invalid token returns TokenError."""
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
assert isinstance(decode_token("not.a.valid.token"), TokenError)
|
||||
assert isinstance(decode_token(""), TokenError)
|
||||
assert isinstance(decode_token("completely-wrong"), TokenError)
|
||||
|
||||
|
||||
def test_create_token_custom_expiry():
|
||||
"""Custom expiry is respected."""
|
||||
user_id = str(uuid4())
|
||||
token = create_access_token(user_id, expires_delta=timedelta(hours=1))
|
||||
payload = decode_token(token)
|
||||
assert payload is not None
|
||||
assert payload.sub == user_id
|
||||
|
||||
|
||||
# ── AuthContext ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_auth_context_unauthenticated():
|
||||
"""AuthContext with no user."""
|
||||
ctx = AuthContext(user=None, permissions=[])
|
||||
assert ctx.is_authenticated is False
|
||||
assert ctx.has_permission("threads", "read") is False
|
||||
|
||||
|
||||
def test_auth_context_authenticated_no_perms():
|
||||
"""AuthContext with user but no permissions."""
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
ctx = AuthContext(user=user, permissions=[])
|
||||
assert ctx.is_authenticated is True
|
||||
assert ctx.has_permission("threads", "read") is False
|
||||
|
||||
|
||||
def test_auth_context_has_permission():
|
||||
"""AuthContext permission checking."""
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
perms = [Permissions.THREADS_READ, Permissions.THREADS_WRITE]
|
||||
ctx = AuthContext(user=user, permissions=perms)
|
||||
assert ctx.has_permission("threads", "read") is True
|
||||
assert ctx.has_permission("threads", "write") is True
|
||||
assert ctx.has_permission("threads", "delete") is False
|
||||
assert ctx.has_permission("runs", "read") is False
|
||||
|
||||
|
||||
def test_auth_context_require_user_raises():
|
||||
"""require_user raises 401 when not authenticated."""
|
||||
ctx = AuthContext(user=None, permissions=[])
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
ctx.require_user()
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
def test_auth_context_require_user_returns_user():
|
||||
"""require_user returns user when authenticated."""
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
ctx = AuthContext(user=user, permissions=[])
|
||||
returned = ctx.require_user()
|
||||
assert returned == user
|
||||
|
||||
|
||||
# ── get_auth_context helper ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_get_auth_context_not_set():
|
||||
"""get_auth_context returns None when auth not set on request."""
|
||||
mock_request = MagicMock()
|
||||
# Make getattr return None (simulating attribute not set)
|
||||
mock_request.state = MagicMock()
|
||||
del mock_request.state.auth
|
||||
assert get_auth_context(mock_request) is None
|
||||
|
||||
|
||||
def test_get_auth_context_set():
|
||||
"""get_auth_context returns the AuthContext from request."""
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
ctx = AuthContext(user=user, permissions=[Permissions.THREADS_READ])
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.auth = ctx
|
||||
|
||||
assert get_auth_context(mock_request) == ctx
|
||||
|
||||
|
||||
# ── require_auth decorator ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_require_auth_sets_auth_context():
|
||||
"""require_auth sets auth context on request from cookie."""
|
||||
from fastapi import Request
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test")
|
||||
@require_auth
|
||||
async def endpoint(request: Request):
|
||||
ctx = get_auth_context(request)
|
||||
return {"authenticated": ctx.is_authenticated}
|
||||
|
||||
with TestClient(app) as client:
|
||||
# No cookie → anonymous
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["authenticated"] is False
|
||||
|
||||
|
||||
def test_require_auth_requires_request_param():
|
||||
"""require_auth raises ValueError if request parameter is missing."""
|
||||
import asyncio
|
||||
|
||||
@require_auth
|
||||
async def bad_endpoint(): # Missing `request` parameter
|
||||
pass
|
||||
|
||||
with pytest.raises(ValueError, match="require_auth decorator requires 'request' parameter"):
|
||||
asyncio.run(bad_endpoint())
|
||||
|
||||
|
||||
# ── require_permission decorator ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_require_permission_requires_auth():
|
||||
"""require_permission raises 401 when not authenticated."""
|
||||
from fastapi import Request
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test")
|
||||
@require_permission("threads", "read")
|
||||
async def endpoint(request: Request):
|
||||
return {"ok": True}
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Authentication required" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_require_permission_denies_wrong_permission():
|
||||
"""User without required permission gets 403."""
|
||||
from fastapi import Request
|
||||
|
||||
app = FastAPI()
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
|
||||
@app.get("/test")
|
||||
@require_permission("threads", "delete")
|
||||
async def endpoint(request: Request):
|
||||
return {"ok": True}
|
||||
|
||||
mock_auth = AuthContext(user=user, permissions=[Permissions.THREADS_READ])
|
||||
|
||||
with patch("app.gateway.authz._authenticate", return_value=mock_auth):
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 403
|
||||
assert "Permission denied" in response.json()["detail"]
|
||||
|
||||
|
||||
# ── Weak JWT secret warning ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
# ── User Model Fields ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_user_model_has_needs_setup_default_false():
|
||||
"""New users default to needs_setup=False."""
|
||||
user = User(email="test@example.com", password_hash="hash")
|
||||
assert user.needs_setup is False
|
||||
|
||||
|
||||
def test_user_model_has_token_version_default_zero():
|
||||
"""New users default to token_version=0."""
|
||||
user = User(email="test@example.com", password_hash="hash")
|
||||
assert user.token_version == 0
|
||||
|
||||
|
||||
def test_user_model_needs_setup_true():
|
||||
"""Auto-created admin has needs_setup=True."""
|
||||
user = User(email="admin@example.com", password_hash="hash", needs_setup=True)
|
||||
assert user.needs_setup is True
|
||||
|
||||
|
||||
def test_sqlite_round_trip_new_fields():
|
||||
"""needs_setup and token_version survive create → read round-trip."""
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from app.gateway.auth.repositories import sqlite as sqlite_mod
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = os.path.join(tmpdir, "test_users.db")
|
||||
old_path = sqlite_mod._resolved_db_path
|
||||
old_init = sqlite_mod._table_initialized
|
||||
sqlite_mod._resolved_db_path = Path(db_path)
|
||||
sqlite_mod._table_initialized = False
|
||||
try:
|
||||
repo = sqlite_mod.SQLiteUserRepository()
|
||||
user = User(
|
||||
email="setup@test.com",
|
||||
password_hash="fakehash",
|
||||
system_role="admin",
|
||||
needs_setup=True,
|
||||
token_version=3,
|
||||
)
|
||||
created = asyncio.run(repo.create_user(user))
|
||||
assert created.needs_setup is True
|
||||
assert created.token_version == 3
|
||||
|
||||
fetched = asyncio.run(repo.get_user_by_email("setup@test.com"))
|
||||
assert fetched is not None
|
||||
assert fetched.needs_setup is True
|
||||
assert fetched.token_version == 3
|
||||
|
||||
fetched.needs_setup = False
|
||||
fetched.token_version = 4
|
||||
asyncio.run(repo.update_user(fetched))
|
||||
refetched = asyncio.run(repo.get_user_by_id(str(fetched.id)))
|
||||
assert refetched.needs_setup is False
|
||||
assert refetched.token_version == 4
|
||||
finally:
|
||||
sqlite_mod._resolved_db_path = old_path
|
||||
sqlite_mod._table_initialized = old_init
|
||||
|
||||
|
||||
# ── Token Versioning ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_jwt_encodes_ver():
|
||||
"""JWT payload includes ver field."""
|
||||
import os
|
||||
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||
token = create_access_token(str(uuid4()), token_version=3)
|
||||
payload = decode_token(token)
|
||||
assert not isinstance(payload, TokenError)
|
||||
assert payload.ver == 3
|
||||
|
||||
|
||||
def test_jwt_default_ver_zero():
|
||||
"""JWT ver defaults to 0."""
|
||||
import os
|
||||
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||
token = create_access_token(str(uuid4()))
|
||||
payload = decode_token(token)
|
||||
assert not isinstance(payload, TokenError)
|
||||
assert payload.ver == 0
|
||||
|
||||
|
||||
def test_token_version_mismatch_rejects():
|
||||
"""Token with stale ver is rejected by get_current_user_from_request."""
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||
|
||||
user_id = str(uuid4())
|
||||
token = create_access_token(user_id, token_version=0)
|
||||
|
||||
mock_user = User(id=user_id, email="test@example.com", password_hash="hash", token_version=1)
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.cookies = {"access_token": token}
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider") as mock_provider_fn:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_user = AsyncMock(return_value=mock_user)
|
||||
mock_provider_fn.return_value = mock_provider
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "revoked" in str(exc_info.value.detail).lower()
|
||||
|
||||
|
||||
# ── change-password extension ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_change_password_request_accepts_new_email():
|
||||
"""ChangePasswordRequest model accepts optional new_email."""
|
||||
from app.gateway.routers.auth import ChangePasswordRequest
|
||||
|
||||
req = ChangePasswordRequest(
|
||||
current_password="old",
|
||||
new_password="newpassword",
|
||||
new_email="new@example.com",
|
||||
)
|
||||
assert req.new_email == "new@example.com"
|
||||
|
||||
|
||||
def test_change_password_request_new_email_optional():
|
||||
"""ChangePasswordRequest model works without new_email."""
|
||||
from app.gateway.routers.auth import ChangePasswordRequest
|
||||
|
||||
req = ChangePasswordRequest(current_password="old", new_password="newpassword")
|
||||
assert req.new_email is None
|
||||
|
||||
|
||||
def test_login_response_includes_needs_setup():
|
||||
"""LoginResponse includes needs_setup field."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
|
||||
resp = LoginResponse(expires_in=3600, needs_setup=True)
|
||||
assert resp.needs_setup is True
|
||||
resp2 = LoginResponse(expires_in=3600)
|
||||
assert resp2.needs_setup is False
|
||||
|
||||
|
||||
# ── Rate Limiting ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_rate_limiter_allows_under_limit():
|
||||
"""Requests under the limit are allowed."""
|
||||
from app.gateway.routers.auth import _check_rate_limit, _login_attempts
|
||||
|
||||
_login_attempts.clear()
|
||||
_check_rate_limit("192.168.1.1") # Should not raise
|
||||
|
||||
|
||||
def test_rate_limiter_blocks_after_max_failures():
|
||||
"""IP is blocked after 5 consecutive failures."""
|
||||
from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure
|
||||
|
||||
_login_attempts.clear()
|
||||
ip = "10.0.0.1"
|
||||
for _ in range(5):
|
||||
_record_login_failure(ip)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
_check_rate_limit(ip)
|
||||
assert exc_info.value.status_code == 429
|
||||
|
||||
|
||||
def test_rate_limiter_resets_on_success():
|
||||
"""Successful login clears the failure counter."""
|
||||
from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure, _record_login_success
|
||||
|
||||
_login_attempts.clear()
|
||||
ip = "10.0.0.2"
|
||||
for _ in range(4):
|
||||
_record_login_failure(ip)
|
||||
_record_login_success(ip)
|
||||
_check_rate_limit(ip) # Should not raise
|
||||
|
||||
|
||||
# ── Client IP extraction ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_get_client_ip_direct_connection():
|
||||
"""Without nginx (no X-Real-IP), falls back to request.client.host."""
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "203.0.113.42"
|
||||
req.headers = {}
|
||||
assert _get_client_ip(req) == "203.0.113.42"
|
||||
|
||||
|
||||
def test_get_client_ip_uses_x_real_ip():
|
||||
"""X-Real-IP (set by nginx) is used when present."""
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "10.0.0.1" # uvicorn may have replaced this with XFF[0]
|
||||
req.headers = {"x-real-ip": "203.0.113.42"}
|
||||
assert _get_client_ip(req) == "203.0.113.42"
|
||||
|
||||
|
||||
def test_get_client_ip_xff_ignored():
|
||||
"""X-Forwarded-For is never used; only X-Real-IP matters."""
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "10.0.0.1"
|
||||
req.headers = {"x-forwarded-for": "10.0.0.1, 198.51.100.5", "x-real-ip": "198.51.100.5"}
|
||||
assert _get_client_ip(req) == "198.51.100.5"
|
||||
|
||||
|
||||
def test_get_client_ip_no_real_ip_fallback():
|
||||
"""No X-Real-IP → falls back to client.host (direct connection)."""
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "127.0.0.1"
|
||||
req.headers = {}
|
||||
assert _get_client_ip(req) == "127.0.0.1"
|
||||
|
||||
|
||||
def test_get_client_ip_x_real_ip_always_preferred():
|
||||
"""X-Real-IP is always preferred over client.host regardless of IP."""
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "203.0.113.99"
|
||||
req.headers = {"x-real-ip": "198.51.100.7"}
|
||||
assert _get_client_ip(req) == "198.51.100.7"
|
||||
|
||||
|
||||
# ── Weak JWT secret warning ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_missing_jwt_secret_generates_ephemeral(monkeypatch, caplog):
|
||||
"""get_auth_config() auto-generates an ephemeral secret when AUTH_JWT_SECRET is unset."""
|
||||
import logging
|
||||
|
||||
import app.gateway.auth.config as config_module
|
||||
|
||||
config_module._auth_config = None
|
||||
monkeypatch.delenv("AUTH_JWT_SECRET", raising=False)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
config = config_module.get_auth_config()
|
||||
|
||||
assert config.jwt_secret # non-empty ephemeral secret
|
||||
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
|
||||
|
||||
# Cleanup
|
||||
config_module._auth_config = None
|
||||
216
backend/tests/test_auth_middleware.py
Normal file
216
backend/tests/test_auth_middleware.py
Normal file
@ -0,0 +1,216 @@
|
||||
"""Tests for the global AuthMiddleware (fail-closed safety net)."""
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from app.gateway.auth_middleware import AuthMiddleware, _is_public
|
||||
|
||||
# ── _is_public unit tests ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/health",
|
||||
"/health/",
|
||||
"/docs",
|
||||
"/docs/",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/api/v1/auth/login/local",
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/logout",
|
||||
"/api/v1/auth/setup-status",
|
||||
],
|
||||
)
|
||||
def test_public_paths(path: str):
|
||||
assert _is_public(path) is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/api/models",
|
||||
"/api/mcp/config",
|
||||
"/api/memory",
|
||||
"/api/skills",
|
||||
"/api/threads/123",
|
||||
"/api/threads/123/uploads",
|
||||
"/api/agents",
|
||||
"/api/channels",
|
||||
"/api/runs/stream",
|
||||
"/api/threads/123/runs",
|
||||
"/api/v1/auth/me",
|
||||
"/api/v1/auth/change-password",
|
||||
],
|
||||
)
|
||||
def test_protected_paths(path: str):
|
||||
assert _is_public(path) is False
|
||||
|
||||
|
||||
# ── Trailing slash / normalization edge cases ─────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/api/v1/auth/login/local/",
|
||||
"/api/v1/auth/register/",
|
||||
"/api/v1/auth/logout/",
|
||||
"/api/v1/auth/setup-status/",
|
||||
],
|
||||
)
|
||||
def test_public_auth_paths_with_trailing_slash(path: str):
|
||||
assert _is_public(path) is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/api/models/",
|
||||
"/api/v1/auth/me/",
|
||||
"/api/v1/auth/change-password/",
|
||||
],
|
||||
)
|
||||
def test_protected_paths_with_trailing_slash(path: str):
|
||||
assert _is_public(path) is False
|
||||
|
||||
|
||||
def test_unknown_api_path_is_protected():
|
||||
"""Fail-closed: any new /api/* path is protected by default."""
|
||||
assert _is_public("/api/new-feature") is False
|
||||
assert _is_public("/api/v2/something") is False
|
||||
assert _is_public("/api/v1/auth/new-endpoint") is False
|
||||
|
||||
|
||||
# ── Middleware integration tests ──────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_app():
|
||||
"""Create a minimal FastAPI app with AuthMiddleware for testing."""
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(AuthMiddleware)
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/api/v1/auth/me")
|
||||
async def auth_me():
|
||||
return {"id": "1", "email": "test@test.com"}
|
||||
|
||||
@app.get("/api/v1/auth/setup-status")
|
||||
async def setup_status():
|
||||
return {"needs_setup": False}
|
||||
|
||||
@app.get("/api/models")
|
||||
async def models_get():
|
||||
return {"models": []}
|
||||
|
||||
@app.put("/api/mcp/config")
|
||||
async def mcp_put():
|
||||
return {"ok": True}
|
||||
|
||||
@app.delete("/api/threads/abc")
|
||||
async def thread_delete():
|
||||
return {"ok": True}
|
||||
|
||||
@app.patch("/api/threads/abc")
|
||||
async def thread_patch():
|
||||
return {"ok": True}
|
||||
|
||||
@app.post("/api/threads/abc/runs/stream")
|
||||
async def stream():
|
||||
return {"ok": True}
|
||||
|
||||
@app.get("/api/future-endpoint")
|
||||
async def future():
|
||||
return {"ok": True}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
return TestClient(_make_app())
|
||||
|
||||
|
||||
def test_public_path_no_cookie(client):
|
||||
res = client.get("/health")
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def test_public_auth_path_no_cookie(client):
|
||||
"""Public auth endpoints (login/register) pass without cookie."""
|
||||
res = client.get("/api/v1/auth/setup-status")
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def test_protected_auth_path_no_cookie(client):
|
||||
"""/auth/me requires cookie even though it's under /api/v1/auth/."""
|
||||
res = client.get("/api/v1/auth/me")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_protected_path_no_cookie_returns_401(client):
|
||||
res = client.get("/api/models")
|
||||
assert res.status_code == 401
|
||||
body = res.json()
|
||||
assert body["detail"]["code"] == "not_authenticated"
|
||||
|
||||
|
||||
def test_protected_path_with_cookie_passes(client):
|
||||
res = client.get("/api/models", cookies={"access_token": "some-token"})
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def test_protected_post_no_cookie_returns_401(client):
|
||||
res = client.post("/api/threads/abc/runs/stream")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
# ── Method matrix: PUT/DELETE/PATCH also protected ────────────────────────
|
||||
|
||||
|
||||
def test_protected_put_no_cookie(client):
|
||||
res = client.put("/api/mcp/config")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_protected_delete_no_cookie(client):
|
||||
res = client.delete("/api/threads/abc")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_protected_patch_no_cookie(client):
|
||||
res = client.patch("/api/threads/abc")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_put_with_cookie_passes(client):
|
||||
client.cookies.set("access_token", "tok")
|
||||
res = client.put("/api/mcp/config")
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def test_delete_with_cookie_passes(client):
|
||||
client.cookies.set("access_token", "tok")
|
||||
res = client.delete("/api/threads/abc")
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
# ── Fail-closed: unknown future endpoints ─────────────────────────────────
|
||||
|
||||
|
||||
def test_unknown_endpoint_no_cookie_returns_401(client):
|
||||
"""Any new /api/* endpoint is blocked by default without cookie."""
|
||||
res = client.get("/api/future-endpoint")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_unknown_endpoint_with_cookie_passes(client):
|
||||
client.cookies.set("access_token", "tok")
|
||||
res = client.get("/api/future-endpoint")
|
||||
assert res.status_code == 200
|
||||
675
backend/tests/test_auth_type_system.py
Normal file
675
backend/tests/test_auth_type_system.py
Normal file
@ -0,0 +1,675 @@
|
||||
"""Tests for auth type system hardening.
|
||||
|
||||
Covers structured error responses, typed decode_token callers,
|
||||
CSRF middleware path matching, config-driven cookie security,
|
||||
and unhappy paths / edge cases for all auth boundaries.
|
||||
"""
|
||||
|
||||
import os
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import jwt as pyjwt
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
|
||||
from app.gateway.auth.jwt import decode_token
|
||||
from app.gateway.csrf_middleware import (
|
||||
CSRF_COOKIE_NAME,
|
||||
CSRF_HEADER_NAME,
|
||||
CSRFMiddleware,
|
||||
is_auth_endpoint,
|
||||
should_check_csrf,
|
||||
)
|
||||
|
||||
# ── Setup ────────────────────────────────────────────────────────────
|
||||
|
||||
_TEST_SECRET = "test-secret-for-auth-type-system-tests-min32"
|
||||
|
||||
|
||||
def _setup_config():
|
||||
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
|
||||
|
||||
|
||||
# ── CSRF Middleware Path Matching ────────────────────────────────────
|
||||
|
||||
|
||||
class _FakeRequest:
|
||||
"""Minimal request mock for CSRF path matching tests."""
|
||||
|
||||
def __init__(self, path: str, method: str = "POST"):
|
||||
self.method = method
|
||||
|
||||
class _URL:
|
||||
def __init__(self, p):
|
||||
self.path = p
|
||||
|
||||
self.url = _URL(path)
|
||||
self.cookies = {}
|
||||
self.headers = {}
|
||||
|
||||
|
||||
def test_csrf_exempts_login_local():
|
||||
"""login/local (actual route) should be exempt from CSRF."""
|
||||
req = _FakeRequest("/api/v1/auth/login/local")
|
||||
assert is_auth_endpoint(req) is True
|
||||
|
||||
|
||||
def test_csrf_exempts_login_local_trailing_slash():
|
||||
"""Trailing slash should also be exempt."""
|
||||
req = _FakeRequest("/api/v1/auth/login/local/")
|
||||
assert is_auth_endpoint(req) is True
|
||||
|
||||
|
||||
def test_csrf_exempts_logout():
|
||||
req = _FakeRequest("/api/v1/auth/logout")
|
||||
assert is_auth_endpoint(req) is True
|
||||
|
||||
|
||||
def test_csrf_exempts_register():
|
||||
req = _FakeRequest("/api/v1/auth/register")
|
||||
assert is_auth_endpoint(req) is True
|
||||
|
||||
|
||||
def test_csrf_does_not_exempt_old_login_path():
|
||||
"""Old /api/v1/auth/login (without /local) should NOT be exempt."""
|
||||
req = _FakeRequest("/api/v1/auth/login")
|
||||
assert is_auth_endpoint(req) is False
|
||||
|
||||
|
||||
def test_csrf_does_not_exempt_me():
|
||||
req = _FakeRequest("/api/v1/auth/me")
|
||||
assert is_auth_endpoint(req) is False
|
||||
|
||||
|
||||
def test_csrf_skips_get_requests():
|
||||
req = _FakeRequest("/api/v1/auth/me", method="GET")
|
||||
assert should_check_csrf(req) is False
|
||||
|
||||
|
||||
def test_csrf_checks_post_to_protected():
|
||||
req = _FakeRequest("/api/v1/some/endpoint", method="POST")
|
||||
assert should_check_csrf(req) is True
|
||||
|
||||
|
||||
# ── Structured Error Response Format ────────────────────────────────
|
||||
|
||||
|
||||
def test_auth_error_response_has_code_and_message():
|
||||
"""All auth errors should have structured {code, message} format."""
|
||||
err = AuthErrorResponse(
|
||||
code=AuthErrorCode.INVALID_CREDENTIALS,
|
||||
message="Wrong password",
|
||||
)
|
||||
d = err.model_dump()
|
||||
assert "code" in d
|
||||
assert "message" in d
|
||||
assert d["code"] == "invalid_credentials"
|
||||
|
||||
|
||||
def test_auth_error_response_all_codes_serializable():
|
||||
"""Every AuthErrorCode should be serializable in AuthErrorResponse."""
|
||||
for code in AuthErrorCode:
|
||||
err = AuthErrorResponse(code=code, message=f"Test {code.value}")
|
||||
d = err.model_dump()
|
||||
assert d["code"] == code.value
|
||||
|
||||
|
||||
# ── decode_token Caller Pattern ──────────────────────────────────────
|
||||
|
||||
|
||||
def test_decode_token_expired_maps_to_token_expired_code():
|
||||
"""TokenError.EXPIRED should map to AuthErrorCode.TOKEN_EXPIRED."""
|
||||
_setup_config()
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt as pyjwt
|
||||
|
||||
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
|
||||
result = decode_token(token)
|
||||
assert result == TokenError.EXPIRED
|
||||
|
||||
# Verify the mapping pattern used in route handlers
|
||||
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
|
||||
assert code == AuthErrorCode.TOKEN_EXPIRED
|
||||
|
||||
|
||||
def test_decode_token_invalid_sig_maps_to_token_invalid_code():
|
||||
"""TokenError.INVALID_SIGNATURE should map to AuthErrorCode.TOKEN_INVALID."""
|
||||
_setup_config()
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt as pyjwt
|
||||
|
||||
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(payload, "wrong-key", algorithm="HS256")
|
||||
result = decode_token(token)
|
||||
assert result == TokenError.INVALID_SIGNATURE
|
||||
|
||||
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
|
||||
assert code == AuthErrorCode.TOKEN_INVALID
|
||||
|
||||
|
||||
def test_decode_token_malformed_maps_to_token_invalid_code():
|
||||
"""TokenError.MALFORMED should map to AuthErrorCode.TOKEN_INVALID."""
|
||||
_setup_config()
|
||||
result = decode_token("garbage")
|
||||
assert result == TokenError.MALFORMED
|
||||
|
||||
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
|
||||
assert code == AuthErrorCode.TOKEN_INVALID
|
||||
|
||||
|
||||
# ── Login Response Format ────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_login_response_model_has_no_access_token():
|
||||
"""LoginResponse should NOT contain access_token field (RFC-001)."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
|
||||
resp = LoginResponse(expires_in=604800)
|
||||
d = resp.model_dump()
|
||||
assert "access_token" not in d
|
||||
assert "expires_in" in d
|
||||
assert d["expires_in"] == 604800
|
||||
|
||||
|
||||
def test_login_response_model_fields():
|
||||
"""LoginResponse has expires_in and needs_setup."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
|
||||
fields = set(LoginResponse.model_fields.keys())
|
||||
assert fields == {"expires_in", "needs_setup"}
|
||||
|
||||
|
||||
# ── AuthConfig in Route ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_used_in_login_response():
|
||||
"""LoginResponse.expires_in should come from config.token_expiry_days."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
|
||||
expected_seconds = 14 * 24 * 3600
|
||||
resp = LoginResponse(expires_in=expected_seconds)
|
||||
assert resp.expires_in == expected_seconds
|
||||
|
||||
|
||||
# ── UserResponse Type Preservation ───────────────────────────────────
|
||||
|
||||
|
||||
def test_user_response_system_role_literal():
|
||||
"""UserResponse.system_role should only accept 'admin' or 'user'."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
|
||||
# Valid roles
|
||||
resp = UserResponse(id="1", email="a@b.com", system_role="admin")
|
||||
assert resp.system_role == "admin"
|
||||
|
||||
resp = UserResponse(id="1", email="a@b.com", system_role="user")
|
||||
assert resp.system_role == "user"
|
||||
|
||||
|
||||
def test_user_response_rejects_invalid_role():
|
||||
"""UserResponse should reject invalid system_role values."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UserResponse(id="1", email="a@b.com", system_role="superadmin")
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════
|
||||
# UNHAPPY PATHS / EDGE CASES
|
||||
# ══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
# ── get_current_user structured 401 responses ────────────────────────
|
||||
|
||||
|
||||
def test_get_current_user_no_cookie_returns_not_authenticated():
|
||||
"""No cookie → 401 with code=not_authenticated."""
|
||||
import asyncio
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
mock_request = type("MockRequest", (), {"cookies": {}})()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
detail = exc_info.value.detail
|
||||
assert detail["code"] == "not_authenticated"
|
||||
|
||||
|
||||
def test_get_current_user_expired_token_returns_token_expired():
|
||||
"""Expired token → 401 with code=token_expired."""
|
||||
import asyncio
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
_setup_config()
|
||||
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
|
||||
|
||||
mock_request = type("MockRequest", (), {"cookies": {"access_token": token}})()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
detail = exc_info.value.detail
|
||||
assert detail["code"] == "token_expired"
|
||||
|
||||
|
||||
def test_get_current_user_invalid_token_returns_token_invalid():
|
||||
"""Bad signature → 401 with code=token_invalid."""
|
||||
import asyncio
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
_setup_config()
|
||||
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256")
|
||||
|
||||
mock_request = type("MockRequest", (), {"cookies": {"access_token": token}})()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
detail = exc_info.value.detail
|
||||
assert detail["code"] == "token_invalid"
|
||||
|
||||
|
||||
def test_get_current_user_malformed_token_returns_token_invalid():
|
||||
"""Garbage token → 401 with code=token_invalid."""
|
||||
import asyncio
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
_setup_config()
|
||||
mock_request = type("MockRequest", (), {"cookies": {"access_token": "not-a-jwt"}})()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
detail = exc_info.value.detail
|
||||
assert detail["code"] == "token_invalid"
|
||||
|
||||
|
||||
# ── decode_token edge cases ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_decode_token_empty_string_returns_malformed():
|
||||
_setup_config()
|
||||
result = decode_token("")
|
||||
assert result == TokenError.MALFORMED
|
||||
|
||||
|
||||
def test_decode_token_whitespace_returns_malformed():
|
||||
_setup_config()
|
||||
result = decode_token(" ")
|
||||
assert result == TokenError.MALFORMED
|
||||
|
||||
|
||||
# ── AuthConfig validation edge cases ─────────────────────────────────
|
||||
|
||||
|
||||
def test_auth_config_missing_jwt_secret_raises():
|
||||
"""AuthConfig requires jwt_secret — no default allowed."""
|
||||
with pytest.raises(ValidationError):
|
||||
AuthConfig()
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_zero_raises():
|
||||
"""token_expiry_days must be >= 1."""
|
||||
with pytest.raises(ValidationError):
|
||||
AuthConfig(jwt_secret="secret", token_expiry_days=0)
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_31_raises():
|
||||
"""token_expiry_days must be <= 30."""
|
||||
with pytest.raises(ValidationError):
|
||||
AuthConfig(jwt_secret="secret", token_expiry_days=31)
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_boundary_1_ok():
|
||||
config = AuthConfig(jwt_secret="secret", token_expiry_days=1)
|
||||
assert config.token_expiry_days == 1
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_boundary_30_ok():
|
||||
config = AuthConfig(jwt_secret="secret", token_expiry_days=30)
|
||||
assert config.token_expiry_days == 30
|
||||
|
||||
|
||||
def test_get_auth_config_missing_env_var_generates_ephemeral(caplog):
|
||||
"""get_auth_config() auto-generates ephemeral secret when AUTH_JWT_SECRET is unset."""
|
||||
import logging
|
||||
|
||||
import app.gateway.auth.config as cfg
|
||||
|
||||
old = cfg._auth_config
|
||||
cfg._auth_config = None
|
||||
try:
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
os.environ.pop("AUTH_JWT_SECRET", None)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
config = cfg.get_auth_config()
|
||||
assert config.jwt_secret
|
||||
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
|
||||
finally:
|
||||
cfg._auth_config = old
|
||||
|
||||
|
||||
# ── CSRF middleware integration (unhappy paths) ──────────────────────
|
||||
|
||||
|
||||
def _make_csrf_app():
|
||||
"""Create a minimal FastAPI app with CSRFMiddleware for testing."""
|
||||
from fastapi import HTTPException as _HTTPException
|
||||
from fastapi.responses import JSONResponse as _JSONResponse
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.exception_handler(_HTTPException)
|
||||
async def _http_exc_handler(request, exc):
|
||||
return _JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
||||
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
@app.post("/api/v1/test/protected")
|
||||
async def protected():
|
||||
return {"ok": True}
|
||||
|
||||
@app.post("/api/v1/auth/login/local")
|
||||
async def login():
|
||||
return {"ok": True}
|
||||
|
||||
@app.get("/api/v1/test/read")
|
||||
async def read_endpoint():
|
||||
return {"ok": True}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def test_csrf_middleware_blocks_post_without_token():
|
||||
"""POST to protected endpoint without CSRF token → 403 with structured detail."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
resp = client.post("/api/v1/test/protected")
|
||||
assert resp.status_code == 403
|
||||
assert "CSRF" in resp.json()["detail"]
|
||||
assert "missing" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_csrf_middleware_blocks_post_with_mismatched_token():
|
||||
"""POST with mismatched CSRF cookie/header → 403 with mismatch detail."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
client.cookies.set(CSRF_COOKIE_NAME, "token-a")
|
||||
resp = client.post(
|
||||
"/api/v1/test/protected",
|
||||
headers={CSRF_HEADER_NAME: "token-b"},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
assert "mismatch" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_csrf_middleware_allows_post_with_matching_token():
|
||||
"""POST with matching CSRF cookie/header → 200."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
token = secrets.token_urlsafe(64)
|
||||
client.cookies.set(CSRF_COOKIE_NAME, token)
|
||||
resp = client.post(
|
||||
"/api/v1/test/protected",
|
||||
headers={CSRF_HEADER_NAME: token},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_csrf_middleware_allows_get_without_token():
|
||||
"""GET requests bypass CSRF check."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
resp = client.get("/api/v1/test/read")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_csrf_middleware_exempts_login_local():
|
||||
"""POST to login/local is exempt from CSRF (no token yet)."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
resp = client.post("/api/v1/auth/login/local")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_csrf_middleware_sets_cookie_on_auth_endpoint():
|
||||
"""Auth endpoints should receive a CSRF cookie in response."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
resp = client.post("/api/v1/auth/login/local")
|
||||
assert CSRF_COOKIE_NAME in resp.cookies
|
||||
|
||||
|
||||
# ── UserResponse edge cases ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_user_response_missing_required_fields():
|
||||
"""UserResponse with missing fields → ValidationError."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UserResponse(id="1") # missing email, system_role
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UserResponse(id="1", email="a@b.com") # missing system_role
|
||||
|
||||
|
||||
def test_user_response_empty_string_role_rejected():
|
||||
"""Empty string is not a valid role."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UserResponse(id="1", email="a@b.com", system_role="")
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════
|
||||
# HTTP-LEVEL API CONTRACT TESTS
|
||||
# ══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
def _make_auth_app():
|
||||
"""Create FastAPI app with auth routes for contract testing."""
|
||||
from app.gateway.app import create_app
|
||||
|
||||
return create_app()
|
||||
|
||||
|
||||
def _get_auth_client():
|
||||
"""Get TestClient for auth API contract tests."""
|
||||
return TestClient(_make_auth_app())
|
||||
|
||||
|
||||
def test_api_auth_me_no_cookie_returns_structured_401():
|
||||
"""/api/v1/auth/me without cookie → 401 with {code: 'not_authenticated'}."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "not_authenticated"
|
||||
assert "message" in body["detail"]
|
||||
|
||||
|
||||
def test_api_auth_me_expired_token_returns_structured_401():
|
||||
"""/api/v1/auth/me with expired token → 401 with {code: 'token_expired'}."""
|
||||
_setup_config()
|
||||
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
|
||||
|
||||
client = _get_auth_client()
|
||||
client.cookies.set("access_token", token)
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "token_expired"
|
||||
|
||||
|
||||
def test_api_auth_me_invalid_sig_returns_structured_401():
|
||||
"""/api/v1/auth/me with bad signature → 401 with {code: 'token_invalid'}."""
|
||||
_setup_config()
|
||||
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(payload, "wrong-key", algorithm="HS256")
|
||||
|
||||
client = _get_auth_client()
|
||||
client.cookies.set("access_token", token)
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "token_invalid"
|
||||
|
||||
|
||||
def test_api_login_bad_credentials_returns_structured_401():
|
||||
"""Login with wrong password → 401 with {code: 'invalid_credentials'}."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
data={"username": "nonexistent@test.com", "password": "wrongpassword"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "invalid_credentials"
|
||||
|
||||
|
||||
def test_api_login_success_no_token_in_body():
|
||||
"""Successful login → response body has expires_in but NOT access_token."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
# Register first
|
||||
client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "contract-test@test.com", "password": "securepassword123"},
|
||||
)
|
||||
# Login
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
data={"username": "contract-test@test.com", "password": "securepassword123"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert "expires_in" in body
|
||||
assert "access_token" not in body
|
||||
# Token should be in cookie, not body
|
||||
assert "access_token" in resp.cookies
|
||||
|
||||
|
||||
def test_api_register_duplicate_returns_structured_400():
|
||||
"""Register with duplicate email → 400 with {code: 'email_already_exists'}."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
email = "dup-contract-test@test.com"
|
||||
# First register
|
||||
client.post("/api/v1/auth/register", json={"email": email, "password": "password123"})
|
||||
# Duplicate
|
||||
resp = client.post("/api/v1/auth/register", json={"email": email, "password": "password456"})
|
||||
assert resp.status_code == 400
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "email_already_exists"
|
||||
|
||||
|
||||
# ── Cookie security: HTTP vs HTTPS ────────────────────────────────────
|
||||
|
||||
|
||||
def _unique_email(prefix: str) -> str:
|
||||
return f"{prefix}-{secrets.token_hex(4)}@test.com"
|
||||
|
||||
|
||||
def _get_set_cookie_headers(resp) -> list[str]:
|
||||
"""Extract all set-cookie header values from a TestClient response."""
|
||||
return [v for k, v in resp.headers.multi_items() if k.lower() == "set-cookie"]
|
||||
|
||||
|
||||
def test_register_http_cookie_httponly_true_secure_false():
|
||||
"""HTTP register → access_token cookie is httponly=True, secure=False, no max_age."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("http-cookie"), "password": "password123"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
cookie_header = resp.headers.get("set-cookie", "")
|
||||
assert "access_token=" in cookie_header
|
||||
assert "httponly" in cookie_header.lower()
|
||||
assert "secure" not in cookie_header.lower().replace("samesite", "")
|
||||
|
||||
|
||||
def test_register_https_cookie_httponly_true_secure_true():
|
||||
"""HTTPS register (x-forwarded-proto) → access_token cookie is httponly=True, secure=True, has max_age."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("https-cookie"), "password": "password123"},
|
||||
headers={"x-forwarded-proto": "https"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
cookie_header = resp.headers.get("set-cookie", "")
|
||||
assert "access_token=" in cookie_header
|
||||
assert "httponly" in cookie_header.lower()
|
||||
assert "secure" in cookie_header.lower()
|
||||
assert "max-age" in cookie_header.lower()
|
||||
|
||||
|
||||
def test_login_https_sets_secure_cookie():
|
||||
"""HTTPS login → access_token cookie has secure flag."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
email = _unique_email("https-login")
|
||||
client.post("/api/v1/auth/register", json={"email": email, "password": "password123"})
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
data={"username": email, "password": "password123"},
|
||||
headers={"x-forwarded-proto": "https"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
cookie_header = resp.headers.get("set-cookie", "")
|
||||
assert "access_token=" in cookie_header
|
||||
assert "httponly" in cookie_header.lower()
|
||||
assert "secure" in cookie_header.lower()
|
||||
|
||||
|
||||
def test_csrf_cookie_secure_on_https():
|
||||
"""HTTPS register → csrf_token cookie has secure flag but NOT httponly."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("csrf-https"), "password": "password123"},
|
||||
headers={"x-forwarded-proto": "https"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
|
||||
assert csrf_cookies, "csrf_token cookie not set on HTTPS register"
|
||||
csrf_header = csrf_cookies[0]
|
||||
assert "secure" in csrf_header.lower()
|
||||
assert "httponly" not in csrf_header.lower()
|
||||
|
||||
|
||||
def test_csrf_cookie_not_secure_on_http():
|
||||
"""HTTP register → csrf_token cookie does NOT have secure flag."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("csrf-http"), "password": "password123"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
|
||||
assert csrf_cookies, "csrf_token cookie not set on HTTP register"
|
||||
csrf_header = csrf_cookies[0]
|
||||
assert "secure" not in csrf_header.lower().replace("samesite", "")
|
||||
214
backend/tests/test_ensure_admin.py
Normal file
214
backend/tests/test_ensure_admin.py
Normal file
@ -0,0 +1,214 @@
|
||||
"""Tests for _ensure_admin_user() in app.py.
|
||||
|
||||
Covers: first-boot admin creation, auto-reset on needs_setup=True,
|
||||
no-op on needs_setup=False, migration, and edge cases.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-ensure-admin-testing-min-32")
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
from app.gateway.auth.models import User
|
||||
|
||||
_JWT_SECRET = "test-secret-key-ensure-admin-testing-min-32"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_auth_config():
|
||||
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
|
||||
yield
|
||||
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
|
||||
|
||||
|
||||
def _make_app_stub(store=None):
|
||||
"""Minimal app-like object with state.store."""
|
||||
app = SimpleNamespace()
|
||||
app.state = SimpleNamespace()
|
||||
app.state.store = store
|
||||
return app
|
||||
|
||||
|
||||
def _make_provider(user_count=0, admin_user=None):
|
||||
p = AsyncMock()
|
||||
p.count_users = AsyncMock(return_value=user_count)
|
||||
p.create_user = AsyncMock(
|
||||
side_effect=lambda **kw: User(
|
||||
email=kw["email"],
|
||||
password_hash="hashed",
|
||||
system_role=kw.get("system_role", "user"),
|
||||
needs_setup=kw.get("needs_setup", False),
|
||||
)
|
||||
)
|
||||
p.get_user_by_email = AsyncMock(return_value=admin_user)
|
||||
p.update_user = AsyncMock(side_effect=lambda u: u)
|
||||
return p
|
||||
|
||||
|
||||
# ── First boot: no users ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_first_boot_creates_admin():
|
||||
"""count_users==0 → create admin with needs_setup=True."""
|
||||
provider = _make_provider(user_count=0)
|
||||
app = _make_app_stub()
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
provider.create_user.assert_called_once()
|
||||
call_kwargs = provider.create_user.call_args[1]
|
||||
assert call_kwargs["email"] == "admin@deerflow.dev"
|
||||
assert call_kwargs["system_role"] == "admin"
|
||||
assert call_kwargs["needs_setup"] is True
|
||||
assert len(call_kwargs["password"]) > 10 # random password generated
|
||||
|
||||
|
||||
def test_first_boot_triggers_migration_if_store_present():
|
||||
"""First boot with store → _migrate_orphaned_threads called."""
|
||||
provider = _make_provider(user_count=0)
|
||||
store = AsyncMock()
|
||||
store.asearch = AsyncMock(return_value=[])
|
||||
app = _make_app_stub(store=store)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
store.asearch.assert_called_once()
|
||||
|
||||
|
||||
def test_first_boot_no_store_skips_migration():
|
||||
"""First boot without store → no crash, migration skipped."""
|
||||
provider = _make_provider(user_count=0)
|
||||
app = _make_app_stub(store=None)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
provider.create_user.assert_called_once()
|
||||
|
||||
|
||||
# ── Subsequent boot: needs_setup=True → auto-reset ───────────────────────
|
||||
|
||||
|
||||
def test_needs_setup_true_resets_password():
|
||||
"""Existing admin with needs_setup=True → password reset + token_version bumped."""
|
||||
admin = User(
|
||||
email="admin@deerflow.dev",
|
||||
password_hash="old-hash",
|
||||
system_role="admin",
|
||||
needs_setup=True,
|
||||
token_version=0,
|
||||
created_at=datetime.now(UTC) - timedelta(seconds=30),
|
||||
)
|
||||
provider = _make_provider(user_count=1, admin_user=admin)
|
||||
app = _make_app_stub()
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="new-hash"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
# Password was reset
|
||||
provider.update_user.assert_called_once()
|
||||
updated = provider.update_user.call_args[0][0]
|
||||
assert updated.password_hash == "new-hash"
|
||||
assert updated.token_version == 1
|
||||
|
||||
|
||||
def test_needs_setup_true_consecutive_resets_increment_version():
|
||||
"""Two boots with needs_setup=True → token_version increments each time."""
|
||||
admin = User(
|
||||
email="admin@deerflow.dev",
|
||||
password_hash="hash",
|
||||
system_role="admin",
|
||||
needs_setup=True,
|
||||
token_version=3,
|
||||
created_at=datetime.now(UTC) - timedelta(seconds=30),
|
||||
)
|
||||
provider = _make_provider(user_count=1, admin_user=admin)
|
||||
app = _make_app_stub()
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="new-hash"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
updated = provider.update_user.call_args[0][0]
|
||||
assert updated.token_version == 4
|
||||
|
||||
|
||||
# ── Subsequent boot: needs_setup=False → no-op ──────────────────────────
|
||||
|
||||
|
||||
def test_needs_setup_false_no_reset():
|
||||
"""Admin with needs_setup=False → no password reset, no update."""
|
||||
admin = User(
|
||||
email="admin@deerflow.dev",
|
||||
password_hash="stable-hash",
|
||||
system_role="admin",
|
||||
needs_setup=False,
|
||||
token_version=2,
|
||||
)
|
||||
provider = _make_provider(user_count=1, admin_user=admin)
|
||||
app = _make_app_stub()
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
provider.update_user.assert_not_called()
|
||||
assert admin.password_hash == "stable-hash"
|
||||
assert admin.token_version == 2
|
||||
|
||||
|
||||
# ── Edge cases ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_no_admin_email_found_no_crash():
|
||||
"""Users exist but no admin@deerflow.dev → no crash, no reset."""
|
||||
provider = _make_provider(user_count=3, admin_user=None)
|
||||
app = _make_app_stub()
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
provider.update_user.assert_not_called()
|
||||
provider.create_user.assert_not_called()
|
||||
|
||||
|
||||
def test_migration_failure_is_non_fatal():
|
||||
"""_migrate_orphaned_threads exception is caught and logged."""
|
||||
provider = _make_provider(user_count=0)
|
||||
store = AsyncMock()
|
||||
store.asearch = AsyncMock(side_effect=RuntimeError("store crashed"))
|
||||
app = _make_app_stub(store=store)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
# Should not raise
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
provider.create_user.assert_called_once()
|
||||
312
backend/tests/test_langgraph_auth.py
Normal file
312
backend/tests/test_langgraph_auth.py
Normal file
@ -0,0 +1,312 @@
|
||||
"""Tests for LangGraph Server auth handler (langgraph_auth.py).
|
||||
|
||||
Validates that the LangGraph auth layer enforces the same rules as Gateway:
|
||||
cookie → JWT decode → DB lookup → token_version check → owner filter
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-for-langgraph-auth-testing-min-32")
|
||||
|
||||
from langgraph_sdk import Auth
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
from app.gateway.auth.jwt import create_access_token, decode_token
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.langgraph_auth import add_owner_filter, authenticate
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
_JWT_SECRET = "test-secret-key-for-langgraph-auth-testing-min-32"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_auth_config():
|
||||
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
|
||||
yield
|
||||
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
|
||||
|
||||
|
||||
def _req(cookies=None, method="GET", headers=None):
|
||||
return SimpleNamespace(cookies=cookies or {}, method=method, headers=headers or {})
|
||||
|
||||
|
||||
def _user(user_id=None, token_version=0):
|
||||
return User(email="test@example.com", password_hash="fakehash", system_role="user", id=user_id or uuid4(), token_version=token_version)
|
||||
|
||||
|
||||
def _mock_provider(user=None):
|
||||
p = AsyncMock()
|
||||
p.get_user = AsyncMock(return_value=user)
|
||||
return p
|
||||
|
||||
|
||||
# ── @auth.authenticate ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_no_cookie_raises_401():
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req()))
|
||||
assert exc.value.status_code == 401
|
||||
assert "Not authenticated" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_invalid_jwt_raises_401():
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": "garbage"})))
|
||||
assert exc.value.status_code == 401
|
||||
assert "Token error" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_expired_jwt_raises_401():
|
||||
token = create_access_token("user-1", expires_delta=timedelta(seconds=-1))
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": token})))
|
||||
assert exc.value.status_code == 401
|
||||
|
||||
|
||||
def test_user_not_found_raises_401():
|
||||
token = create_access_token("ghost")
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(None)):
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": token})))
|
||||
assert exc.value.status_code == 401
|
||||
assert "User not found" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_token_version_mismatch_raises_401():
|
||||
user = _user(token_version=2)
|
||||
token = create_access_token(str(user.id), token_version=1)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": token})))
|
||||
assert exc.value.status_code == 401
|
||||
assert "revoked" in str(exc.value.detail).lower()
|
||||
|
||||
|
||||
def test_valid_token_returns_user_id():
|
||||
user = _user(token_version=0)
|
||||
token = create_access_token(str(user.id), token_version=0)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
result = asyncio.run(authenticate(_req({"access_token": token})))
|
||||
assert result == str(user.id)
|
||||
|
||||
|
||||
def test_valid_token_matching_version():
|
||||
user = _user(token_version=5)
|
||||
token = create_access_token(str(user.id), token_version=5)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
result = asyncio.run(authenticate(_req({"access_token": token})))
|
||||
assert result == str(user.id)
|
||||
|
||||
|
||||
# ── @auth.authenticate edge cases ────────────────────────────────────────
|
||||
|
||||
|
||||
def test_provider_exception_propagates():
|
||||
"""Provider raises → should not be swallowed silently."""
|
||||
token = create_access_token("user-1")
|
||||
p = AsyncMock()
|
||||
p.get_user = AsyncMock(side_effect=RuntimeError("DB down"))
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=p):
|
||||
with pytest.raises(RuntimeError, match="DB down"):
|
||||
asyncio.run(authenticate(_req({"access_token": token})))
|
||||
|
||||
|
||||
def test_jwt_missing_ver_defaults_to_zero():
|
||||
"""JWT without 'ver' claim → decoded as ver=0, matches user with token_version=0."""
|
||||
import jwt as pyjwt
|
||||
|
||||
uid = str(uuid4())
|
||||
raw = pyjwt.encode({"sub": uid, "exp": 9999999999, "iat": 1000000000}, _JWT_SECRET, algorithm="HS256")
|
||||
user = _user(user_id=uid, token_version=0)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
result = asyncio.run(authenticate(_req({"access_token": raw})))
|
||||
assert result == uid
|
||||
|
||||
|
||||
def test_jwt_missing_ver_rejected_when_user_version_nonzero():
|
||||
"""JWT without 'ver' (defaults 0) vs user with token_version=1 → 401."""
|
||||
import jwt as pyjwt
|
||||
|
||||
uid = str(uuid4())
|
||||
raw = pyjwt.encode({"sub": uid, "exp": 9999999999, "iat": 1000000000}, _JWT_SECRET, algorithm="HS256")
|
||||
user = _user(user_id=uid, token_version=1)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": raw})))
|
||||
assert exc.value.status_code == 401
|
||||
|
||||
|
||||
def test_wrong_secret_raises_401():
|
||||
"""Token signed with different secret → 401."""
|
||||
import jwt as pyjwt
|
||||
|
||||
raw = pyjwt.encode({"sub": "user-1", "exp": 9999999999, "ver": 0}, "wrong-secret-that-is-long-enough-32chars!", algorithm="HS256")
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": raw})))
|
||||
assert exc.value.status_code == 401
|
||||
|
||||
|
||||
# ── @auth.on (owner filter) ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class _FakeUser:
|
||||
"""Minimal BaseUser-compatible object without langgraph_api.config dependency."""
|
||||
|
||||
def __init__(self, identity: str):
|
||||
self.identity = identity
|
||||
self.is_authenticated = True
|
||||
self.display_name = identity
|
||||
|
||||
|
||||
def _make_ctx(user_id):
|
||||
return Auth.types.AuthContext(resource="threads", action="create", user=_FakeUser(user_id), permissions=[])
|
||||
|
||||
|
||||
def test_filter_injects_user_id():
|
||||
value = {}
|
||||
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
||||
assert value["metadata"]["owner_id"] == "user-a"
|
||||
|
||||
|
||||
def test_filter_preserves_existing_metadata():
|
||||
value = {"metadata": {"title": "hello"}}
|
||||
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
||||
assert value["metadata"]["owner_id"] == "user-a"
|
||||
assert value["metadata"]["title"] == "hello"
|
||||
|
||||
|
||||
def test_filter_returns_user_id_dict():
|
||||
result = asyncio.run(add_owner_filter(_make_ctx("user-x"), {}))
|
||||
assert result == {"owner_id": "user-x"}
|
||||
|
||||
|
||||
def test_filter_read_write_consistency():
|
||||
value = {}
|
||||
filter_dict = asyncio.run(add_owner_filter(_make_ctx("user-1"), value))
|
||||
assert value["metadata"]["owner_id"] == filter_dict["owner_id"]
|
||||
|
||||
|
||||
def test_different_users_different_filters():
|
||||
f_a = asyncio.run(add_owner_filter(_make_ctx("a"), {}))
|
||||
f_b = asyncio.run(add_owner_filter(_make_ctx("b"), {}))
|
||||
assert f_a["owner_id"] != f_b["owner_id"]
|
||||
|
||||
|
||||
def test_filter_overrides_conflicting_user_id():
|
||||
"""If value already has a different user_id in metadata, it gets overwritten."""
|
||||
value = {"metadata": {"owner_id": "attacker"}}
|
||||
asyncio.run(add_owner_filter(_make_ctx("real-owner"), value))
|
||||
assert value["metadata"]["owner_id"] == "real-owner"
|
||||
|
||||
|
||||
def test_filter_with_empty_metadata():
|
||||
"""Explicit empty metadata dict is fine."""
|
||||
value = {"metadata": {}}
|
||||
result = asyncio.run(add_owner_filter(_make_ctx("user-z"), value))
|
||||
assert value["metadata"]["owner_id"] == "user-z"
|
||||
assert result == {"owner_id": "user-z"}
|
||||
|
||||
|
||||
# ── Gateway parity ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_shared_jwt_secret():
|
||||
token = create_access_token("user-1", token_version=3)
|
||||
payload = decode_token(token)
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
assert not isinstance(payload, TokenError)
|
||||
assert payload.sub == "user-1"
|
||||
assert payload.ver == 3
|
||||
|
||||
|
||||
def test_langgraph_json_has_auth_path():
|
||||
import json
|
||||
|
||||
config = json.loads((Path(__file__).parent.parent / "langgraph.json").read_text())
|
||||
assert "auth" in config
|
||||
assert "langgraph_auth" in config["auth"]["path"]
|
||||
|
||||
|
||||
def test_auth_handler_has_both_layers():
|
||||
from app.gateway.langgraph_auth import auth
|
||||
|
||||
assert auth._authenticate_handler is not None
|
||||
assert len(auth._global_handlers) == 1
|
||||
|
||||
|
||||
# ── CSRF in LangGraph auth ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_csrf_get_no_check():
|
||||
"""GET requests skip CSRF — should proceed to JWT validation."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req(method="GET")))
|
||||
# Rejected by missing cookie, NOT by CSRF
|
||||
assert exc.value.status_code == 401
|
||||
assert "Not authenticated" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_csrf_post_missing_token():
|
||||
"""POST without CSRF token → 403."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req(method="POST", cookies={"access_token": "some-jwt"})))
|
||||
assert exc.value.status_code == 403
|
||||
assert "CSRF token missing" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_csrf_post_mismatched_token():
|
||||
"""POST with mismatched CSRF tokens → 403."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(
|
||||
authenticate(
|
||||
_req(
|
||||
method="POST",
|
||||
cookies={"access_token": "some-jwt", "csrf_token": "real-token"},
|
||||
headers={"x-csrf-token": "wrong-token"},
|
||||
)
|
||||
)
|
||||
)
|
||||
assert exc.value.status_code == 403
|
||||
assert "mismatch" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_csrf_post_matching_token_proceeds_to_jwt():
|
||||
"""POST with matching CSRF tokens passes CSRF check, then fails on JWT."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(
|
||||
authenticate(
|
||||
_req(
|
||||
method="POST",
|
||||
cookies={"access_token": "garbage", "csrf_token": "same-token"},
|
||||
headers={"x-csrf-token": "same-token"},
|
||||
)
|
||||
)
|
||||
)
|
||||
# Past CSRF, rejected by JWT decode
|
||||
assert exc.value.status_code == 401
|
||||
assert "Token error" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_csrf_put_requires_token():
|
||||
"""PUT also requires CSRF."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req(method="PUT", cookies={"access_token": "jwt"})))
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
def test_csrf_delete_requires_token():
|
||||
"""DELETE also requires CSRF."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req(method="DELETE", cookies={"access_token": "jwt"})))
|
||||
assert exc.value.status_code == 403
|
||||
@ -52,7 +52,6 @@
|
||||
"@xyflow/react": "^12.10.0",
|
||||
"ai": "^6.0.33",
|
||||
"best-effort-json-parser": "^1.2.1",
|
||||
"better-auth": "^1.3",
|
||||
"canvas-confetti": "^1.9.4",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
|
||||
183
frontend/pnpm-lock.yaml
generated
183
frontend/pnpm-lock.yaml
generated
@ -113,9 +113,6 @@ importers:
|
||||
best-effort-json-parser:
|
||||
specifier: ^1.2.1
|
||||
version: 1.2.1
|
||||
better-auth:
|
||||
specifier: ^1.3
|
||||
version: 1.4.18(next@16.1.7(@opentelemetry/api@1.9.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(vue@3.5.28(typescript@5.9.3))
|
||||
canvas-confetti:
|
||||
specifier: ^1.9.4
|
||||
version: 1.9.4
|
||||
@ -317,27 +314,6 @@ packages:
|
||||
resolution: {integrity: sha512-LwdZHpScM4Qz8Xw2iKSzS+cfglZzJGvofQICy7W7v4caru4EaAmyUuO6BGrbyQ2mYV11W0U8j5mBhd14dd3B0A==}
|
||||
engines: {node: '>=6.9.0'}
|
||||
|
||||
'@better-auth/core@1.4.18':
|
||||
resolution: {integrity: sha512-q+awYgC7nkLEBdx2sW0iJjkzgSHlIxGnOpsN1r/O1+a4m7osJNHtfK2mKJSL1I+GfNyIlxJF8WvD/NLuYMpmcg==}
|
||||
peerDependencies:
|
||||
'@better-auth/utils': 0.3.0
|
||||
'@better-fetch/fetch': 1.1.21
|
||||
better-call: 1.1.8
|
||||
jose: ^6.1.0
|
||||
kysely: ^0.28.5
|
||||
nanostores: ^1.0.1
|
||||
|
||||
'@better-auth/telemetry@1.4.18':
|
||||
resolution: {integrity: sha512-e5rDF8S4j3Um/0LIVATL2in9dL4lfO2fr2v1Wio4qTMRbfxqnUDTa+6SZtwdeJrbc4O+a3c+IyIpjG9Q/6GpfQ==}
|
||||
peerDependencies:
|
||||
'@better-auth/core': 1.4.18
|
||||
|
||||
'@better-auth/utils@0.3.0':
|
||||
resolution: {integrity: sha512-W+Adw6ZA6mgvnSnhOki270rwJ42t4XzSK6YWGF//BbVXL6SwCLWfyzBc1lN2m/4RM28KubdBKQ4X5VMoLRNPQw==}
|
||||
|
||||
'@better-fetch/fetch@1.1.21':
|
||||
resolution: {integrity: sha512-/ImESw0sskqlVR94jB+5+Pxjf+xBwDZF/N5+y2/q4EqD7IARUTSpPfIo8uf39SYpCxyOCtbyYpUrZ3F/k0zT4A==}
|
||||
|
||||
'@braintree/sanitize-url@7.1.2':
|
||||
resolution: {integrity: sha512-jigsZK+sMF/cuiB7sERuo9V7N9jx+dhmHHnQyDSVdpZwVutaBu7WvNYqMDLSgFgfB30n452TP3vjDAvFC973mA==}
|
||||
|
||||
@ -1116,14 +1092,6 @@ packages:
|
||||
cpu: [x64]
|
||||
os: [win32]
|
||||
|
||||
'@noble/ciphers@2.1.1':
|
||||
resolution: {integrity: sha512-bysYuiVfhxNJuldNXlFEitTVdNnYUc+XNJZd7Qm2a5j1vZHgY+fazadNFWFaMK/2vye0JVlxV3gHmC0WDfAOQw==}
|
||||
engines: {node: '>= 20.19.0'}
|
||||
|
||||
'@noble/hashes@2.0.1':
|
||||
resolution: {integrity: sha512-XlOlEbQcE9fmuXxrVTXCTlG2nlRXa9Rj3rr5Ue/+tX+nmkgbX720YHh0VR3hBF9xDvwnb8D2shVGOwNx+ulArw==}
|
||||
engines: {node: '>= 20.19.0'}
|
||||
|
||||
'@nodelib/fs.scandir@2.1.5':
|
||||
resolution: {integrity: sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==}
|
||||
engines: {node: '>= 8'}
|
||||
@ -2696,76 +2664,6 @@ packages:
|
||||
best-effort-json-parser@1.2.1:
|
||||
resolution: {integrity: sha512-UICSLibQdzS1f+PBsi3u2YE3SsdXcWicHUg3IMvfuaePS2AYnZJdJeKhGv5OM8/mqJwPt79aDrEJ1oa84tELvw==}
|
||||
|
||||
better-auth@1.4.18:
|
||||
resolution: {integrity: sha512-bnyifLWBPcYVltH3RhS7CM62MoelEqC6Q+GnZwfiDWNfepXoQZBjEvn4urcERC7NTKgKq5zNBM8rvPvRBa6xcg==}
|
||||
peerDependencies:
|
||||
'@lynx-js/react': '*'
|
||||
'@prisma/client': ^5.0.0 || ^6.0.0 || ^7.0.0
|
||||
'@sveltejs/kit': ^2.0.0
|
||||
'@tanstack/react-start': ^1.0.0
|
||||
'@tanstack/solid-start': ^1.0.0
|
||||
better-sqlite3: ^12.0.0
|
||||
drizzle-kit: '>=0.31.4'
|
||||
drizzle-orm: '>=0.41.0'
|
||||
mongodb: ^6.0.0 || ^7.0.0
|
||||
mysql2: ^3.0.0
|
||||
next: ^14.0.0 || ^15.0.0 || ^16.0.0
|
||||
pg: ^8.0.0
|
||||
prisma: ^5.0.0 || ^6.0.0 || ^7.0.0
|
||||
react: ^18.0.0 || ^19.0.0
|
||||
react-dom: ^18.0.0 || ^19.0.0
|
||||
solid-js: ^1.0.0
|
||||
svelte: ^4.0.0 || ^5.0.0
|
||||
vitest: ^2.0.0 || ^3.0.0 || ^4.0.0
|
||||
vue: ^3.0.0
|
||||
peerDependenciesMeta:
|
||||
'@lynx-js/react':
|
||||
optional: true
|
||||
'@prisma/client':
|
||||
optional: true
|
||||
'@sveltejs/kit':
|
||||
optional: true
|
||||
'@tanstack/react-start':
|
||||
optional: true
|
||||
'@tanstack/solid-start':
|
||||
optional: true
|
||||
better-sqlite3:
|
||||
optional: true
|
||||
drizzle-kit:
|
||||
optional: true
|
||||
drizzle-orm:
|
||||
optional: true
|
||||
mongodb:
|
||||
optional: true
|
||||
mysql2:
|
||||
optional: true
|
||||
next:
|
||||
optional: true
|
||||
pg:
|
||||
optional: true
|
||||
prisma:
|
||||
optional: true
|
||||
react:
|
||||
optional: true
|
||||
react-dom:
|
||||
optional: true
|
||||
solid-js:
|
||||
optional: true
|
||||
svelte:
|
||||
optional: true
|
||||
vitest:
|
||||
optional: true
|
||||
vue:
|
||||
optional: true
|
||||
|
||||
better-call@1.1.8:
|
||||
resolution: {integrity: sha512-XMQ2rs6FNXasGNfMjzbyroSwKwYbZ/T3IxruSS6U2MJRsSYh3wYtG3o6H00ZlKZ/C/UPOAD97tqgQJNsxyeTXw==}
|
||||
peerDependencies:
|
||||
zod: ^4.0.0
|
||||
peerDependenciesMeta:
|
||||
zod:
|
||||
optional: true
|
||||
|
||||
better-react-mathjax@2.3.0:
|
||||
resolution: {integrity: sha512-K0ceQC+jQmB+NLDogO5HCpqmYf18AU2FxDbLdduYgkHYWZApFggkHE4dIaXCV1NqeoscESYXXo1GSkY6fA295w==}
|
||||
peerDependencies:
|
||||
@ -3973,9 +3871,6 @@ packages:
|
||||
resolution: {integrity: sha512-ekilCSN1jwRvIbgeg/57YFh8qQDNbwDb9xT/qu2DAHbFFZUicIl4ygVaAvzveMhMVr3LnpSKTNnwt8PoOfmKhQ==}
|
||||
hasBin: true
|
||||
|
||||
jose@6.1.3:
|
||||
resolution: {integrity: sha512-0TpaTfihd4QMNwrz/ob2Bp7X04yuxJkjRGi4aKmOqwhov54i6u79oCv7T+C7lo70MKH6BesI3vscD1yb/yzKXQ==}
|
||||
|
||||
js-tiktoken@1.0.21:
|
||||
resolution: {integrity: sha512-biOj/6M5qdgx5TKjDnFT1ymSpM5tbd3ylwDtrQvFQSu0Z7bBYko2dF+W/aUkXUPuk6IVpRxk/3Q2sHOzGlS36g==}
|
||||
|
||||
@ -4026,10 +3921,6 @@ packages:
|
||||
knitwork@1.3.0:
|
||||
resolution: {integrity: sha512-4LqMNoONzR43B1W0ek0fhXMsDNW/zxa1NdFAVMY+k28pgZLovR4G3PB5MrpTxCy1QaZCqNoiaKPr5w5qZHfSNw==}
|
||||
|
||||
kysely@0.28.11:
|
||||
resolution: {integrity: sha512-zpGIFg0HuoC893rIjYX1BETkVWdDnzTzF5e0kWXJFg5lE0k1/LfNWBejrcnOFu8Q2Rfq/hTDTU7XLUM8QOrpzg==}
|
||||
engines: {node: '>=20.0.0'}
|
||||
|
||||
langium@3.3.1:
|
||||
resolution: {integrity: sha512-QJv/h939gDpvT+9SiLVlY7tZC3xB2qK57v0J04Sh9wpMb6MP1q8gB21L3WIo8T5P1MSMg3Ep14L7KkDCFG3y4w==}
|
||||
engines: {node: '>=16.0.0'}
|
||||
@ -4458,10 +4349,6 @@ packages:
|
||||
engines: {node: ^18 || >=20}
|
||||
hasBin: true
|
||||
|
||||
nanostores@1.1.0:
|
||||
resolution: {integrity: sha512-yJBmDJr18xy47dbNVlHcgdPrulSn1nhSE6Ns9vTG+Nx9VPT6iV1MD6aQFp/t52zpf82FhLLTXAXr30NuCnxvwA==}
|
||||
engines: {node: ^20.0.0 || >=22.0.0}
|
||||
|
||||
napi-postinstall@0.3.4:
|
||||
resolution: {integrity: sha512-PHI5f1O0EP5xJ9gQmFGMS6IZcrVvTjpXjz7Na41gTE7eE2hK11lg04CECCYEEjdc17EV4DO+fkGEtt7TpTaTiQ==}
|
||||
engines: {node: ^12.20.0 || ^14.18.0 || >=16.0.0}
|
||||
@ -5050,9 +4937,6 @@ packages:
|
||||
engines: {node: '>=18.0.0', npm: '>=8.0.0'}
|
||||
hasBin: true
|
||||
|
||||
rou3@0.7.12:
|
||||
resolution: {integrity: sha512-iFE4hLDuloSWcD7mjdCDhx2bKcIsYbtOTpfH5MHHLSKMOUyjqQXTeZVa289uuwEGEKFoE/BAPbhaU4B774nceg==}
|
||||
|
||||
roughjs@4.6.6:
|
||||
resolution: {integrity: sha512-ZUz/69+SYpFN/g/lUlo2FXcIjRkSu3nDarreVdGGndHEBJ6cXPdKguS8JGxwj5HA5xIbVKSmLgr5b3AWxtRfvQ==}
|
||||
|
||||
@ -5105,9 +4989,6 @@ packages:
|
||||
server-only@0.0.1:
|
||||
resolution: {integrity: sha512-qepMx2JxAa5jjfzxG79yPPq+8BuFToHd1hm7kI+Z4zAq1ftQiP7HcxMhDDItrbtwVeLg/cY2JnKnrcFkmiswNA==}
|
||||
|
||||
set-cookie-parser@2.7.2:
|
||||
resolution: {integrity: sha512-oeM1lpU/UvhTxw+g3cIfxXHyJRc/uidd3yK1P242gzHds0udQBYzs3y8j4gCCW+ZJ7ad0yctld8RYO+bdurlvw==}
|
||||
|
||||
set-function-length@1.2.2:
|
||||
resolution: {integrity: sha512-pgRc4hJ4/sNjWCSS9AmnS40x3bNMDTknHgL5UaMBTMyJnU90EgWh1Rz+MC9eFu4BuN/UwZjKQuY/1v3rM7HMfg==}
|
||||
engines: {node: '>= 0.4'}
|
||||
@ -5802,27 +5683,6 @@ snapshots:
|
||||
'@babel/helper-string-parser': 7.27.1
|
||||
'@babel/helper-validator-identifier': 7.28.5
|
||||
|
||||
'@better-auth/core@1.4.18(@better-auth/utils@0.3.0)(@better-fetch/fetch@1.1.21)(better-call@1.1.8(zod@3.25.76))(jose@6.1.3)(kysely@0.28.11)(nanostores@1.1.0)':
|
||||
dependencies:
|
||||
'@better-auth/utils': 0.3.0
|
||||
'@better-fetch/fetch': 1.1.21
|
||||
'@standard-schema/spec': 1.1.0
|
||||
better-call: 1.1.8(zod@4.3.6)
|
||||
jose: 6.1.3
|
||||
kysely: 0.28.11
|
||||
nanostores: 1.1.0
|
||||
zod: 4.3.6
|
||||
|
||||
'@better-auth/telemetry@1.4.18(@better-auth/core@1.4.18(@better-auth/utils@0.3.0)(@better-fetch/fetch@1.1.21)(better-call@1.1.8(zod@3.25.76))(jose@6.1.3)(kysely@0.28.11)(nanostores@1.1.0))':
|
||||
dependencies:
|
||||
'@better-auth/core': 1.4.18(@better-auth/utils@0.3.0)(@better-fetch/fetch@1.1.21)(better-call@1.1.8(zod@3.25.76))(jose@6.1.3)(kysely@0.28.11)(nanostores@1.1.0)
|
||||
'@better-auth/utils': 0.3.0
|
||||
'@better-fetch/fetch': 1.1.21
|
||||
|
||||
'@better-auth/utils@0.3.0': {}
|
||||
|
||||
'@better-fetch/fetch@1.1.21': {}
|
||||
|
||||
'@braintree/sanitize-url@7.1.2': {}
|
||||
|
||||
'@cfworker/json-schema@4.1.1': {}
|
||||
@ -6671,10 +6531,6 @@ snapshots:
|
||||
'@next/swc-win32-x64-msvc@16.1.7':
|
||||
optional: true
|
||||
|
||||
'@noble/ciphers@2.1.1': {}
|
||||
|
||||
'@noble/hashes@2.0.1': {}
|
||||
|
||||
'@nodelib/fs.scandir@2.1.5':
|
||||
dependencies:
|
||||
'@nodelib/fs.stat': 2.0.5
|
||||
@ -8242,35 +8098,6 @@ snapshots:
|
||||
|
||||
best-effort-json-parser@1.2.1: {}
|
||||
|
||||
better-auth@1.4.18(next@16.1.7(@opentelemetry/api@1.9.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(vue@3.5.28(typescript@5.9.3)):
|
||||
dependencies:
|
||||
'@better-auth/core': 1.4.18(@better-auth/utils@0.3.0)(@better-fetch/fetch@1.1.21)(better-call@1.1.8(zod@3.25.76))(jose@6.1.3)(kysely@0.28.11)(nanostores@1.1.0)
|
||||
'@better-auth/telemetry': 1.4.18(@better-auth/core@1.4.18(@better-auth/utils@0.3.0)(@better-fetch/fetch@1.1.21)(better-call@1.1.8(zod@3.25.76))(jose@6.1.3)(kysely@0.28.11)(nanostores@1.1.0))
|
||||
'@better-auth/utils': 0.3.0
|
||||
'@better-fetch/fetch': 1.1.21
|
||||
'@noble/ciphers': 2.1.1
|
||||
'@noble/hashes': 2.0.1
|
||||
better-call: 1.1.8(zod@4.3.6)
|
||||
defu: 6.1.4
|
||||
jose: 6.1.3
|
||||
kysely: 0.28.11
|
||||
nanostores: 1.1.0
|
||||
zod: 4.3.6
|
||||
optionalDependencies:
|
||||
next: 16.1.7(@opentelemetry/api@1.9.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
|
||||
react: 19.2.4
|
||||
react-dom: 19.2.4(react@19.2.4)
|
||||
vue: 3.5.28(typescript@5.9.3)
|
||||
|
||||
better-call@1.1.8(zod@4.3.6):
|
||||
dependencies:
|
||||
'@better-auth/utils': 0.3.0
|
||||
'@better-fetch/fetch': 1.1.21
|
||||
rou3: 0.7.12
|
||||
set-cookie-parser: 2.7.2
|
||||
optionalDependencies:
|
||||
zod: 4.3.6
|
||||
|
||||
better-react-mathjax@2.3.0(react@19.2.4):
|
||||
dependencies:
|
||||
mathjax-full: 3.2.2
|
||||
@ -9786,8 +9613,6 @@ snapshots:
|
||||
|
||||
jiti@2.6.1: {}
|
||||
|
||||
jose@6.1.3: {}
|
||||
|
||||
js-tiktoken@1.0.21:
|
||||
dependencies:
|
||||
base64-js: 1.5.1
|
||||
@ -9833,8 +9658,6 @@ snapshots:
|
||||
|
||||
knitwork@1.3.0: {}
|
||||
|
||||
kysely@0.28.11: {}
|
||||
|
||||
langium@3.3.1:
|
||||
dependencies:
|
||||
chevrotain: 11.0.3
|
||||
@ -10529,8 +10352,6 @@ snapshots:
|
||||
|
||||
nanoid@5.1.6: {}
|
||||
|
||||
nanostores@1.1.0: {}
|
||||
|
||||
napi-postinstall@0.3.4: {}
|
||||
|
||||
natural-compare@1.4.0: {}
|
||||
@ -11305,8 +11126,6 @@ snapshots:
|
||||
'@rollup/rollup-win32-x64-msvc': 4.60.0
|
||||
fsevents: 2.3.3
|
||||
|
||||
rou3@0.7.12: {}
|
||||
|
||||
roughjs@4.6.6:
|
||||
dependencies:
|
||||
hachure-fill: 0.5.2
|
||||
@ -11373,8 +11192,6 @@ snapshots:
|
||||
|
||||
server-only@0.0.1: {}
|
||||
|
||||
set-cookie-parser@2.7.2: {}
|
||||
|
||||
set-function-length@1.2.2:
|
||||
dependencies:
|
||||
define-data-property: 1.1.4
|
||||
|
||||
45
frontend/src/app/(auth)/layout.tsx
Normal file
45
frontend/src/app/(auth)/layout.tsx
Normal file
@ -0,0 +1,45 @@
|
||||
import Link from "next/link";
|
||||
import { redirect } from "next/navigation";
|
||||
import { type ReactNode } from "react";
|
||||
|
||||
import { AuthProvider } from "@/core/auth/AuthProvider";
|
||||
import { getServerSideUser } from "@/core/auth/server";
|
||||
import { assertNever } from "@/core/auth/types";
|
||||
|
||||
export const dynamic = "force-dynamic";
|
||||
|
||||
export default async function AuthLayout({
|
||||
children,
|
||||
}: {
|
||||
children: ReactNode;
|
||||
}) {
|
||||
const result = await getServerSideUser();
|
||||
|
||||
switch (result.tag) {
|
||||
case "authenticated":
|
||||
redirect("/workspace");
|
||||
case "needs_setup":
|
||||
// Allow access to setup page
|
||||
return <AuthProvider initialUser={result.user}>{children}</AuthProvider>;
|
||||
case "unauthenticated":
|
||||
return <AuthProvider initialUser={null}>{children}</AuthProvider>;
|
||||
case "gateway_unavailable":
|
||||
return (
|
||||
<div className="flex h-screen flex-col items-center justify-center gap-4">
|
||||
<p className="text-muted-foreground">
|
||||
Service temporarily unavailable.
|
||||
</p>
|
||||
<Link
|
||||
href="/login"
|
||||
className="bg-primary text-primary-foreground hover:bg-primary/90 rounded-md px-4 py-2 text-sm"
|
||||
>
|
||||
Retry
|
||||
</Link>
|
||||
</div>
|
||||
);
|
||||
case "config_error":
|
||||
throw new Error(result.message);
|
||||
default:
|
||||
assertNever(result);
|
||||
}
|
||||
}
|
||||
183
frontend/src/app/(auth)/login/page.tsx
Normal file
183
frontend/src/app/(auth)/login/page.tsx
Normal file
@ -0,0 +1,183 @@
|
||||
"use client";
|
||||
|
||||
import Link from "next/link";
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
import { useEffect, useState } from "react";
|
||||
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { useAuth } from "@/core/auth/AuthProvider";
|
||||
import { parseAuthError } from "@/core/auth/types";
|
||||
|
||||
/**
|
||||
* Validate next parameter
|
||||
* Prevent open redirect attacks
|
||||
* Per RFC-001: Only allow relative paths starting with /
|
||||
*/
|
||||
function validateNextParam(next: string | null): string | null {
|
||||
if (!next) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Need start with / (relative path)
|
||||
if (!next.startsWith("/")) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Disallow protocol-relative URLs
|
||||
if (
|
||||
next.startsWith("//") ||
|
||||
next.startsWith("http://") ||
|
||||
next.startsWith("https://")
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Disallow URLs with different protocols (e.g., javascript:, data:, etc)
|
||||
if (next.includes(":") && !next.startsWith("/")) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Valid relative path
|
||||
return next;
|
||||
}
|
||||
|
||||
export default function LoginPage() {
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
const { isAuthenticated } = useAuth();
|
||||
|
||||
const [email, setEmail] = useState("");
|
||||
const [password, setPassword] = useState("");
|
||||
const [isLogin, setIsLogin] = useState(true);
|
||||
const [error, setError] = useState("");
|
||||
const [loading, setLoading] = useState(false);
|
||||
|
||||
// Get next parameter for validated redirect
|
||||
const nextParam = searchParams.get("next");
|
||||
const redirectPath = validateNextParam(nextParam) ?? "/workspace";
|
||||
|
||||
// Redirect if already authenticated (client-side, post-login)
|
||||
useEffect(() => {
|
||||
if (isAuthenticated) {
|
||||
router.push(redirectPath);
|
||||
}
|
||||
}, [isAuthenticated, redirectPath, router]);
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
setError("");
|
||||
setLoading(true);
|
||||
|
||||
try {
|
||||
const endpoint = isLogin
|
||||
? "/api/v1/auth/login/local"
|
||||
: "/api/v1/auth/register";
|
||||
const body = isLogin
|
||||
? `username=${encodeURIComponent(email)}&password=${encodeURIComponent(password)}`
|
||||
: JSON.stringify({ email, password });
|
||||
|
||||
const headers: HeadersInit = isLogin
|
||||
? { "Content-Type": "application/x-www-form-urlencoded" }
|
||||
: { "Content-Type": "application/json" };
|
||||
|
||||
const res = await fetch(endpoint, {
|
||||
method: "POST",
|
||||
headers,
|
||||
body,
|
||||
credentials: "include", // Important: include HttpOnly cookie
|
||||
});
|
||||
|
||||
if (!res.ok) {
|
||||
const data = await res.json();
|
||||
const authError = parseAuthError(data);
|
||||
setError(authError.message);
|
||||
return;
|
||||
}
|
||||
|
||||
// Both login and register set a cookie — redirect to workspace
|
||||
router.push(redirectPath);
|
||||
} catch (_err) {
|
||||
setError("Network error. Please try again.");
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex min-h-screen items-center justify-center bg-[#0a0a0a]">
|
||||
<div className="border-border/20 w-full max-w-md space-y-6 rounded-lg border bg-black/50 p-8 backdrop-blur-sm">
|
||||
<div className="text-center">
|
||||
<h1 className="font-serif text-3xl">DeerFlow</h1>
|
||||
<p className="text-muted-foreground mt-2">
|
||||
{isLogin ? "Sign in to your account" : "Create a new account"}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<form onSubmit={handleSubmit} className="space-y-4">
|
||||
<div>
|
||||
<label htmlFor="email" className="text-sm font-medium">
|
||||
Email
|
||||
</label>
|
||||
<Input
|
||||
id="email"
|
||||
type="email"
|
||||
value={email}
|
||||
onChange={(e) => setEmail(e.target.value)}
|
||||
placeholder="you@example.com"
|
||||
required
|
||||
className="mt-1 bg-white text-black"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label htmlFor="password" className="text-sm font-medium">
|
||||
Password
|
||||
</label>
|
||||
<Input
|
||||
id="password"
|
||||
type="password"
|
||||
value={password}
|
||||
onChange={(e) => setPassword(e.target.value)}
|
||||
placeholder="•••••••"
|
||||
required
|
||||
minLength={isLogin ? 6 : 8}
|
||||
className="mt-1 bg-white text-black"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{error && <p className="text-sm text-red-500">{error}</p>}
|
||||
|
||||
<Button type="submit" className="w-full" disabled={loading}>
|
||||
{loading
|
||||
? "Please wait..."
|
||||
: isLogin
|
||||
? "Sign In"
|
||||
: "Create Account"}
|
||||
</Button>
|
||||
</form>
|
||||
|
||||
<div className="text-center text-sm">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setIsLogin(!isLogin);
|
||||
setError("");
|
||||
}}
|
||||
className="text-blue-500 hover:underline"
|
||||
>
|
||||
{isLogin
|
||||
? "Don't have an account? Sign up"
|
||||
: "Already have an account? Sign in"}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div className="text-muted-foreground text-center text-xs">
|
||||
<Link href="/" className="hover:underline">
|
||||
← Back to home
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
115
frontend/src/app/(auth)/setup/page.tsx
Normal file
115
frontend/src/app/(auth)/setup/page.tsx
Normal file
@ -0,0 +1,115 @@
|
||||
"use client";
|
||||
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useState } from "react";
|
||||
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { getCsrfHeaders } from "@/core/api/fetcher";
|
||||
import { parseAuthError } from "@/core/auth/types";
|
||||
|
||||
export default function SetupPage() {
|
||||
const router = useRouter();
|
||||
const [email, setEmail] = useState("");
|
||||
const [newPassword, setNewPassword] = useState("");
|
||||
const [confirmPassword, setConfirmPassword] = useState("");
|
||||
const [currentPassword, setCurrentPassword] = useState("");
|
||||
const [error, setError] = useState("");
|
||||
const [loading, setLoading] = useState(false);
|
||||
|
||||
const handleSetup = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
setError("");
|
||||
|
||||
if (newPassword !== confirmPassword) {
|
||||
setError("Passwords do not match");
|
||||
return;
|
||||
}
|
||||
if (newPassword.length < 8) {
|
||||
setError("Password must be at least 8 characters");
|
||||
return;
|
||||
}
|
||||
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await fetch("/api/v1/auth/change-password", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
...getCsrfHeaders(),
|
||||
},
|
||||
credentials: "include",
|
||||
body: JSON.stringify({
|
||||
current_password: currentPassword,
|
||||
new_password: newPassword,
|
||||
new_email: email || undefined,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!res.ok) {
|
||||
const data = await res.json();
|
||||
const authError = parseAuthError(data);
|
||||
setError(authError.message);
|
||||
return;
|
||||
}
|
||||
|
||||
router.push("/workspace");
|
||||
} catch {
|
||||
setError("Network error. Please try again.");
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex min-h-screen items-center justify-center">
|
||||
<div className="w-full max-w-sm space-y-6 p-6">
|
||||
<div className="text-center">
|
||||
<h1 className="font-serif text-3xl">DeerFlow</h1>
|
||||
<p className="text-muted-foreground mt-2">
|
||||
Complete admin account setup
|
||||
</p>
|
||||
<p className="text-muted-foreground mt-1 text-xs">
|
||||
Set your real email and a new password.
|
||||
</p>
|
||||
</div>
|
||||
<form onSubmit={handleSetup} className="space-y-4">
|
||||
<Input
|
||||
type="email"
|
||||
placeholder="Your email"
|
||||
value={email}
|
||||
onChange={(e) => setEmail(e.target.value)}
|
||||
required
|
||||
/>
|
||||
<Input
|
||||
type="password"
|
||||
placeholder="Current password (from console log)"
|
||||
value={currentPassword}
|
||||
onChange={(e) => setCurrentPassword(e.target.value)}
|
||||
required
|
||||
/>
|
||||
<Input
|
||||
type="password"
|
||||
placeholder="New password"
|
||||
value={newPassword}
|
||||
onChange={(e) => setNewPassword(e.target.value)}
|
||||
required
|
||||
minLength={8}
|
||||
/>
|
||||
<Input
|
||||
type="password"
|
||||
placeholder="Confirm new password"
|
||||
value={confirmPassword}
|
||||
onChange={(e) => setConfirmPassword(e.target.value)}
|
||||
required
|
||||
minLength={8}
|
||||
/>
|
||||
{error && <p className="text-sm text-red-500">{error}</p>}
|
||||
<Button type="submit" className="w-full" disabled={loading}>
|
||||
{loading ? "Setting up..." : "Complete Setup"}
|
||||
</Button>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@ -1,5 +0,0 @@
|
||||
import { toNextJsHandler } from "better-auth/next-js";
|
||||
|
||||
import { auth } from "@/server/better-auth";
|
||||
|
||||
export const { GET, POST } = toNextJsHandler(auth.handler);
|
||||
@ -1,47 +1,58 @@
|
||||
"use client";
|
||||
import Link from "next/link";
|
||||
import { redirect } from "next/navigation";
|
||||
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { useCallback, useEffect, useLayoutEffect, useState } from "react";
|
||||
import { Toaster } from "sonner";
|
||||
import { AuthProvider } from "@/core/auth/AuthProvider";
|
||||
import { getServerSideUser } from "@/core/auth/server";
|
||||
import { assertNever } from "@/core/auth/types";
|
||||
|
||||
import { SidebarInset, SidebarProvider } from "@/components/ui/sidebar";
|
||||
import { CommandPalette } from "@/components/workspace/command-palette";
|
||||
import { WorkspaceSidebar } from "@/components/workspace/workspace-sidebar";
|
||||
import { getLocalSettings, useLocalSettings } from "@/core/settings";
|
||||
import { WorkspaceContent } from "./workspace-content";
|
||||
|
||||
const queryClient = new QueryClient();
|
||||
export const dynamic = "force-dynamic";
|
||||
|
||||
export default function WorkspaceLayout({
|
||||
export default async function WorkspaceLayout({
|
||||
children,
|
||||
}: Readonly<{ children: React.ReactNode }>) {
|
||||
const [settings, setSettings] = useLocalSettings();
|
||||
const [open, setOpen] = useState(false); // SSR default: open (matches server render)
|
||||
useLayoutEffect(() => {
|
||||
// Runs synchronously before first paint on the client — no visual flash
|
||||
setOpen(!getLocalSettings().layout.sidebar_collapsed);
|
||||
}, []);
|
||||
useEffect(() => {
|
||||
setOpen(!settings.layout.sidebar_collapsed);
|
||||
}, [settings.layout.sidebar_collapsed]);
|
||||
const handleOpenChange = useCallback(
|
||||
(open: boolean) => {
|
||||
setOpen(open);
|
||||
setSettings("layout", { sidebar_collapsed: !open });
|
||||
},
|
||||
[setSettings],
|
||||
);
|
||||
return (
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<SidebarProvider
|
||||
className="h-screen"
|
||||
open={open}
|
||||
onOpenChange={handleOpenChange}
|
||||
>
|
||||
<WorkspaceSidebar />
|
||||
<SidebarInset className="min-w-0">{children}</SidebarInset>
|
||||
</SidebarProvider>
|
||||
<CommandPalette />
|
||||
<Toaster position="top-center" />
|
||||
</QueryClientProvider>
|
||||
);
|
||||
const result = await getServerSideUser();
|
||||
|
||||
switch (result.tag) {
|
||||
case "authenticated":
|
||||
return (
|
||||
<AuthProvider initialUser={result.user}>
|
||||
<WorkspaceContent>{children}</WorkspaceContent>
|
||||
</AuthProvider>
|
||||
);
|
||||
case "needs_setup":
|
||||
redirect("/setup");
|
||||
case "unauthenticated":
|
||||
redirect("/login");
|
||||
case "gateway_unavailable":
|
||||
return (
|
||||
<div className="flex h-screen flex-col items-center justify-center gap-4">
|
||||
<p className="text-muted-foreground">
|
||||
Service temporarily unavailable.
|
||||
</p>
|
||||
<p className="text-muted-foreground text-xs">
|
||||
The backend may be restarting. Please wait a moment and try again.
|
||||
</p>
|
||||
<div className="flex gap-3">
|
||||
<Link
|
||||
href="/workspace"
|
||||
className="bg-primary text-primary-foreground hover:bg-primary/90 rounded-md px-4 py-2 text-sm"
|
||||
>
|
||||
Retry
|
||||
</Link>
|
||||
<Link
|
||||
href="/api/v1/auth/logout"
|
||||
className="text-muted-foreground hover:bg-muted rounded-md border px-4 py-2 text-sm"
|
||||
>
|
||||
Logout & Reset
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
case "config_error":
|
||||
throw new Error(result.message);
|
||||
default:
|
||||
assertNever(result);
|
||||
}
|
||||
}
|
||||
|
||||
50
frontend/src/app/workspace/workspace-content.tsx
Normal file
50
frontend/src/app/workspace/workspace-content.tsx
Normal file
@ -0,0 +1,50 @@
|
||||
"use client";
|
||||
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { useCallback, useEffect, useLayoutEffect, useState } from "react";
|
||||
import { Toaster } from "sonner";
|
||||
|
||||
import { SidebarInset, SidebarProvider } from "@/components/ui/sidebar";
|
||||
import { CommandPalette } from "@/components/workspace/command-palette";
|
||||
import { WorkspaceSidebar } from "@/components/workspace/workspace-sidebar";
|
||||
import { getLocalSettings, useLocalSettings } from "@/core/settings";
|
||||
|
||||
export function WorkspaceContent({
|
||||
children,
|
||||
}: Readonly<{ children: React.ReactNode }>) {
|
||||
const [queryClient] = useState(() => new QueryClient());
|
||||
const [settings, setSettings] = useLocalSettings();
|
||||
const [open, setOpen] = useState(false); // SSR default: open (matches server render)
|
||||
|
||||
useLayoutEffect(() => {
|
||||
// Runs synchronously before first paint on the client — no visual flash
|
||||
setOpen(!getLocalSettings().layout.sidebar_collapsed);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
setOpen(!settings.layout.sidebar_collapsed);
|
||||
}, [settings.layout.sidebar_collapsed]);
|
||||
|
||||
const handleOpenChange = useCallback(
|
||||
(open: boolean) => {
|
||||
setOpen(open);
|
||||
setSettings("layout", { sidebar_collapsed: !open });
|
||||
},
|
||||
[setSettings],
|
||||
);
|
||||
|
||||
return (
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<SidebarProvider
|
||||
className="h-screen"
|
||||
open={open}
|
||||
onOpenChange={handleOpenChange}
|
||||
>
|
||||
<WorkspaceSidebar />
|
||||
<SidebarInset className="min-w-0">{children}</SidebarInset>
|
||||
</SidebarProvider>
|
||||
<CommandPalette />
|
||||
<Toaster position="top-center" />
|
||||
</QueryClientProvider>
|
||||
);
|
||||
}
|
||||
39
frontend/src/core/api/fetcher.ts
Normal file
39
frontend/src/core/api/fetcher.ts
Normal file
@ -0,0 +1,39 @@
|
||||
import { buildLoginUrl } from "@/core/auth/types";
|
||||
|
||||
/**
|
||||
* Fetch with credentials. Automatically redirects to login on 401.
|
||||
*/
|
||||
export async function fetchWithAuth(
|
||||
input: RequestInfo | string,
|
||||
init?: RequestInit,
|
||||
): Promise<Response> {
|
||||
const url = typeof input === "string" ? input : input.url;
|
||||
const res = await fetch(url, {
|
||||
...init,
|
||||
credentials: "include",
|
||||
});
|
||||
|
||||
if (res.status === 401) {
|
||||
window.location.href = buildLoginUrl(window.location.pathname);
|
||||
throw new Error("Unauthorized");
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build headers for CSRF-protected requests
|
||||
* Per RFC-001: Double Submit Cookie pattern
|
||||
*/
|
||||
export function getCsrfHeaders(): HeadersInit {
|
||||
const token = getCsrfToken();
|
||||
return token ? { "X-CSRF-Token": token } : {};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get CSRF token from cookie
|
||||
*/
|
||||
function getCsrfToken(): string | null {
|
||||
const match = /csrf_token=([^;]+)/.exec(document.cookie);
|
||||
return match?.[1] ?? null;
|
||||
}
|
||||
165
frontend/src/core/auth/AuthProvider.tsx
Normal file
165
frontend/src/core/auth/AuthProvider.tsx
Normal file
@ -0,0 +1,165 @@
|
||||
"use client";
|
||||
|
||||
import { useRouter, usePathname } from "next/navigation";
|
||||
import React, {
|
||||
createContext,
|
||||
useContext,
|
||||
useState,
|
||||
useCallback,
|
||||
useEffect,
|
||||
type ReactNode,
|
||||
} from "react";
|
||||
|
||||
import { type User, buildLoginUrl } from "./types";
|
||||
|
||||
// Re-export for consumers
|
||||
export type { User };
|
||||
|
||||
/**
|
||||
* Authentication context provided to consuming components
|
||||
*/
|
||||
interface AuthContextType {
|
||||
user: User | null;
|
||||
isAuthenticated: boolean;
|
||||
isLoading: boolean;
|
||||
logout: () => Promise<void>;
|
||||
refreshUser: () => Promise<void>;
|
||||
}
|
||||
|
||||
const AuthContext = createContext<AuthContextType | undefined>(undefined);
|
||||
|
||||
interface AuthProviderProps {
|
||||
children: ReactNode;
|
||||
initialUser: User | null;
|
||||
}
|
||||
|
||||
/**
|
||||
* AuthProvider - Unified authentication context for the application
|
||||
*
|
||||
* Per RFC-001:
|
||||
* - Only holds display information (user), never JWT or tokens
|
||||
* - initialUser comes from server-side guard, avoiding client flicker
|
||||
* - Provides logout and refresh capabilities
|
||||
*/
|
||||
export function AuthProvider({ children, initialUser }: AuthProviderProps) {
|
||||
const [user, setUser] = useState<User | null>(initialUser);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const router = useRouter();
|
||||
const pathname = usePathname();
|
||||
|
||||
const isAuthenticated = user !== null;
|
||||
|
||||
/**
|
||||
* Fetch current user from FastAPI
|
||||
* Used when initialUser might be stale (e.g., after tab was inactive)
|
||||
*/
|
||||
const refreshUser = useCallback(async () => {
|
||||
try {
|
||||
setIsLoading(true);
|
||||
const res = await fetch("/api/v1/auth/me", {
|
||||
credentials: "include",
|
||||
});
|
||||
|
||||
if (res.ok) {
|
||||
const data = await res.json();
|
||||
setUser(data);
|
||||
} else if (res.status === 401) {
|
||||
// Session expired or invalid
|
||||
setUser(null);
|
||||
// Redirect to login if on a protected route
|
||||
if (pathname?.startsWith("/workspace")) {
|
||||
router.push(buildLoginUrl(pathname));
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
console.error("Failed to refresh user:", err);
|
||||
setUser(null);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}, [pathname, router]);
|
||||
|
||||
/**
|
||||
* Logout - call FastAPI logout endpoint and clear local state
|
||||
* Per RFC-001: Immediately clear local state, don't wait for server confirmation
|
||||
*/
|
||||
const logout = useCallback(async () => {
|
||||
// Immediately clear local state to prevent UI flicker
|
||||
setUser(null);
|
||||
|
||||
try {
|
||||
await fetch("/api/v1/auth/logout", {
|
||||
method: "POST",
|
||||
credentials: "include",
|
||||
});
|
||||
} catch (err) {
|
||||
console.error("Logout request failed:", err);
|
||||
// Still redirect even if logout request fails
|
||||
}
|
||||
|
||||
// Redirect to home page
|
||||
router.push("/");
|
||||
}, [router]);
|
||||
|
||||
/**
|
||||
* Handle visibility change - refresh user when tab becomes visible again.
|
||||
* Throttled to at most once per 60 s to avoid spamming the backend on rapid tab switches.
|
||||
*/
|
||||
const lastCheckRef = React.useRef(0);
|
||||
|
||||
useEffect(() => {
|
||||
const handleVisibilityChange = () => {
|
||||
if (document.visibilityState !== "visible" || user === null) return;
|
||||
const now = Date.now();
|
||||
if (now - lastCheckRef.current < 60_000) return;
|
||||
lastCheckRef.current = now;
|
||||
void refreshUser();
|
||||
};
|
||||
|
||||
document.addEventListener("visibilitychange", handleVisibilityChange);
|
||||
return () => {
|
||||
document.removeEventListener("visibilitychange", handleVisibilityChange);
|
||||
};
|
||||
}, [user, refreshUser]);
|
||||
|
||||
const value: AuthContextType = {
|
||||
user,
|
||||
isAuthenticated,
|
||||
isLoading,
|
||||
logout,
|
||||
refreshUser,
|
||||
};
|
||||
|
||||
return <AuthContext.Provider value={value}>{children}</AuthContext.Provider>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook to access authentication context
|
||||
* Throws if used outside AuthProvider - this is intentional for proper usage
|
||||
*/
|
||||
export function useAuth(): AuthContextType {
|
||||
const context = useContext(AuthContext);
|
||||
if (context === undefined) {
|
||||
throw new Error("useAuth must be used within an AuthProvider");
|
||||
}
|
||||
return context;
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook to require authentication - redirects to login if not authenticated
|
||||
* Useful for client-side checks in addition to server-side guards
|
||||
*/
|
||||
export function useRequireAuth(): AuthContextType {
|
||||
const auth = useAuth();
|
||||
const router = useRouter();
|
||||
const pathname = usePathname();
|
||||
|
||||
useEffect(() => {
|
||||
// Only redirect if we're sure user is not authenticated (not just loading)
|
||||
if (!auth.isLoading && !auth.isAuthenticated) {
|
||||
router.push(buildLoginUrl(pathname || "/workspace"));
|
||||
}
|
||||
}, [auth.isAuthenticated, auth.isLoading, router, pathname]);
|
||||
|
||||
return auth;
|
||||
}
|
||||
34
frontend/src/core/auth/gateway-config.ts
Normal file
34
frontend/src/core/auth/gateway-config.ts
Normal file
@ -0,0 +1,34 @@
|
||||
import { z } from "zod";
|
||||
|
||||
const gatewayConfigSchema = z.object({
|
||||
internalGatewayUrl: z.string().url(),
|
||||
trustedOrigins: z.array(z.string()).min(1),
|
||||
});
|
||||
|
||||
export type GatewayConfig = z.infer<typeof gatewayConfigSchema>;
|
||||
|
||||
let _cached: GatewayConfig | null = null;
|
||||
|
||||
export function getGatewayConfig(): GatewayConfig {
|
||||
if (_cached) return _cached;
|
||||
|
||||
const isDev = process.env.NODE_ENV === "development";
|
||||
|
||||
const rawUrl = process.env.DEER_FLOW_INTERNAL_GATEWAY_BASE_URL?.trim();
|
||||
const internalGatewayUrl =
|
||||
rawUrl?.replace(/\/+$/, "") ??
|
||||
(isDev ? "http://localhost:8001" : undefined);
|
||||
|
||||
const rawOrigins = process.env.DEER_FLOW_TRUSTED_ORIGINS?.trim();
|
||||
const trustedOrigins = rawOrigins
|
||||
? rawOrigins
|
||||
.split(",")
|
||||
.map((s) => s.trim())
|
||||
.filter(Boolean)
|
||||
: isDev
|
||||
? ["http://localhost:3000"]
|
||||
: undefined;
|
||||
|
||||
_cached = gatewayConfigSchema.parse({ internalGatewayUrl, trustedOrigins });
|
||||
return _cached;
|
||||
}
|
||||
55
frontend/src/core/auth/proxy-policy.ts
Normal file
55
frontend/src/core/auth/proxy-policy.ts
Normal file
@ -0,0 +1,55 @@
|
||||
export interface ProxyPolicy {
|
||||
/** Allowed upstream path prefixes */
|
||||
readonly allowedPaths: readonly string[];
|
||||
/** Request headers to strip before forwarding */
|
||||
readonly strippedRequestHeaders: ReadonlySet<string>;
|
||||
/** Response headers to strip before returning */
|
||||
readonly strippedResponseHeaders: ReadonlySet<string>;
|
||||
/** Credential mode: which cookie to forward */
|
||||
readonly credential: { readonly type: "cookie"; readonly name: string };
|
||||
/** Timeout in ms */
|
||||
readonly timeoutMs: number;
|
||||
/** CSRF: required for non-GET/HEAD */
|
||||
readonly csrf: boolean;
|
||||
}
|
||||
|
||||
export const LANGGRAPH_COMPAT_POLICY: ProxyPolicy = {
|
||||
allowedPaths: [
|
||||
"threads",
|
||||
"runs",
|
||||
"assistants",
|
||||
"store",
|
||||
"models",
|
||||
"mcp",
|
||||
"skills",
|
||||
"memory",
|
||||
],
|
||||
strippedRequestHeaders: new Set([
|
||||
"host",
|
||||
"connection",
|
||||
"keep-alive",
|
||||
"transfer-encoding",
|
||||
"te",
|
||||
"trailer",
|
||||
"upgrade",
|
||||
"authorization",
|
||||
"x-api-key",
|
||||
"origin",
|
||||
"referer",
|
||||
"proxy-authorization",
|
||||
"proxy-authenticate",
|
||||
]),
|
||||
strippedResponseHeaders: new Set([
|
||||
"connection",
|
||||
"keep-alive",
|
||||
"transfer-encoding",
|
||||
"te",
|
||||
"trailer",
|
||||
"upgrade",
|
||||
"content-length",
|
||||
"set-cookie",
|
||||
]),
|
||||
credential: { type: "cookie", name: "access_token" },
|
||||
timeoutMs: 120_000,
|
||||
csrf: true,
|
||||
};
|
||||
57
frontend/src/core/auth/server.ts
Normal file
57
frontend/src/core/auth/server.ts
Normal file
@ -0,0 +1,57 @@
|
||||
import { cookies } from "next/headers";
|
||||
|
||||
import { getGatewayConfig } from "./gateway-config";
|
||||
import { type AuthResult, userSchema } from "./types";
|
||||
|
||||
const SSR_AUTH_TIMEOUT_MS = 5_000;
|
||||
|
||||
/**
|
||||
* Fetch the authenticated user from the gateway using the request's cookies.
|
||||
* Returns a tagged AuthResult — callers use exhaustive switch, no try/catch.
|
||||
*/
|
||||
export async function getServerSideUser(): Promise<AuthResult> {
|
||||
const cookieStore = await cookies();
|
||||
const sessionCookie = cookieStore.get("access_token");
|
||||
|
||||
let internalGatewayUrl: string;
|
||||
try {
|
||||
internalGatewayUrl = getGatewayConfig().internalGatewayUrl;
|
||||
} catch (err) {
|
||||
return { tag: "config_error", message: String(err) };
|
||||
}
|
||||
|
||||
if (!sessionCookie) return { tag: "unauthenticated" };
|
||||
|
||||
const controller = new AbortController();
|
||||
const timeout = setTimeout(() => controller.abort(), SSR_AUTH_TIMEOUT_MS);
|
||||
|
||||
try {
|
||||
const res = await fetch(`${internalGatewayUrl}/api/v1/auth/me`, {
|
||||
headers: { Cookie: `access_token=${sessionCookie.value}` },
|
||||
cache: "no-store",
|
||||
signal: controller.signal,
|
||||
});
|
||||
clearTimeout(timeout); // Clear immediately — covers all response branches
|
||||
|
||||
if (res.ok) {
|
||||
const parsed = userSchema.safeParse(await res.json());
|
||||
if (!parsed.success) {
|
||||
console.error("[SSR auth] Malformed /auth/me response:", parsed.error);
|
||||
return { tag: "gateway_unavailable" };
|
||||
}
|
||||
if (parsed.data.needs_setup) {
|
||||
return { tag: "needs_setup", user: parsed.data };
|
||||
}
|
||||
return { tag: "authenticated", user: parsed.data };
|
||||
}
|
||||
if (res.status === 401 || res.status === 403) {
|
||||
return { tag: "unauthenticated" };
|
||||
}
|
||||
console.error(`[SSR auth] /api/v1/auth/me responded ${res.status}`);
|
||||
return { tag: "gateway_unavailable" };
|
||||
} catch (err) {
|
||||
clearTimeout(timeout);
|
||||
console.error("[SSR auth] Failed to reach gateway:", err);
|
||||
return { tag: "gateway_unavailable" };
|
||||
}
|
||||
}
|
||||
72
frontend/src/core/auth/types.ts
Normal file
72
frontend/src/core/auth/types.ts
Normal file
@ -0,0 +1,72 @@
|
||||
import { z } from "zod";
|
||||
|
||||
// ── User schema (single source of truth) ──────────────────────────
|
||||
|
||||
export const userSchema = z.object({
|
||||
id: z.string(),
|
||||
email: z.string().email(),
|
||||
system_role: z.enum(["admin", "user"]),
|
||||
needs_setup: z.boolean().optional().default(false),
|
||||
});
|
||||
|
||||
export type User = z.infer<typeof userSchema>;
|
||||
|
||||
// ── SSR auth result (tagged union) ────────────────────────────────
|
||||
|
||||
export type AuthResult =
|
||||
| { tag: "authenticated"; user: User }
|
||||
| { tag: "needs_setup"; user: User }
|
||||
| { tag: "unauthenticated" }
|
||||
| { tag: "gateway_unavailable" }
|
||||
| { tag: "config_error"; message: string };
|
||||
|
||||
export function assertNever(x: never): never {
|
||||
throw new Error(`Unexpected auth result: ${JSON.stringify(x)}`);
|
||||
}
|
||||
|
||||
export function buildLoginUrl(returnPath: string): string {
|
||||
return `/login?next=${encodeURIComponent(returnPath)}`;
|
||||
}
|
||||
|
||||
// ── Backend error response parsing ────────────────────────────────
|
||||
|
||||
const AUTH_ERROR_CODES = [
|
||||
"invalid_credentials",
|
||||
"token_expired",
|
||||
"token_invalid",
|
||||
"user_not_found",
|
||||
"email_already_exists",
|
||||
"provider_not_found",
|
||||
"not_authenticated",
|
||||
] as const;
|
||||
|
||||
export type AuthErrorCode = (typeof AUTH_ERROR_CODES)[number];
|
||||
|
||||
export interface AuthErrorResponse {
|
||||
code: AuthErrorCode;
|
||||
message: string;
|
||||
}
|
||||
|
||||
const authErrorSchema = z.object({
|
||||
code: z.enum(AUTH_ERROR_CODES),
|
||||
message: z.string(),
|
||||
});
|
||||
|
||||
export function parseAuthError(data: unknown): AuthErrorResponse {
|
||||
// Try top-level {code, message} first
|
||||
const parsed = authErrorSchema.safeParse(data);
|
||||
if (parsed.success) return parsed.data;
|
||||
|
||||
// Unwrap FastAPI's {detail: {code, message}} envelope
|
||||
if (typeof data === "object" && data !== null && "detail" in data) {
|
||||
const detail = (data as Record<string, unknown>).detail;
|
||||
const nested = authErrorSchema.safeParse(detail);
|
||||
if (nested.success) return nested.data;
|
||||
// Legacy string-detail responses
|
||||
if (typeof detail === "string") {
|
||||
return { code: "invalid_credentials", message: detail };
|
||||
}
|
||||
}
|
||||
|
||||
return { code: "invalid_credentials", message: "Authentication failed" };
|
||||
}
|
||||
@ -7,12 +7,6 @@ export const env = createEnv({
|
||||
* isn't built with invalid env vars.
|
||||
*/
|
||||
server: {
|
||||
BETTER_AUTH_SECRET:
|
||||
process.env.NODE_ENV === "production"
|
||||
? z.string()
|
||||
: z.string().optional(),
|
||||
BETTER_AUTH_GITHUB_CLIENT_ID: z.string().optional(),
|
||||
BETTER_AUTH_GITHUB_CLIENT_SECRET: z.string().optional(),
|
||||
GITHUB_OAUTH_TOKEN: z.string().optional(),
|
||||
NODE_ENV: z
|
||||
.enum(["development", "test", "production"])
|
||||
@ -35,10 +29,6 @@ export const env = createEnv({
|
||||
* middlewares) or client-side so we need to destruct manually.
|
||||
*/
|
||||
runtimeEnv: {
|
||||
BETTER_AUTH_SECRET: process.env.BETTER_AUTH_SECRET,
|
||||
BETTER_AUTH_GITHUB_CLIENT_ID: process.env.BETTER_AUTH_GITHUB_CLIENT_ID,
|
||||
BETTER_AUTH_GITHUB_CLIENT_SECRET:
|
||||
process.env.BETTER_AUTH_GITHUB_CLIENT_SECRET,
|
||||
NODE_ENV: process.env.NODE_ENV,
|
||||
|
||||
NEXT_PUBLIC_BACKEND_BASE_URL: process.env.NEXT_PUBLIC_BACKEND_BASE_URL,
|
||||
|
||||
@ -1,5 +0,0 @@
|
||||
import { createAuthClient } from "better-auth/react";
|
||||
|
||||
export const authClient = createAuthClient();
|
||||
|
||||
export type Session = typeof authClient.$Infer.Session;
|
||||
@ -1,9 +0,0 @@
|
||||
import { betterAuth } from "better-auth";
|
||||
|
||||
export const auth = betterAuth({
|
||||
emailAndPassword: {
|
||||
enabled: true,
|
||||
},
|
||||
});
|
||||
|
||||
export type Session = typeof auth.$Infer.Session;
|
||||
@ -1 +0,0 @@
|
||||
export { auth } from "./config";
|
||||
@ -1,8 +0,0 @@
|
||||
import { headers } from "next/headers";
|
||||
import { cache } from "react";
|
||||
|
||||
import { auth } from ".";
|
||||
|
||||
export const getSession = cache(async () =>
|
||||
auth.api.getSession({ headers: await headers() }),
|
||||
);
|
||||
Loading…
x
Reference in New Issue
Block a user