From 6e8e6a969be803227aa71cb6ba4c5d116910b4b7 Mon Sep 17 00:00:00 2001 From: AochenShen99 <142667174+ShenAC-SAC@users.noreply.github.com> Date: Wed, 13 May 2026 23:56:06 +0800 Subject: [PATCH] test: add blocking IO detector (#2924) * test: add blocking IO detector * test: add blocking IO probe option * test: harden blocking IO probe lifecycle * test: move blocking io detector to support --- backend/tests/conftest.py | 93 ++++++ backend/tests/support/__init__.py | 1 + backend/tests/support/detectors/__init__.py | 1 + .../tests/support/detectors/blocking_io.py | 287 ++++++++++++++++++ backend/tests/test_blocking_io_detector.py | 190 ++++++++++++ .../test_blocking_io_probe_integration.py | 22 ++ 6 files changed, 594 insertions(+) create mode 100644 backend/tests/support/__init__.py create mode 100644 backend/tests/support/detectors/__init__.py create mode 100644 backend/tests/support/detectors/blocking_io.py create mode 100644 backend/tests/test_blocking_io_detector.py create mode 100644 backend/tests/test_blocking_io_probe_integration.py diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index a357a3962..9bc8d4884 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -4,6 +4,8 @@ Sets up sys.path and pre-mocks modules that would cause circular import issues when unit-testing lightweight config/registry code in isolation. """ +from __future__ import annotations + import importlib.util import sys from pathlib import Path @@ -11,11 +13,16 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from support.detectors.blocking_io import BlockingIOProbe, detect_blocking_io # Make 'app' and 'deerflow' importable from any working directory sys.path.insert(0, str(Path(__file__).parent.parent)) sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "scripts")) +_BACKEND_ROOT = Path(__file__).resolve().parents[1] +_blocking_io_probe = BlockingIOProbe(_BACKEND_ROOT) +_BLOCKING_IO_DETECTOR_ATTR = "_blocking_io_detector" + # Break the circular import chain that exists in production code: # deerflow.subagents.__init__ # -> .executor (SubagentExecutor, SubagentResult) @@ -56,6 +63,92 @@ def provisioner_module(): return module +@pytest.fixture() +def blocking_io_detector(): + """Fail a focused test if blocking calls run on the event loop thread.""" + with detect_blocking_io(fail_on_exit=True) as detector: + yield detector + + +def pytest_addoption(parser: pytest.Parser) -> None: + group = parser.getgroup("blocking-io") + group.addoption( + "--detect-blocking-io", + action="store_true", + default=False, + help="Collect blocking calls made while an asyncio event loop is running and report a summary.", + ) + group.addoption( + "--detect-blocking-io-fail", + action="store_true", + default=False, + help="Set a failing exit status when --detect-blocking-io records violations.", + ) + + +def pytest_configure(config: pytest.Config) -> None: + config.addinivalue_line("markers", "no_blocking_io_probe: skip the optional blocking IO probe") + + +def pytest_sessionstart(session: pytest.Session) -> None: + if _blocking_io_probe_enabled(session.config): + _blocking_io_probe.clear() + + +@pytest.hookimpl(hookwrapper=True) +def pytest_runtest_call(item: pytest.Item): + if not _blocking_io_probe_enabled(item.config) or _blocking_io_probe_skipped(item): + yield + return + + detector = detect_blocking_io(fail_on_exit=False, stack_limit=18) + detector.__enter__() + setattr(item, _BLOCKING_IO_DETECTOR_ATTR, detector) + yield + + +@pytest.hookimpl(hookwrapper=True) +def pytest_runtest_teardown(item: pytest.Item): + yield + + detector = getattr(item, _BLOCKING_IO_DETECTOR_ATTR, None) + if detector is None: + return + + try: + detector.__exit__(None, None, None) + _blocking_io_probe.record(item.nodeid, detector.violations) + finally: + delattr(item, _BLOCKING_IO_DETECTOR_ATTR) + + +def pytest_sessionfinish(session: pytest.Session) -> None: + if _blocking_io_fail_enabled(session.config) and _blocking_io_probe.violation_count and session.exitstatus == pytest.ExitCode.OK: + session.exitstatus = pytest.ExitCode.TESTS_FAILED + + +def pytest_terminal_summary(terminalreporter: pytest.TerminalReporter) -> None: + if not _blocking_io_probe_enabled(terminalreporter.config): + return + + header, *details = _blocking_io_probe.format_summary().splitlines() + terminalreporter.write_sep("=", header) + for line in details: + terminalreporter.write_line(line) + + +def _blocking_io_probe_enabled(config: pytest.Config) -> bool: + return bool(config.getoption("--detect-blocking-io") or config.getoption("--detect-blocking-io-fail")) + + +def _blocking_io_fail_enabled(config: pytest.Config) -> bool: + return bool(config.getoption("--detect-blocking-io-fail")) + + +def _blocking_io_probe_skipped(item: pytest.Item) -> bool: + return item.path.name == "test_blocking_io_detector.py" or item.get_closest_marker("no_blocking_io_probe") is not None + + # --------------------------------------------------------------------------- # Auto-set user context for every test unless marked no_auto_user # --------------------------------------------------------------------------- diff --git a/backend/tests/support/__init__.py b/backend/tests/support/__init__.py new file mode 100644 index 000000000..38361eaf5 --- /dev/null +++ b/backend/tests/support/__init__.py @@ -0,0 +1 @@ +"""Shared test support helpers.""" diff --git a/backend/tests/support/detectors/__init__.py b/backend/tests/support/detectors/__init__.py new file mode 100644 index 000000000..cf9568cb6 --- /dev/null +++ b/backend/tests/support/detectors/__init__.py @@ -0,0 +1 @@ +"""Runtime and static detectors used by tests.""" diff --git a/backend/tests/support/detectors/blocking_io.py b/backend/tests/support/detectors/blocking_io.py new file mode 100644 index 000000000..c1adfd55a --- /dev/null +++ b/backend/tests/support/detectors/blocking_io.py @@ -0,0 +1,287 @@ +"""Test helper for detecting blocking calls on an asyncio event loop. + +The detector is intentionally test-only. It monkeypatches a small set of +well-known blocking entry points and their already-loaded module-level aliases, +then records calls only when they happen on a thread that is currently running +an asyncio event loop. Aliases captured in closures or default arguments remain +out of scope. +""" + +from __future__ import annotations + +import asyncio +import importlib +import sys +import traceback +from collections import Counter +from collections.abc import Callable, Iterable, Iterator +from contextlib import AbstractContextManager +from dataclasses import dataclass +from functools import wraps +from pathlib import Path +from types import TracebackType +from typing import Any + +BlockingCallable = Callable[..., Any] + + +@dataclass(frozen=True) +class BlockingCallSpec: + """Describes one blocking callable to wrap during a detector run.""" + + name: str + target: str + record_on_iteration: bool = False + + +@dataclass(frozen=True) +class BlockingCall: + """One blocking call observed on an asyncio event loop thread.""" + + name: str + target: str + stack: tuple[traceback.FrameSummary, ...] + + +DEFAULT_BLOCKING_CALL_SPECS: tuple[BlockingCallSpec, ...] = ( + BlockingCallSpec("time.sleep", "time:sleep"), + BlockingCallSpec("requests.Session.request", "requests.sessions:Session.request"), + BlockingCallSpec("httpx.Client.request", "httpx:Client.request"), + BlockingCallSpec("os.walk", "os:walk", record_on_iteration=True), + BlockingCallSpec("pathlib.Path.resolve", "pathlib:Path.resolve"), + BlockingCallSpec("pathlib.Path.read_text", "pathlib:Path.read_text"), + BlockingCallSpec("pathlib.Path.write_text", "pathlib:Path.write_text"), +) + + +def _is_event_loop_thread() -> bool: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return False + return loop.is_running() + + +def _resolve_target(target: str) -> tuple[object, str, BlockingCallable]: + module_name, attr_path = target.split(":", maxsplit=1) + owner: object = importlib.import_module(module_name) + parts = attr_path.split(".") + for part in parts[:-1]: + owner = getattr(owner, part) + + attr_name = parts[-1] + original = getattr(owner, attr_name) + return owner, attr_name, original + + +def _trim_detector_frames(stack: Iterable[traceback.FrameSummary]) -> tuple[traceback.FrameSummary, ...]: + return tuple(frame for frame in stack if frame.filename != __file__) + + +class BlockingIODetector(AbstractContextManager["BlockingIODetector"]): + """Record blocking calls made from async runtime code. + + By default the detector reports violations but does not fail on context + exit. Tests can set ``fail_on_exit=True`` or call + ``assert_no_blocking_calls()`` explicitly. + """ + + def __init__( + self, + specs: Iterable[BlockingCallSpec] = DEFAULT_BLOCKING_CALL_SPECS, + *, + fail_on_exit: bool = False, + patch_loaded_aliases: bool = True, + stack_limit: int = 12, + ) -> None: + self._specs = tuple(specs) + self._fail_on_exit = fail_on_exit + self._patch_loaded_aliases_enabled = patch_loaded_aliases + self._stack_limit = stack_limit + self._patches: list[tuple[object, str, BlockingCallable]] = [] + self._patch_keys: set[tuple[int, str]] = set() + self.violations: list[BlockingCall] = [] + self._active = False + + def __enter__(self) -> BlockingIODetector: + try: + self._active = True + alias_replacements: dict[int, BlockingCallable] = {} + for spec in self._specs: + owner, attr_name, original = _resolve_target(spec.target) + wrapper = self._wrap(spec, original) + self._patch_attribute(owner, attr_name, original, wrapper) + alias_replacements[id(original)] = wrapper + + if self._patch_loaded_aliases_enabled: + self._patch_loaded_module_aliases(alias_replacements) + except Exception: + self._restore() + self._active = False + raise + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback_value: TracebackType | None, + ) -> bool | None: + self._restore() + self._active = False + if exc_type is None and self._fail_on_exit: + self.assert_no_blocking_calls() + return None + + def _restore(self) -> None: + for owner, attr_name, original in reversed(self._patches): + setattr(owner, attr_name, original) + self._patches.clear() + self._patch_keys.clear() + + def _patch_attribute(self, owner: object, attr_name: str, original: BlockingCallable, replacement: BlockingCallable) -> None: + key = (id(owner), attr_name) + if key in self._patch_keys: + return + setattr(owner, attr_name, replacement) + self._patches.append((owner, attr_name, original)) + self._patch_keys.add(key) + + def _patch_loaded_module_aliases(self, replacements_by_id: dict[int, BlockingCallable]) -> None: + for module in tuple(sys.modules.values()): + namespace = getattr(module, "__dict__", None) + if not isinstance(namespace, dict): + continue + + for attr_name, value in tuple(namespace.items()): + replacement = replacements_by_id.get(id(value)) + if replacement is not None: + self._patch_attribute(module, attr_name, value, replacement) + + def _wrap(self, spec: BlockingCallSpec, original: BlockingCallable) -> BlockingCallable: + @wraps(original) + def wrapper(*args: Any, **kwargs: Any) -> Any: + if spec.record_on_iteration: + result = original(*args, **kwargs) + return self._wrap_iteration(spec, result) + self._record_if_blocking(spec) + return original(*args, **kwargs) + + return wrapper + + def _wrap_iteration(self, spec: BlockingCallSpec, iterable: Iterable[Any]) -> Iterator[Any]: + iterator = iter(iterable) + reported = False + + while True: + if not reported: + reported = self._record_if_blocking(spec) + try: + yield next(iterator) + except StopIteration: + return + + def _record_if_blocking(self, spec: BlockingCallSpec) -> bool: + if self._active and _is_event_loop_thread(): + stack = _trim_detector_frames(traceback.extract_stack(limit=self._stack_limit)) + self.violations.append(BlockingCall(spec.name, spec.target, stack)) + return True + return False + + def assert_no_blocking_calls(self) -> None: + if self.violations: + raise AssertionError(format_blocking_calls(self.violations)) + + +class BlockingIOProbe: + """Collect detector output across tests and format a compact summary.""" + + def __init__(self, project_root: Path) -> None: + self._project_root = project_root.resolve() + self._observed: list[tuple[str, BlockingCall]] = [] + + @property + def violation_count(self) -> int: + return len(self._observed) + + @property + def test_count(self) -> int: + return len({nodeid for nodeid, _violation in self._observed}) + + def clear(self) -> None: + self._observed.clear() + + def record(self, nodeid: str, violations: Iterable[BlockingCall]) -> None: + for violation in violations: + self._observed.append((nodeid, violation)) + + def format_summary(self, *, limit: int = 30) -> str: + if not self._observed: + return "blocking io probe: no violations" + + call_sites: Counter[tuple[str, str, int, str, str]] = Counter() + for _nodeid, violation in self._observed: + frame = self._local_call_site(violation.stack) + if frame is None: + call_sites[(violation.name, "", 0, "", "")] += 1 + continue + + call_sites[ + ( + violation.name, + self._relative(frame.filename), + frame.lineno, + frame.name, + (frame.line or "").strip(), + ) + ] += 1 + + lines = [f"blocking io probe: {self.violation_count} violations across {self.test_count} tests", "Top call sites:"] + for (name, filename, lineno, function, line), count in call_sites.most_common(limit): + lines.append(f"{count:4d} {name} {filename}:{lineno} {function} | {line}") + return "\n".join(lines) + + def _relative(self, filename: str) -> str: + try: + return str(Path(filename).resolve().relative_to(self._project_root)) + except ValueError: + return filename + + def _local_call_site(self, stack: tuple[traceback.FrameSummary, ...]) -> traceback.FrameSummary | None: + local_frames = [frame for frame in stack if str(self._project_root) in frame.filename and "/.venv/" not in frame.filename and not self._relative(frame.filename).startswith("tests/")] + if local_frames: + return local_frames[-1] + + test_frames = [frame for frame in stack if str(self._project_root) in frame.filename and "/.venv/" not in frame.filename] + return test_frames[-1] if test_frames else None + + +def detect_blocking_io( + specs: Iterable[BlockingCallSpec] = DEFAULT_BLOCKING_CALL_SPECS, + *, + fail_on_exit: bool = False, + patch_loaded_aliases: bool = True, + stack_limit: int = 12, +) -> BlockingIODetector: + """Create a detector context manager for a focused test scope.""" + + return BlockingIODetector(specs, fail_on_exit=fail_on_exit, patch_loaded_aliases=patch_loaded_aliases, stack_limit=stack_limit) + + +def format_blocking_calls(violations: Iterable[BlockingCall]) -> str: + """Format detector output with enough stack context to locate call sites.""" + + lines = ["Blocking calls were executed on an asyncio event loop thread:"] + for index, violation in enumerate(violations, start=1): + lines.append(f"{index}. {violation.name} ({violation.target})") + lines.extend(_format_stack(violation.stack)) + return "\n".join(lines) + + +def _format_stack(stack: Iterable[traceback.FrameSummary]) -> Iterator[str]: + for frame in stack: + location = f"{frame.filename}:{frame.lineno}" + lines = [f" at {frame.name} ({location})"] + if frame.line: + lines.append(f" {frame.line.strip()}") + yield from lines diff --git a/backend/tests/test_blocking_io_detector.py b/backend/tests/test_blocking_io_detector.py new file mode 100644 index 000000000..af44d746d --- /dev/null +++ b/backend/tests/test_blocking_io_detector.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import asyncio +import os +import time +from os import walk as imported_walk +from pathlib import Path +from time import sleep as imported_sleep + +import httpx +import pytest +import requests +from support.detectors.blocking_io import ( + BlockingCallSpec, + BlockingIOProbe, + detect_blocking_io, +) + +pytestmark = pytest.mark.asyncio + + +TIME_SLEEP_ONLY = (BlockingCallSpec("time.sleep", "time:sleep"),) +REQUESTS_ONLY = (BlockingCallSpec("requests.Session.request", "requests.sessions:Session.request"),) +HTTPX_ONLY = (BlockingCallSpec("httpx.Client.request", "httpx:Client.request"),) +OS_WALK_ONLY = (BlockingCallSpec("os.walk", "os:walk", record_on_iteration=True),) +PATH_READ_TEXT_ONLY = (BlockingCallSpec("pathlib.Path.read_text", "pathlib:Path.read_text"),) + + +async def test_records_time_sleep_on_event_loop() -> None: + with detect_blocking_io(TIME_SLEEP_ONLY) as detector: + time.sleep(0) + + assert [violation.name for violation in detector.violations] == ["time.sleep"] + + +async def test_records_already_imported_sleep_alias_on_event_loop() -> None: + original_alias = imported_sleep + + with detect_blocking_io(TIME_SLEEP_ONLY) as detector: + imported_sleep(0) + + assert imported_sleep is original_alias + assert [violation.name for violation in detector.violations] == ["time.sleep"] + + +async def test_can_disable_loaded_alias_patching() -> None: + with detect_blocking_io(TIME_SLEEP_ONLY, patch_loaded_aliases=False) as detector: + imported_sleep(0) + + assert detector.violations == [] + + +async def test_does_not_record_time_sleep_offloaded_to_thread() -> None: + with detect_blocking_io(TIME_SLEEP_ONLY) as detector: + await asyncio.to_thread(time.sleep, 0) + + assert detector.violations == [] + + +async def test_fixture_allows_offloaded_sync_work(blocking_io_detector) -> None: + await asyncio.to_thread(time.sleep, 0) + + assert blocking_io_detector.violations == [] + + +async def test_does_not_record_sync_call_without_running_event_loop() -> None: + def call_sleep() -> list[str]: + with detect_blocking_io(TIME_SLEEP_ONLY) as detector: + time.sleep(0) + return [violation.name for violation in detector.violations] + + assert await asyncio.to_thread(call_sleep) == [] + + +async def test_fail_on_exit_includes_call_site() -> None: + with pytest.raises(AssertionError) as exc_info: + with detect_blocking_io(TIME_SLEEP_ONLY, fail_on_exit=True): + time.sleep(0) + + message = str(exc_info.value) + assert "time.sleep" in message + assert "test_fail_on_exit_includes_call_site" in message + + +async def test_records_requests_session_request_without_real_network(monkeypatch: pytest.MonkeyPatch) -> None: + def fake_request(self: requests.Session, method: str, url: str, **kwargs: object) -> str: + return f"{method}:{url}" + + monkeypatch.setattr(requests.sessions.Session, "request", fake_request) + + with detect_blocking_io(REQUESTS_ONLY) as detector: + assert requests.get("https://example.invalid") == "get:https://example.invalid" + + assert [violation.name for violation in detector.violations] == ["requests.Session.request"] + + +async def test_records_sync_httpx_client_request_without_real_network(monkeypatch: pytest.MonkeyPatch) -> None: + def fake_request(self: httpx.Client, method: str, url: str, **kwargs: object) -> httpx.Response: + return httpx.Response(200, request=httpx.Request(method, url)) + + monkeypatch.setattr(httpx.Client, "request", fake_request) + + with detect_blocking_io(HTTPX_ONLY) as detector: + with httpx.Client() as client: + response = client.get("https://example.invalid") + + assert response.status_code == 200 + assert [violation.name for violation in detector.violations] == ["httpx.Client.request"] + + +async def test_records_os_walk_on_event_loop(tmp_path: Path) -> None: + (tmp_path / "nested").mkdir() + + with detect_blocking_io(OS_WALK_ONLY) as detector: + assert list(os.walk(tmp_path)) + + assert [violation.name for violation in detector.violations] == ["os.walk"] + + +async def test_records_already_imported_os_walk_alias_on_iteration(tmp_path: Path) -> None: + (tmp_path / "nested").mkdir() + original_alias = imported_walk + + with detect_blocking_io(OS_WALK_ONLY) as detector: + assert list(imported_walk(tmp_path)) + + assert imported_walk is original_alias + assert [violation.name for violation in detector.violations] == ["os.walk"] + + +async def test_does_not_record_os_walk_before_iteration(tmp_path: Path) -> None: + with detect_blocking_io(OS_WALK_ONLY) as detector: + walker = os.walk(tmp_path) + + assert list(walker) + assert detector.violations == [] + + +async def test_does_not_record_os_walk_iterated_off_event_loop(tmp_path: Path) -> None: + (tmp_path / "nested").mkdir() + + with detect_blocking_io(OS_WALK_ONLY) as detector: + walker = os.walk(tmp_path) + assert await asyncio.to_thread(lambda: list(walker)) + + assert detector.violations == [] + + +async def test_records_path_read_text_on_event_loop(tmp_path: Path) -> None: + path = tmp_path / "data.txt" + path.write_text("content", encoding="utf-8") + + with detect_blocking_io(PATH_READ_TEXT_ONLY) as detector: + assert path.read_text(encoding="utf-8") == "content" + + assert [violation.name for violation in detector.violations] == ["pathlib.Path.read_text"] + + +async def test_probe_formats_summary_for_recorded_violations(tmp_path: Path) -> None: + probe = BlockingIOProbe(Path(__file__).resolve().parents[1]) + path = tmp_path / "data.txt" + path.write_text("content", encoding="utf-8") + + with detect_blocking_io(PATH_READ_TEXT_ONLY, stack_limit=18) as detector: + assert path.read_text(encoding="utf-8") == "content" + + probe.record("tests/test_example.py::test_example", detector.violations) + summary = probe.format_summary() + + assert "blocking io probe: 1 violations across 1 tests" in summary + assert "pathlib.Path.read_text" in summary + + +async def test_probe_formats_empty_summary_and_can_be_cleared(tmp_path: Path) -> None: + probe = BlockingIOProbe(Path(__file__).resolve().parents[1]) + + assert probe.format_summary() == "blocking io probe: no violations" + + path = tmp_path / "data.txt" + path.write_text("content", encoding="utf-8") + with detect_blocking_io(PATH_READ_TEXT_ONLY, stack_limit=18) as detector: + assert path.read_text(encoding="utf-8") == "content" + + probe.record("tests/test_example.py::test_example", detector.violations) + assert probe.violation_count == 1 + + probe.clear() + + assert probe.violation_count == 0 + assert probe.format_summary() == "blocking io probe: no violations" diff --git a/backend/tests/test_blocking_io_probe_integration.py b/backend/tests/test_blocking_io_probe_integration.py new file mode 100644 index 000000000..af7a31b9d --- /dev/null +++ b/backend/tests/test_blocking_io_probe_integration.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import time + +import pytest + +ORIGINAL_SLEEP = time.sleep + + +def replacement_sleep(seconds: float) -> None: + return None + + +def test_probe_survives_monkeypatch_teardown(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(time, "sleep", replacement_sleep) + assert time.sleep is replacement_sleep + + +@pytest.mark.no_blocking_io_probe +def test_probe_restores_original_after_monkeypatch_teardown() -> None: + assert time.sleep is ORIGINAL_SLEEP + assert getattr(time.sleep, "__wrapped__", None) is None