mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
feat(llm): introduce lightweight circuit breaker to prevent rate-limit bans and resource exhaustion (#2095)
This commit is contained in:
parent
979a461af5
commit
4d4ddb3d3f
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from email.utils import parsedate_to_datetime
|
||||
@ -19,6 +20,8 @@ from langchain.agents.middleware.types import (
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.errors import GraphBubbleUp
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_RETRIABLE_STATUS_CODES = {408, 409, 425, 429, 500, 502, 503, 504}
|
||||
@ -67,6 +70,80 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||
retry_base_delay_ms: int = 1000
|
||||
retry_cap_delay_ms: int = 8000
|
||||
|
||||
circuit_failure_threshold: int = 5
|
||||
circuit_recovery_timeout_sec: int = 60
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Load Circuit Breaker configs from app config if available, fall back to defaults
|
||||
try:
|
||||
app_config = get_app_config()
|
||||
self.circuit_failure_threshold = app_config.circuit_breaker.failure_threshold
|
||||
self.circuit_recovery_timeout_sec = app_config.circuit_breaker.recovery_timeout_sec
|
||||
except (FileNotFoundError, RuntimeError):
|
||||
# Gracefully fall back to class defaults in test environments
|
||||
pass
|
||||
|
||||
# Circuit Breaker state
|
||||
self._circuit_lock = threading.Lock()
|
||||
self._circuit_failure_count = 0
|
||||
self._circuit_open_until = 0.0
|
||||
self._circuit_state = "closed"
|
||||
self._circuit_probe_in_flight = False
|
||||
|
||||
def _check_circuit(self) -> bool:
|
||||
"""Returns True if circuit is OPEN (fast fail), False otherwise."""
|
||||
with self._circuit_lock:
|
||||
now = time.time()
|
||||
|
||||
if self._circuit_state == "open":
|
||||
if now < self._circuit_open_until:
|
||||
return True
|
||||
self._circuit_state = "half_open"
|
||||
self._circuit_probe_in_flight = False
|
||||
|
||||
if self._circuit_state == "half_open":
|
||||
if self._circuit_probe_in_flight:
|
||||
return True
|
||||
self._circuit_probe_in_flight = True
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
def _record_success(self) -> None:
|
||||
with self._circuit_lock:
|
||||
if self._circuit_state != "closed" or self._circuit_failure_count > 0:
|
||||
logger.info("Circuit breaker reset (Closed). LLM service recovered.")
|
||||
self._circuit_failure_count = 0
|
||||
self._circuit_open_until = 0.0
|
||||
self._circuit_state = "closed"
|
||||
self._circuit_probe_in_flight = False
|
||||
|
||||
def _record_failure(self) -> None:
|
||||
with self._circuit_lock:
|
||||
if self._circuit_state == "half_open":
|
||||
self._circuit_open_until = time.time() + self.circuit_recovery_timeout_sec
|
||||
self._circuit_state = "open"
|
||||
self._circuit_probe_in_flight = False
|
||||
logger.error(
|
||||
"Circuit breaker probe failed (Open). Will probe again after %ds.",
|
||||
self.circuit_recovery_timeout_sec,
|
||||
)
|
||||
return
|
||||
|
||||
self._circuit_failure_count += 1
|
||||
if self._circuit_failure_count >= self.circuit_failure_threshold:
|
||||
self._circuit_open_until = time.time() + self.circuit_recovery_timeout_sec
|
||||
if self._circuit_state != "open":
|
||||
self._circuit_state = "open"
|
||||
self._circuit_probe_in_flight = False
|
||||
logger.error(
|
||||
"Circuit breaker tripped (Open). Threshold reached (%d). Will probe after %ds.",
|
||||
self.circuit_failure_threshold,
|
||||
self.circuit_recovery_timeout_sec,
|
||||
)
|
||||
|
||||
def _classify_error(self, exc: BaseException) -> tuple[bool, str]:
|
||||
detail = _extract_error_detail(exc)
|
||||
lowered = detail.lower()
|
||||
@ -104,6 +181,9 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||
reason_text = "provider is busy" if reason == "busy" else "provider request failed temporarily"
|
||||
return f"LLM request retry {attempt}/{self.retry_max_attempts}: {reason_text}. Retrying in {seconds}s."
|
||||
|
||||
def _build_circuit_breaker_message(self) -> str:
|
||||
return "The configured LLM provider is currently unavailable due to continuous failures. Circuit breaker is engaged to protect the system. Please wait a moment before trying again."
|
||||
|
||||
def _build_user_message(self, exc: BaseException, reason: str) -> str:
|
||||
detail = _extract_error_detail(exc)
|
||||
if reason == "quota":
|
||||
@ -138,12 +218,20 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
if self._check_circuit():
|
||||
return AIMessage(content=self._build_circuit_breaker_message())
|
||||
|
||||
attempt = 1
|
||||
while True:
|
||||
try:
|
||||
return handler(request)
|
||||
response = handler(request)
|
||||
self._record_success()
|
||||
return response
|
||||
except GraphBubbleUp:
|
||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||
with self._circuit_lock:
|
||||
if self._circuit_state == "half_open":
|
||||
self._circuit_probe_in_flight = False
|
||||
raise
|
||||
except Exception as exc:
|
||||
retriable, reason = self._classify_error(exc)
|
||||
@ -166,6 +254,8 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||
_extract_error_detail(exc),
|
||||
exc_info=exc,
|
||||
)
|
||||
if retriable:
|
||||
self._record_failure()
|
||||
return AIMessage(content=self._build_user_message(exc, reason))
|
||||
|
||||
@override
|
||||
@ -174,12 +264,20 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
if self._check_circuit():
|
||||
return AIMessage(content=self._build_circuit_breaker_message())
|
||||
|
||||
attempt = 1
|
||||
while True:
|
||||
try:
|
||||
return await handler(request)
|
||||
response = await handler(request)
|
||||
self._record_success()
|
||||
return response
|
||||
except GraphBubbleUp:
|
||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||
with self._circuit_lock:
|
||||
if self._circuit_state == "half_open":
|
||||
self._circuit_probe_in_flight = False
|
||||
raise
|
||||
except Exception as exc:
|
||||
retriable, reason = self._classify_error(exc)
|
||||
@ -202,6 +300,8 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||
_extract_error_detail(exc),
|
||||
exc_info=exc,
|
||||
)
|
||||
if retriable:
|
||||
self._record_failure()
|
||||
return AIMessage(content=self._build_user_message(exc, reason))
|
||||
|
||||
|
||||
|
||||
@ -30,6 +30,13 @@ load_dotenv()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CircuitBreakerConfig(BaseModel):
|
||||
"""Configuration for the LLM Circuit Breaker."""
|
||||
|
||||
failure_threshold: int = Field(default=5, description="Number of consecutive failures before tripping the circuit")
|
||||
recovery_timeout_sec: int = Field(default=60, description="Time in seconds before attempting to recover the circuit")
|
||||
|
||||
|
||||
def _default_config_candidates() -> tuple[Path, ...]:
|
||||
"""Return deterministic config.yaml locations without relying on cwd."""
|
||||
backend_dir = Path(__file__).resolve().parents[4]
|
||||
@ -55,6 +62,7 @@ class AppConfig(BaseModel):
|
||||
memory: MemoryConfig = Field(default_factory=MemoryConfig, description="Memory subsystem configuration")
|
||||
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
|
||||
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
|
||||
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
|
||||
model_config = ConfigDict(extra="allow", frozen=False)
|
||||
checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration")
|
||||
stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration")
|
||||
@ -129,6 +137,10 @@ class AppConfig(BaseModel):
|
||||
if "guardrails" in config_data:
|
||||
load_guardrails_config_from_dict(config_data["guardrails"])
|
||||
|
||||
# Load circuit_breaker config if present
|
||||
if "circuit_breaker" in config_data:
|
||||
config_data["circuit_breaker"] = config_data["circuit_breaker"]
|
||||
|
||||
# Load checkpointer config if present
|
||||
if "checkpointer" in config_data:
|
||||
load_checkpointer_config_from_dict(config_data["checkpointer"])
|
||||
|
||||
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage
|
||||
@ -134,3 +135,227 @@ def test_async_model_call_propagates_graph_bubble_up() -> None:
|
||||
|
||||
with pytest.raises(GraphBubbleUp):
|
||||
asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler))
|
||||
|
||||
|
||||
def test_circuit_half_open_graph_bubble_up_resets_probe() -> None:
|
||||
"""Verify that GraphBubbleUp in half_open state resets probe_in_flight."""
|
||||
middleware = _build_middleware()
|
||||
|
||||
# Step 1: Manually set state to half_open and check_circuit() to set probe_in_flight=True
|
||||
middleware._circuit_state = "half_open"
|
||||
middleware._circuit_probe_in_flight = False
|
||||
# Call _check_circuit() once to simulate the probe being allowed through
|
||||
assert middleware._check_circuit() is False
|
||||
assert middleware._circuit_probe_in_flight is True
|
||||
|
||||
# Step 2: Now trigger handler that raises GraphBubbleUp
|
||||
def handler(_request) -> AIMessage:
|
||||
raise GraphBubbleUp()
|
||||
|
||||
# Mock _check_circuit() to return False (since we already did the probe check)
|
||||
import unittest.mock
|
||||
|
||||
with unittest.mock.patch.object(middleware, "_check_circuit", return_value=False):
|
||||
with pytest.raises(GraphBubbleUp):
|
||||
middleware.wrap_model_call(SimpleNamespace(), handler)
|
||||
|
||||
# Verify probe_in_flight was reset, state should remain half_open
|
||||
assert middleware._circuit_probe_in_flight is False
|
||||
assert middleware._circuit_state == "half_open"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_async_circuit_half_open_graph_bubble_up_resets_probe() -> None:
|
||||
"""Verify that GraphBubbleUp in half_open state resets probe_in_flight (async version)."""
|
||||
middleware = _build_middleware()
|
||||
|
||||
# Step 1: Manually set state to half_open and check_circuit() to set probe_in_flight=True
|
||||
middleware._circuit_state = "half_open"
|
||||
middleware._circuit_probe_in_flight = False
|
||||
# Call _check_circuit() once to simulate the probe being allowed through
|
||||
assert middleware._check_circuit() is False
|
||||
assert middleware._circuit_probe_in_flight is True
|
||||
|
||||
# Step 2: Now trigger handler that raises GraphBubbleUp
|
||||
async def handler(_request) -> AIMessage:
|
||||
raise GraphBubbleUp()
|
||||
|
||||
# Mock _check_circuit() to return False (since we already did the probe check)
|
||||
import unittest.mock
|
||||
|
||||
with unittest.mock.patch.object(middleware, "_check_circuit", return_value=False):
|
||||
with pytest.raises(GraphBubbleUp):
|
||||
await middleware.awrap_model_call(SimpleNamespace(), handler)
|
||||
|
||||
# Verify probe_in_flight was reset, state should remain half_open
|
||||
assert middleware._circuit_probe_in_flight is False
|
||||
assert middleware._circuit_state == "half_open"
|
||||
|
||||
|
||||
# ---------- Circuit Breaker Tests ----------
|
||||
|
||||
|
||||
def transient_failing_handler(request: Any) -> Any:
|
||||
raise FakeError("Server Error", status_code=502) # Used for transient error
|
||||
|
||||
|
||||
def quota_failing_handler(request: Any) -> Any:
|
||||
raise FakeError("Quota exceeded", body={"error": {"code": "insufficient_quota"}}) # Used for quota error
|
||||
|
||||
|
||||
def success_handler(request: Any) -> Any:
|
||||
return AIMessage(content="Success")
|
||||
|
||||
|
||||
def mock_classify_retriable(exc: BaseException) -> tuple[bool, str]:
|
||||
return True, "transient"
|
||||
|
||||
|
||||
def mock_classify_non_retriable(exc: BaseException) -> tuple[bool, str]:
|
||||
return False, "quota"
|
||||
|
||||
|
||||
def test_circuit_breaker_trips_and_recovers(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Verify that circuit breaker trips, fast fails, correctly transitions to Half-Open, and recovers or re-opens."""
|
||||
|
||||
# Mock time.sleep to avoid slow tests during retry loops (Speed up from ~4s to 0.1s)
|
||||
waits: list[float] = []
|
||||
monkeypatch.setattr("time.sleep", lambda d: waits.append(d))
|
||||
|
||||
# Mock time.time to decouple from private implementation details and enable time travel
|
||||
current_time = 1000.0
|
||||
monkeypatch.setattr("time.time", lambda: current_time)
|
||||
|
||||
middleware = LLMErrorHandlingMiddleware()
|
||||
middleware.circuit_failure_threshold = 3
|
||||
middleware.circuit_recovery_timeout_sec = 10
|
||||
monkeypatch.setattr(middleware, "_classify_error", mock_classify_retriable)
|
||||
|
||||
request: Any = {"messages": []}
|
||||
|
||||
# --- 0. Test initial state & Success ---
|
||||
# Success handler does not increase count. If it's already 0, it stays 0.
|
||||
middleware.wrap_model_call(request, success_handler)
|
||||
assert middleware._circuit_failure_count == 0
|
||||
assert middleware._check_circuit() is False
|
||||
|
||||
# --- 1. Trip the circuit ---
|
||||
# Fails 3 overall calls. Threshold (3) is reached.
|
||||
middleware.wrap_model_call(request, transient_failing_handler)
|
||||
assert middleware._circuit_failure_count == 1
|
||||
middleware.wrap_model_call(request, transient_failing_handler)
|
||||
assert middleware._circuit_failure_count == 2
|
||||
middleware.wrap_model_call(request, transient_failing_handler)
|
||||
assert middleware._circuit_failure_count == 3
|
||||
assert middleware._check_circuit() is True # Circuit is OPEN
|
||||
|
||||
# --- 2. Fast Fail ---
|
||||
# 2nd call: fast fail immediately without calling handler.
|
||||
# Count should not increase during OPEN state.
|
||||
result = middleware.wrap_model_call(request, success_handler)
|
||||
assert result.content == middleware._build_circuit_breaker_message()
|
||||
assert middleware._circuit_failure_count == 3
|
||||
|
||||
# --- 3. Half-Open -> Fail -> Re-Open ---
|
||||
# Time travel 11 seconds (timeout is 10s). Current time becomes 1011.0
|
||||
current_time += 11.0
|
||||
|
||||
# Verify that the timeout was set EXACTLY relative to current_time + timeout_sec
|
||||
assert middleware._circuit_open_until == current_time - 11.0 + middleware.circuit_recovery_timeout_sec
|
||||
|
||||
# Fails again! The request will go through the 3-attempt retry loop again.
|
||||
middleware.wrap_model_call(request, transient_failing_handler)
|
||||
assert middleware._circuit_failure_count == middleware.circuit_failure_threshold
|
||||
assert middleware._circuit_state == "open" # Re-OPENed
|
||||
|
||||
# --- 4. Half-Open -> Success -> Reset ---
|
||||
# Time travel another 11 seconds
|
||||
current_time += 11.0
|
||||
|
||||
# Succeeds this time! Should completely reset.
|
||||
result = middleware.wrap_model_call(request, success_handler)
|
||||
assert result.content == "Success"
|
||||
assert middleware._circuit_failure_count == 0 # Fully RESET!
|
||||
assert middleware._check_circuit() is False
|
||||
|
||||
|
||||
def test_circuit_breaker_does_not_trip_on_non_retriable_errors(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Verify that circuit breaker ignores business errors like Quota or Auth."""
|
||||
waits: list[float] = []
|
||||
monkeypatch.setattr("time.sleep", lambda d: waits.append(d))
|
||||
|
||||
middleware = LLMErrorHandlingMiddleware()
|
||||
middleware.circuit_failure_threshold = 3
|
||||
monkeypatch.setattr(middleware, "_classify_error", mock_classify_non_retriable)
|
||||
|
||||
request: Any = {"messages": []}
|
||||
|
||||
for _ in range(3):
|
||||
middleware.wrap_model_call(request, quota_failing_handler)
|
||||
|
||||
assert middleware._circuit_failure_count == 0
|
||||
assert middleware._check_circuit() is False
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_async_circuit_breaker_trips_and_recovers(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Verify async version of circuit breaker correctly handles state transitions."""
|
||||
waits: list[float] = []
|
||||
|
||||
async def fake_sleep(d: float) -> None:
|
||||
waits.append(d)
|
||||
|
||||
monkeypatch.setattr(asyncio, "sleep", fake_sleep)
|
||||
|
||||
current_time = 1000.0
|
||||
monkeypatch.setattr("time.time", lambda: current_time)
|
||||
|
||||
middleware = LLMErrorHandlingMiddleware()
|
||||
middleware.circuit_failure_threshold = 3
|
||||
middleware.circuit_recovery_timeout_sec = 10
|
||||
monkeypatch.setattr(middleware, "_classify_error", mock_classify_retriable)
|
||||
|
||||
async def async_failing_handler(request: Any) -> Any:
|
||||
raise FakeError("Server Error", status_code=502)
|
||||
|
||||
request: Any = {"messages": []}
|
||||
|
||||
# --- 1. Trip the circuit ---
|
||||
# Fails 3 overall calls. Threshold (3) is reached.
|
||||
await middleware.awrap_model_call(request, async_failing_handler)
|
||||
assert middleware._circuit_failure_count == 1
|
||||
await middleware.awrap_model_call(request, async_failing_handler)
|
||||
assert middleware._circuit_failure_count == 2
|
||||
await middleware.awrap_model_call(request, async_failing_handler)
|
||||
assert middleware._circuit_failure_count == 3
|
||||
assert middleware._check_circuit() is True
|
||||
|
||||
# --- 2. Fast Fail ---
|
||||
# 2nd call: fast fail immediately without calling handler
|
||||
async def async_success_handler(request: Any) -> Any:
|
||||
return AIMessage(content="Success")
|
||||
|
||||
result = await middleware.awrap_model_call(request, async_success_handler)
|
||||
assert result.content == middleware._build_circuit_breaker_message()
|
||||
assert middleware._circuit_failure_count == 3 # Unchanged
|
||||
|
||||
# --- 3. Half-Open -> Fail -> Re-Open ---
|
||||
# Time travel 11 seconds
|
||||
current_time += 11.0
|
||||
|
||||
# Verify timeout formula
|
||||
assert middleware._circuit_open_until == current_time - 11.0 + middleware.circuit_recovery_timeout_sec
|
||||
|
||||
# Fails again! The request goes through the 3-attempt retry loop.
|
||||
await middleware.awrap_model_call(request, async_failing_handler)
|
||||
assert middleware._circuit_failure_count == middleware.circuit_failure_threshold
|
||||
assert middleware._circuit_state == "open" # Re-OPENed
|
||||
|
||||
# --- 4. Half-Open -> Success -> Reset ---
|
||||
# Time travel another 11 seconds
|
||||
current_time += 11.0
|
||||
|
||||
result = await middleware.awrap_model_call(request, async_success_handler)
|
||||
assert result.content == "Success"
|
||||
assert middleware._circuit_failure_count == 0 # RESET
|
||||
assert middleware._check_circuit() is False
|
||||
|
||||
@ -883,3 +883,20 @@ checkpointer:
|
||||
# use: my_package:MyGuardrailProvider
|
||||
# config:
|
||||
# key: value
|
||||
|
||||
# ============================================================================
|
||||
# Circuit Breaker Configuration
|
||||
# ============================================================================
|
||||
# Circuit breaker for LLM calls prevents repeated requests to a failing provider.
|
||||
# When the failure threshold is reached, subsequent calls fast-fail until recovery.
|
||||
#
|
||||
# This is useful for:
|
||||
# - Avoiding rate-limit bans during provider outages
|
||||
# - Reducing resource exhaustion from retry loops
|
||||
# - Gracefully degrading when LLM services are unavailable
|
||||
|
||||
# circuit_breaker:
|
||||
# # Number of consecutive failures before opening the circuit (default: 5)
|
||||
# failure_threshold: 5
|
||||
# # Time in seconds before attempting to recover (default: 60)
|
||||
# recovery_timeout_sec: 60
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user