deer-flow/backend/tests/test_llm_error_handling_middleware.py
3a672b39c7
Fix/1681 llm call retry handling (#1683)
* fix(runtime): handle llm call errors gracefully

* fix(runtime): preserve graph control flow in llm retry middleware

---------

Co-authored-by: luoxiao6645 <luoxiao6645@gmail.com>
2026-04-02 10:12:17 +08:00

137 lines
4.0 KiB
Python

from __future__ import annotations
import asyncio
from types import SimpleNamespace
import pytest
from langchain_core.messages import AIMessage
from langgraph.errors import GraphBubbleUp
from deerflow.agents.middlewares.llm_error_handling_middleware import (
LLMErrorHandlingMiddleware,
)
class FakeError(Exception):
def __init__(
self,
message: str,
*,
status_code: int | None = None,
code: str | None = None,
headers: dict[str, str] | None = None,
body: dict | None = None,
) -> None:
super().__init__(message)
self.status_code = status_code
self.code = code
self.body = body
self.response = SimpleNamespace(status_code=status_code, headers=headers or {}) if status_code is not None or headers else None
def _build_middleware(**attrs: int) -> LLMErrorHandlingMiddleware:
middleware = LLMErrorHandlingMiddleware()
for key, value in attrs.items():
setattr(middleware, key, value)
return middleware
def test_async_model_call_retries_busy_provider_then_succeeds(
monkeypatch: pytest.MonkeyPatch,
) -> None:
middleware = _build_middleware(retry_max_attempts=3, retry_base_delay_ms=25, retry_cap_delay_ms=25)
attempts = 0
waits: list[float] = []
events: list[dict] = []
async def fake_sleep(delay: float) -> None:
waits.append(delay)
def fake_writer():
return events.append
async def handler(_request) -> AIMessage:
nonlocal attempts
attempts += 1
if attempts < 3:
raise FakeError("当前服务集群负载较高,请稍后重试,感谢您的耐心等待。 (2064)")
return AIMessage(content="ok")
monkeypatch.setattr("asyncio.sleep", fake_sleep)
monkeypatch.setattr(
"langgraph.config.get_stream_writer",
fake_writer,
)
result = asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler))
assert isinstance(result, AIMessage)
assert result.content == "ok"
assert attempts == 3
assert waits == [0.025, 0.025]
assert [event["type"] for event in events] == ["llm_retry", "llm_retry"]
def test_async_model_call_returns_user_message_for_quota_errors() -> None:
middleware = _build_middleware(retry_max_attempts=3)
async def handler(_request) -> AIMessage:
raise FakeError(
"insufficient_quota: account balance is empty",
status_code=429,
code="insufficient_quota",
)
result = asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler))
assert isinstance(result, AIMessage)
assert "out of quota" in str(result.content)
def test_sync_model_call_uses_retry_after_header(monkeypatch: pytest.MonkeyPatch) -> None:
middleware = _build_middleware(retry_max_attempts=2, retry_base_delay_ms=10, retry_cap_delay_ms=10)
waits: list[float] = []
attempts = 0
def fake_sleep(delay: float) -> None:
waits.append(delay)
def handler(_request) -> AIMessage:
nonlocal attempts
attempts += 1
if attempts == 1:
raise FakeError(
"server busy",
status_code=503,
headers={"Retry-After": "2"},
)
return AIMessage(content="ok")
monkeypatch.setattr("time.sleep", fake_sleep)
result = middleware.wrap_model_call(SimpleNamespace(), handler)
assert isinstance(result, AIMessage)
assert result.content == "ok"
assert waits == [2.0]
def test_sync_model_call_propagates_graph_bubble_up() -> None:
middleware = _build_middleware()
def handler(_request) -> AIMessage:
raise GraphBubbleUp()
with pytest.raises(GraphBubbleUp):
middleware.wrap_model_call(SimpleNamespace(), handler)
def test_async_model_call_propagates_graph_bubble_up() -> None:
middleware = _build_middleware()
async def handler(_request) -> AIMessage:
raise GraphBubbleUp()
with pytest.raises(GraphBubbleUp):
asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler))