feat(llm): introduce lightweight circuit breaker to prevent rate-limit bans and resource exhaustion (#2095)

This commit is contained in:
Jin 2026-04-12 17:48:40 +08:00 committed by GitHub
parent 979a461af5
commit 4d4ddb3d3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 356 additions and 2 deletions

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
import logging
import threading
import time
from collections.abc import Awaitable, Callable
from email.utils import parsedate_to_datetime
@ -19,6 +20,8 @@ from langchain.agents.middleware.types import (
from langchain_core.messages import AIMessage
from langgraph.errors import GraphBubbleUp
from deerflow.config import get_app_config
logger = logging.getLogger(__name__)
_RETRIABLE_STATUS_CODES = {408, 409, 425, 429, 500, 502, 503, 504}
@ -67,6 +70,80 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
retry_base_delay_ms: int = 1000
retry_cap_delay_ms: int = 8000
circuit_failure_threshold: int = 5
circuit_recovery_timeout_sec: int = 60
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
# Load Circuit Breaker configs from app config if available, fall back to defaults
try:
app_config = get_app_config()
self.circuit_failure_threshold = app_config.circuit_breaker.failure_threshold
self.circuit_recovery_timeout_sec = app_config.circuit_breaker.recovery_timeout_sec
except (FileNotFoundError, RuntimeError):
# Gracefully fall back to class defaults in test environments
pass
# Circuit Breaker state
self._circuit_lock = threading.Lock()
self._circuit_failure_count = 0
self._circuit_open_until = 0.0
self._circuit_state = "closed"
self._circuit_probe_in_flight = False
def _check_circuit(self) -> bool:
"""Returns True if circuit is OPEN (fast fail), False otherwise."""
with self._circuit_lock:
now = time.time()
if self._circuit_state == "open":
if now < self._circuit_open_until:
return True
self._circuit_state = "half_open"
self._circuit_probe_in_flight = False
if self._circuit_state == "half_open":
if self._circuit_probe_in_flight:
return True
self._circuit_probe_in_flight = True
return False
return False
def _record_success(self) -> None:
with self._circuit_lock:
if self._circuit_state != "closed" or self._circuit_failure_count > 0:
logger.info("Circuit breaker reset (Closed). LLM service recovered.")
self._circuit_failure_count = 0
self._circuit_open_until = 0.0
self._circuit_state = "closed"
self._circuit_probe_in_flight = False
def _record_failure(self) -> None:
with self._circuit_lock:
if self._circuit_state == "half_open":
self._circuit_open_until = time.time() + self.circuit_recovery_timeout_sec
self._circuit_state = "open"
self._circuit_probe_in_flight = False
logger.error(
"Circuit breaker probe failed (Open). Will probe again after %ds.",
self.circuit_recovery_timeout_sec,
)
return
self._circuit_failure_count += 1
if self._circuit_failure_count >= self.circuit_failure_threshold:
self._circuit_open_until = time.time() + self.circuit_recovery_timeout_sec
if self._circuit_state != "open":
self._circuit_state = "open"
self._circuit_probe_in_flight = False
logger.error(
"Circuit breaker tripped (Open). Threshold reached (%d). Will probe after %ds.",
self.circuit_failure_threshold,
self.circuit_recovery_timeout_sec,
)
def _classify_error(self, exc: BaseException) -> tuple[bool, str]:
detail = _extract_error_detail(exc)
lowered = detail.lower()
@ -104,6 +181,9 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
reason_text = "provider is busy" if reason == "busy" else "provider request failed temporarily"
return f"LLM request retry {attempt}/{self.retry_max_attempts}: {reason_text}. Retrying in {seconds}s."
def _build_circuit_breaker_message(self) -> str:
return "The configured LLM provider is currently unavailable due to continuous failures. Circuit breaker is engaged to protect the system. Please wait a moment before trying again."
def _build_user_message(self, exc: BaseException, reason: str) -> str:
detail = _extract_error_detail(exc)
if reason == "quota":
@ -138,12 +218,20 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
if self._check_circuit():
return AIMessage(content=self._build_circuit_breaker_message())
attempt = 1
while True:
try:
return handler(request)
response = handler(request)
self._record_success()
return response
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
with self._circuit_lock:
if self._circuit_state == "half_open":
self._circuit_probe_in_flight = False
raise
except Exception as exc:
retriable, reason = self._classify_error(exc)
@ -166,6 +254,8 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
_extract_error_detail(exc),
exc_info=exc,
)
if retriable:
self._record_failure()
return AIMessage(content=self._build_user_message(exc, reason))
@override
@ -174,12 +264,20 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
if self._check_circuit():
return AIMessage(content=self._build_circuit_breaker_message())
attempt = 1
while True:
try:
return await handler(request)
response = await handler(request)
self._record_success()
return response
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
with self._circuit_lock:
if self._circuit_state == "half_open":
self._circuit_probe_in_flight = False
raise
except Exception as exc:
retriable, reason = self._classify_error(exc)
@ -202,6 +300,8 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
_extract_error_detail(exc),
exc_info=exc,
)
if retriable:
self._record_failure()
return AIMessage(content=self._build_user_message(exc, reason))

