mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
fix(sandbox): add input sanitisation guard to SandboxAuditMiddleware (#1872)
* fix(sandbox): add L2 input sanitisation to SandboxAuditMiddleware Add _validate_input() to reject malformed bash commands before regex classification: empty commands, oversized commands (>10 000 chars), and null bytes that could cause detection/execution layer inconsistency. * fix(sandbox): address Copilot review — type guard, log truncation, reject reason - Coerce None/non-string command to str before validation - Truncate oversized commands in audit logs to prevent log amplification - Propagate reject_reason through _pre_process() to block message - Remove L2 label from comments and test class names * fix(sandbox): isinstance type guard + async input sanitisation tests Address review comments: - Replace str() coercion with isinstance(raw_command, str) guard so non-string truthy values (0, [], False) fall back to empty string instead of passing validation as "0"/"[]"/"False". - Add TestInputSanitisationBlocksInAwrapToolCall with 4 async tests covering empty, null-byte, oversized, and None command via awrap_tool_call path.
This commit is contained in:
parent
1ced6e977c
commit
055e4df049
@ -105,11 +105,16 @@ class SandboxAuditMiddleware(AgentMiddleware[ThreadState]):
|
||||
thread_id = cfg.get("configurable", {}).get("thread_id")
|
||||
return thread_id
|
||||
|
||||
def _write_audit(self, thread_id: str | None, command: str, verdict: str) -> None:
|
||||
_AUDIT_COMMAND_LIMIT = 200
|
||||
|
||||
def _write_audit(self, thread_id: str | None, command: str, verdict: str, *, truncate: bool = False) -> None:
|
||||
audited_command = command
|
||||
if truncate and len(command) > self._AUDIT_COMMAND_LIMIT:
|
||||
audited_command = f"{command[: self._AUDIT_COMMAND_LIMIT]}... ({len(command)} chars)"
|
||||
record = {
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"thread_id": thread_id or "unknown",
|
||||
"command": command,
|
||||
"command": audited_command,
|
||||
"verdict": verdict,
|
||||
}
|
||||
logger.info("[SandboxAudit] %s", json.dumps(record, ensure_ascii=False))
|
||||
@ -139,23 +144,52 @@ class SandboxAuditMiddleware(AgentMiddleware[ThreadState]):
|
||||
status=result.status,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Input sanitisation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# Normal bash commands rarely exceed a few hundred characters. 10 000 is
|
||||
# well above any legitimate use case yet a tiny fraction of Linux ARG_MAX.
|
||||
# Anything longer is almost certainly a payload injection or base64-encoded
|
||||
# attack string.
|
||||
_MAX_COMMAND_LENGTH = 10_000
|
||||
|
||||
def _validate_input(self, command: str) -> str | None:
|
||||
"""Return ``None`` if *command* is acceptable, else a rejection reason."""
|
||||
if not command.strip():
|
||||
return "empty command"
|
||||
if len(command) > self._MAX_COMMAND_LENGTH:
|
||||
return "command too long"
|
||||
if "\x00" in command:
|
||||
return "null byte detected"
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Core logic (shared between sync and async paths)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _pre_process(self, request: ToolCallRequest) -> tuple[str, str | None, str]:
|
||||
def _pre_process(self, request: ToolCallRequest) -> tuple[str, str | None, str, str | None]:
|
||||
"""
|
||||
Returns (command, thread_id, verdict).
|
||||
Returns (command, thread_id, verdict, reject_reason).
|
||||
verdict is 'block', 'warn', or 'pass'.
|
||||
reject_reason is non-None only for input sanitisation rejections.
|
||||
"""
|
||||
args = request.tool_call.get("args", {})
|
||||
command: str = args.get("command", "")
|
||||
raw_command = args.get("command")
|
||||
command = raw_command if isinstance(raw_command, str) else ""
|
||||
thread_id = self._get_thread_id(request)
|
||||
|
||||
# ① classify command
|
||||
# ① input sanitisation — reject malformed input before regex analysis
|
||||
reject_reason = self._validate_input(command)
|
||||
if reject_reason:
|
||||
self._write_audit(thread_id, command, "block", truncate=True)
|
||||
logger.warning("[SandboxAudit] INVALID INPUT thread=%s reason=%s", thread_id, reject_reason)
|
||||
return command, thread_id, "block", reject_reason
|
||||
|
||||
# ② classify command
|
||||
verdict = _classify_command(command)
|
||||
|
||||
# ② audit log
|
||||
# ③ audit log
|
||||
self._write_audit(thread_id, command, verdict)
|
||||
|
||||
if verdict == "block":
|
||||
@ -163,7 +197,7 @@ class SandboxAuditMiddleware(AgentMiddleware[ThreadState]):
|
||||
elif verdict == "warn":
|
||||
logger.warning("[SandboxAudit] WARN (medium-risk) thread=%s cmd=%r", thread_id, command)
|
||||
|
||||
return command, thread_id, verdict
|
||||
return command, thread_id, verdict, None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# wrap_tool_call hooks
|
||||
@ -178,9 +212,10 @@ class SandboxAuditMiddleware(AgentMiddleware[ThreadState]):
|
||||
if request.tool_call.get("name") != "bash":
|
||||
return handler(request)
|
||||
|
||||
command, _, verdict = self._pre_process(request)
|
||||
command, _, verdict, reject_reason = self._pre_process(request)
|
||||
if verdict == "block":
|
||||
return self._build_block_message(request, "security violation detected")
|
||||
reason = reject_reason or "security violation detected"
|
||||
return self._build_block_message(request, reason)
|
||||
result = handler(request)
|
||||
if verdict == "warn":
|
||||
result = self._append_warn_to_result(result, command)
|
||||
@ -195,9 +230,10 @@ class SandboxAuditMiddleware(AgentMiddleware[ThreadState]):
|
||||
if request.tool_call.get("name") != "bash":
|
||||
return await handler(request)
|
||||
|
||||
command, _, verdict = self._pre_process(request)
|
||||
command, _, verdict, reject_reason = self._pre_process(request)
|
||||
if verdict == "block":
|
||||
return self._build_block_message(request, "security violation detected")
|
||||
reason = reject_reason or "security violation detected"
|
||||
return self._build_block_message(request, reason)
|
||||
result = await handler(request)
|
||||
if verdict == "warn":
|
||||
result = self._append_warn_to_result(result, command)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""Tests for SandboxAuditMiddleware - command classification and audit logging."""
|
||||
|
||||
import unittest.mock
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@ -134,6 +135,98 @@ class TestClassifyCommand:
|
||||
assert _classify_command(cmd) == "pass", f"Expected 'pass' for: {cmd!r}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_input unit tests (input sanitisation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidateInput:
|
||||
def setup_method(self):
|
||||
self.mw = SandboxAuditMiddleware()
|
||||
|
||||
def test_empty_string_rejected(self):
|
||||
assert self.mw._validate_input("") == "empty command"
|
||||
|
||||
def test_whitespace_only_rejected(self):
|
||||
assert self.mw._validate_input(" \t\n ") == "empty command"
|
||||
|
||||
def test_normal_command_accepted(self):
|
||||
assert self.mw._validate_input("ls -la") is None
|
||||
|
||||
def test_command_at_max_length_accepted(self):
|
||||
cmd = "a" * 10_000
|
||||
assert self.mw._validate_input(cmd) is None
|
||||
|
||||
def test_command_exceeding_max_length_rejected(self):
|
||||
cmd = "a" * 10_001
|
||||
assert self.mw._validate_input(cmd) == "command too long"
|
||||
|
||||
def test_null_byte_rejected(self):
|
||||
assert self.mw._validate_input("ls\x00; rm -rf /") == "null byte detected"
|
||||
|
||||
def test_null_byte_at_start_rejected(self):
|
||||
assert self.mw._validate_input("\x00ls") == "null byte detected"
|
||||
|
||||
def test_null_byte_at_end_rejected(self):
|
||||
assert self.mw._validate_input("ls\x00") == "null byte detected"
|
||||
|
||||
|
||||
class TestInputSanitisationBlocksInWrapToolCall:
|
||||
"""Verify that input sanitisation rejections flow through wrap_tool_call correctly."""
|
||||
|
||||
def setup_method(self):
|
||||
self.mw = SandboxAuditMiddleware()
|
||||
|
||||
def test_empty_command_blocked_with_reason(self):
|
||||
request = _make_request("")
|
||||
handler = _make_handler()
|
||||
result = self.mw.wrap_tool_call(request, handler)
|
||||
assert not handler.called
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.status == "error"
|
||||
assert "empty command" in result.content.lower()
|
||||
|
||||
def test_null_byte_command_blocked_with_reason(self):
|
||||
request = _make_request("echo\x00rm -rf /")
|
||||
handler = _make_handler()
|
||||
result = self.mw.wrap_tool_call(request, handler)
|
||||
assert not handler.called
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.status == "error"
|
||||
assert "null byte" in result.content.lower()
|
||||
|
||||
def test_oversized_command_blocked_with_reason(self):
|
||||
request = _make_request("a" * 10_001)
|
||||
handler = _make_handler()
|
||||
result = self.mw.wrap_tool_call(request, handler)
|
||||
assert not handler.called
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.status == "error"
|
||||
assert "command too long" in result.content.lower()
|
||||
|
||||
def test_none_command_coerced_to_empty(self):
|
||||
"""args.get('command') returning None should be coerced to str and rejected as empty."""
|
||||
request = _make_request("")
|
||||
# Simulate None value by patching args directly
|
||||
request.tool_call["args"]["command"] = None
|
||||
handler = _make_handler()
|
||||
result = self.mw.wrap_tool_call(request, handler)
|
||||
assert not handler.called
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.status == "error"
|
||||
|
||||
def test_oversized_command_audit_log_truncated(self):
|
||||
"""Oversized commands should be truncated in audit logs to prevent log amplification."""
|
||||
big_cmd = "x" * 10_001
|
||||
request = _make_request(big_cmd)
|
||||
handler = _make_handler()
|
||||
with unittest.mock.patch.object(self.mw, "_write_audit", wraps=self.mw._write_audit) as spy:
|
||||
self.mw.wrap_tool_call(request, handler)
|
||||
spy.assert_called_once()
|
||||
_, kwargs = spy.call_args
|
||||
assert kwargs.get("truncate") is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SandboxAuditMiddleware.wrap_tool_call integration tests
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -301,6 +394,63 @@ class TestSandboxAuditMiddlewareAwrapToolCall:
|
||||
assert result == handler_mock.return_value
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input sanitisation via awrap_tool_call (async path)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInputSanitisationBlocksInAwrapToolCall:
|
||||
"""Verify that input sanitisation rejections flow through awrap_tool_call correctly."""
|
||||
|
||||
def setup_method(self):
|
||||
self.mw = SandboxAuditMiddleware()
|
||||
|
||||
async def _call_async(self, request):
|
||||
handler_mock = _make_handler()
|
||||
|
||||
async def async_handler(req):
|
||||
return handler_mock(req)
|
||||
|
||||
result = await self.mw.awrap_tool_call(request, async_handler)
|
||||
return result, handler_mock.called
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_empty_command_blocked_with_reason(self):
|
||||
request = _make_request("")
|
||||
result, called = await self._call_async(request)
|
||||
assert not called
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.status == "error"
|
||||
assert "empty command" in result.content.lower()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_null_byte_command_blocked_with_reason(self):
|
||||
request = _make_request("echo\x00rm -rf /")
|
||||
result, called = await self._call_async(request)
|
||||
assert not called
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.status == "error"
|
||||
assert "null byte" in result.content.lower()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_oversized_command_blocked_with_reason(self):
|
||||
request = _make_request("a" * 10_001)
|
||||
result, called = await self._call_async(request)
|
||||
assert not called
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.status == "error"
|
||||
assert "command too long" in result.content.lower()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_none_command_coerced_to_empty(self):
|
||||
request = _make_request("")
|
||||
request.tool_call["args"]["command"] = None
|
||||
result, called = await self._call_async(request)
|
||||
assert not called
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.status == "error"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Precision / recall summary (asserted metrics for benchmark reporting)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user