yangyufan 0d1053ca44
fix(uploads): add Windows support for safe symlink-protected uploads (#2794)
* fix(uploads): add Windows support for safe symlink-protected uploads

* fix(uploads): update tests and translate comments;
2026-05-09 18:21:54 +08:00

311 lines
11 KiB
Python

"""Shared upload management logic.
Pure business logic — no FastAPI/HTTP dependencies.
Both Gateway and Client delegate to these functions.
"""
import errno
import os
import re
import stat
from pathlib import Path
from urllib.parse import quote
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
from deerflow.runtime.user_context import get_effective_user_id
class PathTraversalError(ValueError):
"""Raised when a path escapes its allowed base directory."""
class UnsafeUploadPathError(ValueError):
"""Raised when an upload destination is not a safe regular file path."""
# thread_id must be alphanumeric, hyphens, underscores, or dots only.
_SAFE_THREAD_ID = re.compile(r"^[a-zA-Z0-9._-]+$")
def validate_thread_id(thread_id: str) -> None:
"""Reject thread IDs containing characters unsafe for filesystem paths.
Raises:
ValueError: If thread_id is empty or contains unsafe characters.
"""
if not thread_id or not _SAFE_THREAD_ID.match(thread_id):
raise ValueError(f"Invalid thread_id: {thread_id!r}")
def get_uploads_dir(thread_id: str) -> Path:
"""Return the uploads directory path for a thread (no side effects)."""
validate_thread_id(thread_id)
return get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
def ensure_uploads_dir(thread_id: str) -> Path:
"""Return the uploads directory for a thread, creating it if needed."""
base = get_uploads_dir(thread_id)
base.mkdir(parents=True, exist_ok=True)
return base
def normalize_filename(filename: str) -> str:
"""Sanitize a filename by extracting its basename.
Strips any directory components and rejects traversal patterns.
Args:
filename: Raw filename from user input (may contain path components).
Returns:
Safe filename (basename only).
Raises:
ValueError: If filename is empty or resolves to a traversal pattern.
"""
if not filename:
raise ValueError("Filename is empty")
safe = Path(filename).name
if not safe or safe in {".", ".."}:
raise ValueError(f"Filename is unsafe: {filename!r}")
# Reject backslashes — on Linux Path.name keeps them as literal chars,
# but they indicate a Windows-style path that should be stripped or rejected.
if "\\" in safe:
raise ValueError(f"Filename contains backslash: {filename!r}")
if len(safe.encode("utf-8")) > 255:
raise ValueError(f"Filename too long: {len(safe)} chars")
return safe
def claim_unique_filename(name: str, seen: set[str]) -> str:
"""Generate a unique filename by appending ``_N`` suffix on collision.
Automatically adds the returned name to *seen* so callers don't need to.
Args:
name: Candidate filename.
seen: Set of filenames already claimed (mutated in place).
Returns:
A filename not present in *seen* (already added to *seen*).
"""
if name not in seen:
seen.add(name)
return name
stem, suffix = Path(name).stem, Path(name).suffix
counter = 1
candidate = f"{stem}_{counter}{suffix}"
while candidate in seen:
counter += 1
candidate = f"{stem}_{counter}{suffix}"
seen.add(candidate)
return candidate
def validate_path_traversal(path: Path, base: Path) -> None:
"""Verify that *path* is inside *base*.
Raises:
PathTraversalError: If a path traversal is detected.
"""
try:
path.resolve().relative_to(base.resolve())
except ValueError:
raise PathTraversalError("Path traversal detected") from None
def open_upload_file_no_symlink(base_dir: Path, filename: str) -> tuple[Path, object]:
"""Open an upload destination for safe streaming writes.
Upload directories may be mounted into local sandboxes. A sandbox process can
therefore leave a symlink at a future upload filename. Normal ``Path.write_bytes``
follows that link and can overwrite files outside the uploads directory with
gateway privileges. This helper rejects symlink destinations using ``O_NOFOLLOW``
on POSIX. On Windows (which lacks ``O_NOFOLLOW``), it uses dual ``lstat`` checks
and ``fstat`` validation after ``open()`` to reduce the TOCTOU window; this does
not eliminate all races but makes exploitation significantly harder. Path-traversal
validation prevents escapes from *base_dir* in both cases.
"""
safe_name = normalize_filename(filename)
dest = base_dir / safe_name
try:
st = os.lstat(dest)
except FileNotFoundError:
st = None
if st is not None and not stat.S_ISREG(st.st_mode):
raise UnsafeUploadPathError(f"Upload destination is not a regular file: {safe_name}")
validate_path_traversal(dest, base_dir)
has_nofollow = hasattr(os, "O_NOFOLLOW")
if has_nofollow:
# POSIX: O_NOFOLLOW makes open() fail with ELOOP if dest is a symlink.
flags = os.O_WRONLY | os.O_CREAT | os.O_NOFOLLOW
if hasattr(os, "O_NONBLOCK"):
flags |= os.O_NONBLOCK
try:
fd = os.open(dest, flags, 0o600)
except OSError as exc:
if exc.errno in {errno.ELOOP, errno.EISDIR, errno.ENOTDIR, errno.ENXIO, errno.EAGAIN}:
raise UnsafeUploadPathError(f"Unsafe upload destination: {safe_name}") from exc
raise
try:
opened_stat = os.fstat(fd)
if not stat.S_ISREG(opened_stat.st_mode) or opened_stat.st_nlink != 1:
raise UnsafeUploadPathError(f"Upload destination is not an exclusive regular file: {safe_name}")
os.ftruncate(fd, 0)
fh = os.fdopen(fd, "wb")
fd = -1
finally:
if fd >= 0:
os.close(fd)
return dest, fh
# Windows: no O_NOFOLLOW available. Uses a second lstat immediately before open()
# to narrow the TOCTOU window, then fstat after open() as a further defence.
# Note: a narrow race window remains between the pre-open lstat and open(); the
# path-traversal check mitigates escapes from base_dir but cannot prevent an
# attacker who can atomically replace dest with a symlink after the check.
if st is not None and st.st_nlink > 1:
raise UnsafeUploadPathError(f"Upload destination has multiple links: {safe_name}")
flags = os.O_WRONLY | os.O_CREAT
if hasattr(os, "O_BINARY"):
flags |= os.O_BINARY
try:
pre_open_st = os.lstat(dest)
except FileNotFoundError:
pre_open_st = None
if pre_open_st is not None and not stat.S_ISREG(pre_open_st.st_mode):
raise UnsafeUploadPathError(f"Upload destination is not a regular file: {safe_name}")
if pre_open_st is not None and pre_open_st.st_nlink > 1:
raise UnsafeUploadPathError(f"Upload destination has multiple links: {safe_name}")
try:
fd = os.open(dest, flags, 0o600)
except OSError as exc:
if exc.errno in {errno.EISDIR, errno.ENOTDIR, errno.ENXIO, errno.EAGAIN}:
raise UnsafeUploadPathError(f"Unsafe upload destination: {safe_name}") from exc
raise
try:
opened_stat = os.fstat(fd)
if not stat.S_ISREG(opened_stat.st_mode) or opened_stat.st_nlink > 1:
raise UnsafeUploadPathError(f"Upload destination is not an exclusive regular file: {safe_name}")
os.ftruncate(fd, 0)
fh = os.fdopen(fd, "wb")
fd = -1
finally:
if fd >= 0:
os.close(fd)
return dest, fh
def write_upload_file_no_symlink(base_dir: Path, filename: str, data: bytes) -> Path:
"""Write upload bytes without following a pre-existing destination symlink."""
dest, fh = open_upload_file_no_symlink(base_dir, filename)
with fh:
fh.write(data)
return dest
def list_files_in_dir(directory: Path) -> dict:
"""List files (not directories) in *directory*.
Args:
directory: Directory to scan.
Returns:
Dict with "files" list (sorted by name) and "count".
Each file entry has ``size`` as *int* (bytes). Call
:func:`enrich_file_listing` to stringify sizes and add
virtual / artifact URLs.
"""
if not directory.is_dir():
return {"files": [], "count": 0}
files = []
with os.scandir(directory) as entries:
for entry in sorted(entries, key=lambda e: e.name):
if not entry.is_file(follow_symlinks=False):
continue
st = entry.stat(follow_symlinks=False)
files.append(
{
"filename": entry.name,
"size": st.st_size,
"path": entry.path,
"extension": Path(entry.name).suffix,
"modified": st.st_mtime,
}
)
return {"files": files, "count": len(files)}
def delete_file_safe(base_dir: Path, filename: str, *, convertible_extensions: set[str] | None = None) -> dict:
"""Delete a file inside *base_dir* after path-traversal validation.
If *convertible_extensions* is provided and the file's extension matches,
the companion ``.md`` file is also removed (if it exists).
Args:
base_dir: Directory containing the file.
filename: Name of file to delete.
convertible_extensions: Lowercase extensions (e.g. ``{".pdf", ".docx"}``)
whose companion markdown should be cleaned up.
Returns:
Dict with success and message.
Raises:
FileNotFoundError: If the file does not exist.
PathTraversalError: If path traversal is detected.
"""
file_path = (base_dir / filename).resolve()
validate_path_traversal(file_path, base_dir)
if not file_path.is_file():
raise FileNotFoundError(f"File not found: {filename}")
file_path.unlink()
# Clean up companion markdown generated during upload conversion.
if convertible_extensions and file_path.suffix.lower() in convertible_extensions:
file_path.with_suffix(".md").unlink(missing_ok=True)
return {"success": True, "message": f"Deleted {filename}"}
def upload_artifact_url(thread_id: str, filename: str) -> str:
"""Build the artifact URL for a file in a thread's uploads directory.
*filename* is percent-encoded so that spaces, ``#``, ``?`` etc. are safe.
"""
return f"/api/threads/{thread_id}/artifacts{VIRTUAL_PATH_PREFIX}/uploads/{quote(filename, safe='')}"
def upload_virtual_path(filename: str) -> str:
"""Build the virtual path for a file in the uploads directory."""
return f"{VIRTUAL_PATH_PREFIX}/uploads/{filename}"
def enrich_file_listing(result: dict, thread_id: str) -> dict:
"""Add virtual paths, artifact URLs, and stringify sizes on a listing result.
Mutates *result* in place and returns it for convenience.
"""
for f in result["files"]:
filename = f["filename"]
f["size"] = str(f["size"])
f["virtual_path"] = upload_virtual_path(filename)
f["artifact_url"] = upload_artifact_url(thread_id, filename)
return result