deer-flow/backend/tests/test_aio_sandbox.py
Ryker_Feng 5dc2d6cbf5
fix(sandbox): close AioSandbox HTTP client during provider teardown (#2872) (#3245)
* fix(sandbox): close AioSandbox HTTP client during provider teardown (#2872)

AioSandbox allocates a host-side agent_sandbox client (wrapping an
httpx.Client) in __init__, but AioSandboxProvider.release/destroy/shutdown
only popped provider state and tore down the backend container — the
client/transport owned by each cached AioSandbox was never explicitly
closed, accumulating unreclaimed sockets in long-running services.

- Add AioSandbox.close(): best-effort, idempotent close of the wrapped
  httpx_client (falls back to top-level client.close()); errors are
  logged but never raised so backend cleanup is never blocked.
- AioSandboxProvider.release()/destroy() now close the cached AioSandbox
  before dropping it; shutdown() inherits this via destroy().

* fix(sandbox): close the real httpx.Client owned by AioSandbox (#2872)

The previous close() only walked one level (wrapper.httpx_client), which resolves to the Fern-generated HttpClient wrapper that has no close(). The real socket-owning httpx.Client lives one level deeper at _client_wrapper.httpx_client.httpx_client, so the close path never fired and host-side sockets still leaked.

Resolve the real httpx.Client with graceful degradation; clear self._client under the lock for use-after-close and concurrent double-close safety; mark provider release()/destroy() try/except as defense-in-depth; rewrite TestClose against the real nested structure to lock down the original no-op bug.
2026-06-02 22:55:59 +08:00

394 lines
14 KiB
Python

"""Tests for AioSandbox concurrent command serialization (#1433)."""
import threading
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
@pytest.fixture()
def sandbox():
"""Create an AioSandbox with a mocked client."""
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
from deerflow.community.aio_sandbox.aio_sandbox import AioSandbox
sb = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
return sb
class TestExecuteCommandSerialization:
"""Verify that concurrent exec_command calls are serialized."""
def test_lock_prevents_concurrent_execution(self, sandbox):
"""Concurrent threads should not overlap inside execute_command."""
call_log = []
barrier = threading.Barrier(3)
def slow_exec(command, **kwargs):
call_log.append(("enter", command))
import time
time.sleep(0.05)
call_log.append(("exit", command))
return SimpleNamespace(data=SimpleNamespace(output=f"ok: {command}"))
sandbox._client.shell.exec_command = slow_exec
def worker(cmd):
barrier.wait() # ensure all threads contend for the lock simultaneously
sandbox.execute_command(cmd)
threads = []
for i in range(3):
t = threading.Thread(target=worker, args=(f"cmd-{i}",))
threads.append(t)
for t in threads:
t.start()
for t in threads:
t.join()
# Verify serialization: each "enter" should be followed by its own
# "exit" before the next "enter" (no interleaving).
enters = [i for i, (action, _) in enumerate(call_log) if action == "enter"]
exits = [i for i, (action, _) in enumerate(call_log) if action == "exit"]
assert len(enters) == 3
assert len(exits) == 3
for e_idx, x_idx in zip(enters, exits):
assert x_idx == e_idx + 1, f"Interleaved execution detected: {call_log}"
class TestErrorObservationRetry:
"""Verify ErrorObservation detection and fresh-session retry."""
def test_retry_on_error_observation(self, sandbox):
"""When output contains ErrorObservation, retry with a fresh session."""
call_count = 0
def mock_exec(command, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'"))
return SimpleNamespace(data=SimpleNamespace(output="success"))
sandbox._client.shell.exec_command = mock_exec
result = sandbox.execute_command("echo hello")
assert result == "success"
assert call_count == 2
def test_retry_passes_fresh_session_id(self, sandbox):
"""The retry call should include a new session id kwarg."""
calls = []
def mock_exec(command, **kwargs):
calls.append(kwargs)
if len(calls) == 1:
return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'"))
return SimpleNamespace(data=SimpleNamespace(output="ok"))
sandbox._client.shell.exec_command = mock_exec
sandbox.execute_command("test")
assert len(calls) == 2
assert "id" not in calls[0]
assert "id" in calls[1]
assert len(calls[1]["id"]) == 36 # UUID format
def test_no_retry_on_clean_output(self, sandbox):
"""Normal output should not trigger a retry."""
call_count = 0
def mock_exec(command, **kwargs):
nonlocal call_count
call_count += 1
return SimpleNamespace(data=SimpleNamespace(output="all good"))
sandbox._client.shell.exec_command = mock_exec
result = sandbox.execute_command("echo hello")
assert result == "all good"
assert call_count == 1
class TestListDirSerialization:
"""Verify that list_dir also acquires the lock."""
def test_list_dir_uses_lock(self, sandbox):
"""list_dir should hold the lock during execution."""
lock_was_held = []
original_exec = MagicMock(return_value=SimpleNamespace(data=SimpleNamespace(output="/a\n/b")))
def tracking_exec(command, **kwargs):
lock_was_held.append(sandbox._lock.locked())
return original_exec(command, **kwargs)
sandbox._client.shell.exec_command = tracking_exec
result = sandbox.list_dir("/test")
assert result == ["/a", "/b"]
assert lock_was_held == [True], "list_dir must hold the lock during exec_command"
class TestNoChangeTimeout:
"""Verify that no_change_timeout is forwarded to every exec_command call."""
def test_execute_command_passes_no_change_timeout(self, sandbox):
"""execute_command should pass no_change_timeout to exec_command."""
calls = []
def mock_exec(command, **kwargs):
calls.append(kwargs)
return SimpleNamespace(data=SimpleNamespace(output="ok"))
sandbox._client.shell.exec_command = mock_exec
sandbox.execute_command("echo hello")
assert len(calls) == 1
assert calls[0].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT
def test_retry_passes_no_change_timeout(self, sandbox):
"""The ErrorObservation retry path should also pass no_change_timeout."""
calls = []
def mock_exec(command, **kwargs):
calls.append(kwargs)
if len(calls) == 1:
return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'"))
return SimpleNamespace(data=SimpleNamespace(output="ok"))
sandbox._client.shell.exec_command = mock_exec
sandbox.execute_command("echo hello")
assert len(calls) == 2
assert calls[0].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT
assert calls[1].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT
def test_list_dir_passes_no_change_timeout(self, sandbox):
"""list_dir should pass no_change_timeout to exec_command."""
calls = []
def mock_exec(command, **kwargs):
calls.append(kwargs)
return SimpleNamespace(data=SimpleNamespace(output="/a\n/b"))
sandbox._client.shell.exec_command = mock_exec
sandbox.list_dir("/test")
assert len(calls) == 1
assert calls[0].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT
class TestConcurrentFileWrites:
"""Verify file write paths do not lose concurrent updates."""
def test_append_should_preserve_both_parallel_writes(self, sandbox):
storage = {"content": "seed\n"}
active_reads = 0
state_lock = threading.Lock()
overlap_detected = threading.Event()
def overlapping_read_file(path):
nonlocal active_reads
with state_lock:
active_reads += 1
snapshot = storage["content"]
if active_reads == 2:
overlap_detected.set()
overlap_detected.wait(0.05)
with state_lock:
active_reads -= 1
return snapshot
def write_back(*, file, content, **kwargs):
storage["content"] = content
return SimpleNamespace(data=SimpleNamespace())
sandbox.read_file = overlapping_read_file
sandbox._client.file.write_file = write_back
barrier = threading.Barrier(2)
def writer(payload: str):
barrier.wait()
sandbox.write_file("/tmp/shared.log", payload, append=True)
threads = [
threading.Thread(target=writer, args=("A\n",)),
threading.Thread(target=writer, args=("B\n",)),
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert storage["content"] in {"seed\nA\nB\n", "seed\nB\nA\n"}
class TestDownloadFile:
"""Tests for AioSandbox.download_file."""
def test_returns_concatenated_bytes(self, sandbox):
"""download_file should join chunks from the client iterator into bytes."""
sandbox._client.file.download_file = MagicMock(return_value=[b"hel", b"lo"])
result = sandbox.download_file("/mnt/user-data/outputs/file.bin")
assert result == b"hello"
sandbox._client.file.download_file.assert_called_once_with(path="/mnt/user-data/outputs/file.bin")
def test_returns_empty_bytes_for_empty_file(self, sandbox):
"""download_file should return b'' when the iterator yields nothing."""
sandbox._client.file.download_file = MagicMock(return_value=iter([]))
result = sandbox.download_file("/mnt/user-data/outputs/empty.bin")
assert result == b""
def test_uses_lock_during_download(self, sandbox):
"""download_file should hold the lock while calling the client."""
lock_was_held = []
def tracking_download(path):
lock_was_held.append(sandbox._lock.locked())
return iter([b"data"])
sandbox._client.file.download_file = tracking_download
sandbox.download_file("/mnt/user-data/outputs/file.bin")
assert lock_was_held == [True], "download_file must hold the lock during client call"
def test_raises_oserror_on_client_error(self, sandbox):
"""download_file should wrap client exceptions as OSError."""
sandbox._client.file.download_file = MagicMock(side_effect=RuntimeError("network error"))
with pytest.raises(OSError, match="network error"):
sandbox.download_file("/mnt/user-data/outputs/file.bin")
def test_preserves_oserror_from_client(self, sandbox):
"""OSError raised by the client should propagate without re-wrapping."""
sandbox._client.file.download_file = MagicMock(side_effect=OSError("disk error"))
with pytest.raises(OSError, match="disk error"):
sandbox.download_file("/mnt/user-data/outputs/file.bin")
def test_rejects_path_outside_virtual_prefix_and_logs_error(self, sandbox, caplog):
"""download_file must reject downloads outside /mnt/user-data and log the reason."""
sandbox._client.file.download_file = MagicMock()
with caplog.at_level("ERROR"):
with pytest.raises(PermissionError, match="must be under"):
sandbox.download_file("/etc/passwd")
assert "outside allowed directory" in caplog.text
sandbox._client.file.download_file.assert_not_called()
@pytest.mark.parametrize(
"path",
[
"/mnt/workspace/../../etc/passwd",
"../secret",
"/a/b/../../../etc/shadow",
],
)
def test_rejects_path_traversal(self, sandbox, path):
"""download_file must reject paths containing '..' before calling the client."""
sandbox._client.file.download_file = MagicMock()
with pytest.raises(PermissionError, match="path traversal"):
sandbox.download_file(path)
sandbox._client.file.download_file.assert_not_called()
def test_single_chunk(self, sandbox):
"""download_file should work correctly with a single-chunk response."""
sandbox._client.file.download_file = MagicMock(return_value=[b"single-chunk"])
result = sandbox.download_file("/mnt/user-data/outputs/single.bin")
assert result == b"single-chunk"
class TestClose:
"""Verify AioSandbox.close() tears down the host-side HTTP client (#2872)."""
def test_close_calls_real_nested_httpx_client(self, sandbox):
"""close() must close the real httpx.Client at the bottom of the chain.
Mirrors the actual Fern structure:
Sandbox._client_wrapper.httpx_client -> Fern HttpClient (no close())
.httpx_client -> httpx.Client (the real owner)
The intermediate HttpClient deliberately exposes NO close(), so a naive
one-level lookup (the original bug) would silently close nothing.
"""
real_httpx = MagicMock(spec=["close"])
fern_http = SimpleNamespace(httpx_client=real_httpx) # no close on this layer
sandbox._client._client_wrapper = SimpleNamespace(httpx_client=fern_http)
sandbox.close()
real_httpx.close.assert_called_once_with()
def test_close_clears_client_reference(self, sandbox):
"""After close(), the client reference must be dropped (use-after-close safety)."""
real_httpx = MagicMock(spec=["close"])
fern_http = SimpleNamespace(httpx_client=real_httpx)
sandbox._client._client_wrapper = SimpleNamespace(httpx_client=fern_http)
sandbox.close()
assert sandbox._client is None
assert sandbox._closed is True
def test_close_is_idempotent(self, sandbox):
"""Calling close() multiple times must close the underlying client at most once."""
real_httpx = MagicMock(spec=["close"])
fern_http = SimpleNamespace(httpx_client=real_httpx)
sandbox._client._client_wrapper = SimpleNamespace(httpx_client=fern_http)
sandbox.close()
sandbox.close()
sandbox.close()
assert real_httpx.close.call_count == 1
def test_close_swallows_exceptions(self, sandbox, caplog):
"""close() must be best-effort: client errors are logged but never raised."""
real_httpx = MagicMock(spec=["close"])
real_httpx.close.side_effect = RuntimeError("teardown boom")
fern_http = SimpleNamespace(httpx_client=real_httpx)
sandbox._client._client_wrapper = SimpleNamespace(httpx_client=fern_http)
with caplog.at_level("WARNING"):
sandbox.close()
assert "Error closing AioSandbox client" in caplog.text
def test_close_falls_back_to_client_close(self, sandbox):
"""If no nested httpx.Client is reachable, close() degrades to the client's own close()."""
# Replace the mocked client with a stub that exposes only top-level close()
client = MagicMock(spec=["close"])
sandbox._client = client
sandbox.close()
client.close.assert_called_once_with()
def test_close_when_no_close_attr_does_not_raise(self, sandbox):
"""A client without any close attribute must not crash close()."""
sandbox._client = SimpleNamespace() # no close, no _client_wrapper
sandbox.close() # must not raise
assert sandbox._client is None