From 3a672b39c798604abc5911371c56745260186324 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=82=96?= <168966994+luoxiao6645@users.noreply.github.com> Date: Thu, 2 Apr 2026 10:12:17 +0800 Subject: [PATCH] Fix/1681 llm call retry handling (#1683) * fix(runtime): handle llm call errors gracefully * fix(runtime): preserve graph control flow in llm retry middleware --------- Co-authored-by: luoxiao6645 --- .../llm_error_handling_middleware.py | 275 ++++++++++++++++++ .../tool_error_handling_middleware.py | 3 + .../test_llm_error_handling_middleware.py | 136 +++++++++ frontend/src/core/threads/hooks.ts | 14 + 4 files changed, 428 insertions(+) create mode 100644 backend/packages/harness/deerflow/agents/middlewares/llm_error_handling_middleware.py create mode 100644 backend/tests/test_llm_error_handling_middleware.py diff --git a/backend/packages/harness/deerflow/agents/middlewares/llm_error_handling_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/llm_error_handling_middleware.py new file mode 100644 index 000000000..e1a3af714 --- /dev/null +++ b/backend/packages/harness/deerflow/agents/middlewares/llm_error_handling_middleware.py @@ -0,0 +1,275 @@ +"""LLM error handling middleware with retry/backoff and user-facing fallbacks.""" + +from __future__ import annotations + +import asyncio +import logging +import time +from collections.abc import Awaitable, Callable +from email.utils import parsedate_to_datetime +from typing import Any, override + +from langchain.agents import AgentState +from langchain.agents.middleware import AgentMiddleware +from langchain.agents.middleware.types import ( + ModelCallResult, + ModelRequest, + ModelResponse, +) +from langchain_core.messages import AIMessage +from langgraph.errors import GraphBubbleUp + +logger = logging.getLogger(__name__) + +_RETRIABLE_STATUS_CODES = {408, 409, 425, 429, 500, 502, 503, 504} +_BUSY_PATTERNS = ( + "server busy", + "temporarily unavailable", + "try again later", + "please retry", + "please try again", + "overloaded", + "high demand", + "rate limit", + "负载较高", + "服务繁忙", + "稍后重试", + "请稍后重试", +) +_QUOTA_PATTERNS = ( + "insufficient_quota", + "quota", + "billing", + "credit", + "payment", + "余额不足", + "超出限额", + "额度不足", + "欠费", +) +_AUTH_PATTERNS = ( + "authentication", + "unauthorized", + "invalid api key", + "invalid_api_key", + "permission", + "forbidden", + "access denied", + "无权", + "未授权", +) + + +class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]): + """Retry transient LLM errors and surface graceful assistant messages.""" + + retry_max_attempts: int = 3 + retry_base_delay_ms: int = 1000 + retry_cap_delay_ms: int = 8000 + + def _classify_error(self, exc: BaseException) -> tuple[bool, str]: + detail = _extract_error_detail(exc) + lowered = detail.lower() + error_code = _extract_error_code(exc) + status_code = _extract_status_code(exc) + + if _matches_any(lowered, _QUOTA_PATTERNS) or _matches_any(str(error_code).lower(), _QUOTA_PATTERNS): + return False, "quota" + if _matches_any(lowered, _AUTH_PATTERNS): + return False, "auth" + + exc_name = exc.__class__.__name__ + if exc_name in { + "APITimeoutError", + "APIConnectionError", + "InternalServerError", + }: + return True, "transient" + if status_code in _RETRIABLE_STATUS_CODES: + return True, "transient" + if _matches_any(lowered, _BUSY_PATTERNS): + return True, "busy" + + return False, "generic" + + def _build_retry_delay_ms(self, attempt: int, exc: BaseException) -> int: + retry_after = _extract_retry_after_ms(exc) + if retry_after is not None: + return retry_after + backoff = self.retry_base_delay_ms * (2 ** max(0, attempt - 1)) + return min(backoff, self.retry_cap_delay_ms) + + def _build_retry_message(self, attempt: int, wait_ms: int, reason: str) -> str: + seconds = max(1, round(wait_ms / 1000)) + reason_text = "provider is busy" if reason == "busy" else "provider request failed temporarily" + return f"LLM request retry {attempt}/{self.retry_max_attempts}: {reason_text}. Retrying in {seconds}s." + + def _build_user_message(self, exc: BaseException, reason: str) -> str: + detail = _extract_error_detail(exc) + if reason == "quota": + return "The configured LLM provider rejected the request because the account is out of quota, billing is unavailable, or usage is restricted. Please fix the provider account and try again." + if reason == "auth": + return "The configured LLM provider rejected the request because authentication or access is invalid. Please check the provider credentials and try again." + if reason in {"busy", "transient"}: + return "The configured LLM provider is temporarily unavailable after multiple retries. Please wait a moment and continue the conversation." + return f"LLM request failed: {detail}" + + def _emit_retry_event(self, attempt: int, wait_ms: int, reason: str) -> None: + try: + from langgraph.config import get_stream_writer + + writer = get_stream_writer() + writer( + { + "type": "llm_retry", + "attempt": attempt, + "max_attempts": self.retry_max_attempts, + "wait_ms": wait_ms, + "reason": reason, + "message": self._build_retry_message(attempt, wait_ms, reason), + } + ) + except Exception: + logger.debug("Failed to emit llm_retry event", exc_info=True) + + @override + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: + attempt = 1 + while True: + try: + return handler(request) + except GraphBubbleUp: + # Preserve LangGraph control-flow signals (interrupt/pause/resume). + raise + except Exception as exc: + retriable, reason = self._classify_error(exc) + if retriable and attempt < self.retry_max_attempts: + wait_ms = self._build_retry_delay_ms(attempt, exc) + logger.warning( + "Transient LLM error on attempt %d/%d; retrying in %dms: %s", + attempt, + self.retry_max_attempts, + wait_ms, + _extract_error_detail(exc), + ) + self._emit_retry_event(attempt, wait_ms, reason) + time.sleep(wait_ms / 1000) + attempt += 1 + continue + logger.warning( + "LLM call failed after %d attempt(s): %s", + attempt, + _extract_error_detail(exc), + exc_info=exc, + ) + return AIMessage(content=self._build_user_message(exc, reason)) + + @override + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelCallResult: + attempt = 1 + while True: + try: + return await handler(request) + except GraphBubbleUp: + # Preserve LangGraph control-flow signals (interrupt/pause/resume). + raise + except Exception as exc: + retriable, reason = self._classify_error(exc) + if retriable and attempt < self.retry_max_attempts: + wait_ms = self._build_retry_delay_ms(attempt, exc) + logger.warning( + "Transient LLM error on attempt %d/%d; retrying in %dms: %s", + attempt, + self.retry_max_attempts, + wait_ms, + _extract_error_detail(exc), + ) + self._emit_retry_event(attempt, wait_ms, reason) + await asyncio.sleep(wait_ms / 1000) + attempt += 1 + continue + logger.warning( + "LLM call failed after %d attempt(s): %s", + attempt, + _extract_error_detail(exc), + exc_info=exc, + ) + return AIMessage(content=self._build_user_message(exc, reason)) + + +def _matches_any(detail: str, patterns: tuple[str, ...]) -> bool: + return any(pattern in detail for pattern in patterns) + + +def _extract_error_code(exc: BaseException) -> Any: + for attr in ("code", "error_code"): + value = getattr(exc, attr, None) + if value not in (None, ""): + return value + + body = getattr(exc, "body", None) + if isinstance(body, dict): + error = body.get("error") + if isinstance(error, dict): + for key in ("code", "type"): + value = error.get(key) + if value not in (None, ""): + return value + return None + + +def _extract_status_code(exc: BaseException) -> int | None: + for attr in ("status_code", "status"): + value = getattr(exc, attr, None) + if isinstance(value, int): + return value + response = getattr(exc, "response", None) + status = getattr(response, "status_code", None) + return status if isinstance(status, int) else None + + +def _extract_retry_after_ms(exc: BaseException) -> int | None: + response = getattr(exc, "response", None) + headers = getattr(response, "headers", None) + if headers is None: + return None + + raw = None + header_name = "" + for key in ("retry-after-ms", "Retry-After-Ms", "retry-after", "Retry-After"): + header_name = key + if hasattr(headers, "get"): + raw = headers.get(key) + if raw: + break + if not raw: + return None + + try: + multiplier = 1 if "ms" in header_name.lower() else 1000 + return max(0, int(float(raw) * multiplier)) + except (TypeError, ValueError): + try: + target = parsedate_to_datetime(str(raw)) + delta = target.timestamp() - time.time() + return max(0, int(delta * 1000)) + except (TypeError, ValueError, OverflowError): + return None + + +def _extract_error_detail(exc: BaseException) -> str: + detail = str(exc).strip() + if detail: + return detail + message = getattr(exc, "message", None) + if isinstance(message, str) and message.strip(): + return message.strip() + return exc.__class__.__name__ diff --git a/backend/packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py index 35a37f852..c3acd86cc 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py @@ -72,6 +72,7 @@ def _build_runtime_middlewares( lazy_init: bool = True, ) -> list[AgentMiddleware]: """Build shared base middlewares for agent execution.""" + from deerflow.agents.middlewares.llm_error_handling_middleware import LLMErrorHandlingMiddleware from deerflow.agents.middlewares.thread_data_middleware import ThreadDataMiddleware from deerflow.sandbox.middleware import SandboxMiddleware @@ -90,6 +91,8 @@ def _build_runtime_middlewares( middlewares.append(DanglingToolCallMiddleware()) + middlewares.append(LLMErrorHandlingMiddleware()) + # Guardrail middleware (if configured) from deerflow.config.guardrails_config import get_guardrails_config diff --git a/backend/tests/test_llm_error_handling_middleware.py b/backend/tests/test_llm_error_handling_middleware.py new file mode 100644 index 000000000..9c3077e31 --- /dev/null +++ b/backend/tests/test_llm_error_handling_middleware.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +import asyncio +from types import SimpleNamespace + +import pytest +from langchain_core.messages import AIMessage +from langgraph.errors import GraphBubbleUp + +from deerflow.agents.middlewares.llm_error_handling_middleware import ( + LLMErrorHandlingMiddleware, +) + + +class FakeError(Exception): + def __init__( + self, + message: str, + *, + status_code: int | None = None, + code: str | None = None, + headers: dict[str, str] | None = None, + body: dict | None = None, + ) -> None: + super().__init__(message) + self.status_code = status_code + self.code = code + self.body = body + self.response = SimpleNamespace(status_code=status_code, headers=headers or {}) if status_code is not None or headers else None + + +def _build_middleware(**attrs: int) -> LLMErrorHandlingMiddleware: + middleware = LLMErrorHandlingMiddleware() + for key, value in attrs.items(): + setattr(middleware, key, value) + return middleware + + +def test_async_model_call_retries_busy_provider_then_succeeds( + monkeypatch: pytest.MonkeyPatch, +) -> None: + middleware = _build_middleware(retry_max_attempts=3, retry_base_delay_ms=25, retry_cap_delay_ms=25) + attempts = 0 + waits: list[float] = [] + events: list[dict] = [] + + async def fake_sleep(delay: float) -> None: + waits.append(delay) + + def fake_writer(): + return events.append + + async def handler(_request) -> AIMessage: + nonlocal attempts + attempts += 1 + if attempts < 3: + raise FakeError("当前服务集群负载较高,请稍后重试,感谢您的耐心等待。 (2064)") + return AIMessage(content="ok") + + monkeypatch.setattr("asyncio.sleep", fake_sleep) + monkeypatch.setattr( + "langgraph.config.get_stream_writer", + fake_writer, + ) + + result = asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler)) + + assert isinstance(result, AIMessage) + assert result.content == "ok" + assert attempts == 3 + assert waits == [0.025, 0.025] + assert [event["type"] for event in events] == ["llm_retry", "llm_retry"] + + +def test_async_model_call_returns_user_message_for_quota_errors() -> None: + middleware = _build_middleware(retry_max_attempts=3) + + async def handler(_request) -> AIMessage: + raise FakeError( + "insufficient_quota: account balance is empty", + status_code=429, + code="insufficient_quota", + ) + + result = asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler)) + + assert isinstance(result, AIMessage) + assert "out of quota" in str(result.content) + + +def test_sync_model_call_uses_retry_after_header(monkeypatch: pytest.MonkeyPatch) -> None: + middleware = _build_middleware(retry_max_attempts=2, retry_base_delay_ms=10, retry_cap_delay_ms=10) + waits: list[float] = [] + attempts = 0 + + def fake_sleep(delay: float) -> None: + waits.append(delay) + + def handler(_request) -> AIMessage: + nonlocal attempts + attempts += 1 + if attempts == 1: + raise FakeError( + "server busy", + status_code=503, + headers={"Retry-After": "2"}, + ) + return AIMessage(content="ok") + + monkeypatch.setattr("time.sleep", fake_sleep) + + result = middleware.wrap_model_call(SimpleNamespace(), handler) + + assert isinstance(result, AIMessage) + assert result.content == "ok" + assert waits == [2.0] + + +def test_sync_model_call_propagates_graph_bubble_up() -> None: + middleware = _build_middleware() + + def handler(_request) -> AIMessage: + raise GraphBubbleUp() + + with pytest.raises(GraphBubbleUp): + middleware.wrap_model_call(SimpleNamespace(), handler) + + +def test_async_model_call_propagates_graph_bubble_up() -> None: + middleware = _build_middleware() + + async def handler(_request) -> AIMessage: + raise GraphBubbleUp() + + with pytest.raises(GraphBubbleUp): + asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler)) diff --git a/frontend/src/core/threads/hooks.ts b/frontend/src/core/threads/hooks.ts index 99cfda94f..305879045 100644 --- a/frontend/src/core/threads/hooks.ts +++ b/frontend/src/core/threads/hooks.ts @@ -170,6 +170,20 @@ export function useThreadStream({ message: AIMessage; }; updateSubtask({ id: e.task_id, latestMessage: e.message }); + return; + } + + if ( + typeof event === "object" && + event !== null && + "type" in event && + event.type === "llm_retry" && + "message" in event && + typeof event.message === "string" && + event.message.trim() + ) { + const e = event as { type: "llm_retry"; message: string }; + toast(e.message); } }, onError(error) {