mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-10 18:58:21 +00:00
* fix(uploads): add Windows support for safe symlink-protected uploads * fix(uploads): update tests and translate comments;
311 lines
11 KiB
Python
311 lines
11 KiB
Python
"""Shared upload management logic.
|
|
|
|
Pure business logic — no FastAPI/HTTP dependencies.
|
|
Both Gateway and Client delegate to these functions.
|
|
"""
|
|
|
|
import errno
|
|
import os
|
|
import re
|
|
import stat
|
|
from pathlib import Path
|
|
from urllib.parse import quote
|
|
|
|
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
|
from deerflow.runtime.user_context import get_effective_user_id
|
|
|
|
|
|
class PathTraversalError(ValueError):
|
|
"""Raised when a path escapes its allowed base directory."""
|
|
|
|
|
|
class UnsafeUploadPathError(ValueError):
|
|
"""Raised when an upload destination is not a safe regular file path."""
|
|
|
|
|
|
# thread_id must be alphanumeric, hyphens, underscores, or dots only.
|
|
_SAFE_THREAD_ID = re.compile(r"^[a-zA-Z0-9._-]+$")
|
|
|
|
|
|
def validate_thread_id(thread_id: str) -> None:
|
|
"""Reject thread IDs containing characters unsafe for filesystem paths.
|
|
|
|
Raises:
|
|
ValueError: If thread_id is empty or contains unsafe characters.
|
|
"""
|
|
if not thread_id or not _SAFE_THREAD_ID.match(thread_id):
|
|
raise ValueError(f"Invalid thread_id: {thread_id!r}")
|
|
|
|
|
|
def get_uploads_dir(thread_id: str) -> Path:
|
|
"""Return the uploads directory path for a thread (no side effects)."""
|
|
validate_thread_id(thread_id)
|
|
return get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
|
|
|
|
|
def ensure_uploads_dir(thread_id: str) -> Path:
|
|
"""Return the uploads directory for a thread, creating it if needed."""
|
|
base = get_uploads_dir(thread_id)
|
|
base.mkdir(parents=True, exist_ok=True)
|
|
return base
|
|
|
|
|
|
def normalize_filename(filename: str) -> str:
|
|
"""Sanitize a filename by extracting its basename.
|
|
|
|
Strips any directory components and rejects traversal patterns.
|
|
|
|
Args:
|
|
filename: Raw filename from user input (may contain path components).
|
|
|
|
Returns:
|
|
Safe filename (basename only).
|
|
|
|
Raises:
|
|
ValueError: If filename is empty or resolves to a traversal pattern.
|
|
"""
|
|
if not filename:
|
|
raise ValueError("Filename is empty")
|
|
safe = Path(filename).name
|
|
if not safe or safe in {".", ".."}:
|
|
raise ValueError(f"Filename is unsafe: {filename!r}")
|
|
# Reject backslashes — on Linux Path.name keeps them as literal chars,
|
|
# but they indicate a Windows-style path that should be stripped or rejected.
|
|
if "\\" in safe:
|
|
raise ValueError(f"Filename contains backslash: {filename!r}")
|
|
if len(safe.encode("utf-8")) > 255:
|
|
raise ValueError(f"Filename too long: {len(safe)} chars")
|
|
return safe
|
|
|
|
|
|
def claim_unique_filename(name: str, seen: set[str]) -> str:
|
|
"""Generate a unique filename by appending ``_N`` suffix on collision.
|
|
|
|
Automatically adds the returned name to *seen* so callers don't need to.
|
|
|
|
Args:
|
|
name: Candidate filename.
|
|
seen: Set of filenames already claimed (mutated in place).
|
|
|
|
Returns:
|
|
A filename not present in *seen* (already added to *seen*).
|
|
"""
|
|
if name not in seen:
|
|
seen.add(name)
|
|
return name
|
|
stem, suffix = Path(name).stem, Path(name).suffix
|
|
counter = 1
|
|
candidate = f"{stem}_{counter}{suffix}"
|
|
while candidate in seen:
|
|
counter += 1
|
|
candidate = f"{stem}_{counter}{suffix}"
|
|
seen.add(candidate)
|
|
return candidate
|
|
|
|
|
|
def validate_path_traversal(path: Path, base: Path) -> None:
|
|
"""Verify that *path* is inside *base*.
|
|
|
|
Raises:
|
|
PathTraversalError: If a path traversal is detected.
|
|
"""
|
|
try:
|
|
path.resolve().relative_to(base.resolve())
|
|
except ValueError:
|
|
raise PathTraversalError("Path traversal detected") from None
|
|
|
|
|
|
def open_upload_file_no_symlink(base_dir: Path, filename: str) -> tuple[Path, object]:
|
|
"""Open an upload destination for safe streaming writes.
|
|
|
|
Upload directories may be mounted into local sandboxes. A sandbox process can
|
|
therefore leave a symlink at a future upload filename. Normal ``Path.write_bytes``
|
|
follows that link and can overwrite files outside the uploads directory with
|
|
gateway privileges. This helper rejects symlink destinations using ``O_NOFOLLOW``
|
|
on POSIX. On Windows (which lacks ``O_NOFOLLOW``), it uses dual ``lstat`` checks
|
|
and ``fstat`` validation after ``open()`` to reduce the TOCTOU window; this does
|
|
not eliminate all races but makes exploitation significantly harder. Path-traversal
|
|
validation prevents escapes from *base_dir* in both cases.
|
|
"""
|
|
safe_name = normalize_filename(filename)
|
|
dest = base_dir / safe_name
|
|
|
|
try:
|
|
st = os.lstat(dest)
|
|
except FileNotFoundError:
|
|
st = None
|
|
|
|
if st is not None and not stat.S_ISREG(st.st_mode):
|
|
raise UnsafeUploadPathError(f"Upload destination is not a regular file: {safe_name}")
|
|
|
|
validate_path_traversal(dest, base_dir)
|
|
|
|
has_nofollow = hasattr(os, "O_NOFOLLOW")
|
|
|
|
if has_nofollow:
|
|
# POSIX: O_NOFOLLOW makes open() fail with ELOOP if dest is a symlink.
|
|
flags = os.O_WRONLY | os.O_CREAT | os.O_NOFOLLOW
|
|
if hasattr(os, "O_NONBLOCK"):
|
|
flags |= os.O_NONBLOCK
|
|
|
|
try:
|
|
fd = os.open(dest, flags, 0o600)
|
|
except OSError as exc:
|
|
if exc.errno in {errno.ELOOP, errno.EISDIR, errno.ENOTDIR, errno.ENXIO, errno.EAGAIN}:
|
|
raise UnsafeUploadPathError(f"Unsafe upload destination: {safe_name}") from exc
|
|
raise
|
|
|
|
try:
|
|
opened_stat = os.fstat(fd)
|
|
if not stat.S_ISREG(opened_stat.st_mode) or opened_stat.st_nlink != 1:
|
|
raise UnsafeUploadPathError(f"Upload destination is not an exclusive regular file: {safe_name}")
|
|
os.ftruncate(fd, 0)
|
|
fh = os.fdopen(fd, "wb")
|
|
fd = -1
|
|
finally:
|
|
if fd >= 0:
|
|
os.close(fd)
|
|
return dest, fh
|
|
|
|
# Windows: no O_NOFOLLOW available. Uses a second lstat immediately before open()
|
|
# to narrow the TOCTOU window, then fstat after open() as a further defence.
|
|
# Note: a narrow race window remains between the pre-open lstat and open(); the
|
|
# path-traversal check mitigates escapes from base_dir but cannot prevent an
|
|
# attacker who can atomically replace dest with a symlink after the check.
|
|
if st is not None and st.st_nlink > 1:
|
|
raise UnsafeUploadPathError(f"Upload destination has multiple links: {safe_name}")
|
|
|
|
flags = os.O_WRONLY | os.O_CREAT
|
|
if hasattr(os, "O_BINARY"):
|
|
flags |= os.O_BINARY
|
|
|
|
try:
|
|
pre_open_st = os.lstat(dest)
|
|
except FileNotFoundError:
|
|
pre_open_st = None
|
|
|
|
if pre_open_st is not None and not stat.S_ISREG(pre_open_st.st_mode):
|
|
raise UnsafeUploadPathError(f"Upload destination is not a regular file: {safe_name}")
|
|
if pre_open_st is not None and pre_open_st.st_nlink > 1:
|
|
raise UnsafeUploadPathError(f"Upload destination has multiple links: {safe_name}")
|
|
|
|
try:
|
|
fd = os.open(dest, flags, 0o600)
|
|
except OSError as exc:
|
|
if exc.errno in {errno.EISDIR, errno.ENOTDIR, errno.ENXIO, errno.EAGAIN}:
|
|
raise UnsafeUploadPathError(f"Unsafe upload destination: {safe_name}") from exc
|
|
raise
|
|
|
|
try:
|
|
opened_stat = os.fstat(fd)
|
|
if not stat.S_ISREG(opened_stat.st_mode) or opened_stat.st_nlink > 1:
|
|
raise UnsafeUploadPathError(f"Upload destination is not an exclusive regular file: {safe_name}")
|
|
os.ftruncate(fd, 0)
|
|
fh = os.fdopen(fd, "wb")
|
|
fd = -1
|
|
finally:
|
|
if fd >= 0:
|
|
os.close(fd)
|
|
return dest, fh
|
|
|
|
|
|
def write_upload_file_no_symlink(base_dir: Path, filename: str, data: bytes) -> Path:
|
|
"""Write upload bytes without following a pre-existing destination symlink."""
|
|
dest, fh = open_upload_file_no_symlink(base_dir, filename)
|
|
with fh:
|
|
fh.write(data)
|
|
return dest
|
|
|
|
|
|
def list_files_in_dir(directory: Path) -> dict:
|
|
"""List files (not directories) in *directory*.
|
|
|
|
Args:
|
|
directory: Directory to scan.
|
|
|
|
Returns:
|
|
Dict with "files" list (sorted by name) and "count".
|
|
Each file entry has ``size`` as *int* (bytes). Call
|
|
:func:`enrich_file_listing` to stringify sizes and add
|
|
virtual / artifact URLs.
|
|
"""
|
|
if not directory.is_dir():
|
|
return {"files": [], "count": 0}
|
|
|
|
files = []
|
|
with os.scandir(directory) as entries:
|
|
for entry in sorted(entries, key=lambda e: e.name):
|
|
if not entry.is_file(follow_symlinks=False):
|
|
continue
|
|
st = entry.stat(follow_symlinks=False)
|
|
files.append(
|
|
{
|
|
"filename": entry.name,
|
|
"size": st.st_size,
|
|
"path": entry.path,
|
|
"extension": Path(entry.name).suffix,
|
|
"modified": st.st_mtime,
|
|
}
|
|
)
|
|
return {"files": files, "count": len(files)}
|
|
|
|
|
|
def delete_file_safe(base_dir: Path, filename: str, *, convertible_extensions: set[str] | None = None) -> dict:
|
|
"""Delete a file inside *base_dir* after path-traversal validation.
|
|
|
|
If *convertible_extensions* is provided and the file's extension matches,
|
|
the companion ``.md`` file is also removed (if it exists).
|
|
|
|
Args:
|
|
base_dir: Directory containing the file.
|
|
filename: Name of file to delete.
|
|
convertible_extensions: Lowercase extensions (e.g. ``{".pdf", ".docx"}``)
|
|
whose companion markdown should be cleaned up.
|
|
|
|
Returns:
|
|
Dict with success and message.
|
|
|
|
Raises:
|
|
FileNotFoundError: If the file does not exist.
|
|
PathTraversalError: If path traversal is detected.
|
|
"""
|
|
file_path = (base_dir / filename).resolve()
|
|
validate_path_traversal(file_path, base_dir)
|
|
|
|
if not file_path.is_file():
|
|
raise FileNotFoundError(f"File not found: {filename}")
|
|
|
|
file_path.unlink()
|
|
|
|
# Clean up companion markdown generated during upload conversion.
|
|
if convertible_extensions and file_path.suffix.lower() in convertible_extensions:
|
|
file_path.with_suffix(".md").unlink(missing_ok=True)
|
|
|
|
return {"success": True, "message": f"Deleted {filename}"}
|
|
|
|
|
|
def upload_artifact_url(thread_id: str, filename: str) -> str:
|
|
"""Build the artifact URL for a file in a thread's uploads directory.
|
|
|
|
*filename* is percent-encoded so that spaces, ``#``, ``?`` etc. are safe.
|
|
"""
|
|
return f"/api/threads/{thread_id}/artifacts{VIRTUAL_PATH_PREFIX}/uploads/{quote(filename, safe='')}"
|
|
|
|
|
|
def upload_virtual_path(filename: str) -> str:
|
|
"""Build the virtual path for a file in the uploads directory."""
|
|
return f"{VIRTUAL_PATH_PREFIX}/uploads/{filename}"
|
|
|
|
|
|
def enrich_file_listing(result: dict, thread_id: str) -> dict:
|
|
"""Add virtual paths, artifact URLs, and stringify sizes on a listing result.
|
|
|
|
Mutates *result* in place and returns it for convenience.
|
|
"""
|
|
for f in result["files"]:
|
|
filename = f["filename"]
|
|
f["size"] = str(f["size"])
|
|
f["virtual_path"] = upload_virtual_path(filename)
|
|
f["artifact_url"] = upload_artifact_url(thread_id, filename)
|
|
return result
|