From 79cc2279174e24d019300becb2b3b372e516d943 Mon Sep 17 00:00:00 2001 From: Nan Gao Date: Sun, 31 May 2026 16:42:13 +0200 Subject: [PATCH] fix(middleware): fix LLM fallback run status (#3321) * Fix LLM fallback run status * optimize LLM fallback maker extraction in streaming path --- .../llm_error_handling_middleware.py | 44 +++- .../harness/deerflow/runtime/journal.py | 22 ++ .../harness/deerflow/runtime/runs/worker.py | 88 ++++++++ .../test_llm_error_handling_middleware.py | 25 +++ backend/tests/test_run_worker_rollback.py | 188 +++++++++++++++++- 5 files changed, 362 insertions(+), 5 deletions(-) diff --git a/backend/packages/harness/deerflow/agents/middlewares/llm_error_handling_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/llm_error_handling_middleware.py index ef23e08f1..a489d90d0 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/llm_error_handling_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/llm_error_handling_middleware.py @@ -177,6 +177,24 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]): def _build_circuit_breaker_message(self) -> str: return "The configured LLM provider is currently unavailable due to continuous failures. Circuit breaker is engaged to protect the system. Please wait a moment before trying again." + def _build_error_fallback_message( + self, + content: str, + *, + error_type: str, + reason: str, + detail: str, + ) -> AIMessage: + return AIMessage( + content=content, + additional_kwargs={ + "deerflow_error_fallback": True, + "error_type": error_type, + "error_reason": reason, + "error_detail": detail, + }, + ) + def _build_user_message(self, exc: BaseException, reason: str) -> str: detail = _extract_error_detail(exc) if reason == "quota": @@ -187,6 +205,14 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]): return "The configured LLM provider is temporarily unavailable after multiple retries. Please wait a moment and continue the conversation." return f"LLM request failed: {detail}" + def _build_user_fallback_message(self, exc: BaseException, reason: str) -> AIMessage: + return self._build_error_fallback_message( + self._build_user_message(exc, reason), + error_type=type(exc).__name__, + reason=reason, + detail=_extract_error_detail(exc), + ) + def _emit_retry_event(self, attempt: int, wait_ms: int, reason: str) -> None: try: from langgraph.config import get_stream_writer @@ -212,7 +238,12 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]): handler: Callable[[ModelRequest], ModelResponse], ) -> ModelCallResult: if self._check_circuit(): - return AIMessage(content=self._build_circuit_breaker_message()) + return self._build_error_fallback_message( + self._build_circuit_breaker_message(), + error_type="CircuitBreakerOpen", + reason="circuit_open", + detail="LLM circuit breaker is open", + ) attempt = 1 while True: @@ -249,7 +280,7 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]): ) if retriable: self._record_failure() - return AIMessage(content=self._build_user_message(exc, reason)) + return self._build_user_fallback_message(exc, reason) @override async def awrap_model_call( @@ -258,7 +289,12 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]): handler: Callable[[ModelRequest], Awaitable[ModelResponse]], ) -> ModelCallResult: if self._check_circuit(): - return AIMessage(content=self._build_circuit_breaker_message()) + return self._build_error_fallback_message( + self._build_circuit_breaker_message(), + error_type="CircuitBreakerOpen", + reason="circuit_open", + detail="LLM circuit breaker is open", + ) attempt = 1 while True: @@ -295,7 +331,7 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]): ) if retriable: self._record_failure() - return AIMessage(content=self._build_user_message(exc, reason)) + return self._build_user_fallback_message(exc, reason) def _matches_any(detail: str, patterns: tuple[str, ...]) -> bool: diff --git a/backend/packages/harness/deerflow/runtime/journal.py b/backend/packages/harness/deerflow/runtime/journal.py index a12ebd98b..b65e1c0bb 100644 --- a/backend/packages/harness/deerflow/runtime/journal.py +++ b/backend/packages/harness/deerflow/runtime/journal.py @@ -86,6 +86,8 @@ class RunJournal(BaseCallbackHandler): self._last_ai_msg: str | None = None self._first_human_msg: str | None = None self._msg_count = 0 + self._had_llm_error_fallback = False + self._llm_error_fallback_message: str | None = None # Latency tracking self._llm_start_times: dict[str, float] = {} # langchain run_id -> start time @@ -256,6 +258,18 @@ class RunJournal(BaseCallbackHandler): # Token usage from message usage = getattr(message, "usage_metadata", None) usage_dict = dict(usage) if usage else {} + additional_kwargs = getattr(message, "additional_kwargs", None) or {} + if isinstance(additional_kwargs, dict) and additional_kwargs.get("deerflow_error_fallback"): + self._had_llm_error_fallback = True + detail = additional_kwargs.get("error_detail") + reason = additional_kwargs.get("error_reason") + fallback_text = self._message_text(message).strip() + if isinstance(detail, str) and detail.strip(): + self._llm_error_fallback_message = detail.strip() + elif isinstance(reason, str) and reason.strip(): + self._llm_error_fallback_message = reason.strip() + elif fallback_text: + self._llm_error_fallback_message = fallback_text[:2000] # Resolve call index call_index = self._llm_call_index @@ -569,3 +583,11 @@ class RunJournal(BaseCallbackHandler): "last_ai_message": self._last_ai_msg, "first_human_message": self._first_human_msg, } + + @property + def had_llm_error_fallback(self) -> bool: + return self._had_llm_error_fallback + + @property + def llm_error_fallback_message(self) -> str | None: + return self._llm_error_fallback_message diff --git a/backend/packages/harness/deerflow/runtime/runs/worker.py b/backend/packages/harness/deerflow/runtime/runs/worker.py index d84b3edf9..204c74e4c 100644 --- a/backend/packages/harness/deerflow/runtime/runs/worker.py +++ b/backend/packages/harness/deerflow/runtime/runs/worker.py @@ -150,6 +150,7 @@ async def run_agent( pre_run_checkpoint_id: str | None = None pre_run_snapshot: dict[str, Any] | None = None snapshot_capture_failed = False + llm_error_fallback_message: str | None = None journal = None @@ -312,6 +313,7 @@ async def run_agent( if record.abort_event.is_set(): logger.info("Run %s abort requested — stopping", run_id) break + llm_error_fallback_message = llm_error_fallback_message or _extract_llm_error_fallback_message(chunk) sse_event = _lg_mode_to_sse_event(single_mode) await bridge.publish(run_id, sse_event, serialize(chunk, mode=single_mode)) else: @@ -330,6 +332,7 @@ async def run_agent( if mode is None: continue + llm_error_fallback_message = llm_error_fallback_message or _extract_llm_error_fallback_message(chunk) sse_event = _lg_mode_to_sse_event(mode) await bridge.publish(run_id, sse_event, serialize(chunk, mode=mode)) @@ -352,6 +355,12 @@ async def run_agent( logger.warning("Failed to rollback checkpoint for run %s", run_id, exc_info=True) else: await run_manager.set_status(run_id, RunStatus.interrupted) + elif llm_error_fallback_message or (journal is not None and journal.had_llm_error_fallback): + error_msg = llm_error_fallback_message + if error_msg is None and journal is not None: + error_msg = journal.llm_error_fallback_message + error_msg = error_msg or "LLM provider failed after retries" + await run_manager.set_status(run_id, RunStatus.error, error=error_msg) else: await run_manager.set_status(run_id, RunStatus.success) @@ -554,6 +563,85 @@ def _lg_mode_to_sse_event(mode: str) -> str: return mode +def _error_fallback_message_from_metadata(metadata: dict[str, Any], content: Any) -> str: + detail = metadata.get("error_detail") + if isinstance(detail, str) and detail.strip(): + return detail.strip() + reason = metadata.get("error_reason") + if isinstance(reason, str) and reason.strip(): + return reason.strip() + if isinstance(content, str) and content.strip(): + return content.strip()[:2000] + return "LLM provider failed after retries" + + +def _try_extract_from_message(obj: Any) -> str | None: + """Try to extract fallback marker from a single message object or dict.""" + additional_kwargs = getattr(obj, "additional_kwargs", None) + if isinstance(additional_kwargs, dict) and additional_kwargs.get("deerflow_error_fallback"): + return _error_fallback_message_from_metadata(additional_kwargs, getattr(obj, "content", None)) + + if isinstance(obj, dict): + nested_kwargs = obj.get("additional_kwargs") + if isinstance(nested_kwargs, dict) and nested_kwargs.get("deerflow_error_fallback"): + return _error_fallback_message_from_metadata(nested_kwargs, obj.get("content")) + return None + + +def _extract_llm_error_fallback_message(value: Any) -> str | None: + """Find LLM fallback markers in streamed LangGraph chunks. + + Error fallback messages returned by model-call middleware are not guaranteed + to pass through LLM end callbacks, but they do appear in graph state chunks. + """ + # Fast path: large state chunks produced by stream_mode="values" have a + # top-level "messages" list. Scanning only that list avoids expensive deep + # recursion into large state dicts. + if isinstance(value, dict): + messages = value.get("messages") + if isinstance(messages, (list, tuple)): + for msg in messages: + result = _try_extract_from_message(msg) + if result is not None: + return result + # Fallback marker is attached to an AI message in the messages + # channel; it will never appear elsewhere in a values chunk. + return None + # No top-level "messages" — this is likely an "updates" chunk (small + # dict keyed by node name). Fall through to deep walk, which is cheap + # for these payloads. + + # Deep walk for updates / messages / tuple / list modes. Payloads are + # small, so full recursion is acceptable here. + seen: set[int] = set() + + def walk(obj: Any) -> str | None: + oid = id(obj) + if oid in seen: + return None + seen.add(oid) + + result = _try_extract_from_message(obj) + if result is not None: + return result + + if isinstance(obj, dict): + for item in obj.values(): + result = walk(item) + if result is not None: + return result + return None + + if isinstance(obj, (list, tuple, set)): + for item in obj: + result = walk(item) + if result is not None: + return result + return None + + return walk(value) + + def _extract_human_message(graph_input: dict) -> HumanMessage | None: """Extract or construct a HumanMessage from graph_input for event recording. diff --git a/backend/tests/test_llm_error_handling_middleware.py b/backend/tests/test_llm_error_handling_middleware.py index 1ab395cd1..9d2c0fa77 100644 --- a/backend/tests/test_llm_error_handling_middleware.py +++ b/backend/tests/test_llm_error_handling_middleware.py @@ -94,6 +94,31 @@ def test_async_model_call_returns_user_message_for_quota_errors() -> None: assert isinstance(result, AIMessage) assert "out of quota" in str(result.content) + assert result.additional_kwargs["deerflow_error_fallback"] is True + assert result.additional_kwargs["error_reason"] == "quota" + assert result.additional_kwargs["error_type"] == "FakeError" + + +def test_async_model_call_marks_transient_retry_exhaustion_as_error_fallback( + monkeypatch: pytest.MonkeyPatch, +) -> None: + middleware = _build_middleware(retry_max_attempts=2, retry_base_delay_ms=25, retry_cap_delay_ms=25) + + async def fake_sleep(_delay: float) -> None: + return None + + async def handler(_request) -> AIMessage: + raise FakeError("Connection error.", status_code=503) + + monkeypatch.setattr("asyncio.sleep", fake_sleep) + + result = asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler)) + + assert isinstance(result, AIMessage) + assert "temporarily unavailable" in str(result.content) + assert result.additional_kwargs["deerflow_error_fallback"] is True + assert result.additional_kwargs["error_reason"] == "transient" + assert result.additional_kwargs["error_detail"] == "Connection error." def test_sync_model_call_uses_retry_after_header(monkeypatch: pytest.MonkeyPatch) -> None: diff --git a/backend/tests/test_run_worker_rollback.py b/backend/tests/test_run_worker_rollback.py index 5a8ec71f7..00844d88d 100644 --- a/backend/tests/test_run_worker_rollback.py +++ b/backend/tests/test_run_worker_rollback.py @@ -3,12 +3,22 @@ from types import SimpleNamespace from unittest.mock import AsyncMock, call import pytest +from langchain_core.messages import AIMessage from langgraph.checkpoint.base import empty_checkpoint from langgraph.checkpoint.memory import InMemorySaver from deerflow.runtime.runs.manager import RunManager from deerflow.runtime.runs.schemas import RunStatus -from deerflow.runtime.runs.worker import RunContext, _agent_factory_supports_app_config, _build_runtime_context, _install_runtime_context, _rollback_to_pre_run_checkpoint, run_agent +from deerflow.runtime.runs.worker import ( + RunContext, + _agent_factory_supports_app_config, + _build_runtime_context, + _extract_llm_error_fallback_message, + _install_runtime_context, + _rollback_to_pre_run_checkpoint, + _try_extract_from_message, + run_agent, +) class FakeCheckpointer: @@ -95,6 +105,52 @@ async def test_run_agent_threads_explicit_app_config_into_config_only_factory(): bridge.cleanup.assert_awaited_once_with(record.run_id, delay=60) +@pytest.mark.anyio +async def test_run_agent_marks_llm_error_fallback_as_error_status(): + run_manager = RunManager() + record = await run_manager.create("thread-1") + bridge = SimpleNamespace( + publish=AsyncMock(), + publish_end=AsyncMock(), + cleanup=AsyncMock(), + ) + + class DummyAgent: + async def astream(self, graph_input, config=None, stream_mode=None, subgraphs=False): + yield { + "messages": [ + AIMessage( + content="The configured LLM provider is temporarily unavailable after multiple retries.", + additional_kwargs={ + "deerflow_error_fallback": True, + "error_type": "APIConnectionError", + "error_reason": "transient", + "error_detail": "Connection error.", + }, + ) + ] + } + + def factory(*, config): + return DummyAgent() + + await run_agent( + bridge, + run_manager, + record, + ctx=RunContext(checkpointer=None), + agent_factory=factory, + graph_input={}, + config={}, + ) + + fetched = await run_manager.get(record.run_id) + assert fetched is not None + assert fetched.status == RunStatus.error + assert fetched.error == "Connection error." + bridge.publish_end.assert_awaited_once_with(record.run_id) + + @pytest.mark.anyio async def test_run_agent_defaults_root_run_name_from_assistant_id(): run_manager = RunManager() @@ -486,3 +542,133 @@ def test_agent_factory_supports_app_config_returns_false_when_signature_lookup_f monkeypatch.setattr("deerflow.runtime.runs.worker.inspect.signature", lambda _obj: (_ for _ in ()).throw(ValueError("boom"))) assert _agent_factory_supports_app_config(BrokenCallable()) is False + + +# --------------------------------------------------------------------------- +# _extract_llm_error_fallback_message coverage +# --------------------------------------------------------------------------- + + +def test_try_extract_from_message_finds_fallback_on_message_object(): + msg = AIMessage( + content="fallback", + additional_kwargs={ + "deerflow_error_fallback": True, + "error_detail": "Connection error.", + "error_reason": "transient", + }, + ) + assert _try_extract_from_message(msg) == "Connection error." + + +def test_try_extract_from_message_finds_fallback_on_dict(): + msg = { + "content": "fallback", + "additional_kwargs": { + "deerflow_error_fallback": True, + "error_detail": "Quota exceeded.", + }, + } + assert _try_extract_from_message(msg) == "Quota exceeded." + + +def test_try_extract_from_message_returns_none_for_normal_message(): + msg = AIMessage(content="hello") + assert _try_extract_from_message(msg) is None + + +def test_extract_llm_error_fallback_message_large_state_chunk_no_fallback(): + """Normal-size state dict without fallback markers must not raise and should return None.""" + large_state = { + "messages": [ + AIMessage(content="Hello!"), + {"role": "user", "content": "Hi there"}, + ], + "foo": "x" * 10_000, + "bar": {"nested": {"deep": {"data": list(range(1000))}}}, + "baz": [{"id": i, "payload": "y" * 1000} for i in range(500)], + } + assert _extract_llm_error_fallback_message(large_state) is None + + +def test_extract_llm_error_fallback_message_finds_fallback_in_messages_list(): + state = { + "messages": [ + AIMessage(content="Hello!"), + AIMessage( + content="Unavailable.", + additional_kwargs={ + "deerflow_error_fallback": True, + "error_detail": "Connection error.", + }, + ), + ], + "other_state": "large_value" * 1000, + } + assert _extract_llm_error_fallback_message(state) == "Connection error." + + +def test_extract_llm_error_fallback_message_finds_fallback_in_raw_message(): + msg = AIMessage( + content="Unavailable.", + additional_kwargs={ + "deerflow_error_fallback": True, + "error_reason": "quota", + }, + ) + assert _extract_llm_error_fallback_message(msg) == "quota" + + +def test_extract_llm_error_fallback_message_finds_fallback_in_tuple(): + item = ( + "messages", + AIMessage( + content="Unavailable.", + additional_kwargs={ + "deerflow_error_fallback": True, + "error_detail": "Circuit open.", + }, + ), + ) + assert _extract_llm_error_fallback_message(item) == "Circuit open." + + +def test_extract_llm_error_fallback_message_returns_none_for_empty_values(): + assert _extract_llm_error_fallback_message({}) is None + assert _extract_llm_error_fallback_message([]) is None + assert _extract_llm_error_fallback_message(None) is None + assert _extract_llm_error_fallback_message("string") is None + + +def test_extract_llm_error_fallback_message_finds_fallback_in_updates_mode(): + """stream_mode='updates' yields dicts keyed by node name (e.g. {'call_model': {...}}). + Fallback marker is nested inside the node's state update, not at the top level.""" + update_chunk = { + "call_model": { + "messages": [ + AIMessage( + content="Unavailable.", + additional_kwargs={ + "deerflow_error_fallback": True, + "error_detail": "Connection error.", + }, + ) + ] + } + } + assert _extract_llm_error_fallback_message(update_chunk) == "Connection error." + + +def test_extract_llm_error_fallback_message_updates_mode_no_fallback(): + """Normal updates chunk without any fallback should return None safely.""" + update_chunk = { + "__interrupt__": [ + { + "value": "ask_human", + "resumable": True, + "ns": ["agent"], + "when": "during", + } + ] + } + assert _extract_llm_error_fallback_message(update_chunk) is None