From 4d4ddb3d3f396a4cd605e3c85800eb889844e7df Mon Sep 17 00:00:00 2001 From: Jin Date: Sun, 12 Apr 2026 17:48:40 +0800 Subject: [PATCH] feat(llm): introduce lightweight circuit breaker to prevent rate-limit bans and resource exhaustion (#2095) --- .../llm_error_handling_middleware.py | 104 +++++++- .../harness/deerflow/config/app_config.py | 12 + .../test_llm_error_handling_middleware.py | 225 ++++++++++++++++++ config.example.yaml | 17 ++ 4 files changed, 356 insertions(+), 2 deletions(-) diff --git a/backend/packages/harness/deerflow/agents/middlewares/llm_error_handling_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/llm_error_handling_middleware.py index e1a3af714..0c20c7286 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/llm_error_handling_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/llm_error_handling_middleware.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio import logging +import threading import time from collections.abc import Awaitable, Callable from email.utils import parsedate_to_datetime @@ -19,6 +20,8 @@ from langchain.agents.middleware.types import ( from langchain_core.messages import AIMessage from langgraph.errors import GraphBubbleUp +from deerflow.config import get_app_config + logger = logging.getLogger(__name__) _RETRIABLE_STATUS_CODES = {408, 409, 425, 429, 500, 502, 503, 504} @@ -67,6 +70,80 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]): retry_base_delay_ms: int = 1000 retry_cap_delay_ms: int = 8000 + circuit_failure_threshold: int = 5 + circuit_recovery_timeout_sec: int = 60 + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + # Load Circuit Breaker configs from app config if available, fall back to defaults + try: + app_config = get_app_config() + self.circuit_failure_threshold = app_config.circuit_breaker.failure_threshold + self.circuit_recovery_timeout_sec = app_config.circuit_breaker.recovery_timeout_sec + except (FileNotFoundError, RuntimeError): + # Gracefully fall back to class defaults in test environments + pass + + # Circuit Breaker state + self._circuit_lock = threading.Lock() + self._circuit_failure_count = 0 + self._circuit_open_until = 0.0 + self._circuit_state = "closed" + self._circuit_probe_in_flight = False + + def _check_circuit(self) -> bool: + """Returns True if circuit is OPEN (fast fail), False otherwise.""" + with self._circuit_lock: + now = time.time() + + if self._circuit_state == "open": + if now < self._circuit_open_until: + return True + self._circuit_state = "half_open" + self._circuit_probe_in_flight = False + + if self._circuit_state == "half_open": + if self._circuit_probe_in_flight: + return True + self._circuit_probe_in_flight = True + return False + + return False + + def _record_success(self) -> None: + with self._circuit_lock: + if self._circuit_state != "closed" or self._circuit_failure_count > 0: + logger.info("Circuit breaker reset (Closed). LLM service recovered.") + self._circuit_failure_count = 0 + self._circuit_open_until = 0.0 + self._circuit_state = "closed" + self._circuit_probe_in_flight = False + + def _record_failure(self) -> None: + with self._circuit_lock: + if self._circuit_state == "half_open": + self._circuit_open_until = time.time() + self.circuit_recovery_timeout_sec + self._circuit_state = "open" + self._circuit_probe_in_flight = False + logger.error( + "Circuit breaker probe failed (Open). Will probe again after %ds.", + self.circuit_recovery_timeout_sec, + ) + return + + self._circuit_failure_count += 1 + if self._circuit_failure_count >= self.circuit_failure_threshold: + self._circuit_open_until = time.time() + self.circuit_recovery_timeout_sec + if self._circuit_state != "open": + self._circuit_state = "open" + self._circuit_probe_in_flight = False + logger.error( + "Circuit breaker tripped (Open). Threshold reached (%d). Will probe after %ds.", + self.circuit_failure_threshold, + self.circuit_recovery_timeout_sec, + ) + def _classify_error(self, exc: BaseException) -> tuple[bool, str]: detail = _extract_error_detail(exc) lowered = detail.lower() @@ -104,6 +181,9 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]): reason_text = "provider is busy" if reason == "busy" else "provider request failed temporarily" return f"LLM request retry {attempt}/{self.retry_max_attempts}: {reason_text}. Retrying in {seconds}s." + def _build_circuit_breaker_message(self) -> str: + return "The configured LLM provider is currently unavailable due to continuous failures. Circuit breaker is engaged to protect the system. Please wait a moment before trying again." + def _build_user_message(self, exc: BaseException, reason: str) -> str: detail = _extract_error_detail(exc) if reason == "quota": @@ -138,12 +218,20 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]): request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse], ) -> ModelCallResult: + if self._check_circuit(): + return AIMessage(content=self._build_circuit_breaker_message()) + attempt = 1 while True: try: - return handler(request) + response = handler(request) + self._record_success() + return response except GraphBubbleUp: # Preserve LangGraph control-flow signals (interrupt/pause/resume). + with self._circuit_lock: + if self._circuit_state == "half_open": + self._circuit_probe_in_flight = False raise except Exception as exc: retriable, reason = self._classify_error(exc) @@ -166,6 +254,8 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]): _extract_error_detail(exc), exc_info=exc, ) + if retriable: + self._record_failure() return AIMessage(content=self._build_user_message(exc, reason)) @override @@ -174,12 +264,20 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]): request: ModelRequest, handler: Callable[[ModelRequest], Awaitable[ModelResponse]], ) -> ModelCallResult: + if self._check_circuit(): + return AIMessage(content=self._build_circuit_breaker_message()) + attempt = 1 while True: try: - return await handler(request) + response = await handler(request) + self._record_success() + return response except GraphBubbleUp: # Preserve LangGraph control-flow signals (interrupt/pause/resume). + with self._circuit_lock: + if self._circuit_state == "half_open": + self._circuit_probe_in_flight = False raise except Exception as exc: retriable, reason = self._classify_error(exc) @@ -202,6 +300,8 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]): _extract_error_detail(exc), exc_info=exc, ) + if retriable: + self._record_failure() return AIMessage(content=self._build_user_message(exc, reason)) diff --git a/backend/packages/harness/deerflow/config/app_config.py b/backend/packages/harness/deerflow/config/app_config.py index e1ffbf847..df526029c 100644 --- a/backend/packages/harness/deerflow/config/app_config.py +++ b/backend/packages/harness/deerflow/config/app_config.py @@ -30,6 +30,13 @@ load_dotenv() logger = logging.getLogger(__name__) +class CircuitBreakerConfig(BaseModel): + """Configuration for the LLM Circuit Breaker.""" + + failure_threshold: int = Field(default=5, description="Number of consecutive failures before tripping the circuit") + recovery_timeout_sec: int = Field(default=60, description="Time in seconds before attempting to recover the circuit") + + def _default_config_candidates() -> tuple[Path, ...]: """Return deterministic config.yaml locations without relying on cwd.""" backend_dir = Path(__file__).resolve().parents[4] @@ -55,6 +62,7 @@ class AppConfig(BaseModel): memory: MemoryConfig = Field(default_factory=MemoryConfig, description="Memory subsystem configuration") subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration") guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration") + circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration") model_config = ConfigDict(extra="allow", frozen=False) checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration") stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration") @@ -129,6 +137,10 @@ class AppConfig(BaseModel): if "guardrails" in config_data: load_guardrails_config_from_dict(config_data["guardrails"]) + # Load circuit_breaker config if present + if "circuit_breaker" in config_data: + config_data["circuit_breaker"] = config_data["circuit_breaker"] + # Load checkpointer config if present if "checkpointer" in config_data: load_checkpointer_config_from_dict(config_data["checkpointer"]) diff --git a/backend/tests/test_llm_error_handling_middleware.py b/backend/tests/test_llm_error_handling_middleware.py index 9c3077e31..13b730aa3 100644 --- a/backend/tests/test_llm_error_handling_middleware.py +++ b/backend/tests/test_llm_error_handling_middleware.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio from types import SimpleNamespace +from typing import Any import pytest from langchain_core.messages import AIMessage @@ -134,3 +135,227 @@ def test_async_model_call_propagates_graph_bubble_up() -> None: with pytest.raises(GraphBubbleUp): asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler)) + + +def test_circuit_half_open_graph_bubble_up_resets_probe() -> None: + """Verify that GraphBubbleUp in half_open state resets probe_in_flight.""" + middleware = _build_middleware() + + # Step 1: Manually set state to half_open and check_circuit() to set probe_in_flight=True + middleware._circuit_state = "half_open" + middleware._circuit_probe_in_flight = False + # Call _check_circuit() once to simulate the probe being allowed through + assert middleware._check_circuit() is False + assert middleware._circuit_probe_in_flight is True + + # Step 2: Now trigger handler that raises GraphBubbleUp + def handler(_request) -> AIMessage: + raise GraphBubbleUp() + + # Mock _check_circuit() to return False (since we already did the probe check) + import unittest.mock + + with unittest.mock.patch.object(middleware, "_check_circuit", return_value=False): + with pytest.raises(GraphBubbleUp): + middleware.wrap_model_call(SimpleNamespace(), handler) + + # Verify probe_in_flight was reset, state should remain half_open + assert middleware._circuit_probe_in_flight is False + assert middleware._circuit_state == "half_open" + + +@pytest.mark.anyio +async def test_async_circuit_half_open_graph_bubble_up_resets_probe() -> None: + """Verify that GraphBubbleUp in half_open state resets probe_in_flight (async version).""" + middleware = _build_middleware() + + # Step 1: Manually set state to half_open and check_circuit() to set probe_in_flight=True + middleware._circuit_state = "half_open" + middleware._circuit_probe_in_flight = False + # Call _check_circuit() once to simulate the probe being allowed through + assert middleware._check_circuit() is False + assert middleware._circuit_probe_in_flight is True + + # Step 2: Now trigger handler that raises GraphBubbleUp + async def handler(_request) -> AIMessage: + raise GraphBubbleUp() + + # Mock _check_circuit() to return False (since we already did the probe check) + import unittest.mock + + with unittest.mock.patch.object(middleware, "_check_circuit", return_value=False): + with pytest.raises(GraphBubbleUp): + await middleware.awrap_model_call(SimpleNamespace(), handler) + + # Verify probe_in_flight was reset, state should remain half_open + assert middleware._circuit_probe_in_flight is False + assert middleware._circuit_state == "half_open" + + +# ---------- Circuit Breaker Tests ---------- + + +def transient_failing_handler(request: Any) -> Any: + raise FakeError("Server Error", status_code=502) # Used for transient error + + +def quota_failing_handler(request: Any) -> Any: + raise FakeError("Quota exceeded", body={"error": {"code": "insufficient_quota"}}) # Used for quota error + + +def success_handler(request: Any) -> Any: + return AIMessage(content="Success") + + +def mock_classify_retriable(exc: BaseException) -> tuple[bool, str]: + return True, "transient" + + +def mock_classify_non_retriable(exc: BaseException) -> tuple[bool, str]: + return False, "quota" + + +def test_circuit_breaker_trips_and_recovers(monkeypatch: pytest.MonkeyPatch) -> None: + """Verify that circuit breaker trips, fast fails, correctly transitions to Half-Open, and recovers or re-opens.""" + + # Mock time.sleep to avoid slow tests during retry loops (Speed up from ~4s to 0.1s) + waits: list[float] = [] + monkeypatch.setattr("time.sleep", lambda d: waits.append(d)) + + # Mock time.time to decouple from private implementation details and enable time travel + current_time = 1000.0 + monkeypatch.setattr("time.time", lambda: current_time) + + middleware = LLMErrorHandlingMiddleware() + middleware.circuit_failure_threshold = 3 + middleware.circuit_recovery_timeout_sec = 10 + monkeypatch.setattr(middleware, "_classify_error", mock_classify_retriable) + + request: Any = {"messages": []} + + # --- 0. Test initial state & Success --- + # Success handler does not increase count. If it's already 0, it stays 0. + middleware.wrap_model_call(request, success_handler) + assert middleware._circuit_failure_count == 0 + assert middleware._check_circuit() is False + + # --- 1. Trip the circuit --- + # Fails 3 overall calls. Threshold (3) is reached. + middleware.wrap_model_call(request, transient_failing_handler) + assert middleware._circuit_failure_count == 1 + middleware.wrap_model_call(request, transient_failing_handler) + assert middleware._circuit_failure_count == 2 + middleware.wrap_model_call(request, transient_failing_handler) + assert middleware._circuit_failure_count == 3 + assert middleware._check_circuit() is True # Circuit is OPEN + + # --- 2. Fast Fail --- + # 2nd call: fast fail immediately without calling handler. + # Count should not increase during OPEN state. + result = middleware.wrap_model_call(request, success_handler) + assert result.content == middleware._build_circuit_breaker_message() + assert middleware._circuit_failure_count == 3 + + # --- 3. Half-Open -> Fail -> Re-Open --- + # Time travel 11 seconds (timeout is 10s). Current time becomes 1011.0 + current_time += 11.0 + + # Verify that the timeout was set EXACTLY relative to current_time + timeout_sec + assert middleware._circuit_open_until == current_time - 11.0 + middleware.circuit_recovery_timeout_sec + + # Fails again! The request will go through the 3-attempt retry loop again. + middleware.wrap_model_call(request, transient_failing_handler) + assert middleware._circuit_failure_count == middleware.circuit_failure_threshold + assert middleware._circuit_state == "open" # Re-OPENed + + # --- 4. Half-Open -> Success -> Reset --- + # Time travel another 11 seconds + current_time += 11.0 + + # Succeeds this time! Should completely reset. + result = middleware.wrap_model_call(request, success_handler) + assert result.content == "Success" + assert middleware._circuit_failure_count == 0 # Fully RESET! + assert middleware._check_circuit() is False + + +def test_circuit_breaker_does_not_trip_on_non_retriable_errors(monkeypatch: pytest.MonkeyPatch) -> None: + """Verify that circuit breaker ignores business errors like Quota or Auth.""" + waits: list[float] = [] + monkeypatch.setattr("time.sleep", lambda d: waits.append(d)) + + middleware = LLMErrorHandlingMiddleware() + middleware.circuit_failure_threshold = 3 + monkeypatch.setattr(middleware, "_classify_error", mock_classify_non_retriable) + + request: Any = {"messages": []} + + for _ in range(3): + middleware.wrap_model_call(request, quota_failing_handler) + + assert middleware._circuit_failure_count == 0 + assert middleware._check_circuit() is False + + +@pytest.mark.anyio +async def test_async_circuit_breaker_trips_and_recovers(monkeypatch: pytest.MonkeyPatch) -> None: + """Verify async version of circuit breaker correctly handles state transitions.""" + waits: list[float] = [] + + async def fake_sleep(d: float) -> None: + waits.append(d) + + monkeypatch.setattr(asyncio, "sleep", fake_sleep) + + current_time = 1000.0 + monkeypatch.setattr("time.time", lambda: current_time) + + middleware = LLMErrorHandlingMiddleware() + middleware.circuit_failure_threshold = 3 + middleware.circuit_recovery_timeout_sec = 10 + monkeypatch.setattr(middleware, "_classify_error", mock_classify_retriable) + + async def async_failing_handler(request: Any) -> Any: + raise FakeError("Server Error", status_code=502) + + request: Any = {"messages": []} + + # --- 1. Trip the circuit --- + # Fails 3 overall calls. Threshold (3) is reached. + await middleware.awrap_model_call(request, async_failing_handler) + assert middleware._circuit_failure_count == 1 + await middleware.awrap_model_call(request, async_failing_handler) + assert middleware._circuit_failure_count == 2 + await middleware.awrap_model_call(request, async_failing_handler) + assert middleware._circuit_failure_count == 3 + assert middleware._check_circuit() is True + + # --- 2. Fast Fail --- + # 2nd call: fast fail immediately without calling handler + async def async_success_handler(request: Any) -> Any: + return AIMessage(content="Success") + + result = await middleware.awrap_model_call(request, async_success_handler) + assert result.content == middleware._build_circuit_breaker_message() + assert middleware._circuit_failure_count == 3 # Unchanged + + # --- 3. Half-Open -> Fail -> Re-Open --- + # Time travel 11 seconds + current_time += 11.0 + + # Verify timeout formula + assert middleware._circuit_open_until == current_time - 11.0 + middleware.circuit_recovery_timeout_sec + + # Fails again! The request goes through the 3-attempt retry loop. + await middleware.awrap_model_call(request, async_failing_handler) + assert middleware._circuit_failure_count == middleware.circuit_failure_threshold + assert middleware._circuit_state == "open" # Re-OPENed + + # --- 4. Half-Open -> Success -> Reset --- + # Time travel another 11 seconds + current_time += 11.0 + + result = await middleware.awrap_model_call(request, async_success_handler) + assert result.content == "Success" + assert middleware._circuit_failure_count == 0 # RESET + assert middleware._check_circuit() is False diff --git a/config.example.yaml b/config.example.yaml index ac65b6e42..89d8e8a85 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -883,3 +883,20 @@ checkpointer: # use: my_package:MyGuardrailProvider # config: # key: value + +# ============================================================================ +# Circuit Breaker Configuration +# ============================================================================ +# Circuit breaker for LLM calls prevents repeated requests to a failing provider. +# When the failure threshold is reached, subsequent calls fast-fail until recovery. +# +# This is useful for: +# - Avoiding rate-limit bans during provider outages +# - Reducing resource exhaustion from retry loops +# - Gracefully degrading when LLM services are unavailable + +# circuit_breaker: +# # Number of consecutive failures before opening the circuit (default: 5) +# failure_threshold: 5 +# # Time in seconds before attempting to recover (default: 60) +# recovery_timeout_sec: 60