fix(sandbox): add input sanitisation guard to SandboxAuditMiddleware (#1872)

* fix(sandbox): add L2 input sanitisation to SandboxAuditMiddleware

Add _validate_input() to reject malformed bash commands before regex
classification: empty commands, oversized commands (>10 000 chars), and
null bytes that could cause detection/execution layer inconsistency.

* fix(sandbox): address Copilot review — type guard, log truncation, reject reason

- Coerce None/non-string command to str before validation
- Truncate oversized commands in audit logs to prevent log amplification
- Propagate reject_reason through _pre_process() to block message
- Remove L2 label from comments and test class names

* fix(sandbox): isinstance type guard + async input sanitisation tests

Address review comments:
- Replace str() coercion with isinstance(raw_command, str) guard so
  non-string truthy values (0, [], False) fall back to empty string
  instead of passing validation as "0"/"[]"/"False".
- Add TestInputSanitisationBlocksInAwrapToolCall with 4 async tests
  covering empty, null-byte, oversized, and None command via
  awrap_tool_call path.
This commit is contained in:
KKK 2026-04-06 17:21:58 +08:00 committed by GitHub
parent 1ced6e977c
commit 055e4df049
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 198 additions and 12 deletions

View File

@ -105,11 +105,16 @@ class SandboxAuditMiddleware(AgentMiddleware[ThreadState]):
thread_id = cfg.get("configurable", {}).get("thread_id") thread_id = cfg.get("configurable", {}).get("thread_id")
return thread_id return thread_id
def _write_audit(self, thread_id: str | None, command: str, verdict: str) -> None: _AUDIT_COMMAND_LIMIT = 200
def _write_audit(self, thread_id: str | None, command: str, verdict: str, *, truncate: bool = False) -> None:
audited_command = command
if truncate and len(command) > self._AUDIT_COMMAND_LIMIT:
audited_command = f"{command[: self._AUDIT_COMMAND_LIMIT]}... ({len(command)} chars)"
record = { record = {
"timestamp": datetime.now(UTC).isoformat(), "timestamp": datetime.now(UTC).isoformat(),
"thread_id": thread_id or "unknown", "thread_id": thread_id or "unknown",
"command": command, "command": audited_command,
"verdict": verdict, "verdict": verdict,
} }
logger.info("[SandboxAudit] %s", json.dumps(record, ensure_ascii=False)) logger.info("[SandboxAudit] %s", json.dumps(record, ensure_ascii=False))
@ -139,23 +144,52 @@ class SandboxAuditMiddleware(AgentMiddleware[ThreadState]):
status=result.status, status=result.status,
) )
# ------------------------------------------------------------------
# Input sanitisation
# ------------------------------------------------------------------
# Normal bash commands rarely exceed a few hundred characters. 10 000 is
# well above any legitimate use case yet a tiny fraction of Linux ARG_MAX.
# Anything longer is almost certainly a payload injection or base64-encoded
# attack string.
_MAX_COMMAND_LENGTH = 10_000
def _validate_input(self, command: str) -> str | None:
"""Return ``None`` if *command* is acceptable, else a rejection reason."""
if not command.strip():
return "empty command"
if len(command) > self._MAX_COMMAND_LENGTH:
return "command too long"
if "\x00" in command:
return "null byte detected"
return None
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Core logic (shared between sync and async paths) # Core logic (shared between sync and async paths)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def _pre_process(self, request: ToolCallRequest) -> tuple[str, str | None, str]: def _pre_process(self, request: ToolCallRequest) -> tuple[str, str | None, str, str | None]:
""" """
Returns (command, thread_id, verdict). Returns (command, thread_id, verdict, reject_reason).
verdict is 'block', 'warn', or 'pass'. verdict is 'block', 'warn', or 'pass'.
reject_reason is non-None only for input sanitisation rejections.
""" """
args = request.tool_call.get("args", {}) args = request.tool_call.get("args", {})
command: str = args.get("command", "") raw_command = args.get("command")
command = raw_command if isinstance(raw_command, str) else ""
thread_id = self._get_thread_id(request) thread_id = self._get_thread_id(request)
# ① classify command # ① input sanitisation — reject malformed input before regex analysis
reject_reason = self._validate_input(command)
if reject_reason:
self._write_audit(thread_id, command, "block", truncate=True)
logger.warning("[SandboxAudit] INVALID INPUT thread=%s reason=%s", thread_id, reject_reason)
return command, thread_id, "block", reject_reason
# ② classify command
verdict = _classify_command(command) verdict = _classify_command(command)
# ② audit log # audit log
self._write_audit(thread_id, command, verdict) self._write_audit(thread_id, command, verdict)
if verdict == "block": if verdict == "block":
@ -163,7 +197,7 @@ class SandboxAuditMiddleware(AgentMiddleware[ThreadState]):
elif verdict == "warn": elif verdict == "warn":
logger.warning("[SandboxAudit] WARN (medium-risk) thread=%s cmd=%r", thread_id, command) logger.warning("[SandboxAudit] WARN (medium-risk) thread=%s cmd=%r", thread_id, command)
return command, thread_id, verdict return command, thread_id, verdict, None
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# wrap_tool_call hooks # wrap_tool_call hooks
@ -178,9 +212,10 @@ class SandboxAuditMiddleware(AgentMiddleware[ThreadState]):
if request.tool_call.get("name") != "bash": if request.tool_call.get("name") != "bash":
return handler(request) return handler(request)
command, _, verdict = self._pre_process(request) command, _, verdict, reject_reason = self._pre_process(request)
if verdict == "block": if verdict == "block":
return self._build_block_message(request, "security violation detected") reason = reject_reason or "security violation detected"
return self._build_block_message(request, reason)
result = handler(request) result = handler(request)
if verdict == "warn": if verdict == "warn":
result = self._append_warn_to_result(result, command) result = self._append_warn_to_result(result, command)
@ -195,9 +230,10 @@ class SandboxAuditMiddleware(AgentMiddleware[ThreadState]):
if request.tool_call.get("name") != "bash": if request.tool_call.get("name") != "bash":
return await handler(request) return await handler(request)
command, _, verdict = self._pre_process(request) command, _, verdict, reject_reason = self._pre_process(request)
if verdict == "block": if verdict == "block":
return self._build_block_message(request, "security violation detected") reason = reject_reason or "security violation detected"
return self._build_block_message(request, reason)
result = await handler(request) result = await handler(request)
if verdict == "warn": if verdict == "warn":
result = self._append_warn_to_result(result, command) result = self._append_warn_to_result(result, command)

