mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
Fix/1681 llm call retry handling (#1683)
* fix(runtime): handle llm call errors gracefully * fix(runtime): preserve graph control flow in llm retry middleware --------- Co-authored-by: luoxiao6645 <luoxiao6645@gmail.com>
This commit is contained in:
parent
df5339b5d0
commit
3a672b39c7
@ -0,0 +1,275 @@
|
||||
"""LLM error handling middleware with retry/backoff and user-facing fallbacks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from email.utils import parsedate_to_datetime
|
||||
from typing import Any, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.types import (
|
||||
ModelCallResult,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.errors import GraphBubbleUp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_RETRIABLE_STATUS_CODES = {408, 409, 425, 429, 500, 502, 503, 504}
|
||||
_BUSY_PATTERNS = (
|
||||
"server busy",
|
||||
"temporarily unavailable",
|
||||
"try again later",
|
||||
"please retry",
|
||||
"please try again",
|
||||
"overloaded",
|
||||
"high demand",
|
||||
"rate limit",
|
||||
"负载较高",
|
||||
"服务繁忙",
|
||||
"稍后重试",
|
||||
"请稍后重试",
|
||||
)
|
||||
_QUOTA_PATTERNS = (
|
||||
"insufficient_quota",
|
||||
"quota",
|
||||
"billing",
|
||||
"credit",
|
||||
"payment",
|
||||
"余额不足",
|
||||
"超出限额",
|
||||
"额度不足",
|
||||
"欠费",
|
||||
)
|
||||
_AUTH_PATTERNS = (
|
||||
"authentication",
|
||||
"unauthorized",
|
||||
"invalid api key",
|
||||
"invalid_api_key",
|
||||
"permission",
|
||||
"forbidden",
|
||||
"access denied",
|
||||
"无权",
|
||||
"未授权",
|
||||
)
|
||||
|
||||
|
||||
class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Retry transient LLM errors and surface graceful assistant messages."""
|
||||
|
||||
retry_max_attempts: int = 3
|
||||
retry_base_delay_ms: int = 1000
|
||||
retry_cap_delay_ms: int = 8000
|
||||
|
||||
def _classify_error(self, exc: BaseException) -> tuple[bool, str]:
|
||||
detail = _extract_error_detail(exc)
|
||||
lowered = detail.lower()
|
||||
error_code = _extract_error_code(exc)
|
||||
status_code = _extract_status_code(exc)
|
||||
|
||||
if _matches_any(lowered, _QUOTA_PATTERNS) or _matches_any(str(error_code).lower(), _QUOTA_PATTERNS):
|
||||
return False, "quota"
|
||||
if _matches_any(lowered, _AUTH_PATTERNS):
|
||||
return False, "auth"
|
||||
|
||||
exc_name = exc.__class__.__name__
|
||||
if exc_name in {
|
||||
"APITimeoutError",
|
||||
"APIConnectionError",
|
||||
"InternalServerError",
|
||||
}:
|
||||
return True, "transient"
|
||||
if status_code in _RETRIABLE_STATUS_CODES:
|
||||
return True, "transient"
|
||||
if _matches_any(lowered, _BUSY_PATTERNS):
|
||||
return True, "busy"
|
||||
|
||||
return False, "generic"
|
||||
|
||||
def _build_retry_delay_ms(self, attempt: int, exc: BaseException) -> int:
|
||||
retry_after = _extract_retry_after_ms(exc)
|
||||
if retry_after is not None:
|
||||
return retry_after
|
||||
backoff = self.retry_base_delay_ms * (2 ** max(0, attempt - 1))
|
||||
return min(backoff, self.retry_cap_delay_ms)
|
||||
|
||||
def _build_retry_message(self, attempt: int, wait_ms: int, reason: str) -> str:
|
||||
seconds = max(1, round(wait_ms / 1000))
|
||||
reason_text = "provider is busy" if reason == "busy" else "provider request failed temporarily"
|
||||
return f"LLM request retry {attempt}/{self.retry_max_attempts}: {reason_text}. Retrying in {seconds}s."
|
||||
|
||||
def _build_user_message(self, exc: BaseException, reason: str) -> str:
|
||||
detail = _extract_error_detail(exc)
|
||||
if reason == "quota":
|
||||
return "The configured LLM provider rejected the request because the account is out of quota, billing is unavailable, or usage is restricted. Please fix the provider account and try again."
|
||||
if reason == "auth":
|
||||
return "The configured LLM provider rejected the request because authentication or access is invalid. Please check the provider credentials and try again."
|
||||
if reason in {"busy", "transient"}:
|
||||
return "The configured LLM provider is temporarily unavailable after multiple retries. Please wait a moment and continue the conversation."
|
||||
return f"LLM request failed: {detail}"
|
||||
|
||||
def _emit_retry_event(self, attempt: int, wait_ms: int, reason: str) -> None:
|
||||
try:
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
writer = get_stream_writer()
|
||||
writer(
|
||||
{
|
||||
"type": "llm_retry",
|
||||
"attempt": attempt,
|
||||
"max_attempts": self.retry_max_attempts,
|
||||
"wait_ms": wait_ms,
|
||||
"reason": reason,
|
||||
"message": self._build_retry_message(attempt, wait_ms, reason),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to emit llm_retry event", exc_info=True)
|
||||
|
||||
@override
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
attempt = 1
|
||||
while True:
|
||||
try:
|
||||
return handler(request)
|
||||
except GraphBubbleUp:
|
||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||
raise
|
||||
except Exception as exc:
|
||||
retriable, reason = self._classify_error(exc)
|
||||
if retriable and attempt < self.retry_max_attempts:
|
||||
wait_ms = self._build_retry_delay_ms(attempt, exc)
|
||||
logger.warning(
|
||||
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
|
||||
attempt,
|
||||
self.retry_max_attempts,
|
||||
wait_ms,
|
||||
_extract_error_detail(exc),
|
||||
)
|
||||
self._emit_retry_event(attempt, wait_ms, reason)
|
||||
time.sleep(wait_ms / 1000)
|
||||
attempt += 1
|
||||
continue
|
||||
logger.warning(
|
||||
"LLM call failed after %d attempt(s): %s",
|
||||
attempt,
|
||||
_extract_error_detail(exc),
|
||||
exc_info=exc,
|
||||
)
|
||||
return AIMessage(content=self._build_user_message(exc, reason))
|
||||
|
||||
@override
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
attempt = 1
|
||||
while True:
|
||||
try:
|
||||
return await handler(request)
|
||||
except GraphBubbleUp:
|
||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||
raise
|
||||
except Exception as exc:
|
||||
retriable, reason = self._classify_error(exc)
|
||||
if retriable and attempt < self.retry_max_attempts:
|
||||
wait_ms = self._build_retry_delay_ms(attempt, exc)
|
||||
logger.warning(
|
||||
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
|
||||
attempt,
|
||||
self.retry_max_attempts,
|
||||
wait_ms,
|
||||
_extract_error_detail(exc),
|
||||
)
|
||||
self._emit_retry_event(attempt, wait_ms, reason)
|
||||
await asyncio.sleep(wait_ms / 1000)
|
||||
attempt += 1
|
||||
continue
|
||||
logger.warning(
|
||||
"LLM call failed after %d attempt(s): %s",
|
||||
attempt,
|
||||
_extract_error_detail(exc),
|
||||
exc_info=exc,
|
||||
)
|
||||
return AIMessage(content=self._build_user_message(exc, reason))
|
||||
|
||||
|
||||
def _matches_any(detail: str, patterns: tuple[str, ...]) -> bool:
|
||||
return any(pattern in detail for pattern in patterns)
|
||||
|
||||
|
||||
def _extract_error_code(exc: BaseException) -> Any:
|
||||
for attr in ("code", "error_code"):
|
||||
value = getattr(exc, attr, None)
|
||||
if value not in (None, ""):
|
||||
return value
|
||||
|
||||
body = getattr(exc, "body", None)
|
||||
if isinstance(body, dict):
|
||||
error = body.get("error")
|
||||
if isinstance(error, dict):
|
||||
for key in ("code", "type"):
|
||||
value = error.get(key)
|
||||
if value not in (None, ""):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _extract_status_code(exc: BaseException) -> int | None:
|
||||
for attr in ("status_code", "status"):
|
||||
value = getattr(exc, attr, None)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
response = getattr(exc, "response", None)
|
||||
status = getattr(response, "status_code", None)
|
||||
return status if isinstance(status, int) else None
|
||||
|
||||
|
||||
def _extract_retry_after_ms(exc: BaseException) -> int | None:
|
||||
response = getattr(exc, "response", None)
|
||||
headers = getattr(response, "headers", None)
|
||||
if headers is None:
|
||||
return None
|
||||
|
||||
raw = None
|
||||
header_name = ""
|
||||
for key in ("retry-after-ms", "Retry-After-Ms", "retry-after", "Retry-After"):
|
||||
header_name = key
|
||||
if hasattr(headers, "get"):
|
||||
raw = headers.get(key)
|
||||
if raw:
|
||||
break
|
||||
if not raw:
|
||||
return None
|
||||
|
||||
try:
|
||||
multiplier = 1 if "ms" in header_name.lower() else 1000
|
||||
return max(0, int(float(raw) * multiplier))
|
||||
except (TypeError, ValueError):
|
||||
try:
|
||||
target = parsedate_to_datetime(str(raw))
|
||||
delta = target.timestamp() - time.time()
|
||||
return max(0, int(delta * 1000))
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
return None
|
||||
|
||||
|
||||
def _extract_error_detail(exc: BaseException) -> str:
|
||||
detail = str(exc).strip()
|
||||
if detail:
|
||||
return detail
|
||||
message = getattr(exc, "message", None)
|
||||
if isinstance(message, str) and message.strip():
|
||||
return message.strip()
|
||||
return exc.__class__.__name__
|
||||
@ -72,6 +72,7 @@ def _build_runtime_middlewares(
|
||||
lazy_init: bool = True,
|
||||
) -> list[AgentMiddleware]:
|
||||
"""Build shared base middlewares for agent execution."""
|
||||
from deerflow.agents.middlewares.llm_error_handling_middleware import LLMErrorHandlingMiddleware
|
||||
from deerflow.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
|
||||
from deerflow.sandbox.middleware import SandboxMiddleware
|
||||
|
||||
@ -90,6 +91,8 @@ def _build_runtime_middlewares(
|
||||
|
||||
middlewares.append(DanglingToolCallMiddleware())
|
||||
|
||||
middlewares.append(LLMErrorHandlingMiddleware())
|
||||
|
||||
# Guardrail middleware (if configured)
|
||||
from deerflow.config.guardrails_config import get_guardrails_config
|
||||
|
||||
|
||||
136
backend/tests/test_llm_error_handling_middleware.py
Normal file
136
backend/tests/test_llm_error_handling_middleware.py
Normal file
@ -0,0 +1,136 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.errors import GraphBubbleUp
|
||||
|
||||
from deerflow.agents.middlewares.llm_error_handling_middleware import (
|
||||
LLMErrorHandlingMiddleware,
|
||||
)
|
||||
|
||||
|
||||
class FakeError(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
status_code: int | None = None,
|
||||
code: str | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
body: dict | None = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
self.code = code
|
||||
self.body = body
|
||||
self.response = SimpleNamespace(status_code=status_code, headers=headers or {}) if status_code is not None or headers else None
|
||||
|
||||
|
||||
def _build_middleware(**attrs: int) -> LLMErrorHandlingMiddleware:
|
||||
middleware = LLMErrorHandlingMiddleware()
|
||||
for key, value in attrs.items():
|
||||
setattr(middleware, key, value)
|
||||
return middleware
|
||||
|
||||
|
||||
def test_async_model_call_retries_busy_provider_then_succeeds(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
middleware = _build_middleware(retry_max_attempts=3, retry_base_delay_ms=25, retry_cap_delay_ms=25)
|
||||
attempts = 0
|
||||
waits: list[float] = []
|
||||
events: list[dict] = []
|
||||
|
||||
async def fake_sleep(delay: float) -> None:
|
||||
waits.append(delay)
|
||||
|
||||
def fake_writer():
|
||||
return events.append
|
||||
|
||||
async def handler(_request) -> AIMessage:
|
||||
nonlocal attempts
|
||||
attempts += 1
|
||||
if attempts < 3:
|
||||
raise FakeError("当前服务集群负载较高,请稍后重试,感谢您的耐心等待。 (2064)")
|
||||
return AIMessage(content="ok")
|
||||
|
||||
monkeypatch.setattr("asyncio.sleep", fake_sleep)
|
||||
monkeypatch.setattr(
|
||||
"langgraph.config.get_stream_writer",
|
||||
fake_writer,
|
||||
)
|
||||
|
||||
result = asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler))
|
||||
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.content == "ok"
|
||||
assert attempts == 3
|
||||
assert waits == [0.025, 0.025]
|
||||
assert [event["type"] for event in events] == ["llm_retry", "llm_retry"]
|
||||
|
||||
|
||||
def test_async_model_call_returns_user_message_for_quota_errors() -> None:
|
||||
middleware = _build_middleware(retry_max_attempts=3)
|
||||
|
||||
async def handler(_request) -> AIMessage:
|
||||
raise FakeError(
|
||||
"insufficient_quota: account balance is empty",
|
||||
status_code=429,
|
||||
code="insufficient_quota",
|
||||
)
|
||||
|
||||
result = asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler))
|
||||
|
||||
assert isinstance(result, AIMessage)
|
||||
assert "out of quota" in str(result.content)
|
||||
|
||||
|
||||
def test_sync_model_call_uses_retry_after_header(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
middleware = _build_middleware(retry_max_attempts=2, retry_base_delay_ms=10, retry_cap_delay_ms=10)
|
||||
waits: list[float] = []
|
||||
attempts = 0
|
||||
|
||||
def fake_sleep(delay: float) -> None:
|
||||
waits.append(delay)
|
||||
|
||||
def handler(_request) -> AIMessage:
|
||||
nonlocal attempts
|
||||
attempts += 1
|
||||
if attempts == 1:
|
||||
raise FakeError(
|
||||
"server busy",
|
||||
status_code=503,
|
||||
headers={"Retry-After": "2"},
|
||||
)
|
||||
return AIMessage(content="ok")
|
||||
|
||||
monkeypatch.setattr("time.sleep", fake_sleep)
|
||||
|
||||
result = middleware.wrap_model_call(SimpleNamespace(), handler)
|
||||
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.content == "ok"
|
||||
assert waits == [2.0]
|
||||
|
||||
|
||||
def test_sync_model_call_propagates_graph_bubble_up() -> None:
|
||||
middleware = _build_middleware()
|
||||
|
||||
def handler(_request) -> AIMessage:
|
||||
raise GraphBubbleUp()
|
||||
|
||||
with pytest.raises(GraphBubbleUp):
|
||||
middleware.wrap_model_call(SimpleNamespace(), handler)
|
||||
|
||||
|
||||
def test_async_model_call_propagates_graph_bubble_up() -> None:
|
||||
middleware = _build_middleware()
|
||||
|
||||
async def handler(_request) -> AIMessage:
|
||||
raise GraphBubbleUp()
|
||||
|
||||
with pytest.raises(GraphBubbleUp):
|
||||
asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler))
|
||||
@ -170,6 +170,20 @@ export function useThreadStream({
|
||||
message: AIMessage;
|
||||
};
|
||||
updateSubtask({ id: e.task_id, latestMessage: e.message });
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
typeof event === "object" &&
|
||||
event !== null &&
|
||||
"type" in event &&
|
||||
event.type === "llm_retry" &&
|
||||
"message" in event &&
|
||||
typeof event.message === "string" &&
|
||||
event.message.trim()
|
||||
) {
|
||||
const e = event as { type: "llm_retry"; message: string };
|
||||
toast(e.message);
|
||||
}
|
||||
},
|
||||
onError(error) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user