Fix/1681 llm call retry handling (#1683)

* fix(runtime): handle llm call errors gracefully

* fix(runtime): preserve graph control flow in llm retry middleware

---------

Co-authored-by: luoxiao6645 <luoxiao6645@gmail.com>
This commit is contained in:
2026-04-02 10:12:17 +08:00 committed by GitHub
parent df5339b5d0
commit 3a672b39c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 428 additions and 0 deletions

View File

@ -0,0 +1,275 @@
"""LLM error handling middleware with retry/backoff and user-facing fallbacks."""
from __future__ import annotations
import asyncio
import logging
import time
from collections.abc import Awaitable, Callable
from email.utils import parsedate_to_datetime
from typing import Any, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import (
ModelCallResult,
ModelRequest,
ModelResponse,
)
from langchain_core.messages import AIMessage
from langgraph.errors import GraphBubbleUp
logger = logging.getLogger(__name__)
_RETRIABLE_STATUS_CODES = {408, 409, 425, 429, 500, 502, 503, 504}
_BUSY_PATTERNS = (
"server busy",
"temporarily unavailable",
"try again later",
"please retry",
"please try again",
"overloaded",
"high demand",
"rate limit",
"负载较高",
"服务繁忙",
"稍后重试",
"请稍后重试",
)
_QUOTA_PATTERNS = (
"insufficient_quota",
"quota",
"billing",
"credit",
"payment",
"余额不足",
"超出限额",
"额度不足",
"欠费",
)
_AUTH_PATTERNS = (
"authentication",
"unauthorized",
"invalid api key",
"invalid_api_key",
"permission",
"forbidden",
"access denied",
"无权",
"未授权",
)
class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
"""Retry transient LLM errors and surface graceful assistant messages."""
retry_max_attempts: int = 3
retry_base_delay_ms: int = 1000
retry_cap_delay_ms: int = 8000
def _classify_error(self, exc: BaseException) -> tuple[bool, str]:
detail = _extract_error_detail(exc)
lowered = detail.lower()
error_code = _extract_error_code(exc)
status_code = _extract_status_code(exc)
if _matches_any(lowered, _QUOTA_PATTERNS) or _matches_any(str(error_code).lower(), _QUOTA_PATTERNS):
return False, "quota"
if _matches_any(lowered, _AUTH_PATTERNS):
return False, "auth"
exc_name = exc.__class__.__name__
if exc_name in {
"APITimeoutError",
"APIConnectionError",
"InternalServerError",
}:
return True, "transient"
if status_code in _RETRIABLE_STATUS_CODES:
return True, "transient"
if _matches_any(lowered, _BUSY_PATTERNS):
return True, "busy"
return False, "generic"
def _build_retry_delay_ms(self, attempt: int, exc: BaseException) -> int:
retry_after = _extract_retry_after_ms(exc)
if retry_after is not None:
return retry_after
backoff = self.retry_base_delay_ms * (2 ** max(0, attempt - 1))
return min(backoff, self.retry_cap_delay_ms)
def _build_retry_message(self, attempt: int, wait_ms: int, reason: str) -> str:
seconds = max(1, round(wait_ms / 1000))
reason_text = "provider is busy" if reason == "busy" else "provider request failed temporarily"
return f"LLM request retry {attempt}/{self.retry_max_attempts}: {reason_text}. Retrying in {seconds}s."
def _build_user_message(self, exc: BaseException, reason: str) -> str:
detail = _extract_error_detail(exc)
if reason == "quota":
return "The configured LLM provider rejected the request because the account is out of quota, billing is unavailable, or usage is restricted. Please fix the provider account and try again."
if reason == "auth":
return "The configured LLM provider rejected the request because authentication or access is invalid. Please check the provider credentials and try again."
if reason in {"busy", "transient"}:
return "The configured LLM provider is temporarily unavailable after multiple retries. Please wait a moment and continue the conversation."
return f"LLM request failed: {detail}"
def _emit_retry_event(self, attempt: int, wait_ms: int, reason: str) -> None:
try:
from langgraph.config import get_stream_writer
writer = get_stream_writer()
writer(
{
"type": "llm_retry",
"attempt": attempt,
"max_attempts": self.retry_max_attempts,
"wait_ms": wait_ms,
"reason": reason,
"message": self._build_retry_message(attempt, wait_ms, reason),
}
)
except Exception:
logger.debug("Failed to emit llm_retry event", exc_info=True)
@override
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
attempt = 1
while True:
try:
return handler(request)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
raise
except Exception as exc:
retriable, reason = self._classify_error(exc)
if retriable and attempt < self.retry_max_attempts:
wait_ms = self._build_retry_delay_ms(attempt, exc)
logger.warning(
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
attempt,
self.retry_max_attempts,
wait_ms,
_extract_error_detail(exc),
)
self._emit_retry_event(attempt, wait_ms, reason)
time.sleep(wait_ms / 1000)
attempt += 1
continue
logger.warning(
"LLM call failed after %d attempt(s): %s",
attempt,
_extract_error_detail(exc),
exc_info=exc,
)
return AIMessage(content=self._build_user_message(exc, reason))
@override
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
attempt = 1
while True:
try:
return await handler(request)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
raise
except Exception as exc:
retriable, reason = self._classify_error(exc)
if retriable and attempt < self.retry_max_attempts:
wait_ms = self._build_retry_delay_ms(attempt, exc)
logger.warning(
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
attempt,
self.retry_max_attempts,
wait_ms,
_extract_error_detail(exc),
)
self._emit_retry_event(attempt, wait_ms, reason)
await asyncio.sleep(wait_ms / 1000)
attempt += 1
continue
logger.warning(
"LLM call failed after %d attempt(s): %s",
attempt,
_extract_error_detail(exc),
exc_info=exc,
)
return AIMessage(content=self._build_user_message(exc, reason))
def _matches_any(detail: str, patterns: tuple[str, ...]) -> bool:
return any(pattern in detail for pattern in patterns)
def _extract_error_code(exc: BaseException) -> Any:
for attr in ("code", "error_code"):
value = getattr(exc, attr, None)
if value not in (None, ""):
return value
body = getattr(exc, "body", None)
if isinstance(body, dict):
error = body.get("error")
if isinstance(error, dict):
for key in ("code", "type"):
value = error.get(key)
if value not in (None, ""):
return value
return None
def _extract_status_code(exc: BaseException) -> int | None:
for attr in ("status_code", "status"):
value = getattr(exc, attr, None)
if isinstance(value, int):
return value
response = getattr(exc, "response", None)
status = getattr(response, "status_code", None)
return status if isinstance(status, int) else None
def _extract_retry_after_ms(exc: BaseException) -> int | None:
response = getattr(exc, "response", None)
headers = getattr(response, "headers", None)
if headers is None:
return None
raw = None
header_name = ""
for key in ("retry-after-ms", "Retry-After-Ms", "retry-after", "Retry-After"):
header_name = key
if hasattr(headers, "get"):
raw = headers.get(key)
if raw:
break
if not raw:
return None
try:
multiplier = 1 if "ms" in header_name.lower() else 1000
return max(0, int(float(raw) * multiplier))
except (TypeError, ValueError):
try:
target = parsedate_to_datetime(str(raw))
delta = target.timestamp() - time.time()
return max(0, int(delta * 1000))
except (TypeError, ValueError, OverflowError):
return None
def _extract_error_detail(exc: BaseException) -> str:
detail = str(exc).strip()
if detail:
return detail
message = getattr(exc, "message", None)
if isinstance(message, str) and message.strip():
return message.strip()
return exc.__class__.__name__

View File

@ -72,6 +72,7 @@ def _build_runtime_middlewares(
lazy_init: bool = True,
) -> list[AgentMiddleware]:
"""Build shared base middlewares for agent execution."""
from deerflow.agents.middlewares.llm_error_handling_middleware import LLMErrorHandlingMiddleware
from deerflow.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
from deerflow.sandbox.middleware import SandboxMiddleware
@ -90,6 +91,8 @@ def _build_runtime_middlewares(
middlewares.append(DanglingToolCallMiddleware())
middlewares.append(LLMErrorHandlingMiddleware())
# Guardrail middleware (if configured)
from deerflow.config.guardrails_config import get_guardrails_config

View File

@ -0,0 +1,136 @@
from __future__ import annotations
import asyncio
from types import SimpleNamespace
import pytest
from langchain_core.messages import AIMessage
from langgraph.errors import GraphBubbleUp
from deerflow.agents.middlewares.llm_error_handling_middleware import (
LLMErrorHandlingMiddleware,
)
class FakeError(Exception):
def __init__(
self,
message: str,
*,
status_code: int | None = None,
code: str | None = None,
headers: dict[str, str] | None = None,
body: dict | None = None,
) -> None:
super().__init__(message)
self.status_code = status_code
self.code = code
self.body = body
self.response = SimpleNamespace(status_code=status_code, headers=headers or {}) if status_code is not None or headers else None
def _build_middleware(**attrs: int) -> LLMErrorHandlingMiddleware:
middleware = LLMErrorHandlingMiddleware()
for key, value in attrs.items():
setattr(middleware, key, value)
return middleware
def test_async_model_call_retries_busy_provider_then_succeeds(
monkeypatch: pytest.MonkeyPatch,
) -> None:
middleware = _build_middleware(retry_max_attempts=3, retry_base_delay_ms=25, retry_cap_delay_ms=25)
attempts = 0
waits: list[float] = []
events: list[dict] = []
async def fake_sleep(delay: float) -> None:
waits.append(delay)
def fake_writer():
return events.append
async def handler(_request) -> AIMessage:
nonlocal attempts
attempts += 1
if attempts < 3:
raise FakeError("当前服务集群负载较高,请稍后重试,感谢您的耐心等待。 (2064)")
return AIMessage(content="ok")
monkeypatch.setattr("asyncio.sleep", fake_sleep)
monkeypatch.setattr(
"langgraph.config.get_stream_writer",
fake_writer,
)
result = asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler))
assert isinstance(result, AIMessage)
assert result.content == "ok"
assert attempts == 3
assert waits == [0.025, 0.025]
assert [event["type"] for event in events] == ["llm_retry", "llm_retry"]
def test_async_model_call_returns_user_message_for_quota_errors() -> None:
middleware = _build_middleware(retry_max_attempts=3)
async def handler(_request) -> AIMessage:
raise FakeError(
"insufficient_quota: account balance is empty",
status_code=429,
code="insufficient_quota",
)
result = asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler))
assert isinstance(result, AIMessage)
assert "out of quota" in str(result.content)
def test_sync_model_call_uses_retry_after_header(monkeypatch: pytest.MonkeyPatch) -> None:
middleware = _build_middleware(retry_max_attempts=2, retry_base_delay_ms=10, retry_cap_delay_ms=10)
waits: list[float] = []
attempts = 0
def fake_sleep(delay: float) -> None:
waits.append(delay)
def handler(_request) -> AIMessage:
nonlocal attempts
attempts += 1
if attempts == 1:
raise FakeError(
"server busy",
status_code=503,
headers={"Retry-After": "2"},
)
return AIMessage(content="ok")
monkeypatch.setattr("time.sleep", fake_sleep)
result = middleware.wrap_model_call(SimpleNamespace(), handler)
assert isinstance(result, AIMessage)
assert result.content == "ok"
assert waits == [2.0]
def test_sync_model_call_propagates_graph_bubble_up() -> None:
middleware = _build_middleware()
def handler(_request) -> AIMessage:
raise GraphBubbleUp()
with pytest.raises(GraphBubbleUp):
middleware.wrap_model_call(SimpleNamespace(), handler)
def test_async_model_call_propagates_graph_bubble_up() -> None:
middleware = _build_middleware()
async def handler(_request) -> AIMessage:
raise GraphBubbleUp()
with pytest.raises(GraphBubbleUp):
asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler))

View File

@ -170,6 +170,20 @@ export function useThreadStream({
message: AIMessage;
};
updateSubtask({ id: e.task_id, latestMessage: e.message });
return;
}
if (
typeof event === "object" &&
event !== null &&
"type" in event &&
event.type === "llm_retry" &&
"message" in event &&
typeof event.message === "string" &&
event.message.trim()
) {
const e = event as { type: "llm_retry"; message: string };
toast(e.message);
}
},
onError(error) {