deer-flow/backend/tests/test_llm_error_handling_middleware.py
Huixin615 88e36d9686
fix(#3189): prevent write_file streaming timeout on long reports (#3195)
* fix(#3189): prevent write_file streaming timeout on long reports

Adds a layered defense against StreamChunkTimeoutError caused by oversized
single-shot write_file tool calls:

- factory: default stream_chunk_timeout to 240s for OpenAI-compatible
  clients (overridable via ModelConfig.stream_chunk_timeout in config.yaml)
- sandbox/tools: server-side 80 KB length guard on non-append write_file
  calls (configurable via DEERFLOW_WRITE_FILE_MAX_BYTES env var, 0 disables);
  rejects oversized payloads with a structured error pointing the model at
  str_replace or append=True
- middleware: classify StreamChunkTimeoutError as transient but cap retries
  at 1 via per-exception _RETRY_BUDGET_OVERRIDES (same-payload retry on a
  chunk-gap timeout buffers the same way upstream; full 3-attempt loop
  would stack 6-12 min of dead air)
- middleware: surface an actionable user-facing message for stream-drop
  exceptions instead of leaking the raw langchain stack
- prompts: add a routing-style File Editing Workflow hint to both lead_agent
  and general_purpose subagent prompts, pointing the model at str_replace
  for incremental edits (mirrors Claude Code's Edit / Codex's apply_patch)
- tests: behavioural coverage for size guard, retry budget override,
  stream-drop user message, factory default injection

Refs #3189

* fix(#3189): drop stream_chunk_timeout for non-OpenAI providers

Address CR feedback on PR #3195:

- factory: pop `stream_chunk_timeout` from kwargs for any model_use_path other than `langchain_openai:ChatOpenAI` instead of returning early. `ModelConfig.stream_chunk_timeout` is part of the shared schema, so a user-supplied value on a non-OpenAI provider would otherwise be forwarded to its constructor and raise `TypeError: unexpected keyword argument`.

- factory: rewrite docstring to describe the actual `exclude_none=True` behaviour (explicit null is excluded and falls back to the default) instead of the misleading "None falling out via exclude_none=True keeps its value".

- tests: add regression coverage asserting the kwarg is stripped before reaching a non-OpenAI provider's constructor.

Refs: bytedance#3189

* fix(#3189): restrict stream-drop user copy to StreamChunkTimeoutError only

Per CR on #3195: narrow _STREAM_DROP_EXCEPTIONS to StreamChunkTimeoutError. Generic httpx RemoteProtocolError / ReadError fall back to the standard 'temporarily unavailable' copy, since they routinely fire on transient network blips where the 'split the output' guidance is misleading. Retry/backoff classification is unchanged — both remain transient/retriable. Tests updated to reflect new copy, plus a symmetric regression test for ReadError.

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-06-07 17:47:11 +08:00

681 lines
26 KiB
Python

from __future__ import annotations
import asyncio
from types import SimpleNamespace
from typing import Any
import pytest
from langchain_core.messages import AIMessage
from langgraph.errors import GraphBubbleUp
from deerflow.agents.middlewares.llm_error_handling_middleware import (
LLMErrorHandlingMiddleware,
)
from deerflow.config.app_config import AppConfig
from deerflow.config.sandbox_config import SandboxConfig
def _make_app_config() -> AppConfig:
"""Minimal AppConfig for middleware tests; circuit_breaker uses defaults."""
return AppConfig(sandbox=SandboxConfig(use="test"))
class FakeError(Exception):
def __init__(
self,
message: str,
*,
status_code: int | None = None,
code: str | None = None,
headers: dict[str, str] | None = None,
body: dict | None = None,
) -> None:
super().__init__(message)
self.status_code = status_code
self.code = code
self.body = body
self.response = SimpleNamespace(status_code=status_code, headers=headers or {}) if status_code is not None or headers else None
def _build_middleware(**attrs: int) -> LLMErrorHandlingMiddleware:
middleware = LLMErrorHandlingMiddleware(app_config=_make_app_config())
for key, value in attrs.items():
setattr(middleware, key, value)
return middleware
def test_async_model_call_retries_busy_provider_then_succeeds(
monkeypatch: pytest.MonkeyPatch,
) -> None:
middleware = _build_middleware(retry_max_attempts=3, retry_base_delay_ms=25, retry_cap_delay_ms=25)
attempts = 0
waits: list[float] = []
events: list[dict] = []
async def fake_sleep(delay: float) -> None:
waits.append(delay)
def fake_writer():
return events.append
async def handler(_request) -> AIMessage:
nonlocal attempts
attempts += 1
if attempts < 3:
raise FakeError("当前服务集群负载较高,请稍后重试,感谢您的耐心等待。 (2064)")
return AIMessage(content="ok")
monkeypatch.setattr("asyncio.sleep", fake_sleep)
monkeypatch.setattr(
"langgraph.config.get_stream_writer",
fake_writer,
)
result = asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler))
assert isinstance(result, AIMessage)
assert result.content == "ok"
assert attempts == 3
assert waits == [0.025, 0.025]
assert [event["type"] for event in events] == ["llm_retry", "llm_retry"]
def test_async_model_call_returns_user_message_for_quota_errors() -> None:
middleware = _build_middleware(retry_max_attempts=3)
async def handler(_request) -> AIMessage:
raise FakeError(
"insufficient_quota: account balance is empty",
status_code=429,
code="insufficient_quota",
)
result = asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler))
assert isinstance(result, AIMessage)
assert "out of quota" in str(result.content)
assert result.additional_kwargs["deerflow_error_fallback"] is True
assert result.additional_kwargs["error_reason"] == "quota"
assert result.additional_kwargs["error_type"] == "FakeError"
def test_async_model_call_marks_transient_retry_exhaustion_as_error_fallback(
monkeypatch: pytest.MonkeyPatch,
) -> None:
middleware = _build_middleware(retry_max_attempts=2, retry_base_delay_ms=25, retry_cap_delay_ms=25)
async def fake_sleep(_delay: float) -> None:
return None
async def handler(_request) -> AIMessage:
raise FakeError("Connection error.", status_code=503)
monkeypatch.setattr("asyncio.sleep", fake_sleep)
result = asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler))
assert isinstance(result, AIMessage)
assert "temporarily unavailable" in str(result.content)
assert result.additional_kwargs["deerflow_error_fallback"] is True
assert result.additional_kwargs["error_reason"] == "transient"
assert result.additional_kwargs["error_detail"] == "Connection error."
def test_sync_model_call_uses_retry_after_header(monkeypatch: pytest.MonkeyPatch) -> None:
middleware = _build_middleware(retry_max_attempts=2, retry_base_delay_ms=10, retry_cap_delay_ms=10)
waits: list[float] = []
attempts = 0
def fake_sleep(delay: float) -> None:
waits.append(delay)
def handler(_request) -> AIMessage:
nonlocal attempts
attempts += 1
if attempts == 1:
raise FakeError(
"server busy",
status_code=503,
headers={"Retry-After": "2"},
)
return AIMessage(content="ok")
monkeypatch.setattr("time.sleep", fake_sleep)
result = middleware.wrap_model_call(SimpleNamespace(), handler)
assert isinstance(result, AIMessage)
assert result.content == "ok"
assert waits == [2.0]
def test_sync_model_call_propagates_graph_bubble_up() -> None:
middleware = _build_middleware()
def handler(_request) -> AIMessage:
raise GraphBubbleUp()
with pytest.raises(GraphBubbleUp):
middleware.wrap_model_call(SimpleNamespace(), handler)
def test_async_model_call_propagates_graph_bubble_up() -> None:
middleware = _build_middleware()
async def handler(_request) -> AIMessage:
raise GraphBubbleUp()
with pytest.raises(GraphBubbleUp):
asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler))
def test_circuit_half_open_graph_bubble_up_resets_probe() -> None:
"""Verify that GraphBubbleUp in half_open state resets probe_in_flight."""
middleware = _build_middleware()
# Step 1: Manually set state to half_open and check_circuit() to set probe_in_flight=True
middleware._circuit_state = "half_open"
middleware._circuit_probe_in_flight = False
# Call _check_circuit() once to simulate the probe being allowed through
assert middleware._check_circuit() is False
assert middleware._circuit_probe_in_flight is True
# Step 2: Now trigger handler that raises GraphBubbleUp
def handler(_request) -> AIMessage:
raise GraphBubbleUp()
# Mock _check_circuit() to return False (since we already did the probe check)
import unittest.mock
with unittest.mock.patch.object(middleware, "_check_circuit", return_value=False):
with pytest.raises(GraphBubbleUp):
middleware.wrap_model_call(SimpleNamespace(), handler)
# Verify probe_in_flight was reset, state should remain half_open
assert middleware._circuit_probe_in_flight is False
assert middleware._circuit_state == "half_open"
@pytest.mark.anyio
async def test_async_circuit_half_open_graph_bubble_up_resets_probe() -> None:
"""Verify that GraphBubbleUp in half_open state resets probe_in_flight (async version)."""
middleware = _build_middleware()
# Step 1: Manually set state to half_open and check_circuit() to set probe_in_flight=True
middleware._circuit_state = "half_open"
middleware._circuit_probe_in_flight = False
# Call _check_circuit() once to simulate the probe being allowed through
assert middleware._check_circuit() is False
assert middleware._circuit_probe_in_flight is True
# Step 2: Now trigger handler that raises GraphBubbleUp
async def handler(_request) -> AIMessage:
raise GraphBubbleUp()
# Mock _check_circuit() to return False (since we already did the probe check)
import unittest.mock
with unittest.mock.patch.object(middleware, "_check_circuit", return_value=False):
with pytest.raises(GraphBubbleUp):
await middleware.awrap_model_call(SimpleNamespace(), handler)
# Verify probe_in_flight was reset, state should remain half_open
assert middleware._circuit_probe_in_flight is False
assert middleware._circuit_state == "half_open"
# ---------- Circuit Breaker Tests ----------
def transient_failing_handler(request: Any) -> Any:
raise FakeError("Server Error", status_code=502) # Used for transient error
def quota_failing_handler(request: Any) -> Any:
raise FakeError("Quota exceeded", body={"error": {"code": "insufficient_quota"}}) # Used for quota error
def success_handler(request: Any) -> Any:
return AIMessage(content="Success")
def mock_classify_retriable(exc: BaseException) -> tuple[bool, str]:
return True, "transient"
def mock_classify_non_retriable(exc: BaseException) -> tuple[bool, str]:
return False, "quota"
def test_circuit_breaker_trips_and_recovers(monkeypatch: pytest.MonkeyPatch) -> None:
"""Verify that circuit breaker trips, fast fails, correctly transitions to Half-Open, and recovers or re-opens."""
# Mock time.sleep to avoid slow tests during retry loops (Speed up from ~4s to 0.1s)
waits: list[float] = []
monkeypatch.setattr("time.sleep", lambda d: waits.append(d))
# Mock time.time to decouple from private implementation details and enable time travel
current_time = 1000.0
monkeypatch.setattr("time.time", lambda: current_time)
middleware = _build_middleware(circuit_failure_threshold=3, circuit_recovery_timeout_sec=10)
monkeypatch.setattr(middleware, "_classify_error", mock_classify_retriable)
request: Any = {"messages": []}
# --- 0. Test initial state & Success ---
# Success handler does not increase count. If it's already 0, it stays 0.
middleware.wrap_model_call(request, success_handler)
assert middleware._circuit_failure_count == 0
assert middleware._check_circuit() is False
# --- 1. Trip the circuit ---
# Fails 3 overall calls. Threshold (3) is reached.
middleware.wrap_model_call(request, transient_failing_handler)
assert middleware._circuit_failure_count == 1
middleware.wrap_model_call(request, transient_failing_handler)
assert middleware._circuit_failure_count == 2
middleware.wrap_model_call(request, transient_failing_handler)
assert middleware._circuit_failure_count == 3
assert middleware._check_circuit() is True # Circuit is OPEN
# --- 2. Fast Fail ---
# 2nd call: fast fail immediately without calling handler.
# Count should not increase during OPEN state.
result = middleware.wrap_model_call(request, success_handler)
assert result.content == middleware._build_circuit_breaker_message()
assert middleware._circuit_failure_count == 3
# --- 3. Half-Open -> Fail -> Re-Open ---
# Time travel 11 seconds (timeout is 10s). Current time becomes 1011.0
current_time += 11.0
# Verify that the timeout was set EXACTLY relative to current_time + timeout_sec
assert middleware._circuit_open_until == current_time - 11.0 + middleware.circuit_recovery_timeout_sec
# Fails again! The request will go through the 3-attempt retry loop again.
middleware.wrap_model_call(request, transient_failing_handler)
assert middleware._circuit_failure_count == middleware.circuit_failure_threshold
assert middleware._circuit_state == "open" # Re-OPENed
# --- 4. Half-Open -> Success -> Reset ---
# Time travel another 11 seconds
current_time += 11.0
# Succeeds this time! Should completely reset.
result = middleware.wrap_model_call(request, success_handler)
assert result.content == "Success"
assert middleware._circuit_failure_count == 0 # Fully RESET!
assert middleware._check_circuit() is False
def test_circuit_breaker_does_not_trip_on_non_retriable_errors(monkeypatch: pytest.MonkeyPatch) -> None:
"""Verify that circuit breaker ignores business errors like Quota or Auth."""
waits: list[float] = []
monkeypatch.setattr("time.sleep", lambda d: waits.append(d))
middleware = _build_middleware(circuit_failure_threshold=3)
monkeypatch.setattr(middleware, "_classify_error", mock_classify_non_retriable)
request: Any = {"messages": []}
for _ in range(3):
middleware.wrap_model_call(request, quota_failing_handler)
assert middleware._circuit_failure_count == 0
assert middleware._check_circuit() is False
# ---------- ReadError / RemoteProtocolError retriable classification ----------
class _ReadError(Exception):
"""Local stand-in for httpx.ReadError — same class name, no httpx dependency."""
class _RemoteProtocolError(Exception):
"""Local stand-in for httpx.RemoteProtocolError — same class name, no httpx dependency."""
_ReadError.__name__ = "ReadError"
_RemoteProtocolError.__name__ = "RemoteProtocolError"
def test_classify_error_read_error_is_retriable() -> None:
middleware = _build_middleware()
exc = _ReadError("Connection dropped mid-stream")
exc.__class__.__name__ = "ReadError"
retriable, reason = middleware._classify_error(exc)
assert retriable is True
assert reason == "transient"
def test_classify_error_remote_protocol_error_is_retriable() -> None:
middleware = _build_middleware()
exc = _RemoteProtocolError("Server closed connection unexpectedly")
exc.__class__.__name__ = "RemoteProtocolError"
retriable, reason = middleware._classify_error(exc)
assert retriable is True
assert reason == "transient"
def test_sync_read_error_triggers_retry_loop(monkeypatch: pytest.MonkeyPatch) -> None:
middleware = _build_middleware(retry_max_attempts=3, retry_base_delay_ms=10, retry_cap_delay_ms=10)
attempts = 0
waits: list[float] = []
monkeypatch.setattr("time.sleep", lambda d: waits.append(d))
def handler(_request) -> AIMessage:
nonlocal attempts
attempts += 1
raise _ReadError("Connection dropped mid-stream")
result = middleware.wrap_model_call(SimpleNamespace(), handler)
assert isinstance(result, AIMessage)
# ReadError is a generic connection drop, not a chunk-gap timeout, so
# it must fall back to the legacy transient copy rather than the
# specialized "split the work into smaller steps" guidance (#3195 CR).
assert "temporarily unavailable" in result.content
assert "streaming response was interrupted" not in result.content
assert attempts == 3 # exhausted all retries
assert len(waits) == 2 # slept between attempts 1→2 and 2→3
@pytest.mark.anyio
async def test_async_read_error_triggers_retry_loop(monkeypatch: pytest.MonkeyPatch) -> None:
middleware = _build_middleware(retry_max_attempts=3, retry_base_delay_ms=10, retry_cap_delay_ms=10)
attempts = 0
waits: list[float] = []
async def fake_sleep(d: float) -> None:
waits.append(d)
monkeypatch.setattr(asyncio, "sleep", fake_sleep)
async def handler(_request) -> AIMessage:
nonlocal attempts
attempts += 1
raise _ReadError("Connection dropped mid-stream")
result = await middleware.awrap_model_call(SimpleNamespace(), handler)
assert isinstance(result, AIMessage)
# ReadError is a generic connection drop, not a chunk-gap timeout, so
# it must fall back to the legacy transient copy rather than the
# specialized "split the work into smaller steps" guidance (#3195 CR).
assert "temporarily unavailable" in result.content
assert "streaming response was interrupted" not in result.content
assert attempts == 3 # exhausted all retries
assert len(waits) == 2 # slept between attempts 1→2 and 2→3
@pytest.mark.anyio
async def test_async_circuit_breaker_trips_and_recovers(monkeypatch: pytest.MonkeyPatch) -> None:
"""Verify async version of circuit breaker correctly handles state transitions."""
waits: list[float] = []
async def fake_sleep(d: float) -> None:
waits.append(d)
monkeypatch.setattr(asyncio, "sleep", fake_sleep)
current_time = 1000.0
monkeypatch.setattr("time.time", lambda: current_time)
middleware = _build_middleware(circuit_failure_threshold=3, circuit_recovery_timeout_sec=10)
monkeypatch.setattr(middleware, "_classify_error", mock_classify_retriable)
async def async_failing_handler(request: Any) -> Any:
raise FakeError("Server Error", status_code=502)
request: Any = {"messages": []}
# --- 1. Trip the circuit ---
# Fails 3 overall calls. Threshold (3) is reached.
await middleware.awrap_model_call(request, async_failing_handler)
assert middleware._circuit_failure_count == 1
await middleware.awrap_model_call(request, async_failing_handler)
assert middleware._circuit_failure_count == 2
await middleware.awrap_model_call(request, async_failing_handler)
assert middleware._circuit_failure_count == 3
assert middleware._check_circuit() is True
# --- 2. Fast Fail ---
# 2nd call: fast fail immediately without calling handler
async def async_success_handler(request: Any) -> Any:
return AIMessage(content="Success")
result = await middleware.awrap_model_call(request, async_success_handler)
assert result.content == middleware._build_circuit_breaker_message()
assert middleware._circuit_failure_count == 3 # Unchanged
# --- 3. Half-Open -> Fail -> Re-Open ---
# Time travel 11 seconds
current_time += 11.0
# Verify timeout formula
assert middleware._circuit_open_until == current_time - 11.0 + middleware.circuit_recovery_timeout_sec
# Fails again! The request goes through the 3-attempt retry loop.
await middleware.awrap_model_call(request, async_failing_handler)
assert middleware._circuit_failure_count == middleware.circuit_failure_threshold
assert middleware._circuit_state == "open" # Re-OPENed
# --- 4. Half-Open -> Success -> Reset ---
# Time travel another 11 seconds
current_time += 11.0
result = await middleware.awrap_model_call(request, async_success_handler)
assert result.content == "Success"
assert middleware._circuit_failure_count == 0 # RESET
assert middleware._check_circuit() is False
class _StreamChunkTimeoutError(Exception):
"""Local stand-in for langchain_openai's StreamChunkTimeoutError —
matched by class name, no langchain-openai import needed (mirrors
how this file already stubs httpx.ReadError / RemoteProtocolError).
"""
_StreamChunkTimeoutError.__name__ = "StreamChunkTimeoutError"
def test_classify_error_stream_chunk_timeout_is_retriable() -> None:
"""StreamChunkTimeoutError must be classified as transient/retriable."""
middleware = _build_middleware()
exc = _StreamChunkTimeoutError("No streaming chunk received for 120.0s (model=mimo-v2.5, chunks_received=58).")
exc.__class__.__name__ = "StreamChunkTimeoutError"
retriable, reason = middleware._classify_error(exc)
assert retriable is True
assert reason == "transient"
def test_sync_stream_chunk_timeout_retries_once(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Sync handler raising StreamChunkTimeoutError is retried exactly once —
the per-exception override caps it at 2 total attempts (1 first call + 1
retry) even when retry_max_attempts=3.
Same-payload retry on a chunk-gap timeout buffers the same way upstream;
a full 3-attempt loop would stack 6-12 minutes of dead air before
surfacing failure. We keep one cheap reconnect for genuine transient TCP
blips, then surface the failure so the model can re-plan on its next turn.
"""
middleware = _build_middleware(
retry_max_attempts=3,
retry_base_delay_ms=10,
retry_cap_delay_ms=10,
)
attempts = 0
waits: list[float] = []
monkeypatch.setattr("time.sleep", lambda d: waits.append(d))
def handler(_request) -> AIMessage:
nonlocal attempts
attempts += 1
raise _StreamChunkTimeoutError("No streaming chunk received for 120.0s")
result = middleware.wrap_model_call(SimpleNamespace(), handler)
assert isinstance(result, AIMessage)
assert "streaming response was interrupted" in result.content
# Override caps StreamChunkTimeoutError at 2 attempts (1 first call + 1 retry).
assert attempts == 2
# Exactly one sleep between the first attempt and the single retry.
assert len(waits) == 1
@pytest.mark.anyio
async def test_async_stream_chunk_timeout_retries_once(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Async mirror of the sync test: StreamChunkTimeoutError is capped at
2 attempts (1 first call + 1 retry) so we don't stack 6-12 minutes of
dead air on a same-payload buffering failure.
"""
middleware = _build_middleware(
retry_max_attempts=3,
retry_base_delay_ms=10,
retry_cap_delay_ms=10,
)
attempts = 0
waits: list[float] = []
async def fake_sleep(d: float) -> None:
waits.append(d)
monkeypatch.setattr(asyncio, "sleep", fake_sleep)
async def handler(_request) -> AIMessage:
nonlocal attempts
attempts += 1
raise _StreamChunkTimeoutError("No streaming chunk received for 120.0s")
result = await middleware.awrap_model_call(SimpleNamespace(), handler)
assert isinstance(result, AIMessage)
assert "streaming response was interrupted" in result.content
assert attempts == 2
# Exactly one sleep between the first attempt and the single retry.
assert len(waits) == 1
def test_max_attempts_for_returns_override_for_stream_chunk_timeout() -> None:
"""StreamChunkTimeoutError must use the tightened budget (2 = "keep one retry"),
not the default of 3."""
middleware = _build_middleware(retry_max_attempts=3)
exc = _StreamChunkTimeoutError("upstream stalled")
exc.__class__.__name__ = "StreamChunkTimeoutError"
assert middleware._max_attempts_for(exc) == 2
def test_max_attempts_for_falls_back_to_default_for_unlisted_exception() -> None:
"""ReadError / RemoteProtocolError keep the full retry budget — only
StreamChunkTimeoutError pays for stalling upstream for `stream_chunk_timeout`
seconds per attempt, so only it gets the tighter cap.
"""
middleware = _build_middleware(retry_max_attempts=3)
read_err = _ReadError("conn reset")
read_err.__class__.__name__ = "ReadError"
proto_err = _RemoteProtocolError("peer closed")
proto_err.__class__.__name__ = "RemoteProtocolError"
assert middleware._max_attempts_for(read_err) == 3
assert middleware._max_attempts_for(proto_err) == 3
assert middleware._max_attempts_for(FakeError("boom")) == 3
def test_max_attempts_for_override_never_exceeds_user_cap() -> None:
"""If the operator lowered retry_max_attempts below the override default,
the user-configured cap wins — overrides only ever *tighten*, never loosen.
"""
middleware = _build_middleware(retry_max_attempts=1)
exc = _StreamChunkTimeoutError("upstream stalled")
exc.__class__.__name__ = "StreamChunkTimeoutError"
assert middleware._max_attempts_for(exc) == 1
def test_user_message_for_stream_chunk_timeout_mentions_split_or_shorten() -> None:
"""When the retry budget for StreamChunkTimeoutError is exhausted, the user
message must guide the user toward splitting / shortening the request
instead of suggesting a generic retry. This is the actionable advice
Reviewer B asked for in the follow-up CR (issue #3189).
"""
middleware = _build_middleware()
exc = _StreamChunkTimeoutError("No streaming chunk received for 120.0s")
exc.__class__.__name__ = "StreamChunkTimeoutError"
message = middleware._build_user_message(exc, reason="transient")
assert "streaming response was interrupted" in message
assert "split" in message or "shorten" in message
# The old generic "streaming response was interrupted" wording must NOT appear here,
# otherwise the actionable guidance is buried.
assert "temporarily unavailable" not in message
def test_user_message_for_remote_protocol_error_uses_generic_transient_copy() -> None:
"""RemoteProtocolError is a generic connection drop that can fire on
transient network blips with perfectly normal payloads. The
"split the work into smaller steps" guidance only applies when the
upstream chunk-gap watchdog fires (StreamChunkTimeoutError), so
RemoteProtocolError must fall back to the legacy transient copy.
Regression guard for the #3195 CR feedback.
"""
middleware = _build_middleware()
exc = _RemoteProtocolError("Server closed connection unexpectedly")
exc.__class__.__name__ = "RemoteProtocolError"
message = middleware._build_user_message(exc, reason="transient")
assert "temporarily unavailable" in message
assert "streaming response was interrupted" not in message
def test_user_message_for_read_error_uses_generic_transient_copy() -> None:
"""httpx.ReadError is symmetric to RemoteProtocolError: a generic
connection drop that must NOT receive the "split the work" guidance.
Regression guard for the #3195 CR feedback.
"""
middleware = _build_middleware()
exc = FakeError("connection dropped mid-stream")
exc.__class__.__name__ = "ReadError"
message = middleware._build_user_message(exc, reason="transient")
assert "temporarily unavailable" in message
assert "streaming response was interrupted" not in message
def test_user_message_for_generic_transient_keeps_legacy_copy() -> None:
"""Generic transient errors (HTTP 503, 'cluster busy', etc.) must keep
the original 'streaming response was interrupted' message — only stream-drop
exceptions get the new specialized copy. This prevents regression on
callers who already rely on the legacy wording.
"""
middleware = _build_middleware()
exc = FakeError("server busy", status_code=503)
message = middleware._build_user_message(exc, reason="transient")
assert "temporarily unavailable" in message
assert "streaming response was interrupted" not in message
def test_user_message_for_quota_unchanged() -> None:
"""Sanity check: the quota / auth branches must remain untouched by the
stream-drop refactor.
"""
middleware = _build_middleware()
exc = FakeError("insufficient_quota", status_code=429, code="insufficient_quota")
message = middleware._build_user_message(exc, reason="quota")
assert "out of quota" in message
assert "streaming response was interrupted" not in message