mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-15 21:23:41 +00:00
test: add blocking IO detector (#2924)
* test: add blocking IO detector * test: add blocking IO probe option * test: harden blocking IO probe lifecycle * test: move blocking io detector to support
This commit is contained in:
parent
eab7ae3d62
commit
6e8e6a969b
@ -4,6 +4,8 @@ Sets up sys.path and pre-mocks modules that would cause circular import
|
||||
issues when unit-testing lightweight config/registry code in isolation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
@ -11,11 +13,16 @@ from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from support.detectors.blocking_io import BlockingIOProbe, detect_blocking_io
|
||||
|
||||
# Make 'app' and 'deerflow' importable from any working directory
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "scripts"))
|
||||
|
||||
_BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
||||
_blocking_io_probe = BlockingIOProbe(_BACKEND_ROOT)
|
||||
_BLOCKING_IO_DETECTOR_ATTR = "_blocking_io_detector"
|
||||
|
||||
# Break the circular import chain that exists in production code:
|
||||
# deerflow.subagents.__init__
|
||||
# -> .executor (SubagentExecutor, SubagentResult)
|
||||
@ -56,6 +63,92 @@ def provisioner_module():
|
||||
return module
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def blocking_io_detector():
|
||||
"""Fail a focused test if blocking calls run on the event loop thread."""
|
||||
with detect_blocking_io(fail_on_exit=True) as detector:
|
||||
yield detector
|
||||
|
||||
|
||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
group = parser.getgroup("blocking-io")
|
||||
group.addoption(
|
||||
"--detect-blocking-io",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Collect blocking calls made while an asyncio event loop is running and report a summary.",
|
||||
)
|
||||
group.addoption(
|
||||
"--detect-blocking-io-fail",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Set a failing exit status when --detect-blocking-io records violations.",
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config: pytest.Config) -> None:
|
||||
config.addinivalue_line("markers", "no_blocking_io_probe: skip the optional blocking IO probe")
|
||||
|
||||
|
||||
def pytest_sessionstart(session: pytest.Session) -> None:
|
||||
if _blocking_io_probe_enabled(session.config):
|
||||
_blocking_io_probe.clear()
|
||||
|
||||
|
||||
@pytest.hookimpl(hookwrapper=True)
|
||||
def pytest_runtest_call(item: pytest.Item):
|
||||
if not _blocking_io_probe_enabled(item.config) or _blocking_io_probe_skipped(item):
|
||||
yield
|
||||
return
|
||||
|
||||
detector = detect_blocking_io(fail_on_exit=False, stack_limit=18)
|
||||
detector.__enter__()
|
||||
setattr(item, _BLOCKING_IO_DETECTOR_ATTR, detector)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.hookimpl(hookwrapper=True)
|
||||
def pytest_runtest_teardown(item: pytest.Item):
|
||||
yield
|
||||
|
||||
detector = getattr(item, _BLOCKING_IO_DETECTOR_ATTR, None)
|
||||
if detector is None:
|
||||
return
|
||||
|
||||
try:
|
||||
detector.__exit__(None, None, None)
|
||||
_blocking_io_probe.record(item.nodeid, detector.violations)
|
||||
finally:
|
||||
delattr(item, _BLOCKING_IO_DETECTOR_ATTR)
|
||||
|
||||
|
||||
def pytest_sessionfinish(session: pytest.Session) -> None:
|
||||
if _blocking_io_fail_enabled(session.config) and _blocking_io_probe.violation_count and session.exitstatus == pytest.ExitCode.OK:
|
||||
session.exitstatus = pytest.ExitCode.TESTS_FAILED
|
||||
|
||||
|
||||
def pytest_terminal_summary(terminalreporter: pytest.TerminalReporter) -> None:
|
||||
if not _blocking_io_probe_enabled(terminalreporter.config):
|
||||
return
|
||||
|
||||
header, *details = _blocking_io_probe.format_summary().splitlines()
|
||||
terminalreporter.write_sep("=", header)
|
||||
for line in details:
|
||||
terminalreporter.write_line(line)
|
||||
|
||||
|
||||
def _blocking_io_probe_enabled(config: pytest.Config) -> bool:
|
||||
return bool(config.getoption("--detect-blocking-io") or config.getoption("--detect-blocking-io-fail"))
|
||||
|
||||
|
||||
def _blocking_io_fail_enabled(config: pytest.Config) -> bool:
|
||||
return bool(config.getoption("--detect-blocking-io-fail"))
|
||||
|
||||
|
||||
def _blocking_io_probe_skipped(item: pytest.Item) -> bool:
|
||||
return item.path.name == "test_blocking_io_detector.py" or item.get_closest_marker("no_blocking_io_probe") is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auto-set user context for every test unless marked no_auto_user
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
1
backend/tests/support/__init__.py
Normal file
1
backend/tests/support/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Shared test support helpers."""
|
||||
1
backend/tests/support/detectors/__init__.py
Normal file
1
backend/tests/support/detectors/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Runtime and static detectors used by tests."""
|
||||
287
backend/tests/support/detectors/blocking_io.py
Normal file
287
backend/tests/support/detectors/blocking_io.py
Normal file
@ -0,0 +1,287 @@
|
||||
"""Test helper for detecting blocking calls on an asyncio event loop.
|
||||
|
||||
The detector is intentionally test-only. It monkeypatches a small set of
|
||||
well-known blocking entry points and their already-loaded module-level aliases,
|
||||
then records calls only when they happen on a thread that is currently running
|
||||
an asyncio event loop. Aliases captured in closures or default arguments remain
|
||||
out of scope.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import sys
|
||||
import traceback
|
||||
from collections import Counter
|
||||
from collections.abc import Callable, Iterable, Iterator
|
||||
from contextlib import AbstractContextManager
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
|
||||
BlockingCallable = Callable[..., Any]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BlockingCallSpec:
|
||||
"""Describes one blocking callable to wrap during a detector run."""
|
||||
|
||||
name: str
|
||||
target: str
|
||||
record_on_iteration: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BlockingCall:
|
||||
"""One blocking call observed on an asyncio event loop thread."""
|
||||
|
||||
name: str
|
||||
target: str
|
||||
stack: tuple[traceback.FrameSummary, ...]
|
||||
|
||||
|
||||
DEFAULT_BLOCKING_CALL_SPECS: tuple[BlockingCallSpec, ...] = (
|
||||
BlockingCallSpec("time.sleep", "time:sleep"),
|
||||
BlockingCallSpec("requests.Session.request", "requests.sessions:Session.request"),
|
||||
BlockingCallSpec("httpx.Client.request", "httpx:Client.request"),
|
||||
BlockingCallSpec("os.walk", "os:walk", record_on_iteration=True),
|
||||
BlockingCallSpec("pathlib.Path.resolve", "pathlib:Path.resolve"),
|
||||
BlockingCallSpec("pathlib.Path.read_text", "pathlib:Path.read_text"),
|
||||
BlockingCallSpec("pathlib.Path.write_text", "pathlib:Path.write_text"),
|
||||
)
|
||||
|
||||
|
||||
def _is_event_loop_thread() -> bool:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return False
|
||||
return loop.is_running()
|
||||
|
||||
|
||||
def _resolve_target(target: str) -> tuple[object, str, BlockingCallable]:
|
||||
module_name, attr_path = target.split(":", maxsplit=1)
|
||||
owner: object = importlib.import_module(module_name)
|
||||
parts = attr_path.split(".")
|
||||
for part in parts[:-1]:
|
||||
owner = getattr(owner, part)
|
||||
|
||||
attr_name = parts[-1]
|
||||
original = getattr(owner, attr_name)
|
||||
return owner, attr_name, original
|
||||
|
||||
|
||||
def _trim_detector_frames(stack: Iterable[traceback.FrameSummary]) -> tuple[traceback.FrameSummary, ...]:
|
||||
return tuple(frame for frame in stack if frame.filename != __file__)
|
||||
|
||||
|
||||
class BlockingIODetector(AbstractContextManager["BlockingIODetector"]):
|
||||
"""Record blocking calls made from async runtime code.
|
||||
|
||||
By default the detector reports violations but does not fail on context
|
||||
exit. Tests can set ``fail_on_exit=True`` or call
|
||||
``assert_no_blocking_calls()`` explicitly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
specs: Iterable[BlockingCallSpec] = DEFAULT_BLOCKING_CALL_SPECS,
|
||||
*,
|
||||
fail_on_exit: bool = False,
|
||||
patch_loaded_aliases: bool = True,
|
||||
stack_limit: int = 12,
|
||||
) -> None:
|
||||
self._specs = tuple(specs)
|
||||
self._fail_on_exit = fail_on_exit
|
||||
self._patch_loaded_aliases_enabled = patch_loaded_aliases
|
||||
self._stack_limit = stack_limit
|
||||
self._patches: list[tuple[object, str, BlockingCallable]] = []
|
||||
self._patch_keys: set[tuple[int, str]] = set()
|
||||
self.violations: list[BlockingCall] = []
|
||||
self._active = False
|
||||
|
||||
def __enter__(self) -> BlockingIODetector:
|
||||
try:
|
||||
self._active = True
|
||||
alias_replacements: dict[int, BlockingCallable] = {}
|
||||
for spec in self._specs:
|
||||
owner, attr_name, original = _resolve_target(spec.target)
|
||||
wrapper = self._wrap(spec, original)
|
||||
self._patch_attribute(owner, attr_name, original, wrapper)
|
||||
alias_replacements[id(original)] = wrapper
|
||||
|
||||
if self._patch_loaded_aliases_enabled:
|
||||
self._patch_loaded_module_aliases(alias_replacements)
|
||||
except Exception:
|
||||
self._restore()
|
||||
self._active = False
|
||||
raise
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback_value: TracebackType | None,
|
||||
) -> bool | None:
|
||||
self._restore()
|
||||
self._active = False
|
||||
if exc_type is None and self._fail_on_exit:
|
||||
self.assert_no_blocking_calls()
|
||||
return None
|
||||
|
||||
def _restore(self) -> None:
|
||||
for owner, attr_name, original in reversed(self._patches):
|
||||
setattr(owner, attr_name, original)
|
||||
self._patches.clear()
|
||||
self._patch_keys.clear()
|
||||
|
||||
def _patch_attribute(self, owner: object, attr_name: str, original: BlockingCallable, replacement: BlockingCallable) -> None:
|
||||
key = (id(owner), attr_name)
|
||||
if key in self._patch_keys:
|
||||
return
|
||||
setattr(owner, attr_name, replacement)
|
||||
self._patches.append((owner, attr_name, original))
|
||||
self._patch_keys.add(key)
|
||||
|
||||
def _patch_loaded_module_aliases(self, replacements_by_id: dict[int, BlockingCallable]) -> None:
|
||||
for module in tuple(sys.modules.values()):
|
||||
namespace = getattr(module, "__dict__", None)
|
||||
if not isinstance(namespace, dict):
|
||||
continue
|
||||
|
||||
for attr_name, value in tuple(namespace.items()):
|
||||
replacement = replacements_by_id.get(id(value))
|
||||
if replacement is not None:
|
||||
self._patch_attribute(module, attr_name, value, replacement)
|
||||
|
||||
def _wrap(self, spec: BlockingCallSpec, original: BlockingCallable) -> BlockingCallable:
|
||||
@wraps(original)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
if spec.record_on_iteration:
|
||||
result = original(*args, **kwargs)
|
||||
return self._wrap_iteration(spec, result)
|
||||
self._record_if_blocking(spec)
|
||||
return original(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
def _wrap_iteration(self, spec: BlockingCallSpec, iterable: Iterable[Any]) -> Iterator[Any]:
|
||||
iterator = iter(iterable)
|
||||
reported = False
|
||||
|
||||
while True:
|
||||
if not reported:
|
||||
reported = self._record_if_blocking(spec)
|
||||
try:
|
||||
yield next(iterator)
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
def _record_if_blocking(self, spec: BlockingCallSpec) -> bool:
|
||||
if self._active and _is_event_loop_thread():
|
||||
stack = _trim_detector_frames(traceback.extract_stack(limit=self._stack_limit))
|
||||
self.violations.append(BlockingCall(spec.name, spec.target, stack))
|
||||
return True
|
||||
return False
|
||||
|
||||
def assert_no_blocking_calls(self) -> None:
|
||||
if self.violations:
|
||||
raise AssertionError(format_blocking_calls(self.violations))
|
||||
|
||||
|
||||
class BlockingIOProbe:
|
||||
"""Collect detector output across tests and format a compact summary."""
|
||||
|
||||
def __init__(self, project_root: Path) -> None:
|
||||
self._project_root = project_root.resolve()
|
||||
self._observed: list[tuple[str, BlockingCall]] = []
|
||||
|
||||
@property
|
||||
def violation_count(self) -> int:
|
||||
return len(self._observed)
|
||||
|
||||
@property
|
||||
def test_count(self) -> int:
|
||||
return len({nodeid for nodeid, _violation in self._observed})
|
||||
|
||||
def clear(self) -> None:
|
||||
self._observed.clear()
|
||||
|
||||
def record(self, nodeid: str, violations: Iterable[BlockingCall]) -> None:
|
||||
for violation in violations:
|
||||
self._observed.append((nodeid, violation))
|
||||
|
||||
def format_summary(self, *, limit: int = 30) -> str:
|
||||
if not self._observed:
|
||||
return "blocking io probe: no violations"
|
||||
|
||||
call_sites: Counter[tuple[str, str, int, str, str]] = Counter()
|
||||
for _nodeid, violation in self._observed:
|
||||
frame = self._local_call_site(violation.stack)
|
||||
if frame is None:
|
||||
call_sites[(violation.name, "<unknown>", 0, "<unknown>", "")] += 1
|
||||
continue
|
||||
|
||||
call_sites[
|
||||
(
|
||||
violation.name,
|
||||
self._relative(frame.filename),
|
||||
frame.lineno,
|
||||
frame.name,
|
||||
(frame.line or "").strip(),
|
||||
)
|
||||
] += 1
|
||||
|
||||
lines = [f"blocking io probe: {self.violation_count} violations across {self.test_count} tests", "Top call sites:"]
|
||||
for (name, filename, lineno, function, line), count in call_sites.most_common(limit):
|
||||
lines.append(f"{count:4d} {name} {filename}:{lineno} {function} | {line}")
|
||||
return "\n".join(lines)
|
||||
|
||||
def _relative(self, filename: str) -> str:
|
||||
try:
|
||||
return str(Path(filename).resolve().relative_to(self._project_root))
|
||||
except ValueError:
|
||||
return filename
|
||||
|
||||
def _local_call_site(self, stack: tuple[traceback.FrameSummary, ...]) -> traceback.FrameSummary | None:
|
||||
local_frames = [frame for frame in stack if str(self._project_root) in frame.filename and "/.venv/" not in frame.filename and not self._relative(frame.filename).startswith("tests/")]
|
||||
if local_frames:
|
||||
return local_frames[-1]
|
||||
|
||||
test_frames = [frame for frame in stack if str(self._project_root) in frame.filename and "/.venv/" not in frame.filename]
|
||||
return test_frames[-1] if test_frames else None
|
||||
|
||||
|
||||
def detect_blocking_io(
|
||||
specs: Iterable[BlockingCallSpec] = DEFAULT_BLOCKING_CALL_SPECS,
|
||||
*,
|
||||
fail_on_exit: bool = False,
|
||||
patch_loaded_aliases: bool = True,
|
||||
stack_limit: int = 12,
|
||||
) -> BlockingIODetector:
|
||||
"""Create a detector context manager for a focused test scope."""
|
||||
|
||||
return BlockingIODetector(specs, fail_on_exit=fail_on_exit, patch_loaded_aliases=patch_loaded_aliases, stack_limit=stack_limit)
|
||||
|
||||
|
||||
def format_blocking_calls(violations: Iterable[BlockingCall]) -> str:
|
||||
"""Format detector output with enough stack context to locate call sites."""
|
||||
|
||||
lines = ["Blocking calls were executed on an asyncio event loop thread:"]
|
||||
for index, violation in enumerate(violations, start=1):
|
||||
lines.append(f"{index}. {violation.name} ({violation.target})")
|
||||
lines.extend(_format_stack(violation.stack))
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _format_stack(stack: Iterable[traceback.FrameSummary]) -> Iterator[str]:
|
||||
for frame in stack:
|
||||
location = f"{frame.filename}:{frame.lineno}"
|
||||
lines = [f" at {frame.name} ({location})"]
|
||||
if frame.line:
|
||||
lines.append(f" {frame.line.strip()}")
|
||||
yield from lines
|
||||
190
backend/tests/test_blocking_io_detector.py
Normal file
190
backend/tests/test_blocking_io_detector.py
Normal file
@ -0,0 +1,190 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from os import walk as imported_walk
|
||||
from pathlib import Path
|
||||
from time import sleep as imported_sleep
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
from support.detectors.blocking_io import (
|
||||
BlockingCallSpec,
|
||||
BlockingIOProbe,
|
||||
detect_blocking_io,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
TIME_SLEEP_ONLY = (BlockingCallSpec("time.sleep", "time:sleep"),)
|
||||
REQUESTS_ONLY = (BlockingCallSpec("requests.Session.request", "requests.sessions:Session.request"),)
|
||||
HTTPX_ONLY = (BlockingCallSpec("httpx.Client.request", "httpx:Client.request"),)
|
||||
OS_WALK_ONLY = (BlockingCallSpec("os.walk", "os:walk", record_on_iteration=True),)
|
||||
PATH_READ_TEXT_ONLY = (BlockingCallSpec("pathlib.Path.read_text", "pathlib:Path.read_text"),)
|
||||
|
||||
|
||||
async def test_records_time_sleep_on_event_loop() -> None:
|
||||
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
|
||||
time.sleep(0)
|
||||
|
||||
assert [violation.name for violation in detector.violations] == ["time.sleep"]
|
||||
|
||||
|
||||
async def test_records_already_imported_sleep_alias_on_event_loop() -> None:
|
||||
original_alias = imported_sleep
|
||||
|
||||
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
|
||||
imported_sleep(0)
|
||||
|
||||
assert imported_sleep is original_alias
|
||||
assert [violation.name for violation in detector.violations] == ["time.sleep"]
|
||||
|
||||
|
||||
async def test_can_disable_loaded_alias_patching() -> None:
|
||||
with detect_blocking_io(TIME_SLEEP_ONLY, patch_loaded_aliases=False) as detector:
|
||||
imported_sleep(0)
|
||||
|
||||
assert detector.violations == []
|
||||
|
||||
|
||||
async def test_does_not_record_time_sleep_offloaded_to_thread() -> None:
|
||||
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
|
||||
await asyncio.to_thread(time.sleep, 0)
|
||||
|
||||
assert detector.violations == []
|
||||
|
||||
|
||||
async def test_fixture_allows_offloaded_sync_work(blocking_io_detector) -> None:
|
||||
await asyncio.to_thread(time.sleep, 0)
|
||||
|
||||
assert blocking_io_detector.violations == []
|
||||
|
||||
|
||||
async def test_does_not_record_sync_call_without_running_event_loop() -> None:
|
||||
def call_sleep() -> list[str]:
|
||||
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
|
||||
time.sleep(0)
|
||||
return [violation.name for violation in detector.violations]
|
||||
|
||||
assert await asyncio.to_thread(call_sleep) == []
|
||||
|
||||
|
||||
async def test_fail_on_exit_includes_call_site() -> None:
|
||||
with pytest.raises(AssertionError) as exc_info:
|
||||
with detect_blocking_io(TIME_SLEEP_ONLY, fail_on_exit=True):
|
||||
time.sleep(0)
|
||||
|
||||
message = str(exc_info.value)
|
||||
assert "time.sleep" in message
|
||||
assert "test_fail_on_exit_includes_call_site" in message
|
||||
|
||||
|
||||
async def test_records_requests_session_request_without_real_network(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def fake_request(self: requests.Session, method: str, url: str, **kwargs: object) -> str:
|
||||
return f"{method}:{url}"
|
||||
|
||||
monkeypatch.setattr(requests.sessions.Session, "request", fake_request)
|
||||
|
||||
with detect_blocking_io(REQUESTS_ONLY) as detector:
|
||||
assert requests.get("https://example.invalid") == "get:https://example.invalid"
|
||||
|
||||
assert [violation.name for violation in detector.violations] == ["requests.Session.request"]
|
||||
|
||||
|
||||
async def test_records_sync_httpx_client_request_without_real_network(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def fake_request(self: httpx.Client, method: str, url: str, **kwargs: object) -> httpx.Response:
|
||||
return httpx.Response(200, request=httpx.Request(method, url))
|
||||
|
||||
monkeypatch.setattr(httpx.Client, "request", fake_request)
|
||||
|
||||
with detect_blocking_io(HTTPX_ONLY) as detector:
|
||||
with httpx.Client() as client:
|
||||
response = client.get("https://example.invalid")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert [violation.name for violation in detector.violations] == ["httpx.Client.request"]
|
||||
|
||||
|
||||
async def test_records_os_walk_on_event_loop(tmp_path: Path) -> None:
|
||||
(tmp_path / "nested").mkdir()
|
||||
|
||||
with detect_blocking_io(OS_WALK_ONLY) as detector:
|
||||
assert list(os.walk(tmp_path))
|
||||
|
||||
assert [violation.name for violation in detector.violations] == ["os.walk"]
|
||||
|
||||
|
||||
async def test_records_already_imported_os_walk_alias_on_iteration(tmp_path: Path) -> None:
|
||||
(tmp_path / "nested").mkdir()
|
||||
original_alias = imported_walk
|
||||
|
||||
with detect_blocking_io(OS_WALK_ONLY) as detector:
|
||||
assert list(imported_walk(tmp_path))
|
||||
|
||||
assert imported_walk is original_alias
|
||||
assert [violation.name for violation in detector.violations] == ["os.walk"]
|
||||
|
||||
|
||||
async def test_does_not_record_os_walk_before_iteration(tmp_path: Path) -> None:
|
||||
with detect_blocking_io(OS_WALK_ONLY) as detector:
|
||||
walker = os.walk(tmp_path)
|
||||
|
||||
assert list(walker)
|
||||
assert detector.violations == []
|
||||
|
||||
|
||||
async def test_does_not_record_os_walk_iterated_off_event_loop(tmp_path: Path) -> None:
|
||||
(tmp_path / "nested").mkdir()
|
||||
|
||||
with detect_blocking_io(OS_WALK_ONLY) as detector:
|
||||
walker = os.walk(tmp_path)
|
||||
assert await asyncio.to_thread(lambda: list(walker))
|
||||
|
||||
assert detector.violations == []
|
||||
|
||||
|
||||
async def test_records_path_read_text_on_event_loop(tmp_path: Path) -> None:
|
||||
path = tmp_path / "data.txt"
|
||||
path.write_text("content", encoding="utf-8")
|
||||
|
||||
with detect_blocking_io(PATH_READ_TEXT_ONLY) as detector:
|
||||
assert path.read_text(encoding="utf-8") == "content"
|
||||
|
||||
assert [violation.name for violation in detector.violations] == ["pathlib.Path.read_text"]
|
||||
|
||||
|
||||
async def test_probe_formats_summary_for_recorded_violations(tmp_path: Path) -> None:
|
||||
probe = BlockingIOProbe(Path(__file__).resolve().parents[1])
|
||||
path = tmp_path / "data.txt"
|
||||
path.write_text("content", encoding="utf-8")
|
||||
|
||||
with detect_blocking_io(PATH_READ_TEXT_ONLY, stack_limit=18) as detector:
|
||||
assert path.read_text(encoding="utf-8") == "content"
|
||||
|
||||
probe.record("tests/test_example.py::test_example", detector.violations)
|
||||
summary = probe.format_summary()
|
||||
|
||||
assert "blocking io probe: 1 violations across 1 tests" in summary
|
||||
assert "pathlib.Path.read_text" in summary
|
||||
|
||||
|
||||
async def test_probe_formats_empty_summary_and_can_be_cleared(tmp_path: Path) -> None:
|
||||
probe = BlockingIOProbe(Path(__file__).resolve().parents[1])
|
||||
|
||||
assert probe.format_summary() == "blocking io probe: no violations"
|
||||
|
||||
path = tmp_path / "data.txt"
|
||||
path.write_text("content", encoding="utf-8")
|
||||
with detect_blocking_io(PATH_READ_TEXT_ONLY, stack_limit=18) as detector:
|
||||
assert path.read_text(encoding="utf-8") == "content"
|
||||
|
||||
probe.record("tests/test_example.py::test_example", detector.violations)
|
||||
assert probe.violation_count == 1
|
||||
|
||||
probe.clear()
|
||||
|
||||
assert probe.violation_count == 0
|
||||
assert probe.format_summary() == "blocking io probe: no violations"
|
||||
22
backend/tests/test_blocking_io_probe_integration.py
Normal file
22
backend/tests/test_blocking_io_probe_integration.py
Normal file
@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
ORIGINAL_SLEEP = time.sleep
|
||||
|
||||
|
||||
def replacement_sleep(seconds: float) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def test_probe_survives_monkeypatch_teardown(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(time, "sleep", replacement_sleep)
|
||||
assert time.sleep is replacement_sleep
|
||||
|
||||
|
||||
@pytest.mark.no_blocking_io_probe
|
||||
def test_probe_restores_original_after_monkeypatch_teardown() -> None:
|
||||
assert time.sleep is ORIGINAL_SLEEP
|
||||
assert getattr(time.sleep, "__wrapped__", None) is None
|
||||
Loading…
x
Reference in New Issue
Block a user