fix(harness): constrain view_image to thread data paths (#2557)

* fix(harness): constrain view_image to thread data paths

Fixes #2530

* fix(harness): address view_image review findings

* style(harness): format view_image changes

* fix(harness): address view_image review comments
This commit is contained in:
DanielWalnut 2026-04-28 11:13:17 +08:00 committed by GitHub
parent 9dc25987e0
commit af8c0cfb78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 282 additions and 32 deletions

View File

@ -254,9 +254,11 @@ def _assemble_from_features(
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
chain.append(ViewImageMiddleware())
from deerflow.tools.builtins import view_image_tool
extra_tools.append(view_image_tool)
if feat.sandbox is not False:
from deerflow.tools.builtins import view_image_tool
extra_tools.append(view_image_tool)
# --- [11] Subagent ---
if feat.subagent is not False:

View File

@ -548,7 +548,7 @@ def validate_local_tool_path(path: str, thread_data: ThreadDataState | None, *,
This function is a security gate it checks whether *path* may be
accessed and raises on violation. It does **not** resolve the virtual
path to a host path; callers are responsible for resolution via
``_resolve_and_validate_user_data_path`` or ``_resolve_skills_path``.
``resolve_and_validate_user_data_path`` or ``_resolve_skills_path``.
Allowed virtual-path families:
- ``/mnt/user-data/*`` always allowed (read + write)
@ -635,6 +635,11 @@ def _resolve_and_validate_user_data_path(path: str, thread_data: ThreadDataState
return str(resolved)
def resolve_and_validate_user_data_path(path: str, thread_data: ThreadDataState) -> str:
"""Resolve a /mnt/user-data virtual path and validate it stays in bounds."""
return _resolve_and_validate_user_data_path(path, thread_data)
def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState | None) -> None:
"""Validate absolute paths in local-sandbox bash commands.

View File

@ -8,7 +8,42 @@ from langchain_core.messages import ToolMessage
from langgraph.types import Command
from langgraph.typing import ContextT
from deerflow.agents.thread_state import ThreadState
from deerflow.agents.thread_state import ThreadDataState, ThreadState
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
_ALLOWED_IMAGE_VIRTUAL_ROOTS = (
f"{VIRTUAL_PATH_PREFIX}/workspace",
f"{VIRTUAL_PATH_PREFIX}/uploads",
f"{VIRTUAL_PATH_PREFIX}/outputs",
)
_ALLOWED_IMAGE_VIRTUAL_ROOTS_TEXT = ", ".join(_ALLOWED_IMAGE_VIRTUAL_ROOTS)
_MAX_IMAGE_BYTES = 20 * 1024 * 1024
_EXTENSION_TO_MIME = {
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".webp": "image/webp",
}
def _is_allowed_image_virtual_path(image_path: str) -> bool:
return any(image_path == root or image_path.startswith(f"{root}/") for root in _ALLOWED_IMAGE_VIRTUAL_ROOTS)
def _detect_image_mime(image_data: bytes) -> str | None:
if image_data.startswith(b"\xff\xd8\xff"):
return "image/jpeg"
if image_data.startswith(b"\x89PNG\r\n\x1a\n"):
return "image/png"
if len(image_data) >= 12 and image_data.startswith(b"RIFF") and image_data[8:12] == b"WEBP":
return "image/webp"
return None
def _sanitize_image_error(error: Exception, thread_data: ThreadDataState | None) -> str:
from deerflow.sandbox.tools import mask_local_paths_in_output
return mask_local_paths_in_output(f"{type(error).__name__}: {error}", thread_data)
@tool("view_image", parse_docstring=True)
@ -29,22 +64,39 @@ def view_image_tool(
- For multiple files at once (use present_files instead)
Args:
image_path: Absolute path to the image file. Common formats supported: jpg, jpeg, png, webp.
image_path: Absolute /mnt/user-data virtual path to the image file. Common formats supported: jpg, jpeg, png, webp.
"""
from deerflow.sandbox.tools import get_thread_data, replace_virtual_path
from deerflow.sandbox.exceptions import SandboxRuntimeError
from deerflow.sandbox.tools import (
get_thread_data,
resolve_and_validate_user_data_path,
validate_local_tool_path,
)
# Replace virtual path with actual path
# /mnt/user-data/* paths are mapped to thread-specific directories
thread_data = get_thread_data(runtime)
actual_path = replace_virtual_path(image_path, thread_data)
# Validate that the path is absolute
path = Path(actual_path)
if not path.is_absolute():
if not _is_allowed_image_virtual_path(image_path):
return Command(
update={"messages": [ToolMessage(f"Error: Path must be absolute, got: {image_path}", tool_call_id=tool_call_id)]},
update={
"messages": [
ToolMessage(
f"Error: Only image paths under {_ALLOWED_IMAGE_VIRTUAL_ROOTS_TEXT} are allowed",
tool_call_id=tool_call_id,
)
]
},
)
try:
validate_local_tool_path(image_path, thread_data, read_only=True)
actual_path = resolve_and_validate_user_data_path(image_path, thread_data)
except (PermissionError, SandboxRuntimeError) as e:
return Command(
update={"messages": [ToolMessage(f"Error: {str(e)}", tool_call_id=tool_call_id)]},
)
path = Path(actual_path)
# Validate that the file exists
if not path.exists():
return Command(
@ -58,34 +110,49 @@ def view_image_tool(
)
# Validate image extension
valid_extensions = {".jpg", ".jpeg", ".png", ".webp"}
if path.suffix.lower() not in valid_extensions:
expected_mime_type = _EXTENSION_TO_MIME.get(path.suffix.lower())
if expected_mime_type is None:
return Command(
update={"messages": [ToolMessage(f"Error: Unsupported image format: {path.suffix}. Supported formats: {', '.join(valid_extensions)}", tool_call_id=tool_call_id)]},
update={"messages": [ToolMessage(f"Error: Unsupported image format: {path.suffix}. Supported formats: {', '.join(_EXTENSION_TO_MIME)}", tool_call_id=tool_call_id)]},
)
# Detect MIME type from file extension
mime_type, _ = mimetypes.guess_type(actual_path)
if mime_type is None:
# Fallback to default MIME types for common image formats
extension_to_mime = {
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".webp": "image/webp",
}
mime_type = extension_to_mime.get(path.suffix.lower(), "application/octet-stream")
mime_type = expected_mime_type
try:
image_size = path.stat().st_size
except OSError as e:
return Command(
update={"messages": [ToolMessage(f"Error reading image metadata: {_sanitize_image_error(e, thread_data)}", tool_call_id=tool_call_id)]},
)
if image_size > _MAX_IMAGE_BYTES:
return Command(
update={"messages": [ToolMessage(f"Error: Image file is too large: {image_size} bytes. Maximum supported size is {_MAX_IMAGE_BYTES} bytes", tool_call_id=tool_call_id)]},
)
# Read image file and convert to base64
try:
with open(actual_path, "rb") as f:
image_data = f.read()
image_base64 = base64.b64encode(image_data).decode("utf-8")
except Exception as e:
return Command(
update={"messages": [ToolMessage(f"Error reading image file: {str(e)}", tool_call_id=tool_call_id)]},
update={"messages": [ToolMessage(f"Error reading image file: {_sanitize_image_error(e, thread_data)}", tool_call_id=tool_call_id)]},
)
detected_mime_type = _detect_image_mime(image_data)
if detected_mime_type is None:
return Command(
update={"messages": [ToolMessage("Error: File contents do not match a supported image format", tool_call_id=tool_call_id)]},
)
if detected_mime_type != expected_mime_type:
return Command(
update={"messages": [ToolMessage(f"Error: Image contents are {detected_mime_type}, but file extension indicates {expected_mime_type}", tool_call_id=tool_call_id)]},
)
mime_type = detected_mime_type
image_base64 = base64.b64encode(image_data).decode("utf-8")
# Update viewed_images in state
# The merge_viewed_images reducer will handle merging with existing images
new_viewed_images = {image_path: {"base64": image_base64, "mime_type": mime_type}}

View File

@ -116,10 +116,22 @@ def test_middleware_and_features_conflict():
# ---------------------------------------------------------------------------
# 7. Vision feature auto-injects view_image_tool
# 7. Vision feature auto-injects view_image_tool when thread data is available
# ---------------------------------------------------------------------------
@patch("deerflow.agents.factory.create_agent")
def test_vision_injects_view_image_tool(mock_create_agent):
mock_create_agent.return_value = MagicMock()
feat = RuntimeFeatures(vision=True, sandbox=True)
create_deerflow_agent(_make_mock_model(), features=feat)
call_kwargs = mock_create_agent.call_args[1]
tool_names = [t.name for t in call_kwargs["tools"]]
assert "view_image" in tool_names
@patch("deerflow.agents.factory.create_agent")
def test_vision_without_sandbox_does_not_inject_view_image_tool(mock_create_agent):
mock_create_agent.return_value = MagicMock()
feat = RuntimeFeatures(vision=True, sandbox=False)
@ -127,7 +139,7 @@ def test_vision_injects_view_image_tool(mock_create_agent):
call_kwargs = mock_create_agent.call_args[1]
tool_names = [t.name for t in call_kwargs["tools"]]
assert "view_image" in tool_names
assert "view_image" not in tool_names
def test_view_image_middleware_preserves_viewed_images_reducer():
@ -301,11 +313,11 @@ def test_always_on_error_handling(mock_create_agent):
# ---------------------------------------------------------------------------
# 17. Vision with custom middleware still injects tool
# 17. Vision with custom middleware follows thread-data availability
# ---------------------------------------------------------------------------
@patch("deerflow.agents.factory.create_agent")
def test_vision_custom_middleware_still_injects_tool(mock_create_agent):
"""Custom vision middleware still gets the view_image_tool auto-injected."""
def test_vision_custom_middleware_without_sandbox_does_not_inject_tool(mock_create_agent):
"""Custom vision middleware without thread data does not get view_image_tool auto-injected."""
from langchain.agents.middleware import AgentMiddleware
mock_create_agent.return_value = MagicMock()
@ -319,7 +331,7 @@ def test_vision_custom_middleware_still_injects_tool(mock_create_agent):
call_kwargs = mock_create_agent.call_args[1]
tool_names = [t.name for t in call_kwargs["tools"]]
assert "view_image" in tool_names
assert "view_image" not in tool_names
# ===========================================================================

View File

@ -0,0 +1,164 @@
import base64
import importlib
import os
from pathlib import Path
from types import SimpleNamespace
import pytest
from deerflow.tools.builtins.view_image_tool import view_image_tool
view_image_module = importlib.import_module("deerflow.tools.builtins.view_image_tool")
PNG_BYTES = base64.b64decode("iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==")
def _make_thread_data(tmp_path: Path) -> dict[str, str]:
user_data = tmp_path / "threads" / "thread-1" / "user-data"
workspace = user_data / "workspace"
uploads = user_data / "uploads"
outputs = user_data / "outputs"
for directory in (workspace, uploads, outputs):
directory.mkdir(parents=True)
return {
"workspace_path": str(workspace),
"uploads_path": str(uploads),
"outputs_path": str(outputs),
}
def _make_runtime(thread_data: dict[str, str]) -> SimpleNamespace:
return SimpleNamespace(
state={"thread_data": thread_data},
context={"thread_id": "thread-1"},
config={},
)
def _message_content(result) -> str:
return result.update["messages"][0].content
def test_view_image_rejects_external_absolute_path(tmp_path: Path) -> None:
thread_data = _make_thread_data(tmp_path)
outside_image = tmp_path / "outside.png"
outside_image.write_bytes(PNG_BYTES)
result = view_image_tool.func(
runtime=_make_runtime(thread_data),
image_path=str(outside_image),
tool_call_id="tc-external",
)
assert "Only image paths under /mnt/user-data" in _message_content(result)
assert "viewed_images" not in result.update
def test_view_image_reads_virtual_uploads_path(tmp_path: Path) -> None:
thread_data = _make_thread_data(tmp_path)
image_path = Path(thread_data["uploads_path"]) / "sample.png"
image_path.write_bytes(PNG_BYTES)
result = view_image_tool.func(
runtime=_make_runtime(thread_data),
image_path="/mnt/user-data/uploads/sample.png",
tool_call_id="tc-uploads",
)
assert _message_content(result) == "Successfully read image"
viewed_image = result.update["viewed_images"]["/mnt/user-data/uploads/sample.png"]
assert viewed_image["base64"] == base64.b64encode(PNG_BYTES).decode("utf-8")
assert viewed_image["mime_type"] == "image/png"
def test_view_image_rejects_spoofed_extension(tmp_path: Path) -> None:
thread_data = _make_thread_data(tmp_path)
image_path = Path(thread_data["uploads_path"]) / "not-really.png"
image_path.write_bytes(b"not an image")
result = view_image_tool.func(
runtime=_make_runtime(thread_data),
image_path="/mnt/user-data/uploads/not-really.png",
tool_call_id="tc-spoofed",
)
assert "contents do not match" in _message_content(result)
assert "viewed_images" not in result.update
def test_view_image_rejects_mismatched_magic_bytes(tmp_path: Path) -> None:
thread_data = _make_thread_data(tmp_path)
image_path = Path(thread_data["uploads_path"]) / "jpeg-named-png.png"
image_path.write_bytes(b"\xff\xd8\xff\xe0fake-jpeg")
result = view_image_tool.func(
runtime=_make_runtime(thread_data),
image_path="/mnt/user-data/uploads/jpeg-named-png.png",
tool_call_id="tc-mismatch",
)
assert "file extension indicates image/png" in _message_content(result)
assert "viewed_images" not in result.update
def test_view_image_rejects_oversized_image(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
thread_data = _make_thread_data(tmp_path)
image_path = Path(thread_data["uploads_path"]) / "sample.png"
image_path.write_bytes(PNG_BYTES)
monkeypatch.setattr(view_image_module, "_MAX_IMAGE_BYTES", len(PNG_BYTES) - 1)
result = view_image_tool.func(
runtime=_make_runtime(thread_data),
image_path="/mnt/user-data/uploads/sample.png",
tool_call_id="tc-oversized",
)
assert "Image file is too large" in _message_content(result)
assert "viewed_images" not in result.update
def test_view_image_sanitizes_read_errors(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
thread_data = _make_thread_data(tmp_path)
image_path = Path(thread_data["uploads_path"]) / "sample.png"
image_path.write_bytes(PNG_BYTES)
def _open(*args, **kwargs):
raise PermissionError(f"permission denied: {image_path}")
monkeypatch.setattr("builtins.open", _open)
result = view_image_tool.func(
runtime=_make_runtime(thread_data),
image_path="/mnt/user-data/uploads/sample.png",
tool_call_id="tc-read-error",
)
message = _message_content(result)
assert "Error reading image file" in message
assert str(image_path) not in message
assert str(Path(thread_data["uploads_path"])) not in message
assert "/mnt/user-data/uploads/sample.png" in message
assert "viewed_images" not in result.update
@pytest.mark.skipif(os.name == "nt", reason="symlink semantics differ on Windows")
def test_view_image_rejects_uploads_symlink_escape(tmp_path: Path) -> None:
thread_data = _make_thread_data(tmp_path)
outside_image = tmp_path / "outside-target.png"
outside_image.write_bytes(PNG_BYTES)
link_path = Path(thread_data["uploads_path"]) / "escape.png"
try:
link_path.symlink_to(outside_image)
except OSError as exc:
pytest.skip(f"symlink creation failed: {exc}")
result = view_image_tool.func(
runtime=_make_runtime(thread_data),
image_path="/mnt/user-data/uploads/escape.png",
tool_call_id="tc-symlink",
)
assert "path traversal" in _message_content(result)
assert "viewed_images" not in result.update