mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
* fix: prevent concurrent subagent file write conflicts Serialize same-path str_replace operations in sandbox tools Guard AioSandbox write_file/update_file with the existing sandbox lock Add regression tests for concurrent str_replace and append races Verify with backend full tests and ruff lint checks * fix(sandbox): Fix the concurrency issue of file operations on the same path in isolated sandboxes. Ensure that different sandbox instances use independent locks for file operations on the same virtual path to avoid concurrency conflicts. Change the lock key from a single path to a composite key of (sandbox.id, path), and add tests to verify the concurrent safety of isolated sandboxes. * feat(sandbox): Extract file operation lock logic to standalone module and fix concurrency issues Extract file operation lock related logic from tools.py into a separate file_operation_lock.py module. Fix data race issues during concurrent str_replace and write_file operations.
184 lines
6.1 KiB
Python
184 lines
6.1 KiB
Python
"""Tests for AioSandbox concurrent command serialization (#1433)."""
|
|
|
|
import threading
|
|
from types import SimpleNamespace
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
|
|
@pytest.fixture()
|
|
def sandbox():
|
|
"""Create an AioSandbox with a mocked client."""
|
|
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
|
|
from deerflow.community.aio_sandbox.aio_sandbox import AioSandbox
|
|
|
|
sb = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
|
|
return sb
|
|
|
|
|
|
class TestExecuteCommandSerialization:
|
|
"""Verify that concurrent exec_command calls are serialized."""
|
|
|
|
def test_lock_prevents_concurrent_execution(self, sandbox):
|
|
"""Concurrent threads should not overlap inside execute_command."""
|
|
call_log = []
|
|
barrier = threading.Barrier(3)
|
|
|
|
def slow_exec(command, **kwargs):
|
|
call_log.append(("enter", command))
|
|
import time
|
|
|
|
time.sleep(0.05)
|
|
call_log.append(("exit", command))
|
|
return SimpleNamespace(data=SimpleNamespace(output=f"ok: {command}"))
|
|
|
|
sandbox._client.shell.exec_command = slow_exec
|
|
|
|
def worker(cmd):
|
|
barrier.wait() # ensure all threads contend for the lock simultaneously
|
|
sandbox.execute_command(cmd)
|
|
|
|
threads = []
|
|
for i in range(3):
|
|
t = threading.Thread(target=worker, args=(f"cmd-{i}",))
|
|
threads.append(t)
|
|
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
|
|
# Verify serialization: each "enter" should be followed by its own
|
|
# "exit" before the next "enter" (no interleaving).
|
|
enters = [i for i, (action, _) in enumerate(call_log) if action == "enter"]
|
|
exits = [i for i, (action, _) in enumerate(call_log) if action == "exit"]
|
|
assert len(enters) == 3
|
|
assert len(exits) == 3
|
|
for e_idx, x_idx in zip(enters, exits):
|
|
assert x_idx == e_idx + 1, f"Interleaved execution detected: {call_log}"
|
|
|
|
|
|
class TestErrorObservationRetry:
|
|
"""Verify ErrorObservation detection and fresh-session retry."""
|
|
|
|
def test_retry_on_error_observation(self, sandbox):
|
|
"""When output contains ErrorObservation, retry with a fresh session."""
|
|
call_count = 0
|
|
|
|
def mock_exec(command, **kwargs):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count == 1:
|
|
return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'"))
|
|
return SimpleNamespace(data=SimpleNamespace(output="success"))
|
|
|
|
sandbox._client.shell.exec_command = mock_exec
|
|
|
|
result = sandbox.execute_command("echo hello")
|
|
assert result == "success"
|
|
assert call_count == 2
|
|
|
|
def test_retry_passes_fresh_session_id(self, sandbox):
|
|
"""The retry call should include a new session id kwarg."""
|
|
calls = []
|
|
|
|
def mock_exec(command, **kwargs):
|
|
calls.append(kwargs)
|
|
if len(calls) == 1:
|
|
return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'"))
|
|
return SimpleNamespace(data=SimpleNamespace(output="ok"))
|
|
|
|
sandbox._client.shell.exec_command = mock_exec
|
|
|
|
sandbox.execute_command("test")
|
|
assert len(calls) == 2
|
|
assert "id" not in calls[0]
|
|
assert "id" in calls[1]
|
|
assert len(calls[1]["id"]) == 36 # UUID format
|
|
|
|
def test_no_retry_on_clean_output(self, sandbox):
|
|
"""Normal output should not trigger a retry."""
|
|
call_count = 0
|
|
|
|
def mock_exec(command, **kwargs):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
return SimpleNamespace(data=SimpleNamespace(output="all good"))
|
|
|
|
sandbox._client.shell.exec_command = mock_exec
|
|
|
|
result = sandbox.execute_command("echo hello")
|
|
assert result == "all good"
|
|
assert call_count == 1
|
|
|
|
|
|
class TestListDirSerialization:
|
|
"""Verify that list_dir also acquires the lock."""
|
|
|
|
def test_list_dir_uses_lock(self, sandbox):
|
|
"""list_dir should hold the lock during execution."""
|
|
lock_was_held = []
|
|
|
|
original_exec = MagicMock(return_value=SimpleNamespace(data=SimpleNamespace(output="/a\n/b")))
|
|
|
|
def tracking_exec(command, **kwargs):
|
|
lock_was_held.append(sandbox._lock.locked())
|
|
return original_exec(command, **kwargs)
|
|
|
|
sandbox._client.shell.exec_command = tracking_exec
|
|
|
|
result = sandbox.list_dir("/test")
|
|
assert result == ["/a", "/b"]
|
|
assert lock_was_held == [True], "list_dir must hold the lock during exec_command"
|
|
|
|
|
|
class TestConcurrentFileWrites:
|
|
"""Verify file write paths do not lose concurrent updates."""
|
|
|
|
def test_append_should_preserve_both_parallel_writes(self, sandbox):
|
|
storage = {"content": "seed\n"}
|
|
active_reads = 0
|
|
state_lock = threading.Lock()
|
|
overlap_detected = threading.Event()
|
|
|
|
def overlapping_read_file(path):
|
|
nonlocal active_reads
|
|
with state_lock:
|
|
active_reads += 1
|
|
snapshot = storage["content"]
|
|
if active_reads == 2:
|
|
overlap_detected.set()
|
|
|
|
overlap_detected.wait(0.05)
|
|
|
|
with state_lock:
|
|
active_reads -= 1
|
|
|
|
return snapshot
|
|
|
|
def write_back(*, file, content, **kwargs):
|
|
storage["content"] = content
|
|
return SimpleNamespace(data=SimpleNamespace())
|
|
|
|
sandbox.read_file = overlapping_read_file
|
|
sandbox._client.file.write_file = write_back
|
|
|
|
barrier = threading.Barrier(2)
|
|
|
|
def writer(payload: str):
|
|
barrier.wait()
|
|
sandbox.write_file("/tmp/shared.log", payload, append=True)
|
|
|
|
threads = [
|
|
threading.Thread(target=writer, args=("A\n",)),
|
|
threading.Thread(target=writer, args=("B\n",)),
|
|
]
|
|
|
|
for thread in threads:
|
|
thread.start()
|
|
for thread in threads:
|
|
thread.join()
|
|
|
|
assert storage["content"] in {"seed\nA\nB\n", "seed\nB\nA\n"}
|