mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
feat(client): support custom middleware injection (#1520)
* feat(client): support custom middleware injection Add support for custom middleware, allowing custom middleware list to be passed when initializing DeerFlowClient. These middleware will be injected after the default middleware when creating the agent, extending the agent's functionality. * feat: inject custom middlewares before ClarificationMiddleware to preserve ordering - Add `custom_middlewares` param to `_build_middlewares` - Inject custom middlewares right before `ClarificationMiddleware` to keep it as the last in the chain - Remove unsafe `.extend()` in `client.py` - Update tests in `test_client.py` and `test_lead_agent_model_resolution.py` to assert correct injection ordering
This commit is contained in:
parent
89183ae76a
commit
481494b9c0
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import SummarizationMiddleware
|
||||
from langchain.agents.middleware import AgentMiddleware, SummarizationMiddleware
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from deerflow.agents.lead_agent.prompt import apply_prompt_template
|
||||
@ -205,12 +205,13 @@ Being proactive with task management demonstrates thoroughness and ensures all r
|
||||
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
|
||||
# ToolErrorHandlingMiddleware should be before ClarificationMiddleware to convert tool exceptions to ToolMessages
|
||||
# ClarificationMiddleware should be last to intercept clarification requests after model calls
|
||||
def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_name: str | None = None):
|
||||
def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_name: str | None = None, custom_middlewares: list[AgentMiddleware] | None = None):
|
||||
"""Build middleware chain based on runtime configuration.
|
||||
|
||||
Args:
|
||||
config: Runtime configuration containing configurable options like is_plan_mode.
|
||||
agent_name: If provided, MemoryMiddleware will use per-agent memory storage.
|
||||
custom_middlewares: Optional list of custom middlewares to inject into the chain.
|
||||
|
||||
Returns:
|
||||
List of middleware instances.
|
||||
@ -260,6 +261,10 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
|
||||
# LoopDetectionMiddleware — detect and break repetitive tool call loops
|
||||
middlewares.append(LoopDetectionMiddleware())
|
||||
|
||||
# Inject custom middlewares before ClarificationMiddleware
|
||||
if custom_middlewares:
|
||||
middlewares.extend(custom_middlewares)
|
||||
|
||||
# ClarificationMiddleware should always be last
|
||||
middlewares.append(ClarificationMiddleware())
|
||||
return middlewares
|
||||
|
||||
@ -22,12 +22,13 @@ import mimetypes
|
||||
import shutil
|
||||
import tempfile
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
@ -116,6 +117,7 @@ class DeerFlowClient:
|
||||
subagent_enabled: bool = False,
|
||||
plan_mode: bool = False,
|
||||
agent_name: str | None = None,
|
||||
middlewares: Sequence[AgentMiddleware] | None = None,
|
||||
):
|
||||
"""Initialize the client.
|
||||
|
||||
@ -131,6 +133,7 @@ class DeerFlowClient:
|
||||
subagent_enabled: Enable subagent delegation.
|
||||
plan_mode: Enable TodoList middleware for plan mode.
|
||||
agent_name: Name of the agent to use.
|
||||
middlewares: Optional list of custom middlewares to inject into the agent.
|
||||
"""
|
||||
if config_path is not None:
|
||||
reload_app_config(config_path)
|
||||
@ -145,6 +148,7 @@ class DeerFlowClient:
|
||||
self._subagent_enabled = subagent_enabled
|
||||
self._plan_mode = plan_mode
|
||||
self._agent_name = agent_name
|
||||
self._middlewares = list(middlewares) if middlewares else []
|
||||
|
||||
# Lazy agent — created on first call, recreated when config changes.
|
||||
self._agent = None
|
||||
@ -217,7 +221,7 @@ class DeerFlowClient:
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
|
||||
"tools": self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled),
|
||||
"middleware": _build_middlewares(config, model_name=model_name, agent_name=self._agent_name),
|
||||
"middleware": _build_middlewares(config, model_name=model_name, agent_name=self._agent_name, custom_middlewares=self._middlewares),
|
||||
"system_prompt": apply_prompt_template(
|
||||
subagent_enabled=subagent_enabled,
|
||||
max_concurrent_subagents=max_concurrent_subagents,
|
||||
|
||||
@ -63,13 +63,22 @@ class TestClientInit:
|
||||
assert client._agent is None
|
||||
|
||||
def test_custom_params(self, mock_app_config):
|
||||
mock_middleware = MagicMock()
|
||||
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
|
||||
c = DeerFlowClient(model_name="gpt-4", thinking_enabled=False, subagent_enabled=True, plan_mode=True, agent_name="test-agent")
|
||||
c = DeerFlowClient(
|
||||
model_name="gpt-4",
|
||||
thinking_enabled=False,
|
||||
subagent_enabled=True,
|
||||
plan_mode=True,
|
||||
agent_name="test-agent",
|
||||
middlewares=[mock_middleware]
|
||||
)
|
||||
assert c._model_name == "gpt-4"
|
||||
assert c._thinking_enabled is False
|
||||
assert c._subagent_enabled is True
|
||||
assert c._plan_mode is True
|
||||
assert c._agent_name == "test-agent"
|
||||
assert c._middlewares == [mock_middleware]
|
||||
|
||||
def test_invalid_agent_name(self, mock_app_config):
|
||||
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
|
||||
@ -413,6 +422,33 @@ class TestEnsureAgent:
|
||||
|
||||
assert mock_create_agent.call_args.kwargs["checkpointer"] is mock_checkpointer
|
||||
|
||||
def test_injects_custom_middlewares(self, client):
|
||||
mock_agent = MagicMock()
|
||||
mock_custom_middleware = MagicMock()
|
||||
client._middlewares = [mock_custom_middleware]
|
||||
config = client._get_runnable_config("t1")
|
||||
|
||||
mock_clarification = MagicMock()
|
||||
mock_clarification.__class__.__name__ = "ClarificationMiddleware"
|
||||
|
||||
def fake_build_middlewares(*args, **kwargs):
|
||||
custom = kwargs.get("custom_middlewares") or []
|
||||
return [MagicMock()] + custom + [mock_clarification]
|
||||
|
||||
with (
|
||||
patch("deerflow.client.create_chat_model"),
|
||||
patch("deerflow.client.create_agent", return_value=mock_agent) as mock_create_agent,
|
||||
patch("deerflow.client._build_middlewares", side_effect=fake_build_middlewares),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
|
||||
called_middlewares = mock_create_agent.call_args.kwargs["middleware"]
|
||||
assert len(called_middlewares) == 3
|
||||
assert called_middlewares[-2] is mock_custom_middleware
|
||||
assert called_middlewares[-1] is mock_clarification
|
||||
|
||||
def test_skips_default_checkpointer_when_unconfigured(self, client):
|
||||
mock_agent = MagicMock()
|
||||
config = client._get_runnable_config("t1")
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.agents.lead_agent import agent as lead_agent_module
|
||||
@ -133,9 +135,13 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
|
||||
middlewares = lead_agent_module._build_middlewares(
|
||||
{"configurable": {"model_name": "stale-model", "is_plan_mode": False, "subagent_enabled": False}},
|
||||
model_name="vision-model",
|
||||
custom_middlewares=[MagicMock()]
|
||||
)
|
||||
|
||||
assert any(isinstance(m, lead_agent_module.ViewImageMiddleware) for m in middlewares)
|
||||
# verify the custom middleware is injected correctly
|
||||
assert len(middlewares) > 0 and isinstance(middlewares[-2], MagicMock)
|
||||
|
||||
|
||||
|
||||
def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user