diff --git a/backend/packages/harness/deerflow/agents/middlewares/sandbox_audit_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/sandbox_audit_middleware.py index 2955848b7..3f9ab74ad 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/sandbox_audit_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/sandbox_audit_middleware.py @@ -105,11 +105,16 @@ class SandboxAuditMiddleware(AgentMiddleware[ThreadState]): thread_id = cfg.get("configurable", {}).get("thread_id") return thread_id - def _write_audit(self, thread_id: str | None, command: str, verdict: str) -> None: + _AUDIT_COMMAND_LIMIT = 200 + + def _write_audit(self, thread_id: str | None, command: str, verdict: str, *, truncate: bool = False) -> None: + audited_command = command + if truncate and len(command) > self._AUDIT_COMMAND_LIMIT: + audited_command = f"{command[: self._AUDIT_COMMAND_LIMIT]}... ({len(command)} chars)" record = { "timestamp": datetime.now(UTC).isoformat(), "thread_id": thread_id or "unknown", - "command": command, + "command": audited_command, "verdict": verdict, } logger.info("[SandboxAudit] %s", json.dumps(record, ensure_ascii=False)) @@ -139,23 +144,52 @@ class SandboxAuditMiddleware(AgentMiddleware[ThreadState]): status=result.status, ) + # ------------------------------------------------------------------ + # Input sanitisation + # ------------------------------------------------------------------ + + # Normal bash commands rarely exceed a few hundred characters. 10 000 is + # well above any legitimate use case yet a tiny fraction of Linux ARG_MAX. + # Anything longer is almost certainly a payload injection or base64-encoded + # attack string. + _MAX_COMMAND_LENGTH = 10_000 + + def _validate_input(self, command: str) -> str | None: + """Return ``None`` if *command* is acceptable, else a rejection reason.""" + if not command.strip(): + return "empty command" + if len(command) > self._MAX_COMMAND_LENGTH: + return "command too long" + if "\x00" in command: + return "null byte detected" + return None + # ------------------------------------------------------------------ # Core logic (shared between sync and async paths) # ------------------------------------------------------------------ - def _pre_process(self, request: ToolCallRequest) -> tuple[str, str | None, str]: + def _pre_process(self, request: ToolCallRequest) -> tuple[str, str | None, str, str | None]: """ - Returns (command, thread_id, verdict). + Returns (command, thread_id, verdict, reject_reason). verdict is 'block', 'warn', or 'pass'. + reject_reason is non-None only for input sanitisation rejections. """ args = request.tool_call.get("args", {}) - command: str = args.get("command", "") + raw_command = args.get("command") + command = raw_command if isinstance(raw_command, str) else "" thread_id = self._get_thread_id(request) - # ① classify command + # ① input sanitisation — reject malformed input before regex analysis + reject_reason = self._validate_input(command) + if reject_reason: + self._write_audit(thread_id, command, "block", truncate=True) + logger.warning("[SandboxAudit] INVALID INPUT thread=%s reason=%s", thread_id, reject_reason) + return command, thread_id, "block", reject_reason + + # ② classify command verdict = _classify_command(command) - # ② audit log + # ③ audit log self._write_audit(thread_id, command, verdict) if verdict == "block": @@ -163,7 +197,7 @@ class SandboxAuditMiddleware(AgentMiddleware[ThreadState]): elif verdict == "warn": logger.warning("[SandboxAudit] WARN (medium-risk) thread=%s cmd=%r", thread_id, command) - return command, thread_id, verdict + return command, thread_id, verdict, None # ------------------------------------------------------------------ # wrap_tool_call hooks @@ -178,9 +212,10 @@ class SandboxAuditMiddleware(AgentMiddleware[ThreadState]): if request.tool_call.get("name") != "bash": return handler(request) - command, _, verdict = self._pre_process(request) + command, _, verdict, reject_reason = self._pre_process(request) if verdict == "block": - return self._build_block_message(request, "security violation detected") + reason = reject_reason or "security violation detected" + return self._build_block_message(request, reason) result = handler(request) if verdict == "warn": result = self._append_warn_to_result(result, command) @@ -195,9 +230,10 @@ class SandboxAuditMiddleware(AgentMiddleware[ThreadState]): if request.tool_call.get("name") != "bash": return await handler(request) - command, _, verdict = self._pre_process(request) + command, _, verdict, reject_reason = self._pre_process(request) if verdict == "block": - return self._build_block_message(request, "security violation detected") + reason = reject_reason or "security violation detected" + return self._build_block_message(request, reason) result = await handler(request) if verdict == "warn": result = self._append_warn_to_result(result, command) diff --git a/backend/tests/test_sandbox_audit_middleware.py b/backend/tests/test_sandbox_audit_middleware.py index e98298156..6a1d4b244 100644 --- a/backend/tests/test_sandbox_audit_middleware.py +++ b/backend/tests/test_sandbox_audit_middleware.py @@ -1,5 +1,6 @@ """Tests for SandboxAuditMiddleware - command classification and audit logging.""" +import unittest.mock from types import SimpleNamespace from unittest.mock import MagicMock, patch @@ -134,6 +135,98 @@ class TestClassifyCommand: assert _classify_command(cmd) == "pass", f"Expected 'pass' for: {cmd!r}" +# --------------------------------------------------------------------------- +# _validate_input unit tests (input sanitisation) +# --------------------------------------------------------------------------- + + +class TestValidateInput: + def setup_method(self): + self.mw = SandboxAuditMiddleware() + + def test_empty_string_rejected(self): + assert self.mw._validate_input("") == "empty command" + + def test_whitespace_only_rejected(self): + assert self.mw._validate_input(" \t\n ") == "empty command" + + def test_normal_command_accepted(self): + assert self.mw._validate_input("ls -la") is None + + def test_command_at_max_length_accepted(self): + cmd = "a" * 10_000 + assert self.mw._validate_input(cmd) is None + + def test_command_exceeding_max_length_rejected(self): + cmd = "a" * 10_001 + assert self.mw._validate_input(cmd) == "command too long" + + def test_null_byte_rejected(self): + assert self.mw._validate_input("ls\x00; rm -rf /") == "null byte detected" + + def test_null_byte_at_start_rejected(self): + assert self.mw._validate_input("\x00ls") == "null byte detected" + + def test_null_byte_at_end_rejected(self): + assert self.mw._validate_input("ls\x00") == "null byte detected" + + +class TestInputSanitisationBlocksInWrapToolCall: + """Verify that input sanitisation rejections flow through wrap_tool_call correctly.""" + + def setup_method(self): + self.mw = SandboxAuditMiddleware() + + def test_empty_command_blocked_with_reason(self): + request = _make_request("") + handler = _make_handler() + result = self.mw.wrap_tool_call(request, handler) + assert not handler.called + assert isinstance(result, ToolMessage) + assert result.status == "error" + assert "empty command" in result.content.lower() + + def test_null_byte_command_blocked_with_reason(self): + request = _make_request("echo\x00rm -rf /") + handler = _make_handler() + result = self.mw.wrap_tool_call(request, handler) + assert not handler.called + assert isinstance(result, ToolMessage) + assert result.status == "error" + assert "null byte" in result.content.lower() + + def test_oversized_command_blocked_with_reason(self): + request = _make_request("a" * 10_001) + handler = _make_handler() + result = self.mw.wrap_tool_call(request, handler) + assert not handler.called + assert isinstance(result, ToolMessage) + assert result.status == "error" + assert "command too long" in result.content.lower() + + def test_none_command_coerced_to_empty(self): + """args.get('command') returning None should be coerced to str and rejected as empty.""" + request = _make_request("") + # Simulate None value by patching args directly + request.tool_call["args"]["command"] = None + handler = _make_handler() + result = self.mw.wrap_tool_call(request, handler) + assert not handler.called + assert isinstance(result, ToolMessage) + assert result.status == "error" + + def test_oversized_command_audit_log_truncated(self): + """Oversized commands should be truncated in audit logs to prevent log amplification.""" + big_cmd = "x" * 10_001 + request = _make_request(big_cmd) + handler = _make_handler() + with unittest.mock.patch.object(self.mw, "_write_audit", wraps=self.mw._write_audit) as spy: + self.mw.wrap_tool_call(request, handler) + spy.assert_called_once() + _, kwargs = spy.call_args + assert kwargs.get("truncate") is True + + # --------------------------------------------------------------------------- # SandboxAuditMiddleware.wrap_tool_call integration tests # --------------------------------------------------------------------------- @@ -301,6 +394,63 @@ class TestSandboxAuditMiddlewareAwrapToolCall: assert result == handler_mock.return_value +# --------------------------------------------------------------------------- +# Input sanitisation via awrap_tool_call (async path) +# --------------------------------------------------------------------------- + + +class TestInputSanitisationBlocksInAwrapToolCall: + """Verify that input sanitisation rejections flow through awrap_tool_call correctly.""" + + def setup_method(self): + self.mw = SandboxAuditMiddleware() + + async def _call_async(self, request): + handler_mock = _make_handler() + + async def async_handler(req): + return handler_mock(req) + + result = await self.mw.awrap_tool_call(request, async_handler) + return result, handler_mock.called + + @pytest.mark.anyio + async def test_empty_command_blocked_with_reason(self): + request = _make_request("") + result, called = await self._call_async(request) + assert not called + assert isinstance(result, ToolMessage) + assert result.status == "error" + assert "empty command" in result.content.lower() + + @pytest.mark.anyio + async def test_null_byte_command_blocked_with_reason(self): + request = _make_request("echo\x00rm -rf /") + result, called = await self._call_async(request) + assert not called + assert isinstance(result, ToolMessage) + assert result.status == "error" + assert "null byte" in result.content.lower() + + @pytest.mark.anyio + async def test_oversized_command_blocked_with_reason(self): + request = _make_request("a" * 10_001) + result, called = await self._call_async(request) + assert not called + assert isinstance(result, ToolMessage) + assert result.status == "error" + assert "command too long" in result.content.lower() + + @pytest.mark.anyio + async def test_none_command_coerced_to_empty(self): + request = _make_request("") + request.tool_call["args"]["command"] = None + result, called = await self._call_async(request) + assert not called + assert isinstance(result, ToolMessage) + assert result.status == "error" + + # --------------------------------------------------------------------------- # Precision / recall summary (asserted metrics for benchmark reporting) # ---------------------------------------------------------------------------