mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
fix(backend): fix uploads for mounted sandbox providers (#2199)
* fix uploads for mounted sandbox providers * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
c43a45ea40
commit
55bc09ac33
@ -8,7 +8,7 @@ from fastapi import APIRouter, File, HTTPException, UploadFile
|
||||
from pydantic import BaseModel
|
||||
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||
from deerflow.sandbox.sandbox_provider import SandboxProvider, get_sandbox_provider
|
||||
from deerflow.uploads.manager import (
|
||||
PathTraversalError,
|
||||
delete_file_safe,
|
||||
@ -53,6 +53,10 @@ def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
|
||||
os.chmod(file_path, writable_mode, **chmod_kwargs)
|
||||
|
||||
|
||||
def _uses_thread_data_mounts(sandbox_provider: SandboxProvider) -> bool:
|
||||
return bool(getattr(sandbox_provider, "uses_thread_data_mounts", False))
|
||||
|
||||
|
||||
@router.post("", response_model=UploadResponse)
|
||||
async def upload_files(
|
||||
thread_id: str,
|
||||
@ -70,8 +74,11 @@ async def upload_files(
|
||||
uploaded_files = []
|
||||
|
||||
sandbox_provider = get_sandbox_provider()
|
||||
sandbox_id = sandbox_provider.acquire(thread_id)
|
||||
sandbox = sandbox_provider.get(sandbox_id)
|
||||
sync_to_sandbox = not _uses_thread_data_mounts(sandbox_provider)
|
||||
sandbox = None
|
||||
if sync_to_sandbox:
|
||||
sandbox_id = sandbox_provider.acquire(thread_id)
|
||||
sandbox = sandbox_provider.get(sandbox_id)
|
||||
|
||||
for file in files:
|
||||
if not file.filename:
|
||||
@ -90,7 +97,7 @@ async def upload_files(
|
||||
|
||||
virtual_path = upload_virtual_path(safe_filename)
|
||||
|
||||
if sandbox_id != "local":
|
||||
if sync_to_sandbox and sandbox is not None:
|
||||
_make_file_sandbox_writable(file_path)
|
||||
sandbox.update_file(virtual_path, content)
|
||||
|
||||
@ -110,7 +117,7 @@ async def upload_files(
|
||||
if md_path:
|
||||
md_virtual_path = upload_virtual_path(md_path.name)
|
||||
|
||||
if sandbox_id != "local":
|
||||
if sync_to_sandbox and sandbox is not None:
|
||||
_make_file_sandbox_writable(md_path)
|
||||
sandbox.update_file(md_virtual_path, md_path.read_bytes())
|
||||
|
||||
|
||||
@ -119,6 +119,16 @@ class AioSandboxProvider(SandboxProvider):
|
||||
if self._config.get("idle_timeout", DEFAULT_IDLE_TIMEOUT) > 0:
|
||||
self._start_idle_checker()
|
||||
|
||||
@property
|
||||
def uses_thread_data_mounts(self) -> bool:
|
||||
"""Whether thread workspace/uploads/outputs are visible via mounts.
|
||||
|
||||
Local container backends bind-mount the thread data directories, so files
|
||||
written by the gateway are already visible when the sandbox starts.
|
||||
Remote backends may require explicit file sync.
|
||||
"""
|
||||
return isinstance(self._backend, LocalContainerBackend)
|
||||
|
||||
# ── Factory methods ──────────────────────────────────────────────────
|
||||
|
||||
def _create_backend(self) -> SandboxBackend:
|
||||
|
||||
@ -11,6 +11,8 @@ _singleton: LocalSandbox | None = None
|
||||
|
||||
|
||||
class LocalSandboxProvider(SandboxProvider):
|
||||
uses_thread_data_mounts = True
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the local sandbox provider with path mappings."""
|
||||
self._path_mappings = self._setup_path_mappings()
|
||||
|
||||
@ -8,6 +8,8 @@ from deerflow.sandbox.sandbox import Sandbox
|
||||
class SandboxProvider(ABC):
|
||||
"""Abstract base class for sandbox providers"""
|
||||
|
||||
uses_thread_data_mounts: bool = False
|
||||
|
||||
@abstractmethod
|
||||
def acquire(self, thread_id: str | None = None) -> str:
|
||||
"""Acquire a sandbox environment and return its ID.
|
||||
|
||||
@ -14,6 +14,7 @@ def test_upload_files_writes_thread_storage_and_skips_local_sandbox_sync(tmp_pat
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.uses_thread_data_mounts = True
|
||||
provider.acquire.return_value = "local"
|
||||
sandbox = MagicMock()
|
||||
provider.get.return_value = sandbox
|
||||
@ -34,11 +35,33 @@ def test_upload_files_writes_thread_storage_and_skips_local_sandbox_sync(tmp_pat
|
||||
sandbox.update_file.assert_not_called()
|
||||
|
||||
|
||||
def test_upload_files_skips_acquire_when_thread_data_is_mounted(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.uses_thread_data_mounts = True
|
||||
|
||||
with (
|
||||
patch.object(uploads, "get_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||
):
|
||||
file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads"))
|
||||
result = asyncio.run(uploads.upload_files("thread-mounted", files=[file]))
|
||||
|
||||
assert result.success is True
|
||||
assert (thread_uploads_dir / "notes.txt").read_bytes() == b"hello uploads"
|
||||
provider.acquire.assert_not_called()
|
||||
provider.get.assert_not_called()
|
||||
|
||||
|
||||
def test_upload_files_syncs_non_local_sandbox_and_marks_markdown_file(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.uses_thread_data_mounts = False
|
||||
provider.acquire.return_value = "aio-1"
|
||||
sandbox = MagicMock()
|
||||
provider.get.return_value = sandbox
|
||||
@ -75,6 +98,7 @@ def test_upload_files_makes_non_local_files_sandbox_writable(tmp_path):
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.uses_thread_data_mounts = False
|
||||
provider.acquire.return_value = "aio-1"
|
||||
sandbox = MagicMock()
|
||||
provider.get.return_value = sandbox
|
||||
@ -104,6 +128,7 @@ def test_upload_files_does_not_adjust_permissions_for_local_sandbox(tmp_path):
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.uses_thread_data_mounts = True
|
||||
provider.acquire.return_value = "local"
|
||||
sandbox = MagicMock()
|
||||
provider.get.return_value = sandbox
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user