mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
* fix(sandbox): strengthen regex coverage in SandboxAuditMiddleware
Expand high-risk patterns from 6 to 13 and medium-risk from 4 to 6,
closing several bypass vectors identified by cross-referencing Claude
Code's BashSecurity validator chain against DeerFlow's threat model.
High-risk additions:
- Generalised pipe-to-sh (replaces narrow curl|sh rule)
- Targeted command substitution ($() / backtick with dangerous executables)
- base64 decode piped to execution
- Overwrite system binaries (/usr/bin/, /bin/, /sbin/)
- Overwrite shell startup files (~/.bashrc, ~/.profile, etc.)
- /proc/*/environ leakage
- LD_PRELOAD / LD_LIBRARY_PATH hijack
- /dev/tcp/ bash built-in networking
Medium-risk additions:
- sudo/su (no-op under Docker root, warn only)
- PATH= modification (long attack chain, warn only)
Design decisions:
- Command substitution uses targeted matching (curl/wget/bash/sh/python/
ruby/perl/base64) rather than blanket block to avoid false positives
on safe usage like $(date) or `whoami`.
- Skipped encoding/obfuscation checks (hex, octal, Unicode homoglyphs)
as ROI is low in Docker sandbox — LLMs don't generate encoded commands
and container isolation bounds the blast radius.
- Merged pip/pip3 into single pip3? pattern.
* feat(sandbox): compound command splitting and fork bomb detection
Split compound bash commands (&&, ||, ;) into sub-commands and classify
each independently — prevents dangerous commands hidden after safe
prefixes (e.g. "cd /workspace && rm -rf /") from bypassing detection.
- Add _split_compound_command() with shlex quote-aware splitting
- Add fork bomb detection patterns (classic and while-loop variants)
- Most severe verdict wins; block short-circuits
- 15 new tests covering compound commands, splitting, and fork bombs
* test(sandbox): add async tests for fork bomb and compound commands
Cover awrap_tool_call path for fork bomb detection (3 variants) and
compound command splitting (block/warn/pass scenarios).
* fix(sandbox): address Copilot review — no-whitespace operators, >>/etc/, whole-command scan
- _split_compound_command: replace shlex-based implementation with a
character-by-character quote/escape-aware scanner. shlex.split only
separates '&&' / '||' / ';' when they are surrounded by whitespace,
so payloads like 'rm -rf /&&echo ok' or 'safe;rm -rf /' bypassed the
previous splitter and therefore the per-sub-command classifier.
- _HIGH_RISK_PATTERNS: change r'>\s*/etc/' to r'>+\s*/etc/' so append
redirection ('>>/etc/hosts') is also blocked.
- _classify_command: run a whole-command high-risk scan *before*
splitting. Structural attacks like 'while true; do bash & done'
span multiple shell statements — splitting on ';' destroys the
pattern context, so the raw command must be scanned first.
- tests: add no-whitespace operator cases to TestSplitCompoundCommand
and test_compound_command_classification to lock in the bypass fix.
364 lines
14 KiB
Python
364 lines
14 KiB
Python
"""SandboxAuditMiddleware - bash command security auditing."""
|
||
|
||
import json
|
||
import logging
|
||
import re
|
||
import shlex
|
||
from collections.abc import Awaitable, Callable
|
||
from datetime import UTC, datetime
|
||
from typing import override
|
||
|
||
from langchain.agents.middleware import AgentMiddleware
|
||
from langchain_core.messages import ToolMessage
|
||
from langgraph.prebuilt.tool_node import ToolCallRequest
|
||
from langgraph.types import Command
|
||
|
||
from deerflow.agents.thread_state import ThreadState
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Command classification rules
|
||
# ---------------------------------------------------------------------------
|
||
|
||
# Each pattern is compiled once at import time.
|
||
_HIGH_RISK_PATTERNS: list[re.Pattern[str]] = [
|
||
# --- original rules (retained) ---
|
||
re.compile(r"rm\s+-[^\s]*r[^\s]*\s+(/\*?|~/?\*?|/home\b|/root\b)\s*$"),
|
||
re.compile(r"dd\s+if="),
|
||
re.compile(r"mkfs"),
|
||
re.compile(r"cat\s+/etc/shadow"),
|
||
re.compile(r">+\s*/etc/"),
|
||
# --- pipe to sh/bash (generalised, replaces old curl|sh rule) ---
|
||
re.compile(r"\|\s*(ba)?sh\b"),
|
||
# --- command substitution (targeted – only dangerous executables) ---
|
||
re.compile(r"[`$]\(?\s*(curl|wget|bash|sh|python|ruby|perl|base64)"),
|
||
# --- base64 decode piped to execution ---
|
||
re.compile(r"base64\s+.*-d.*\|"),
|
||
# --- overwrite system binaries ---
|
||
re.compile(r">+\s*(/usr/bin/|/bin/|/sbin/)"),
|
||
# --- overwrite shell startup files ---
|
||
re.compile(r">+\s*~/?\.(bashrc|profile|zshrc|bash_profile)"),
|
||
# --- process environment leakage ---
|
||
re.compile(r"/proc/[^/]+/environ"),
|
||
# --- dynamic linker hijack (one-step escalation) ---
|
||
re.compile(r"\b(LD_PRELOAD|LD_LIBRARY_PATH)\s*="),
|
||
# --- bash built-in networking (bypasses tool allowlists) ---
|
||
re.compile(r"/dev/tcp/"),
|
||
# --- fork bomb ---
|
||
re.compile(r"\S+\(\)\s*\{[^}]*\|\s*\S+\s*&"), # :(){ :|:& };:
|
||
re.compile(r"while\s+true.*&\s*done"), # while true; do bash & done
|
||
]
|
||
|
||
_MEDIUM_RISK_PATTERNS: list[re.Pattern[str]] = [
|
||
re.compile(r"chmod\s+777"),
|
||
re.compile(r"pip3?\s+install"),
|
||
re.compile(r"apt(-get)?\s+install"),
|
||
# sudo/su: no-op under Docker root; warn so LLM is aware
|
||
re.compile(r"\b(sudo|su)\b"),
|
||
# PATH modification: long attack chain, warn rather than block
|
||
re.compile(r"\bPATH\s*="),
|
||
]
|
||
|
||
|
||
def _split_compound_command(command: str) -> list[str]:
|
||
"""Split a compound command into sub-commands (quote-aware).
|
||
|
||
Scans the raw command string so unquoted shell control operators are
|
||
recognised even when they are not surrounded by whitespace
|
||
(e.g. ``safe;rm -rf /`` or ``rm -rf /&&echo ok``). Operators inside
|
||
quotes are ignored. If the command ends with an unclosed quote or a
|
||
dangling escape, return the whole command unchanged (fail-closed —
|
||
safer to classify the unsplit string than silently drop parts).
|
||
"""
|
||
parts: list[str] = []
|
||
current: list[str] = []
|
||
in_single_quote = False
|
||
in_double_quote = False
|
||
escaping = False
|
||
index = 0
|
||
|
||
while index < len(command):
|
||
char = command[index]
|
||
|
||
if escaping:
|
||
current.append(char)
|
||
escaping = False
|
||
index += 1
|
||
continue
|
||
|
||
if char == "\\" and not in_single_quote:
|
||
current.append(char)
|
||
escaping = True
|
||
index += 1
|
||
continue
|
||
|
||
if char == "'" and not in_double_quote:
|
||
in_single_quote = not in_single_quote
|
||
current.append(char)
|
||
index += 1
|
||
continue
|
||
|
||
if char == '"' and not in_single_quote:
|
||
in_double_quote = not in_double_quote
|
||
current.append(char)
|
||
index += 1
|
||
continue
|
||
|
||
if not in_single_quote and not in_double_quote:
|
||
if command.startswith("&&", index) or command.startswith("||", index):
|
||
part = "".join(current).strip()
|
||
if part:
|
||
parts.append(part)
|
||
current = []
|
||
index += 2
|
||
continue
|
||
if char == ";":
|
||
part = "".join(current).strip()
|
||
if part:
|
||
parts.append(part)
|
||
current = []
|
||
index += 1
|
||
continue
|
||
|
||
current.append(char)
|
||
index += 1
|
||
|
||
# Unclosed quote or dangling escape → fail-closed, return whole command
|
||
if in_single_quote or in_double_quote or escaping:
|
||
return [command]
|
||
|
||
part = "".join(current).strip()
|
||
if part:
|
||
parts.append(part)
|
||
return parts if parts else [command]
|
||
|
||
|
||
def _classify_single_command(command: str) -> str:
|
||
"""Classify a single (non-compound) command. Return 'block', 'warn', or 'pass'."""
|
||
normalized = " ".join(command.split())
|
||
|
||
for pattern in _HIGH_RISK_PATTERNS:
|
||
if pattern.search(normalized):
|
||
return "block"
|
||
|
||
# Also try shlex-parsed tokens for high-risk detection
|
||
try:
|
||
tokens = shlex.split(command)
|
||
joined = " ".join(tokens)
|
||
for pattern in _HIGH_RISK_PATTERNS:
|
||
if pattern.search(joined):
|
||
return "block"
|
||
except ValueError:
|
||
# shlex.split fails on unclosed quotes — treat as suspicious
|
||
return "block"
|
||
|
||
for pattern in _MEDIUM_RISK_PATTERNS:
|
||
if pattern.search(normalized):
|
||
return "warn"
|
||
|
||
return "pass"
|
||
|
||
|
||
def _classify_command(command: str) -> str:
|
||
"""Return 'block', 'warn', or 'pass'.
|
||
|
||
Strategy:
|
||
1. First scan the *whole* raw command against high-risk patterns. This
|
||
catches structural attacks like ``while true; do bash & done`` or
|
||
``:(){ :|:& };:`` that span multiple shell statements — splitting them
|
||
on ``;`` would destroy the pattern context.
|
||
2. Then split compound commands (e.g. ``cmd1 && cmd2 ; cmd3``) and
|
||
classify each sub-command independently. The most severe verdict wins.
|
||
"""
|
||
# Pass 1: whole-command high-risk scan (catches multi-statement patterns)
|
||
normalized = " ".join(command.split())
|
||
for pattern in _HIGH_RISK_PATTERNS:
|
||
if pattern.search(normalized):
|
||
return "block"
|
||
|
||
# Pass 2: per-sub-command classification
|
||
sub_commands = _split_compound_command(command)
|
||
worst = "pass"
|
||
for sub in sub_commands:
|
||
verdict = _classify_single_command(sub)
|
||
if verdict == "block":
|
||
return "block" # short-circuit: can't get worse
|
||
if verdict == "warn":
|
||
worst = "warn"
|
||
return worst
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Middleware
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class SandboxAuditMiddleware(AgentMiddleware[ThreadState]):
|
||
"""Bash command security auditing middleware.
|
||
|
||
For every ``bash`` tool call:
|
||
1. **Command classification**: regex + shlex analysis grades commands as
|
||
high-risk (block), medium-risk (warn), or safe (pass).
|
||
2. **Audit log**: every bash call is recorded as a structured JSON entry
|
||
via the standard logger (visible in langgraph.log).
|
||
|
||
High-risk commands (e.g. ``rm -rf /``, ``curl url | bash``) are blocked:
|
||
the handler is not called and an error ``ToolMessage`` is returned so the
|
||
agent loop can continue gracefully.
|
||
|
||
Medium-risk commands (e.g. ``pip install``, ``chmod 777``) are executed
|
||
normally; a warning is appended to the tool result so the LLM is aware.
|
||
"""
|
||
|
||
state_schema = ThreadState
|
||
|
||
# ------------------------------------------------------------------
|
||
# Helpers
|
||
# ------------------------------------------------------------------
|
||
|
||
def _get_thread_id(self, request: ToolCallRequest) -> str | None:
|
||
runtime = request.runtime # ToolRuntime; may be None-like in tests
|
||
if runtime is None:
|
||
return None
|
||
ctx = getattr(runtime, "context", None) or {}
|
||
thread_id = ctx.get("thread_id") if isinstance(ctx, dict) else None
|
||
if thread_id is None:
|
||
cfg = getattr(runtime, "config", None) or {}
|
||
thread_id = cfg.get("configurable", {}).get("thread_id")
|
||
return thread_id
|
||
|
||
_AUDIT_COMMAND_LIMIT = 200
|
||
|
||
def _write_audit(self, thread_id: str | None, command: str, verdict: str, *, truncate: bool = False) -> None:
|
||
audited_command = command
|
||
if truncate and len(command) > self._AUDIT_COMMAND_LIMIT:
|
||
audited_command = f"{command[: self._AUDIT_COMMAND_LIMIT]}... ({len(command)} chars)"
|
||
record = {
|
||
"timestamp": datetime.now(UTC).isoformat(),
|
||
"thread_id": thread_id or "unknown",
|
||
"command": audited_command,
|
||
"verdict": verdict,
|
||
}
|
||
logger.info("[SandboxAudit] %s", json.dumps(record, ensure_ascii=False))
|
||
|
||
def _build_block_message(self, request: ToolCallRequest, reason: str) -> ToolMessage:
|
||
tool_call_id = str(request.tool_call.get("id") or "missing_id")
|
||
return ToolMessage(
|
||
content=f"Command blocked: {reason}. Please use a safer alternative approach.",
|
||
tool_call_id=tool_call_id,
|
||
name="bash",
|
||
status="error",
|
||
)
|
||
|
||
def _append_warn_to_result(self, result: ToolMessage | Command, command: str) -> ToolMessage | Command:
|
||
"""Append a warning note to the tool result for medium-risk commands."""
|
||
if not isinstance(result, ToolMessage):
|
||
return result
|
||
warning = f"\n\n⚠️ Warning: `{command}` is a medium-risk command that may modify the runtime environment."
|
||
if isinstance(result.content, list):
|
||
new_content = list(result.content) + [{"type": "text", "text": warning}]
|
||
else:
|
||
new_content = str(result.content) + warning
|
||
return ToolMessage(
|
||
content=new_content,
|
||
tool_call_id=result.tool_call_id,
|
||
name=result.name,
|
||
status=result.status,
|
||
)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Input sanitisation
|
||
# ------------------------------------------------------------------
|
||
|
||
# Normal bash commands rarely exceed a few hundred characters. 10 000 is
|
||
# well above any legitimate use case yet a tiny fraction of Linux ARG_MAX.
|
||
# Anything longer is almost certainly a payload injection or base64-encoded
|
||
# attack string.
|
||
_MAX_COMMAND_LENGTH = 10_000
|
||
|
||
def _validate_input(self, command: str) -> str | None:
|
||
"""Return ``None`` if *command* is acceptable, else a rejection reason."""
|
||
if not command.strip():
|
||
return "empty command"
|
||
if len(command) > self._MAX_COMMAND_LENGTH:
|
||
return "command too long"
|
||
if "\x00" in command:
|
||
return "null byte detected"
|
||
return None
|
||
|
||
# ------------------------------------------------------------------
|
||
# Core logic (shared between sync and async paths)
|
||
# ------------------------------------------------------------------
|
||
|
||
def _pre_process(self, request: ToolCallRequest) -> tuple[str, str | None, str, str | None]:
|
||
"""
|
||
Returns (command, thread_id, verdict, reject_reason).
|
||
verdict is 'block', 'warn', or 'pass'.
|
||
reject_reason is non-None only for input sanitisation rejections.
|
||
"""
|
||
args = request.tool_call.get("args", {})
|
||
raw_command = args.get("command")
|
||
command = raw_command if isinstance(raw_command, str) else ""
|
||
thread_id = self._get_thread_id(request)
|
||
|
||
# ① input sanitisation — reject malformed input before regex analysis
|
||
reject_reason = self._validate_input(command)
|
||
if reject_reason:
|
||
self._write_audit(thread_id, command, "block", truncate=True)
|
||
logger.warning("[SandboxAudit] INVALID INPUT thread=%s reason=%s", thread_id, reject_reason)
|
||
return command, thread_id, "block", reject_reason
|
||
|
||
# ② classify command
|
||
verdict = _classify_command(command)
|
||
|
||
# ③ audit log
|
||
self._write_audit(thread_id, command, verdict)
|
||
|
||
if verdict == "block":
|
||
logger.warning("[SandboxAudit] BLOCKED thread=%s cmd=%r", thread_id, command)
|
||
elif verdict == "warn":
|
||
logger.warning("[SandboxAudit] WARN (medium-risk) thread=%s cmd=%r", thread_id, command)
|
||
|
||
return command, thread_id, verdict, None
|
||
|
||
# ------------------------------------------------------------------
|
||
# wrap_tool_call hooks
|
||
# ------------------------------------------------------------------
|
||
|
||
@override
|
||
def wrap_tool_call(
|
||
self,
|
||
request: ToolCallRequest,
|
||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||
) -> ToolMessage | Command:
|
||
if request.tool_call.get("name") != "bash":
|
||
return handler(request)
|
||
|
||
command, _, verdict, reject_reason = self._pre_process(request)
|
||
if verdict == "block":
|
||
reason = reject_reason or "security violation detected"
|
||
return self._build_block_message(request, reason)
|
||
result = handler(request)
|
||
if verdict == "warn":
|
||
result = self._append_warn_to_result(result, command)
|
||
return result
|
||
|
||
@override
|
||
async def awrap_tool_call(
|
||
self,
|
||
request: ToolCallRequest,
|
||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||
) -> ToolMessage | Command:
|
||
if request.tool_call.get("name") != "bash":
|
||
return await handler(request)
|
||
|
||
command, _, verdict, reject_reason = self._pre_process(request)
|
||
if verdict == "block":
|
||
reason = reject_reason or "security violation detected"
|
||
return self._build_block_message(request, reason)
|
||
result = await handler(request)
|
||
if verdict == "warn":
|
||
result = self._append_warn_to_result(result, command)
|
||
return result
|