From 8939ccaed2f03ba831b16c83605d09b80f36e632 Mon Sep 17 00:00:00 2001 From: KiteEater <145987840+Kiteeater@users.noreply.github.com> Date: Fri, 1 May 2026 20:19:30 +0800 Subject: [PATCH] fix(uploads): enforce streaming upload limits in gateway (#2589) * fix: enforce gateway upload limits * fix: acquire sandbox before upload writes * Fix upload limit config wiring * Sanitize upload size error filenames * test: call upload routes unwrapped * fix: guard upload limits endpoint --------- Co-authored-by: Willem Jiang --- backend/app/gateway/routers/uploads.py | 124 +++++++++++- backend/docs/FILE_UPLOAD.md | 22 ++- backend/tests/test_uploads_router.py | 256 ++++++++++++++++++++++++- config.example.yaml | 5 + 4 files changed, 393 insertions(+), 14 deletions(-) diff --git a/backend/app/gateway/routers/uploads.py b/backend/app/gateway/routers/uploads.py index 0ecc2266a..604a6e154 100644 --- a/backend/app/gateway/routers/uploads.py +++ b/backend/app/gateway/routers/uploads.py @@ -30,6 +30,11 @@ logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/threads/{thread_id}/uploads", tags=["uploads"]) +UPLOAD_CHUNK_SIZE = 8192 +DEFAULT_MAX_FILES = 10 +DEFAULT_MAX_FILE_SIZE = 50 * 1024 * 1024 +DEFAULT_MAX_TOTAL_SIZE = 100 * 1024 * 1024 + class UploadResponse(BaseModel): """Response model for file upload.""" @@ -39,6 +44,14 @@ class UploadResponse(BaseModel): message: str +class UploadLimits(BaseModel): + """Application-level upload limits exposed to clients.""" + + max_files: int + max_file_size: int + max_total_size: int + + def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None: """Ensure uploaded files remain writable when mounted into non-local sandboxes. @@ -69,6 +82,62 @@ def _get_uploads_config_value(app_config: AppConfig, key: str, default: object) return getattr(uploads_cfg, key, default) +def _get_upload_limit(app_config: AppConfig, key: str, default: int, *, legacy_key: str | None = None) -> int: + try: + value = _get_uploads_config_value(app_config, key, None) + if value is None and legacy_key is not None: + value = _get_uploads_config_value(app_config, legacy_key, None) + if value is None: + value = default + limit = int(value) + if limit <= 0: + raise ValueError + return limit + except Exception: + logger.warning("Invalid uploads.%s value; falling back to %d", key, default) + return default + + +def _get_upload_limits(app_config: AppConfig) -> UploadLimits: + return UploadLimits( + max_files=_get_upload_limit(app_config, "max_files", DEFAULT_MAX_FILES, legacy_key="max_file_count"), + max_file_size=_get_upload_limit(app_config, "max_file_size", DEFAULT_MAX_FILE_SIZE, legacy_key="max_single_file_size"), + max_total_size=_get_upload_limit(app_config, "max_total_size", DEFAULT_MAX_TOTAL_SIZE), + ) + + +def _cleanup_uploaded_paths(paths: list[os.PathLike[str] | str]) -> None: + for path in reversed(paths): + try: + os.unlink(path) + except FileNotFoundError: + pass + except Exception: + logger.warning("Failed to clean up upload path after rejected request: %s", path, exc_info=True) + + +async def _write_upload_file_streaming( + file: UploadFile, + file_path: os.PathLike[str] | str, + *, + display_filename: str, + max_single_file_size: int, + max_total_size: int, + total_size: int, +) -> tuple[int, int]: + file_size = 0 + with open(file_path, "wb") as output: + while chunk := await file.read(UPLOAD_CHUNK_SIZE): + file_size += len(chunk) + total_size += len(chunk) + if file_size > max_single_file_size: + raise HTTPException(status_code=413, detail=f"File too large: {display_filename}") + if total_size > max_total_size: + raise HTTPException(status_code=413, detail="Total upload size too large") + output.write(chunk) + return file_size, total_size + + def _auto_convert_documents_enabled(app_config: AppConfig) -> bool: """Return whether automatic host-side document conversion is enabled. @@ -96,12 +165,19 @@ async def upload_files( if not files: raise HTTPException(status_code=400, detail="No files provided") + limits = _get_upload_limits(config) + if len(files) > limits.max_files: + raise HTTPException(status_code=413, detail=f"Too many files: maximum is {limits.max_files}") + try: uploads_dir = ensure_uploads_dir(thread_id) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id()) uploaded_files = [] + written_paths = [] + sandbox_sync_targets = [] + total_size = 0 sandbox_provider = get_sandbox_provider() sync_to_sandbox = not _uses_thread_data_mounts(sandbox_provider) @@ -109,6 +185,8 @@ async def upload_files( if sync_to_sandbox: sandbox_id = sandbox_provider.acquire(thread_id) sandbox = sandbox_provider.get(sandbox_id) + if sandbox is None: + raise HTTPException(status_code=500, detail="Failed to acquire sandbox") auto_convert_documents = _auto_convert_documents_enabled(config) for file in files: @@ -122,35 +200,41 @@ async def upload_files( continue try: - content = await file.read() file_path = uploads_dir / safe_filename - file_path.write_bytes(content) + written_paths.append(file_path) + file_size, total_size = await _write_upload_file_streaming( + file, + file_path, + display_filename=safe_filename, + max_single_file_size=limits.max_file_size, + max_total_size=limits.max_total_size, + total_size=total_size, + ) virtual_path = upload_virtual_path(safe_filename) - if sync_to_sandbox and sandbox is not None: - _make_file_sandbox_writable(file_path) - sandbox.update_file(virtual_path, content) + if sync_to_sandbox: + sandbox_sync_targets.append((file_path, virtual_path)) file_info = { "filename": safe_filename, - "size": str(len(content)), + "size": str(file_size), "path": str(sandbox_uploads / safe_filename), "virtual_path": virtual_path, "artifact_url": upload_artifact_url(thread_id, safe_filename), } - logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}") + logger.info(f"Saved file: {safe_filename} ({file_size} bytes) to {file_info['path']}") file_ext = file_path.suffix.lower() if auto_convert_documents and file_ext in CONVERTIBLE_EXTENSIONS: md_path = await convert_file_to_markdown(file_path) if md_path: + written_paths.append(md_path) md_virtual_path = upload_virtual_path(md_path.name) - if sync_to_sandbox and sandbox is not None: - _make_file_sandbox_writable(md_path) - sandbox.update_file(md_virtual_path, md_path.read_bytes()) + if sync_to_sandbox: + sandbox_sync_targets.append((md_path, md_virtual_path)) file_info["markdown_file"] = md_path.name file_info["markdown_path"] = str(sandbox_uploads / md_path.name) @@ -159,10 +243,19 @@ async def upload_files( uploaded_files.append(file_info) + except HTTPException as e: + _cleanup_uploaded_paths(written_paths) + raise e except Exception as e: logger.error(f"Failed to upload {file.filename}: {e}") + _cleanup_uploaded_paths(written_paths) raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}") + if sync_to_sandbox: + for file_path, virtual_path in sandbox_sync_targets: + _make_file_sandbox_writable(file_path) + sandbox.update_file(virtual_path, file_path.read_bytes()) + return UploadResponse( success=True, files=uploaded_files, @@ -170,6 +263,17 @@ async def upload_files( ) +@router.get("/limits", response_model=UploadLimits) +@require_permission("threads", "read", owner_check=True) +async def get_upload_limits( + thread_id: str, + request: Request, + config: AppConfig = Depends(get_config), +) -> UploadLimits: + """Return upload limits used by the gateway for this thread.""" + return _get_upload_limits(config) + + @router.get("/list", response_model=dict) @require_permission("threads", "read", owner_check=True) async def list_uploaded_files(thread_id: str, request: Request) -> dict: diff --git a/backend/docs/FILE_UPLOAD.md b/backend/docs/FILE_UPLOAD.md index bf6962e4f..2b15b27e7 100644 --- a/backend/docs/FILE_UPLOAD.md +++ b/backend/docs/FILE_UPLOAD.md @@ -22,6 +22,8 @@ POST /api/threads/{thread_id}/uploads **请求体:** `multipart/form-data` - `files`: 一个或多个文件 +网关会在应用层限制上传规模,默认最多 10 个文件、单文件 50 MiB、单次请求总计 100 MiB。可通过 `config.yaml` 的 `uploads.max_files`、`uploads.max_file_size`、`uploads.max_total_size` 调整;前端会读取同一组限制并在选择文件时提示,超过限制时后端返回 `413 Payload Too Large`。 + **响应:** ```json { @@ -48,7 +50,23 @@ POST /api/threads/{thread_id}/uploads - `virtual_path`: Agent 在沙箱中使用的虚拟路径 - `artifact_url`: 前端通过 HTTP 访问文件的 URL -### 2. 列出已上传文件 +### 2. 查询上传限制 +``` +GET /api/threads/{thread_id}/uploads/limits +``` + +返回网关当前生效的上传限制,供前端在用户选择文件前提示和拦截。 + +**响应:** +```json +{ + "max_files": 10, + "max_file_size": 52428800, + "max_total_size": 104857600 +} +``` + +### 3. 列出已上传文件 ``` GET /api/threads/{thread_id}/uploads/list ``` @@ -71,7 +89,7 @@ GET /api/threads/{thread_id}/uploads/list } ``` -### 3. 删除文件 +### 4. 删除文件 ``` DELETE /api/threads/{thread_id}/uploads/{filename} ``` diff --git a/backend/tests/test_uploads_router.py b/backend/tests/test_uploads_router.py index 7f9b442d0..a2538ec40 100644 --- a/backend/tests/test_uploads_router.py +++ b/backend/tests/test_uploads_router.py @@ -5,12 +5,35 @@ from pathlib import Path from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch -from _router_auth_helpers import call_unwrapped -from fastapi import UploadFile +import pytest +from _router_auth_helpers import call_unwrapped, make_authed_test_app +from fastapi import HTTPException, UploadFile +from fastapi.testclient import TestClient from app.gateway.routers import uploads +class ChunkedUpload: + def __init__(self, filename: str, chunks: list[bytes]): + self.filename = filename + self._chunks = list(chunks) + self.read_calls: list[int | None] = [] + + async def read(self, size: int | None = None) -> bytes: + self.read_calls.append(size) + if size is None: + raise AssertionError("upload must be read with an explicit chunk size") + if not self._chunks: + return b"" + return self._chunks.pop(0) + + +def _mounted_provider() -> MagicMock: + provider = MagicMock() + provider.uses_thread_data_mounts = True + return provider + + def test_upload_files_writes_thread_storage_and_skips_local_sandbox_sync(tmp_path): thread_uploads_dir = tmp_path / "uploads" thread_uploads_dir.mkdir(parents=True) @@ -178,6 +201,173 @@ def test_upload_files_does_not_adjust_permissions_for_local_sandbox(tmp_path): make_writable.assert_not_called() +def test_upload_files_acquires_non_local_sandbox_before_writing(tmp_path): + thread_uploads_dir = tmp_path / "uploads" + thread_uploads_dir.mkdir(parents=True) + + provider = MagicMock() + provider.uses_thread_data_mounts = False + sandbox = MagicMock() + provider.get.return_value = sandbox + + def acquire_before_writes(thread_id: str) -> str: + assert list(thread_uploads_dir.iterdir()) == [] + return "aio-1" + + provider.acquire.side_effect = acquire_before_writes + + with ( + patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir), + patch.object(uploads, "get_sandbox_provider", return_value=provider), + ): + file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads")) + result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file], config=SimpleNamespace())) + + assert result.success is True + provider.acquire.assert_called_once_with("thread-aio") + sandbox.update_file.assert_called_once_with("/mnt/user-data/uploads/notes.txt", b"hello uploads") + + +def test_upload_files_fails_before_writing_when_non_local_sandbox_unavailable(tmp_path): + thread_uploads_dir = tmp_path / "uploads" + thread_uploads_dir.mkdir(parents=True) + + provider = MagicMock() + provider.uses_thread_data_mounts = False + provider.acquire.side_effect = RuntimeError("sandbox unavailable") + file = ChunkedUpload("notes.txt", [b"hello uploads"]) + + with ( + patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir), + patch.object(uploads, "get_sandbox_provider", return_value=provider), + ): + with pytest.raises(RuntimeError, match="sandbox unavailable"): + asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file], config=SimpleNamespace())) + + assert list(thread_uploads_dir.iterdir()) == [] + assert file.read_calls == [] + provider.get.assert_not_called() + + +def test_upload_files_rejects_too_many_files_before_writing(tmp_path): + thread_uploads_dir = tmp_path / "uploads" + thread_uploads_dir.mkdir(parents=True) + + with ( + patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir), + patch.object(uploads, "get_sandbox_provider", return_value=_mounted_provider()), + patch.object(uploads, "_get_upload_limits", return_value=uploads.UploadLimits(max_files=1, max_file_size=10, max_total_size=20)), + ): + files = [ + ChunkedUpload("one.txt", [b"one"]), + ChunkedUpload("two.txt", [b"two"]), + ] + with pytest.raises(HTTPException) as exc_info: + asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=files, config=SimpleNamespace())) + + assert exc_info.value.status_code == 413 + assert list(thread_uploads_dir.iterdir()) == [] + assert files[0].read_calls == [] + assert files[1].read_calls == [] + + +def test_upload_files_rejects_oversized_single_file_and_removes_partial_file(tmp_path): + thread_uploads_dir = tmp_path / "uploads" + thread_uploads_dir.mkdir(parents=True) + + provider = _mounted_provider() + file = ChunkedUpload("big.txt", [b"123456"]) + + with ( + patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir), + patch.object(uploads, "get_sandbox_provider", return_value=provider), + patch.object(uploads, "_get_upload_limits", return_value=uploads.UploadLimits(max_files=10, max_file_size=5, max_total_size=20)), + ): + with pytest.raises(HTTPException) as exc_info: + asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace())) + + assert exc_info.value.status_code == 413 + assert not (thread_uploads_dir / "big.txt").exists() + assert file.read_calls == [8192] + provider.acquire.assert_not_called() + + +def test_upload_files_rejects_total_size_over_limit_and_cleans_request_files(tmp_path): + thread_uploads_dir = tmp_path / "uploads" + thread_uploads_dir.mkdir(parents=True) + + with ( + patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir), + patch.object(uploads, "get_sandbox_provider", return_value=_mounted_provider()), + patch.object(uploads, "_get_upload_limits", return_value=uploads.UploadLimits(max_files=10, max_file_size=10, max_total_size=5)), + ): + files = [ + ChunkedUpload("first.txt", [b"123"]), + ChunkedUpload("second.txt", [b"456"]), + ] + with pytest.raises(HTTPException) as exc_info: + asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=files, config=SimpleNamespace())) + + assert exc_info.value.status_code == 413 + assert not (thread_uploads_dir / "first.txt").exists() + assert not (thread_uploads_dir / "second.txt").exists() + + +def test_upload_files_does_not_sync_non_local_sandbox_when_total_size_exceeds_limit(tmp_path): + thread_uploads_dir = tmp_path / "uploads" + thread_uploads_dir.mkdir(parents=True) + + provider = MagicMock() + provider.uses_thread_data_mounts = False + provider.acquire.return_value = "aio-1" + sandbox = MagicMock() + provider.get.return_value = sandbox + + with ( + patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir), + patch.object(uploads, "get_sandbox_provider", return_value=provider), + patch.object(uploads, "_get_upload_limits", return_value=uploads.UploadLimits(max_files=10, max_file_size=10, max_total_size=5)), + ): + files = [ + ChunkedUpload("first.txt", [b"123"]), + ChunkedUpload("second.txt", [b"456"]), + ] + with pytest.raises(HTTPException) as exc_info: + asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=files, config=SimpleNamespace())) + + assert exc_info.value.status_code == 413 + provider.acquire.assert_called_once_with("thread-aio") + provider.get.assert_called_once_with("aio-1") + sandbox.update_file.assert_not_called() + + +def test_upload_files_does_not_sync_non_local_sandbox_when_conversion_fails(tmp_path): + thread_uploads_dir = tmp_path / "uploads" + thread_uploads_dir.mkdir(parents=True) + + provider = MagicMock() + provider.uses_thread_data_mounts = False + provider.acquire.return_value = "aio-1" + sandbox = MagicMock() + provider.get.return_value = sandbox + + with ( + patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir), + patch.object(uploads, "get_sandbox_provider", return_value=provider), + patch.object(uploads, "_auto_convert_documents_enabled", return_value=True), + patch.object(uploads, "convert_file_to_markdown", AsyncMock(side_effect=RuntimeError("conversion failed"))), + ): + file = UploadFile(filename="report.pdf", file=BytesIO(b"pdf-bytes")) + with pytest.raises(HTTPException) as exc_info: + asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file], config=SimpleNamespace())) + + assert exc_info.value.status_code == 500 + provider.acquire.assert_called_once_with("thread-aio") + provider.get.assert_called_once_with("aio-1") + sandbox.update_file.assert_not_called() + assert not (thread_uploads_dir / "report.pdf").exists() + + def test_make_file_sandbox_writable_adds_write_bits_for_regular_files(tmp_path): file_path = tmp_path / "report.pdf" file_path.write_bytes(b"pdf-bytes") @@ -286,3 +476,65 @@ def test_auto_convert_documents_enabled_accepts_boolean_and_string_truthy_values assert uploads._auto_convert_documents_enabled(true_cfg) is True assert uploads._auto_convert_documents_enabled(string_true_cfg) is True assert uploads._auto_convert_documents_enabled(string_false_cfg) is False + + +def test_upload_limits_endpoint_reads_uploads_config(): + cfg = MagicMock() + cfg.uploads = { + "max_files": 15, + "max_file_size": "1048576", + "max_total_size": 2097152, + } + + result = asyncio.run(call_unwrapped(uploads.get_upload_limits, "thread-local", request=MagicMock(), config=cfg)) + + assert result.max_files == 15 + assert result.max_file_size == 1048576 + assert result.max_total_size == 2097152 + + +def test_upload_limits_endpoint_requires_thread_access(): + cfg = MagicMock() + cfg.uploads = {} + app = make_authed_test_app(owner_check_passes=False) + app.state.config = cfg + app.include_router(uploads.router) + + with TestClient(app) as client: + response = client.get("/api/threads/thread-local/uploads/limits") + + assert response.status_code == 404 + + +def test_upload_limits_accept_legacy_config_keys(): + cfg = MagicMock() + cfg.uploads = { + "max_file_count": 7, + "max_single_file_size": 123, + "max_total_size": 456, + } + + limits = uploads._get_upload_limits(cfg) + + assert limits == uploads.UploadLimits(max_files=7, max_file_size=123, max_total_size=456) + + +def test_upload_files_uses_configured_file_count_limit(tmp_path): + thread_uploads_dir = tmp_path / "uploads" + thread_uploads_dir.mkdir(parents=True) + + cfg = MagicMock() + cfg.uploads = {"max_files": 1} + + with ( + patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir), + patch.object(uploads, "get_sandbox_provider", return_value=_mounted_provider()), + ): + files = [ + ChunkedUpload("one.txt", [b"one"]), + ChunkedUpload("two.txt", [b"two"]), + ] + with pytest.raises(HTTPException) as exc_info: + asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=files, config=cfg)) + + assert exc_info.value.status_code == 413 diff --git a/config.example.yaml b/config.example.yaml index a14ca5886..04ccd0b12 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -501,6 +501,11 @@ tool_search: # Option 1: Local Sandbox (Default) # Executes commands directly on the host machine uploads: + # Application-level upload limits enforced by the gateway and exposed to the + # frontend before file selection. + max_files: 10 + max_file_size: 52428800 # 50 MiB + max_total_size: 104857600 # 100 MiB # Automatic Office/PDF conversion runs on the backend host before sandbox # isolation applies. Keep this disabled unless uploads come from a fully # trusted source and you intentionally accept host-side parser risk.