deer-flow/backend/tests/test_aio_sandbox.py
Varian_米泽 a2cb38f62b
fix: prevent concurrent subagent file write conflicts in sandbox tools (#1714)
* fix: prevent concurrent subagent file write conflicts

Serialize same-path str_replace operations in sandbox tools

Guard AioSandbox write_file/update_file with the existing sandbox lock

Add regression tests for concurrent str_replace and append races

Verify with backend full tests and ruff lint checks

* fix(sandbox): Fix the concurrency issue of file operations on the same path in isolated sandboxes.

Ensure that different sandbox instances use independent locks for file operations on the same virtual path to avoid concurrency conflicts. Change the lock key from a single path to a composite key of (sandbox.id, path), and add tests to verify the concurrent safety of isolated sandboxes.

* feat(sandbox): Extract file operation lock logic to standalone module and fix concurrency issues

Extract file operation lock related logic from tools.py into a separate file_operation_lock.py module.
Fix data race issues during concurrent str_replace and write_file operations.
2026-04-02 15:39:41 +08:00

184 lines
6.1 KiB
Python

"""Tests for AioSandbox concurrent command serialization (#1433)."""
import threading
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
@pytest.fixture()
def sandbox():
"""Create an AioSandbox with a mocked client."""
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
from deerflow.community.aio_sandbox.aio_sandbox import AioSandbox
sb = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
return sb
class TestExecuteCommandSerialization:
"""Verify that concurrent exec_command calls are serialized."""
def test_lock_prevents_concurrent_execution(self, sandbox):
"""Concurrent threads should not overlap inside execute_command."""
call_log = []
barrier = threading.Barrier(3)
def slow_exec(command, **kwargs):
call_log.append(("enter", command))
import time
time.sleep(0.05)
call_log.append(("exit", command))
return SimpleNamespace(data=SimpleNamespace(output=f"ok: {command}"))
sandbox._client.shell.exec_command = slow_exec
def worker(cmd):
barrier.wait() # ensure all threads contend for the lock simultaneously
sandbox.execute_command(cmd)
threads = []
for i in range(3):
t = threading.Thread(target=worker, args=(f"cmd-{i}",))
threads.append(t)
for t in threads:
t.start()
for t in threads:
t.join()
# Verify serialization: each "enter" should be followed by its own
# "exit" before the next "enter" (no interleaving).
enters = [i for i, (action, _) in enumerate(call_log) if action == "enter"]
exits = [i for i, (action, _) in enumerate(call_log) if action == "exit"]
assert len(enters) == 3
assert len(exits) == 3
for e_idx, x_idx in zip(enters, exits):
assert x_idx == e_idx + 1, f"Interleaved execution detected: {call_log}"
class TestErrorObservationRetry:
"""Verify ErrorObservation detection and fresh-session retry."""
def test_retry_on_error_observation(self, sandbox):
"""When output contains ErrorObservation, retry with a fresh session."""
call_count = 0
def mock_exec(command, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'"))
return SimpleNamespace(data=SimpleNamespace(output="success"))
sandbox._client.shell.exec_command = mock_exec
result = sandbox.execute_command("echo hello")
assert result == "success"
assert call_count == 2
def test_retry_passes_fresh_session_id(self, sandbox):
"""The retry call should include a new session id kwarg."""
calls = []
def mock_exec(command, **kwargs):
calls.append(kwargs)
if len(calls) == 1:
return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'"))
return SimpleNamespace(data=SimpleNamespace(output="ok"))
sandbox._client.shell.exec_command = mock_exec
sandbox.execute_command("test")
assert len(calls) == 2
assert "id" not in calls[0]
assert "id" in calls[1]
assert len(calls[1]["id"]) == 36 # UUID format
def test_no_retry_on_clean_output(self, sandbox):
"""Normal output should not trigger a retry."""
call_count = 0
def mock_exec(command, **kwargs):
nonlocal call_count
call_count += 1
return SimpleNamespace(data=SimpleNamespace(output="all good"))
sandbox._client.shell.exec_command = mock_exec
result = sandbox.execute_command("echo hello")
assert result == "all good"
assert call_count == 1
class TestListDirSerialization:
"""Verify that list_dir also acquires the lock."""
def test_list_dir_uses_lock(self, sandbox):
"""list_dir should hold the lock during execution."""
lock_was_held = []
original_exec = MagicMock(return_value=SimpleNamespace(data=SimpleNamespace(output="/a\n/b")))
def tracking_exec(command, **kwargs):
lock_was_held.append(sandbox._lock.locked())
return original_exec(command, **kwargs)
sandbox._client.shell.exec_command = tracking_exec
result = sandbox.list_dir("/test")
assert result == ["/a", "/b"]
assert lock_was_held == [True], "list_dir must hold the lock during exec_command"
class TestConcurrentFileWrites:
"""Verify file write paths do not lose concurrent updates."""
def test_append_should_preserve_both_parallel_writes(self, sandbox):
storage = {"content": "seed\n"}
active_reads = 0
state_lock = threading.Lock()
overlap_detected = threading.Event()
def overlapping_read_file(path):
nonlocal active_reads
with state_lock:
active_reads += 1
snapshot = storage["content"]
if active_reads == 2:
overlap_detected.set()
overlap_detected.wait(0.05)
with state_lock:
active_reads -= 1
return snapshot
def write_back(*, file, content, **kwargs):
storage["content"] = content
return SimpleNamespace(data=SimpleNamespace())
sandbox.read_file = overlapping_read_file
sandbox._client.file.write_file = write_back
barrier = threading.Barrier(2)
def writer(payload: str):
barrier.wait()
sandbox.write_file("/tmp/shared.log", payload, append=True)
threads = [
threading.Thread(target=writer, args=("A\n",)),
threading.Thread(target=writer, args=("B\n",)),
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert storage["content"] in {"seed\nA\nB\n", "seed\nB\nA\n"}