From 2b33bfd78f8076c968d1f0560a685341559949ca Mon Sep 17 00:00:00 2001 From: greatmengqi Date: Wed, 8 Apr 2026 13:32:39 +0800 Subject: [PATCH] security(auth): wire @require_permission(owner_check=True) on isolation routes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Apply the require_permission decorator to all 28 routes that take a {thread_id} path parameter. Combined with the strict middleware (previous commit), this gives the double-layer protection that AUTH_TEST_PLAN test 7.5.9 documents: Layer 1 (AuthMiddleware): cookie + JWT validation, rejects junk cookies and stamps request.state.user Layer 2 (@require_permission with owner_check=True): per-resource ownership verification via ThreadMetaStore.check_access — returns 404 if a different user owns the thread The decorator's owner_check branch is rewritten to use the SQL thread_meta_repo (the 2.0-rc persistence layer) instead of the LangGraph store path that PR #1728 used (_store_get / get_store in routers/threads.py). The inject_record convenience is dropped — no caller in 2.0 needs the LangGraph blob, and the SQL repo has a different shape. Routes decorated (28 total): - threads.py: delete, patch, get, get-state, post-state, post-history - thread_runs.py: post-runs, post-runs-stream, post-runs-wait, list_runs, get_run, cancel_run, join_run, stream_existing_run, list_thread_messages, list_run_messages, list_run_events, thread_token_usage - feedback.py: create, list, stats, delete - uploads.py: upload (added Request param), list, delete - artifacts.py: get_artifact - suggestions.py: generate (renamed body parameter to avoid conflict with FastAPI Request) Test fixes: - test_suggestions_router.py: bypass the decorator via __wrapped__ (the unit tests cover parsing logic, not auth — no point spinning up a thread_meta_repo just to test JSON unwrapping) - test_auth_middleware.py 4 fake-cookie tests: already updated in the previous commit (745bf432) Tests: 293 passed (auth + persistence + isolation + suggestions) Lint: clean --- backend/app/gateway/authz.py | 40 +++++++++++++--------- backend/app/gateway/routers/artifacts.py | 2 ++ backend/app/gateway/routers/feedback.py | 5 +++ backend/app/gateway/routers/suggestions.py | 14 ++++---- backend/app/gateway/routers/thread_runs.py | 13 +++++++ backend/app/gateway/routers/threads.py | 7 ++++ backend/app/gateway/routers/uploads.py | 11 ++++-- backend/tests/test_suggestions_router.py | 16 ++++++--- 8 files changed, 79 insertions(+), 29 deletions(-) diff --git a/backend/app/gateway/authz.py b/backend/app/gateway/authz.py index 015f747c3..1750bad3b 100644 --- a/backend/app/gateway/authz.py +++ b/backend/app/gateway/authz.py @@ -231,28 +231,36 @@ def require_permission( detail=f"Permission denied: {resource}:{action}", ) - # Owner check for thread-specific resources + # Owner check for thread-specific resources. + # + # 2.0-rc moved thread metadata into the SQL persistence layer + # (``threads_meta`` table). We verify ownership via + # ``ThreadMetaStore.check_access`` instead of the LangGraph + # store path that the original PR #1728 used. ``check_access`` + # returns True for missing rows (untracked legacy thread) and + # for rows whose ``owner_id`` is NULL (shared / pre-auth data), + # so this is a strict-deny check rather than strict-allow: + # only an *existing* row with a *different* owner_id triggers + # 404. + # + # ``inject_record`` is no longer supported — it was a + # convenience for handlers that wanted the LangGraph store + # blob; the SQL repo would need a different shape and no + # caller in 2.0 needs it. if owner_check: thread_id = kwargs.get("thread_id") if thread_id is None: raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter") - # Get thread and verify ownership - from app.gateway.routers.threads import _store_get, get_store + from app.gateway.deps import get_thread_meta_repo - store = get_store(request) - if store is not None: - record = await _store_get(store, thread_id) - if record: - owner_id = record.get("metadata", {}).get(owner_filter_key) - if owner_id and owner_id != str(auth.user.id): - raise HTTPException( - status_code=404, - detail=f"Thread {thread_id} not found", - ) - # Inject record if requested - if inject_record: - kwargs["thread_record"] = record + thread_meta_repo = get_thread_meta_repo(request) + allowed = await thread_meta_repo.check_access(thread_id, str(auth.user.id)) + if not allowed: + raise HTTPException( + status_code=404, + detail=f"Thread {thread_id} not found", + ) return await func(*args, **kwargs) diff --git a/backend/app/gateway/routers/artifacts.py b/backend/app/gateway/routers/artifacts.py index a58fd5c0b..78ea5fa00 100644 --- a/backend/app/gateway/routers/artifacts.py +++ b/backend/app/gateway/routers/artifacts.py @@ -7,6 +7,7 @@ from urllib.parse import quote from fastapi import APIRouter, HTTPException, Request from fastapi.responses import FileResponse, PlainTextResponse, Response +from app.gateway.authz import require_permission from app.gateway.path_utils import resolve_thread_virtual_path logger = logging.getLogger(__name__) @@ -81,6 +82,7 @@ def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> byte summary="Get Artifact File", description="Retrieve an artifact file generated by the AI agent. Text and binary files can be viewed inline, while active web content is always downloaded.", ) +@require_permission("threads", "read", owner_check=True) async def get_artifact(thread_id: str, path: str, request: Request, download: bool = False) -> Response: """Get an artifact file by its path. diff --git a/backend/app/gateway/routers/feedback.py b/backend/app/gateway/routers/feedback.py index 579b29a9e..449c87c97 100644 --- a/backend/app/gateway/routers/feedback.py +++ b/backend/app/gateway/routers/feedback.py @@ -12,6 +12,7 @@ from typing import Any from fastapi import APIRouter, HTTPException, Request from pydantic import BaseModel, Field +from app.gateway.authz import require_permission from app.gateway.deps import get_current_user, get_feedback_repo, get_run_store logger = logging.getLogger(__name__) @@ -53,6 +54,7 @@ class FeedbackStatsResponse(BaseModel): @router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse) +@require_permission("threads", "write", owner_check=True) async def create_feedback( thread_id: str, run_id: str, @@ -85,6 +87,7 @@ async def create_feedback( @router.get("/{thread_id}/runs/{run_id}/feedback", response_model=list[FeedbackResponse]) +@require_permission("threads", "read", owner_check=True) async def list_feedback( thread_id: str, run_id: str, @@ -96,6 +99,7 @@ async def list_feedback( @router.get("/{thread_id}/runs/{run_id}/feedback/stats", response_model=FeedbackStatsResponse) +@require_permission("threads", "read", owner_check=True) async def feedback_stats( thread_id: str, run_id: str, @@ -107,6 +111,7 @@ async def feedback_stats( @router.delete("/{thread_id}/runs/{run_id}/feedback/{feedback_id}") +@require_permission("threads", "delete", owner_check=True) async def delete_feedback( thread_id: str, run_id: str, diff --git a/backend/app/gateway/routers/suggestions.py b/backend/app/gateway/routers/suggestions.py index ac54e674d..0da5e4322 100644 --- a/backend/app/gateway/routers/suggestions.py +++ b/backend/app/gateway/routers/suggestions.py @@ -1,10 +1,11 @@ import json import logging -from fastapi import APIRouter +from fastapi import APIRouter, Request from langchain_core.messages import HumanMessage, SystemMessage from pydantic import BaseModel, Field +from app.gateway.authz import require_permission from deerflow.models import create_chat_model logger = logging.getLogger(__name__) @@ -98,12 +99,13 @@ def _format_conversation(messages: list[SuggestionMessage]) -> str: summary="Generate Follow-up Questions", description="Generate short follow-up questions a user might ask next, based on recent conversation context.", ) -async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> SuggestionsResponse: - if not request.messages: +@require_permission("threads", "read", owner_check=True) +async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request: Request) -> SuggestionsResponse: + if not body.messages: return SuggestionsResponse(suggestions=[]) - n = request.n - conversation = _format_conversation(request.messages) + n = body.n + conversation = _format_conversation(body.messages) if not conversation: return SuggestionsResponse(suggestions=[]) @@ -120,7 +122,7 @@ async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> S user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions" try: - model = create_chat_model(name=request.model_name, thinking_enabled=False) + model = create_chat_model(name=body.model_name, thinking_enabled=False) response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)]) raw = _extract_response_text(response.content) suggestions = _parse_json_string_list(raw) or [] diff --git a/backend/app/gateway/routers/thread_runs.py b/backend/app/gateway/routers/thread_runs.py index a26bdfbf3..d139dafd1 100644 --- a/backend/app/gateway/routers/thread_runs.py +++ b/backend/app/gateway/routers/thread_runs.py @@ -19,6 +19,7 @@ from fastapi import APIRouter, HTTPException, Query, Request from fastapi.responses import Response, StreamingResponse from pydantic import BaseModel, Field +from app.gateway.authz import require_permission from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge from app.gateway.services import sse_consumer, start_run from deerflow.runtime import RunRecord, serialize_channel_values @@ -93,6 +94,7 @@ def _record_to_response(record: RunRecord) -> RunResponse: @router.post("/{thread_id}/runs", response_model=RunResponse) +@require_permission("runs", "create", owner_check=True) async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -> RunResponse: """Create a background run (returns immediately).""" record = await start_run(body, thread_id, request) @@ -100,6 +102,7 @@ async def create_run(thread_id: str, body: RunCreateRequest, request: Request) - @router.post("/{thread_id}/runs/stream") +@require_permission("runs", "create", owner_check=True) async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -> StreamingResponse: """Create a run and stream events via SSE. @@ -127,6 +130,7 @@ async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) - @router.post("/{thread_id}/runs/wait", response_model=dict) +@require_permission("runs", "create", owner_check=True) async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict: """Create a run and block until it completes, returning the final state.""" record = await start_run(body, thread_id, request) @@ -152,6 +156,7 @@ async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> @router.get("/{thread_id}/runs", response_model=list[RunResponse]) +@require_permission("runs", "read", owner_check=True) async def list_runs(thread_id: str, request: Request) -> list[RunResponse]: """List all runs for a thread.""" run_mgr = get_run_manager(request) @@ -160,6 +165,7 @@ async def list_runs(thread_id: str, request: Request) -> list[RunResponse]: @router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse) +@require_permission("runs", "read", owner_check=True) async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse: """Get details of a specific run.""" run_mgr = get_run_manager(request) @@ -170,6 +176,7 @@ async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse: @router.post("/{thread_id}/runs/{run_id}/cancel") +@require_permission("runs", "cancel", owner_check=True) async def cancel_run( thread_id: str, run_id: str, @@ -207,6 +214,7 @@ async def cancel_run( @router.get("/{thread_id}/runs/{run_id}/join") +@require_permission("runs", "read", owner_check=True) async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse: """Join an existing run's SSE stream.""" bridge = get_stream_bridge(request) @@ -227,6 +235,7 @@ async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingRe @router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None) +@require_permission("runs", "read", owner_check=True) async def stream_existing_run( thread_id: str, run_id: str, @@ -274,6 +283,7 @@ async def stream_existing_run( @router.get("/{thread_id}/messages") +@require_permission("runs", "read", owner_check=True) async def list_thread_messages( thread_id: str, request: Request, @@ -287,6 +297,7 @@ async def list_thread_messages( @router.get("/{thread_id}/runs/{run_id}/messages") +@require_permission("runs", "read", owner_check=True) async def list_run_messages(thread_id: str, run_id: str, request: Request) -> list[dict]: """Return displayable messages for a specific run.""" event_store = get_run_event_store(request) @@ -294,6 +305,7 @@ async def list_run_messages(thread_id: str, run_id: str, request: Request) -> li @router.get("/{thread_id}/runs/{run_id}/events") +@require_permission("runs", "read", owner_check=True) async def list_run_events( thread_id: str, run_id: str, @@ -308,6 +320,7 @@ async def list_run_events( @router.get("/{thread_id}/token-usage") +@require_permission("threads", "read", owner_check=True) async def thread_token_usage(thread_id: str, request: Request) -> dict: """Thread-level token usage aggregation.""" run_store = get_run_store(request) diff --git a/backend/app/gateway/routers/threads.py b/backend/app/gateway/routers/threads.py index 487bf5413..431108af0 100644 --- a/backend/app/gateway/routers/threads.py +++ b/backend/app/gateway/routers/threads.py @@ -20,6 +20,7 @@ from typing import Any from fastapi import APIRouter, HTTPException, Request from pydantic import BaseModel, Field +from app.gateway.authz import require_permission from app.gateway.deps import get_checkpointer from app.gateway.utils import sanitize_log_param from deerflow.config.paths import Paths, get_paths @@ -165,6 +166,7 @@ def _derive_thread_status(checkpoint_tuple) -> str: @router.delete("/{thread_id}", response_model=ThreadDeleteResponse) +@require_permission("threads", "delete", owner_check=True) async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteResponse: """Delete local persisted filesystem data for a thread. @@ -293,6 +295,7 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th @router.patch("/{thread_id}", response_model=ThreadResponse) +@require_permission("threads", "write", owner_check=True) async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse: """Merge metadata into a thread record.""" from app.gateway.deps import get_thread_meta_repo @@ -320,6 +323,7 @@ async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Reques @router.get("/{thread_id}", response_model=ThreadResponse) +@require_permission("threads", "read", owner_check=True) async def get_thread(thread_id: str, request: Request) -> ThreadResponse: """Get thread info. @@ -376,6 +380,7 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse: @router.get("/{thread_id}/state", response_model=ThreadStateResponse) +@require_permission("threads", "read", owner_check=True) async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse: """Get the latest state snapshot for a thread. @@ -425,6 +430,7 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo @router.post("/{thread_id}/state", response_model=ThreadStateResponse) +@require_permission("threads", "write", owner_check=True) async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse: """Update thread state (e.g. for human-in-the-loop resume or title rename). @@ -514,6 +520,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re @router.post("/{thread_id}/history", response_model=list[HistoryEntry]) +@require_permission("threads", "read", owner_check=True) async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]: """Get checkpoint history for a thread. diff --git a/backend/app/gateway/routers/uploads.py b/backend/app/gateway/routers/uploads.py index 9d9d0c9bc..22f9c89b8 100644 --- a/backend/app/gateway/routers/uploads.py +++ b/backend/app/gateway/routers/uploads.py @@ -4,9 +4,10 @@ import logging import os import stat -from fastapi import APIRouter, File, HTTPException, UploadFile +from fastapi import APIRouter, File, HTTPException, Request, UploadFile from pydantic import BaseModel +from app.gateway.authz import require_permission from deerflow.config.paths import get_paths from deerflow.sandbox.sandbox_provider import get_sandbox_provider from deerflow.uploads.manager import ( @@ -54,8 +55,10 @@ def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None: @router.post("", response_model=UploadResponse) +@require_permission("threads", "write", owner_check=True) async def upload_files( thread_id: str, + request: Request, files: list[UploadFile] = File(...), ) -> UploadResponse: """Upload multiple files to a thread's uploads directory.""" @@ -133,7 +136,8 @@ async def upload_files( @router.get("/list", response_model=dict) -async def list_uploaded_files(thread_id: str) -> dict: +@require_permission("threads", "read", owner_check=True) +async def list_uploaded_files(thread_id: str, request: Request) -> dict: """List all files in a thread's uploads directory.""" try: uploads_dir = get_uploads_dir(thread_id) @@ -151,7 +155,8 @@ async def list_uploaded_files(thread_id: str) -> dict: @router.delete("/{filename}") -async def delete_uploaded_file(thread_id: str, filename: str) -> dict: +@require_permission("threads", "delete", owner_check=True) +async def delete_uploaded_file(thread_id: str, filename: str, request: Request) -> dict: """Delete a file from a thread's uploads directory.""" try: uploads_dir = get_uploads_dir(thread_id) diff --git a/backend/tests/test_suggestions_router.py b/backend/tests/test_suggestions_router.py index fee07dd44..ea9eb41df 100644 --- a/backend/tests/test_suggestions_router.py +++ b/backend/tests/test_suggestions_router.py @@ -46,7 +46,9 @@ def test_generate_suggestions_parses_and_limits(monkeypatch): fake_model.ainvoke = AsyncMock(return_value=MagicMock(content='```json\n["Q1", "Q2", "Q3", "Q4"]\n```')) monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model) - result = asyncio.run(suggestions.generate_suggestions("t1", req)) + # Bypass the require_permission decorator (which needs request + + # thread_meta_repo) — these tests cover the parsing logic. + result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None)) assert result.suggestions == ["Q1", "Q2", "Q3"] @@ -64,7 +66,9 @@ def test_generate_suggestions_parses_list_block_content(monkeypatch): fake_model.ainvoke = AsyncMock(return_value=MagicMock(content=[{"type": "text", "text": '```json\n["Q1", "Q2"]\n```'}])) monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model) - result = asyncio.run(suggestions.generate_suggestions("t1", req)) + # Bypass the require_permission decorator (which needs request + + # thread_meta_repo) — these tests cover the parsing logic. + result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None)) assert result.suggestions == ["Q1", "Q2"] @@ -82,7 +86,9 @@ def test_generate_suggestions_parses_output_text_block_content(monkeypatch): fake_model.ainvoke = AsyncMock(return_value=MagicMock(content=[{"type": "output_text", "text": '```json\n["Q1", "Q2"]\n```'}])) monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model) - result = asyncio.run(suggestions.generate_suggestions("t1", req)) + # Bypass the require_permission decorator (which needs request + + # thread_meta_repo) — these tests cover the parsing logic. + result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None)) assert result.suggestions == ["Q1", "Q2"] @@ -97,6 +103,8 @@ def test_generate_suggestions_returns_empty_on_model_error(monkeypatch): fake_model.ainvoke = AsyncMock(side_effect=RuntimeError("boom")) monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model) - result = asyncio.run(suggestions.generate_suggestions("t1", req)) + # Bypass the require_permission decorator (which needs request + + # thread_meta_repo) — these tests cover the parsing logic. + result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None)) assert result.suggestions == []