mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-01 22:38:23 +00:00
* refactor: thread app_config through middleware factories Continues the incremental config-refactor sequence (#2611 root, #2612 lead path) one layer deeper into the middleware factories. Two ambient lookups inside _build_runtime_middlewares are eliminated and the LLMErrorHandling band-aid removed: - _build_runtime_middlewares / build_lead_runtime_middlewares / build_subagent_runtime_middlewares now require app_config: AppConfig. - get_guardrails_config() inside the factory is replaced with app_config.guardrails (semantically identical — same default-factory GuardrailsConfig — verified by direct equality check). - LLMErrorHandlingMiddleware.__init__ now requires app_config and reads circuit_breaker fields directly. The class-level circuit_failure_threshold / circuit_recovery_timeout_sec defaults are removed along with the try/except (FileNotFoundError, RuntimeError): pass band-aid — the let-it-crash invariant the rest of the refactor enforces. Caller chain (already-resolved app_config sources): - _build_middlewares in lead_agent/agent.py: reorder so resolved_app_config = app_config or get_app_config() is computed BEFORE build_lead_runtime_middlewares is called, then passed as kwarg. - SubagentExecutor: optional app_config parameter (mirrors the lead-agent pattern); _create_agent does the same `or get_app_config()` fallback at agent-build time, so task_tool callers don't need to plumb app_config through yet (typed-context plumbing for tool runtimes is a separate refactor). Tests: - test_llm_error_handling_middleware: _make_app_config helper using AppConfig(sandbox=SandboxConfig(use="test")) — same minimal-config pattern conftest already uses. Three direct LLMErrorHandlingMiddleware() calls each followed by post-construction circuit_breaker mutation fold cleanly into _build_middleware(circuit_failure_threshold=..., circuit_recovery_timeout_sec=...). Verification: - tests/test_llm_error_handling_middleware.py — 14 passed - tests/test_subagent_executor.py — 28 passed - tests/test_tool_error_handling_middleware.py — 6 passed - tests/test_task_tool_core_logic.py — 18 passed (verifies task_tool unchanged behavior) - Full suite: 2697 passed, 3 skipped. The single intermittent failure in tests/test_client_e2e.py::test_tool_call_produces_events is pre-existing LLM flakiness (the test asserts the model decided to call a tool; reproduces 1/3 on unchanged main as well). * fix: address middleware app config review comments * fix: satisfy app config annotation lint * test: cover explicit app config middleware wiring --------- Co-authored-by: greatmengqi <chenmengqi.0376@bytedance.com>
169 lines
5.4 KiB
Python
169 lines
5.4 KiB
Python
import sys
|
|
from types import ModuleType, SimpleNamespace
|
|
|
|
import pytest
|
|
from langchain_core.messages import ToolMessage
|
|
from langgraph.errors import GraphInterrupt
|
|
|
|
from deerflow.agents.middlewares.tool_error_handling_middleware import (
|
|
ToolErrorHandlingMiddleware,
|
|
build_subagent_runtime_middlewares,
|
|
)
|
|
from deerflow.config.app_config import AppConfig, CircuitBreakerConfig
|
|
from deerflow.config.guardrails_config import GuardrailsConfig
|
|
from deerflow.config.sandbox_config import SandboxConfig
|
|
|
|
|
|
def _module(name: str, **attrs):
|
|
module = ModuleType(name)
|
|
for key, value in attrs.items():
|
|
setattr(module, key, value)
|
|
return module
|
|
|
|
|
|
def _make_app_config() -> AppConfig:
|
|
return AppConfig(
|
|
sandbox=SandboxConfig(use="test"),
|
|
guardrails=GuardrailsConfig(enabled=False),
|
|
circuit_breaker=CircuitBreakerConfig(failure_threshold=7, recovery_timeout_sec=11),
|
|
)
|
|
|
|
|
|
def _request(name: str = "web_search", tool_call_id: str | None = "tc-1"):
|
|
tool_call = {"name": name}
|
|
if tool_call_id is not None:
|
|
tool_call["id"] = tool_call_id
|
|
return SimpleNamespace(tool_call=tool_call)
|
|
|
|
|
|
def test_build_subagent_runtime_middlewares_threads_app_config_to_llm_middleware(monkeypatch: pytest.MonkeyPatch):
|
|
captured: dict[str, object] = {}
|
|
|
|
class FakeMiddleware:
|
|
def __init__(self, *args, **kwargs):
|
|
self.args = args
|
|
self.kwargs = kwargs
|
|
|
|
class FakeLLMErrorHandlingMiddleware:
|
|
def __init__(self, *, app_config):
|
|
captured["app_config"] = app_config
|
|
|
|
app_config = _make_app_config()
|
|
|
|
monkeypatch.setitem(
|
|
sys.modules,
|
|
"deerflow.agents.middlewares.llm_error_handling_middleware",
|
|
_module(
|
|
"deerflow.agents.middlewares.llm_error_handling_middleware",
|
|
LLMErrorHandlingMiddleware=FakeLLMErrorHandlingMiddleware,
|
|
),
|
|
)
|
|
monkeypatch.setitem(
|
|
sys.modules,
|
|
"deerflow.agents.middlewares.thread_data_middleware",
|
|
_module("deerflow.agents.middlewares.thread_data_middleware", ThreadDataMiddleware=FakeMiddleware),
|
|
)
|
|
monkeypatch.setitem(
|
|
sys.modules,
|
|
"deerflow.sandbox.middleware",
|
|
_module("deerflow.sandbox.middleware", SandboxMiddleware=FakeMiddleware),
|
|
)
|
|
monkeypatch.setitem(
|
|
sys.modules,
|
|
"deerflow.agents.middlewares.dangling_tool_call_middleware",
|
|
_module("deerflow.agents.middlewares.dangling_tool_call_middleware", DanglingToolCallMiddleware=FakeMiddleware),
|
|
)
|
|
monkeypatch.setitem(
|
|
sys.modules,
|
|
"deerflow.agents.middlewares.sandbox_audit_middleware",
|
|
_module("deerflow.agents.middlewares.sandbox_audit_middleware", SandboxAuditMiddleware=FakeMiddleware),
|
|
)
|
|
|
|
middlewares = build_subagent_runtime_middlewares(app_config=app_config, lazy_init=False)
|
|
|
|
assert captured["app_config"] is app_config
|
|
assert len(middlewares) == 6
|
|
assert isinstance(middlewares[-1], ToolErrorHandlingMiddleware)
|
|
|
|
|
|
def test_wrap_tool_call_passthrough_on_success():
|
|
middleware = ToolErrorHandlingMiddleware()
|
|
req = _request()
|
|
expected = ToolMessage(content="ok", tool_call_id="tc-1", name="web_search")
|
|
|
|
result = middleware.wrap_tool_call(req, lambda _req: expected)
|
|
|
|
assert result is expected
|
|
|
|
|
|
def test_wrap_tool_call_returns_error_tool_message_on_exception():
|
|
middleware = ToolErrorHandlingMiddleware()
|
|
req = _request(name="web_search", tool_call_id="tc-42")
|
|
|
|
def _boom(_req):
|
|
raise RuntimeError("network down")
|
|
|
|
result = middleware.wrap_tool_call(req, _boom)
|
|
|
|
assert isinstance(result, ToolMessage)
|
|
assert result.tool_call_id == "tc-42"
|
|
assert result.name == "web_search"
|
|
assert result.status == "error"
|
|
assert "Tool 'web_search' failed" in result.text
|
|
assert "network down" in result.text
|
|
|
|
|
|
def test_wrap_tool_call_uses_fallback_tool_call_id_when_missing():
|
|
middleware = ToolErrorHandlingMiddleware()
|
|
req = _request(name="mcp_tool", tool_call_id=None)
|
|
|
|
def _boom(_req):
|
|
raise ValueError("bad request")
|
|
|
|
result = middleware.wrap_tool_call(req, _boom)
|
|
|
|
assert isinstance(result, ToolMessage)
|
|
assert result.tool_call_id == "missing_tool_call_id"
|
|
assert result.name == "mcp_tool"
|
|
assert result.status == "error"
|
|
|
|
|
|
def test_wrap_tool_call_reraises_graph_interrupt():
|
|
middleware = ToolErrorHandlingMiddleware()
|
|
req = _request(name="ask_clarification", tool_call_id="tc-int")
|
|
|
|
def _interrupt(_req):
|
|
raise GraphInterrupt(())
|
|
|
|
with pytest.raises(GraphInterrupt):
|
|
middleware.wrap_tool_call(req, _interrupt)
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_awrap_tool_call_returns_error_tool_message_on_exception():
|
|
middleware = ToolErrorHandlingMiddleware()
|
|
req = _request(name="mcp_tool", tool_call_id="tc-async")
|
|
|
|
async def _boom(_req):
|
|
raise TimeoutError("request timed out")
|
|
|
|
result = await middleware.awrap_tool_call(req, _boom)
|
|
|
|
assert isinstance(result, ToolMessage)
|
|
assert result.tool_call_id == "tc-async"
|
|
assert result.name == "mcp_tool"
|
|
assert result.status == "error"
|
|
assert "request timed out" in result.text
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_awrap_tool_call_reraises_graph_interrupt():
|
|
middleware = ToolErrorHandlingMiddleware()
|
|
req = _request(name="ask_clarification", tool_call_id="tc-int-async")
|
|
|
|
async def _interrupt(_req):
|
|
raise GraphInterrupt(())
|
|
|
|
with pytest.raises(GraphInterrupt):
|
|
await middleware.awrap_tool_call(req, _interrupt)
|