mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
fix(sandbox): serialize concurrent exec_command calls in AioSandbox (#1435)
* fix(sandbox): serialize concurrent exec_command calls in AioSandbox The AIO sandbox container maintains a single persistent shell session that corrupts when multiple exec_command requests arrive concurrently (e.g. when ToolNode issues parallel tool_calls). The corrupted session returns 'ErrorObservation' strings as output, cascading into subsequent commands. Add a threading.Lock to AioSandbox to serialize shell commands. As a secondary defense, detect ErrorObservation in output and retry with a fresh session ID. Fixes #1433 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix(sandbox): address Copilot review findings - Fix shell injection in list_dir: use shlex.quote(path) to escape user-provided paths in the find command - Narrow ErrorObservation retry condition from broad substring match to the specific corruption signature to prevent false retries - Improve test_lock_prevents_concurrent_execution: use threading.Barrier to ensure all workers contend for the lock simultaneously - Improve test_list_dir_uses_lock: assert lock.locked() is True during exec_command to verify lock acquisition * style: auto-format with ruff --------- Co-authored-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
aae59a8ba8
commit
a3bfea631c
@ -1,5 +1,8 @@
|
|||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
|
import shlex
|
||||||
|
import threading
|
||||||
|
import uuid
|
||||||
|
|
||||||
from agent_sandbox import Sandbox as AioSandboxClient
|
from agent_sandbox import Sandbox as AioSandboxClient
|
||||||
|
|
||||||
@ -7,11 +10,15 @@ from deerflow.sandbox.sandbox import Sandbox
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_ERROR_OBSERVATION_SIGNATURE = "'ErrorObservation' object has no attribute 'exit_code'"
|
||||||
|
|
||||||
|
|
||||||
class AioSandbox(Sandbox):
|
class AioSandbox(Sandbox):
|
||||||
"""Sandbox implementation using the agent-infra/sandbox Docker container.
|
"""Sandbox implementation using the agent-infra/sandbox Docker container.
|
||||||
|
|
||||||
This sandbox connects to a running AIO sandbox container via HTTP API.
|
This sandbox connects to a running AIO sandbox container via HTTP API.
|
||||||
|
A threading lock serializes shell commands to prevent concurrent requests
|
||||||
|
from corrupting the container's single persistent session (see #1433).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, id: str, base_url: str, home_dir: str | None = None):
|
def __init__(self, id: str, base_url: str, home_dir: str | None = None):
|
||||||
@ -26,6 +33,7 @@ class AioSandbox(Sandbox):
|
|||||||
self._base_url = base_url
|
self._base_url = base_url
|
||||||
self._client = AioSandboxClient(base_url=base_url, timeout=600)
|
self._client = AioSandboxClient(base_url=base_url, timeout=600)
|
||||||
self._home_dir = home_dir
|
self._home_dir = home_dir
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def base_url(self) -> str:
|
def base_url(self) -> str:
|
||||||
@ -42,19 +50,34 @@ class AioSandbox(Sandbox):
|
|||||||
def execute_command(self, command: str) -> str:
|
def execute_command(self, command: str) -> str:
|
||||||
"""Execute a shell command in the sandbox.
|
"""Execute a shell command in the sandbox.
|
||||||
|
|
||||||
|
Uses a lock to serialize concurrent requests. The AIO sandbox
|
||||||
|
container maintains a single persistent shell session that
|
||||||
|
corrupts when hit with concurrent exec_command calls (returns
|
||||||
|
``ErrorObservation`` instead of real output). If corruption is
|
||||||
|
detected despite the lock (e.g. multiple processes sharing a
|
||||||
|
sandbox), the command is retried on a fresh session.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
command: The command to execute.
|
command: The command to execute.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The output of the command.
|
The output of the command.
|
||||||
"""
|
"""
|
||||||
try:
|
with self._lock:
|
||||||
result = self._client.shell.exec_command(command=command)
|
try:
|
||||||
output = result.data.output if result.data else ""
|
result = self._client.shell.exec_command(command=command)
|
||||||
return output if output else "(no output)"
|
output = result.data.output if result.data else ""
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to execute command in sandbox: {e}")
|
if output and _ERROR_OBSERVATION_SIGNATURE in output:
|
||||||
return f"Error: {e}"
|
logger.warning("ErrorObservation detected in sandbox output, retrying with a fresh session")
|
||||||
|
fresh_id = str(uuid.uuid4())
|
||||||
|
result = self._client.shell.exec_command(command=command, id=fresh_id)
|
||||||
|
output = result.data.output if result.data else ""
|
||||||
|
|
||||||
|
return output if output else "(no output)"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to execute command in sandbox: {e}")
|
||||||
|
return f"Error: {e}"
|
||||||
|
|
||||||
def read_file(self, path: str) -> str:
|
def read_file(self, path: str) -> str:
|
||||||
"""Read the content of a file in the sandbox.
|
"""Read the content of a file in the sandbox.
|
||||||
@ -82,17 +105,16 @@ class AioSandbox(Sandbox):
|
|||||||
Returns:
|
Returns:
|
||||||
The contents of the directory.
|
The contents of the directory.
|
||||||
"""
|
"""
|
||||||
try:
|
with self._lock:
|
||||||
# Use shell command to list directory with depth limit
|
try:
|
||||||
# The -L flag limits the depth for the tree command
|
result = self._client.shell.exec_command(command=f"find {shlex.quote(path)} -maxdepth {max_depth} -type f -o -type d 2>/dev/null | head -500")
|
||||||
result = self._client.shell.exec_command(command=f"find {path} -maxdepth {max_depth} -type f -o -type d 2>/dev/null | head -500")
|
output = result.data.output if result.data else ""
|
||||||
output = result.data.output if result.data else ""
|
if output:
|
||||||
if output:
|
return [line.strip() for line in output.strip().split("\n") if line.strip()]
|
||||||
return [line.strip() for line in output.strip().split("\n") if line.strip()]
|
return []
|
||||||
return []
|
except Exception as e:
|
||||||
except Exception as e:
|
logger.error(f"Failed to list directory in sandbox: {e}")
|
||||||
logger.error(f"Failed to list directory in sandbox: {e}")
|
return []
|
||||||
return []
|
|
||||||
|
|
||||||
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
||||||
"""Write content to a file in the sandbox.
|
"""Write content to a file in the sandbox.
|
||||||
|
|||||||
133
backend/tests/test_aio_sandbox.py
Normal file
133
backend/tests/test_aio_sandbox.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
"""Tests for AioSandbox concurrent command serialization (#1433)."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def sandbox():
|
||||||
|
"""Create an AioSandbox with a mocked client."""
|
||||||
|
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
|
||||||
|
from deerflow.community.aio_sandbox.aio_sandbox import AioSandbox
|
||||||
|
|
||||||
|
sb = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
|
||||||
|
return sb
|
||||||
|
|
||||||
|
|
||||||
|
class TestExecuteCommandSerialization:
|
||||||
|
"""Verify that concurrent exec_command calls are serialized."""
|
||||||
|
|
||||||
|
def test_lock_prevents_concurrent_execution(self, sandbox):
|
||||||
|
"""Concurrent threads should not overlap inside execute_command."""
|
||||||
|
call_log = []
|
||||||
|
barrier = threading.Barrier(3)
|
||||||
|
|
||||||
|
def slow_exec(command, **kwargs):
|
||||||
|
call_log.append(("enter", command))
|
||||||
|
import time
|
||||||
|
|
||||||
|
time.sleep(0.05)
|
||||||
|
call_log.append(("exit", command))
|
||||||
|
return SimpleNamespace(data=SimpleNamespace(output=f"ok: {command}"))
|
||||||
|
|
||||||
|
sandbox._client.shell.exec_command = slow_exec
|
||||||
|
|
||||||
|
def worker(cmd):
|
||||||
|
barrier.wait() # ensure all threads contend for the lock simultaneously
|
||||||
|
sandbox.execute_command(cmd)
|
||||||
|
|
||||||
|
threads = []
|
||||||
|
for i in range(3):
|
||||||
|
t = threading.Thread(target=worker, args=(f"cmd-{i}",))
|
||||||
|
threads.append(t)
|
||||||
|
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
# Verify serialization: each "enter" should be followed by its own
|
||||||
|
# "exit" before the next "enter" (no interleaving).
|
||||||
|
enters = [i for i, (action, _) in enumerate(call_log) if action == "enter"]
|
||||||
|
exits = [i for i, (action, _) in enumerate(call_log) if action == "exit"]
|
||||||
|
assert len(enters) == 3
|
||||||
|
assert len(exits) == 3
|
||||||
|
for e_idx, x_idx in zip(enters, exits):
|
||||||
|
assert x_idx == e_idx + 1, f"Interleaved execution detected: {call_log}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestErrorObservationRetry:
|
||||||
|
"""Verify ErrorObservation detection and fresh-session retry."""
|
||||||
|
|
||||||
|
def test_retry_on_error_observation(self, sandbox):
|
||||||
|
"""When output contains ErrorObservation, retry with a fresh session."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def mock_exec(command, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'"))
|
||||||
|
return SimpleNamespace(data=SimpleNamespace(output="success"))
|
||||||
|
|
||||||
|
sandbox._client.shell.exec_command = mock_exec
|
||||||
|
|
||||||
|
result = sandbox.execute_command("echo hello")
|
||||||
|
assert result == "success"
|
||||||
|
assert call_count == 2
|
||||||
|
|
||||||
|
def test_retry_passes_fresh_session_id(self, sandbox):
|
||||||
|
"""The retry call should include a new session id kwarg."""
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
def mock_exec(command, **kwargs):
|
||||||
|
calls.append(kwargs)
|
||||||
|
if len(calls) == 1:
|
||||||
|
return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'"))
|
||||||
|
return SimpleNamespace(data=SimpleNamespace(output="ok"))
|
||||||
|
|
||||||
|
sandbox._client.shell.exec_command = mock_exec
|
||||||
|
|
||||||
|
sandbox.execute_command("test")
|
||||||
|
assert len(calls) == 2
|
||||||
|
assert "id" not in calls[0]
|
||||||
|
assert "id" in calls[1]
|
||||||
|
assert len(calls[1]["id"]) == 36 # UUID format
|
||||||
|
|
||||||
|
def test_no_retry_on_clean_output(self, sandbox):
|
||||||
|
"""Normal output should not trigger a retry."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def mock_exec(command, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return SimpleNamespace(data=SimpleNamespace(output="all good"))
|
||||||
|
|
||||||
|
sandbox._client.shell.exec_command = mock_exec
|
||||||
|
|
||||||
|
result = sandbox.execute_command("echo hello")
|
||||||
|
assert result == "all good"
|
||||||
|
assert call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestListDirSerialization:
|
||||||
|
"""Verify that list_dir also acquires the lock."""
|
||||||
|
|
||||||
|
def test_list_dir_uses_lock(self, sandbox):
|
||||||
|
"""list_dir should hold the lock during execution."""
|
||||||
|
lock_was_held = []
|
||||||
|
|
||||||
|
original_exec = MagicMock(return_value=SimpleNamespace(data=SimpleNamespace(output="/a\n/b")))
|
||||||
|
|
||||||
|
def tracking_exec(command, **kwargs):
|
||||||
|
lock_was_held.append(sandbox._lock.locked())
|
||||||
|
return original_exec(command, **kwargs)
|
||||||
|
|
||||||
|
sandbox._client.shell.exec_command = tracking_exec
|
||||||
|
|
||||||
|
result = sandbox.list_dir("/test")
|
||||||
|
assert result == ["/a", "/b"]
|
||||||
|
assert lock_was_held == [True], "list_dir must hold the lock during exec_command"
|
||||||
Loading…
x
Reference in New Issue
Block a user