View File

@ -30,6 +30,13 @@ load_dotenv()
logger = logging.getLogger(__name__)
class CircuitBreakerConfig(BaseModel):
"""Configuration for the LLM Circuit Breaker."""
failure_threshold: int = Field(default=5, description="Number of consecutive failures before tripping the circuit")
recovery_timeout_sec: int = Field(default=60, description="Time in seconds before attempting to recover the circuit")
def _default_config_candidates() -> tuple[Path, ...]:
"""Return deterministic config.yaml locations without relying on cwd."""
backend_dir = Path(__file__).resolve().parents[4]
@ -55,6 +62,7 @@ class AppConfig(BaseModel):
memory: MemoryConfig = Field(default_factory=MemoryConfig, description="Memory subsystem configuration")
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
model_config = ConfigDict(extra="allow", frozen=False)
checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration")
stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration")
@ -129,6 +137,10 @@ class AppConfig(BaseModel):
if "guardrails" in config_data:
load_guardrails_config_from_dict(config_data["guardrails"])
# Load circuit_breaker config if present
if "circuit_breaker" in config_data:
config_data["circuit_breaker"] = config_data["circuit_breaker"]
# Load checkpointer config if present
if "checkpointer" in config_data:
load_checkpointer_config_from_dict(config_data["checkpointer"])

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
from types import SimpleNamespace
from typing import Any
import pytest
from langchain_core.messages import AIMessage
@ -134,3 +135,227 @@ def test_async_model_call_propagates_graph_bubble_up() -> None:
with pytest.raises(GraphBubbleUp):
asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler))
def test_circuit_half_open_graph_bubble_up_resets_probe() -> None:
"""Verify that GraphBubbleUp in half_open state resets probe_in_flight."""
middleware = _build_middleware()
# Step 1: Manually set state to half_open and check_circuit() to set probe_in_flight=True
middleware._circuit_state = "half_open"
middleware._circuit_probe_in_flight = False
# Call _check_circuit() once to simulate the probe being allowed through
assert middleware._check_circuit() is False
assert middleware._circuit_probe_in_flight is True
# Step 2: Now trigger handler that raises GraphBubbleUp
def handler(_request) -> AIMessage:
raise GraphBubbleUp()
# Mock _check_circuit() to return False (since we already did the probe check)
import unittest.mock
with unittest.mock.patch.object(middleware, "_check_circuit", return_value=False):
with pytest.raises(GraphBubbleUp):
middleware.wrap_model_call(SimpleNamespace(), handler)
# Verify probe_in_flight was reset, state should remain half_open
assert middleware._circuit_probe_in_flight is False
assert middleware._circuit_state == "half_open"
@pytest.mark.anyio
async def test_async_circuit_half_open_graph_bubble_up_resets_probe() -> None:
"""Verify that GraphBubbleUp in half_open state resets probe_in_flight (async version)."""
middleware = _build_middleware()
# Step 1: Manually set state to half_open and check_circuit() to set probe_in_flight=True
middleware._circuit_state = "half_open"
middleware._circuit_probe_in_flight = False
# Call _check_circuit() once to simulate the probe being allowed through
assert middleware._check_circuit() is False
assert middleware._circuit_probe_in_flight is True
# Step 2: Now trigger handler that raises GraphBubbleUp
async def handler(_request) -> AIMessage:
raise GraphBubbleUp()
# Mock _check_circuit() to return False (since we already did the probe check)
import unittest.mock
with unittest.mock.patch.object(middleware, "_check_circuit", return_value=False):
with pytest.raises(GraphBubbleUp):
await middleware.awrap_model_call(SimpleNamespace(), handler)
# Verify probe_in_flight was reset, state should remain half_open
assert middleware._circuit_probe_in_flight is False
assert middleware._circuit_state == "half_open"
# ---------- Circuit Breaker Tests ----------
def transient_failing_handler(request: Any) -> Any:
raise FakeError("Server Error", status_code=502) # Used for transient error
def quota_failing_handler(request: Any) -> Any:
raise FakeError("Quota exceeded", body={"error": {"code": "insufficient_quota"}}) # Used for quota error
def success_handler(request: Any) -> Any:
return AIMessage(content="Success")
def mock_classify_retriable(exc: BaseException) -> tuple[bool, str]:
return True, "transient"
def mock_classify_non_retriable(exc: BaseException) -> tuple[bool, str]:
return False, "quota"
def test_circuit_breaker_trips_and_recovers(monkeypatch: pytest.MonkeyPatch) -> None:
"""Verify that circuit breaker trips, fast fails, correctly transitions to Half-Open, and recovers or re-opens."""
# Mock time.sleep to avoid slow tests during retry loops (Speed up from ~4s to 0.1s)
waits: list[float] = []
monkeypatch.setattr("time.sleep", lambda d: waits.append(d))
# Mock time.time to decouple from private implementation details and enable time travel
current_time = 1000.0
monkeypatch.setattr("time.time", lambda: current_time)
middleware = LLMErrorHandlingMiddleware()
middleware.circuit_failure_threshold = 3
middleware.circuit_recovery_timeout_sec = 10
monkeypatch.setattr(middleware, "_classify_error", mock_classify_retriable)
request: Any = {"messages": []}
# --- 0. Test initial state & Success ---
# Success handler does not increase count. If it's already 0, it stays 0.
middleware.wrap_model_call(request, success_handler)
assert middleware._circuit_failure_count == 0
assert middleware._check_circuit() is False
# --- 1. Trip the circuit ---
# Fails 3 overall calls. Threshold (3) is reached.
middleware.wrap_model_call(request, transient_failing_handler)
assert middleware._circuit_failure_count == 1
middleware.wrap_model_call(request, transient_failing_handler)
assert middleware._circuit_failure_count == 2
middleware.wrap_model_call(request, transient_failing_handler)
assert middleware._circuit_failure_count == 3
assert middleware._check_circuit() is True # Circuit is OPEN
# --- 2. Fast Fail ---
# 2nd call: fast fail immediately without calling handler.
# Count should not increase during OPEN state.
result = middleware.wrap_model_call(request, success_handler)
assert result.content == middleware._build_circuit_breaker_message()
assert middleware._circuit_failure_count == 3
# --- 3. Half-Open -> Fail -> Re-Open ---
# Time travel 11 seconds (timeout is 10s). Current time becomes 1011.0
current_time += 11.0
# Verify that the timeout was set EXACTLY relative to current_time + timeout_sec
assert middleware._circuit_open_until == current_time - 11.0 + middleware.circuit_recovery_timeout_sec
# Fails again! The request will go through the 3-attempt retry loop again.
middleware.wrap_model_call(request, transient_failing_handler)
assert middleware._circuit_failure_count == middleware.circuit_failure_threshold
assert middleware._circuit_state == "open" # Re-OPENed
# --- 4. Half-Open -> Success -> Reset ---
# Time travel another 11 seconds
current_time += 11.0
# Succeeds this time! Should completely reset.
result = middleware.wrap_model_call(request, success_handler)
assert result.content == "Success"
assert middleware._circuit_failure_count == 0 # Fully RESET!
assert middleware._check_circuit() is False
def test_circuit_breaker_does_not_trip_on_non_retriable_errors(monkeypatch: pytest.MonkeyPatch) -> None:
"""Verify that circuit breaker ignores business errors like Quota or Auth."""
waits: list[float] = []
monkeypatch.setattr("time.sleep", lambda d: waits.append(d))
middleware = LLMErrorHandlingMiddleware()
middleware.circuit_failure_threshold = 3
monkeypatch.setattr(middleware, "_classify_error", mock_classify_non_retriable)
request: Any = {"messages": []}
for _ in range(3):
middleware.wrap_model_call(request, quota_failing_handler)
assert middleware._circuit_failure_count == 0
assert middleware._check_circuit() is False
@pytest.mark.anyio
async def test_async_circuit_breaker_trips_and_recovers(monkeypatch: pytest.MonkeyPatch) -> None:
"""Verify async version of circuit breaker correctly handles state transitions."""
waits: list[float] = []
async def fake_sleep(d: float) -> None:
waits.append(d)
monkeypatch.setattr(asyncio, "sleep", fake_sleep)
current_time = 1000.0
monkeypatch.setattr("time.time", lambda: current_time)
middleware = LLMErrorHandlingMiddleware()
middleware.circuit_failure_threshold = 3
middleware.circuit_recovery_timeout_sec = 10
monkeypatch.setattr(middleware, "_classify_error", mock_classify_retriable)
async def async_failing_handler(request: Any) -> Any:
raise FakeError("Server Error", status_code=502)
request: Any = {"messages": []}
# --- 1. Trip the circuit ---
# Fails 3 overall calls. Threshold (3) is reached.
await middleware.awrap_model_call(request, async_failing_handler)
assert middleware._circuit_failure_count == 1
await middleware.awrap_model_call(request, async_failing_handler)
assert middleware._circuit_failure_count == 2
await middleware.awrap_model_call(request, async_failing_handler)
assert middleware._circuit_failure_count == 3
assert middleware._check_circuit() is True
# --- 2. Fast Fail ---
# 2nd call: fast fail immediately without calling handler
async def async_success_handler(request: Any) -> Any:
return AIMessage(content="Success")
result = await middleware.awrap_model_call(request, async_success_handler)
assert result.content == middleware._build_circuit_breaker_message()
assert middleware._circuit_failure_count == 3 # Unchanged
# --- 3. Half-Open -> Fail -> Re-Open ---
# Time travel 11 seconds
current_time += 11.0
# Verify timeout formula
assert middleware._circuit_open_until == current_time - 11.0 + middleware.circuit_recovery_timeout_sec
# Fails again! The request goes through the 3-attempt retry loop.
await middleware.awrap_model_call(request, async_failing_handler)
assert middleware._circuit_failure_count == middleware.circuit_failure_threshold
assert middleware._circuit_state == "open" # Re-OPENed
# --- 4. Half-Open -> Success -> Reset ---
# Time travel another 11 seconds
current_time += 11.0
result = await middleware.awrap_model_call(request, async_success_handler)
assert result.content == "Success"
assert middleware._circuit_failure_count == 0 # RESET
assert middleware._check_circuit() is False

View File

@ -883,3 +883,20 @@ checkpointer:
# use: my_package:MyGuardrailProvider
# config:
# key: value
# ============================================================================
# Circuit Breaker Configuration
# ============================================================================
# Circuit breaker for LLM calls prevents repeated requests to a failing provider.
# When the failure threshold is reached, subsequent calls fast-fail until recovery.
#
# This is useful for:
# - Avoiding rate-limit bans during provider outages
# - Reducing resource exhaustion from retry loops
# - Gracefully degrading when LLM services are unavailable
# circuit_breaker:
# # Number of consecutive failures before opening the circuit (default: 5)
# failure_threshold: 5
# # Time in seconds before attempting to recover (default: 60)
# recovery_timeout_sec: 60