View File

@ -1,5 +1,6 @@
"""Tests for SandboxAuditMiddleware - command classification and audit logging.""" """Tests for SandboxAuditMiddleware - command classification and audit logging."""
import unittest.mock
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
@ -134,6 +135,98 @@ class TestClassifyCommand:
assert _classify_command(cmd) == "pass", f"Expected 'pass' for: {cmd!r}" assert _classify_command(cmd) == "pass", f"Expected 'pass' for: {cmd!r}"
# ---------------------------------------------------------------------------
# _validate_input unit tests (input sanitisation)
# ---------------------------------------------------------------------------
class TestValidateInput:
def setup_method(self):
self.mw = SandboxAuditMiddleware()
def test_empty_string_rejected(self):
assert self.mw._validate_input("") == "empty command"
def test_whitespace_only_rejected(self):
assert self.mw._validate_input(" \t\n ") == "empty command"
def test_normal_command_accepted(self):
assert self.mw._validate_input("ls -la") is None
def test_command_at_max_length_accepted(self):
cmd = "a" * 10_000
assert self.mw._validate_input(cmd) is None
def test_command_exceeding_max_length_rejected(self):
cmd = "a" * 10_001
assert self.mw._validate_input(cmd) == "command too long"
def test_null_byte_rejected(self):
assert self.mw._validate_input("ls\x00; rm -rf /") == "null byte detected"
def test_null_byte_at_start_rejected(self):
assert self.mw._validate_input("\x00ls") == "null byte detected"
def test_null_byte_at_end_rejected(self):
assert self.mw._validate_input("ls\x00") == "null byte detected"
class TestInputSanitisationBlocksInWrapToolCall:
"""Verify that input sanitisation rejections flow through wrap_tool_call correctly."""
def setup_method(self):
self.mw = SandboxAuditMiddleware()
def test_empty_command_blocked_with_reason(self):
request = _make_request("")
handler = _make_handler()
result = self.mw.wrap_tool_call(request, handler)
assert not handler.called
assert isinstance(result, ToolMessage)
assert result.status == "error"
assert "empty command" in result.content.lower()
def test_null_byte_command_blocked_with_reason(self):
request = _make_request("echo\x00rm -rf /")
handler = _make_handler()
result = self.mw.wrap_tool_call(request, handler)
assert not handler.called
assert isinstance(result, ToolMessage)
assert result.status == "error"
assert "null byte" in result.content.lower()
def test_oversized_command_blocked_with_reason(self):
request = _make_request("a" * 10_001)
handler = _make_handler()
result = self.mw.wrap_tool_call(request, handler)
assert not handler.called
assert isinstance(result, ToolMessage)
assert result.status == "error"
assert "command too long" in result.content.lower()
def test_none_command_coerced_to_empty(self):
"""args.get('command') returning None should be coerced to str and rejected as empty."""
request = _make_request("")
# Simulate None value by patching args directly
request.tool_call["args"]["command"] = None
handler = _make_handler()
result = self.mw.wrap_tool_call(request, handler)
assert not handler.called
assert isinstance(result, ToolMessage)
assert result.status == "error"
def test_oversized_command_audit_log_truncated(self):
"""Oversized commands should be truncated in audit logs to prevent log amplification."""
big_cmd = "x" * 10_001
request = _make_request(big_cmd)
handler = _make_handler()
with unittest.mock.patch.object(self.mw, "_write_audit", wraps=self.mw._write_audit) as spy:
self.mw.wrap_tool_call(request, handler)
spy.assert_called_once()
_, kwargs = spy.call_args
assert kwargs.get("truncate") is True
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# SandboxAuditMiddleware.wrap_tool_call integration tests # SandboxAuditMiddleware.wrap_tool_call integration tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -301,6 +394,63 @@ class TestSandboxAuditMiddlewareAwrapToolCall:
assert result == handler_mock.return_value assert result == handler_mock.return_value
# ---------------------------------------------------------------------------
# Input sanitisation via awrap_tool_call (async path)
# ---------------------------------------------------------------------------
class TestInputSanitisationBlocksInAwrapToolCall:
"""Verify that input sanitisation rejections flow through awrap_tool_call correctly."""
def setup_method(self):
self.mw = SandboxAuditMiddleware()
async def _call_async(self, request):
handler_mock = _make_handler()
async def async_handler(req):
return handler_mock(req)
result = await self.mw.awrap_tool_call(request, async_handler)
return result, handler_mock.called
@pytest.mark.anyio
async def test_empty_command_blocked_with_reason(self):
request = _make_request("")
result, called = await self._call_async(request)
assert not called
assert isinstance(result, ToolMessage)
assert result.status == "error"
assert "empty command" in result.content.lower()
@pytest.mark.anyio
async def test_null_byte_command_blocked_with_reason(self):
request = _make_request("echo\x00rm -rf /")
result, called = await self._call_async(request)
assert not called
assert isinstance(result, ToolMessage)
assert result.status == "error"
assert "null byte" in result.content.lower()
@pytest.mark.anyio
async def test_oversized_command_blocked_with_reason(self):
request = _make_request("a" * 10_001)
result, called = await self._call_async(request)
assert not called
assert isinstance(result, ToolMessage)
assert result.status == "error"
assert "command too long" in result.content.lower()
@pytest.mark.anyio
async def test_none_command_coerced_to_empty(self):
request = _make_request("")
request.tool_call["args"]["command"] = None
result, called = await self._call_async(request)
assert not called
assert isinstance(result, ToolMessage)
assert result.status == "error"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Precision / recall summary (asserted metrics for benchmark reporting) # Precision / recall summary (asserted metrics for benchmark reporting)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------