diff --git a/backend/packages/harness/deerflow/agents/lead_agent/agent.py b/backend/packages/harness/deerflow/agents/lead_agent/agent.py index 1bcc0ee4c..fe743e448 100644 --- a/backend/packages/harness/deerflow/agents/lead_agent/agent.py +++ b/backend/packages/harness/deerflow/agents/lead_agent/agent.py @@ -1,7 +1,7 @@ import logging from langchain.agents import create_agent -from langchain.agents.middleware import SummarizationMiddleware +from langchain.agents.middleware import AgentMiddleware, SummarizationMiddleware from langchain_core.runnables import RunnableConfig from deerflow.agents.lead_agent.prompt import apply_prompt_template @@ -205,12 +205,13 @@ Being proactive with task management demonstrates thoroughness and ensures all r # ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM # ToolErrorHandlingMiddleware should be before ClarificationMiddleware to convert tool exceptions to ToolMessages # ClarificationMiddleware should be last to intercept clarification requests after model calls -def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_name: str | None = None): +def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_name: str | None = None, custom_middlewares: list[AgentMiddleware] | None = None): """Build middleware chain based on runtime configuration. Args: config: Runtime configuration containing configurable options like is_plan_mode. agent_name: If provided, MemoryMiddleware will use per-agent memory storage. + custom_middlewares: Optional list of custom middlewares to inject into the chain. Returns: List of middleware instances. @@ -260,6 +261,10 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam # LoopDetectionMiddleware — detect and break repetitive tool call loops middlewares.append(LoopDetectionMiddleware()) + # Inject custom middlewares before ClarificationMiddleware + if custom_middlewares: + middlewares.extend(custom_middlewares) + # ClarificationMiddleware should always be last middlewares.append(ClarificationMiddleware()) return middlewares diff --git a/backend/packages/harness/deerflow/client.py b/backend/packages/harness/deerflow/client.py index c2893afc1..964d76964 100644 --- a/backend/packages/harness/deerflow/client.py +++ b/backend/packages/harness/deerflow/client.py @@ -22,12 +22,13 @@ import mimetypes import shutil import tempfile import uuid -from collections.abc import Generator +from collections.abc import Generator, Sequence from dataclasses import dataclass, field from pathlib import Path from typing import Any from langchain.agents import create_agent +from langchain.agents.middleware import AgentMiddleware from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.runnables import RunnableConfig @@ -116,6 +117,7 @@ class DeerFlowClient: subagent_enabled: bool = False, plan_mode: bool = False, agent_name: str | None = None, + middlewares: Sequence[AgentMiddleware] | None = None, ): """Initialize the client. @@ -131,6 +133,7 @@ class DeerFlowClient: subagent_enabled: Enable subagent delegation. plan_mode: Enable TodoList middleware for plan mode. agent_name: Name of the agent to use. + middlewares: Optional list of custom middlewares to inject into the agent. """ if config_path is not None: reload_app_config(config_path) @@ -145,6 +148,7 @@ class DeerFlowClient: self._subagent_enabled = subagent_enabled self._plan_mode = plan_mode self._agent_name = agent_name + self._middlewares = list(middlewares) if middlewares else [] # Lazy agent — created on first call, recreated when config changes. self._agent = None @@ -217,7 +221,7 @@ class DeerFlowClient: kwargs: dict[str, Any] = { "model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled), "tools": self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled), - "middleware": _build_middlewares(config, model_name=model_name, agent_name=self._agent_name), + "middleware": _build_middlewares(config, model_name=model_name, agent_name=self._agent_name, custom_middlewares=self._middlewares), "system_prompt": apply_prompt_template( subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, diff --git a/backend/tests/test_client.py b/backend/tests/test_client.py index 84e7955c6..1c5cdc53e 100644 --- a/backend/tests/test_client.py +++ b/backend/tests/test_client.py @@ -63,13 +63,22 @@ class TestClientInit: assert client._agent is None def test_custom_params(self, mock_app_config): + mock_middleware = MagicMock() with patch("deerflow.client.get_app_config", return_value=mock_app_config): - c = DeerFlowClient(model_name="gpt-4", thinking_enabled=False, subagent_enabled=True, plan_mode=True, agent_name="test-agent") + c = DeerFlowClient( + model_name="gpt-4", + thinking_enabled=False, + subagent_enabled=True, + plan_mode=True, + agent_name="test-agent", + middlewares=[mock_middleware] + ) assert c._model_name == "gpt-4" assert c._thinking_enabled is False assert c._subagent_enabled is True assert c._plan_mode is True assert c._agent_name == "test-agent" + assert c._middlewares == [mock_middleware] def test_invalid_agent_name(self, mock_app_config): with patch("deerflow.client.get_app_config", return_value=mock_app_config): @@ -413,6 +422,33 @@ class TestEnsureAgent: assert mock_create_agent.call_args.kwargs["checkpointer"] is mock_checkpointer + def test_injects_custom_middlewares(self, client): + mock_agent = MagicMock() + mock_custom_middleware = MagicMock() + client._middlewares = [mock_custom_middleware] + config = client._get_runnable_config("t1") + + mock_clarification = MagicMock() + mock_clarification.__class__.__name__ = "ClarificationMiddleware" + + def fake_build_middlewares(*args, **kwargs): + custom = kwargs.get("custom_middlewares") or [] + return [MagicMock()] + custom + [mock_clarification] + + with ( + patch("deerflow.client.create_chat_model"), + patch("deerflow.client.create_agent", return_value=mock_agent) as mock_create_agent, + patch("deerflow.client._build_middlewares", side_effect=fake_build_middlewares), + patch("deerflow.client.apply_prompt_template", return_value="prompt"), + patch.object(client, "_get_tools", return_value=[]), + ): + client._ensure_agent(config) + + called_middlewares = mock_create_agent.call_args.kwargs["middleware"] + assert len(called_middlewares) == 3 + assert called_middlewares[-2] is mock_custom_middleware + assert called_middlewares[-1] is mock_clarification + def test_skips_default_checkpointer_when_unconfigured(self, client): mock_agent = MagicMock() config = client._get_runnable_config("t1") diff --git a/backend/tests/test_lead_agent_model_resolution.py b/backend/tests/test_lead_agent_model_resolution.py index 79ec380be..b964f53ca 100644 --- a/backend/tests/test_lead_agent_model_resolution.py +++ b/backend/tests/test_lead_agent_model_resolution.py @@ -2,6 +2,8 @@ from __future__ import annotations +from unittest.mock import MagicMock + import pytest from deerflow.agents.lead_agent import agent as lead_agent_module @@ -133,9 +135,13 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch): middlewares = lead_agent_module._build_middlewares( {"configurable": {"model_name": "stale-model", "is_plan_mode": False, "subagent_enabled": False}}, model_name="vision-model", + custom_middlewares=[MagicMock()] ) assert any(isinstance(m, lead_agent_module.ViewImageMiddleware) for m in middlewares) + # verify the custom middleware is injected correctly + assert len(middlewares) > 0 and isinstance(middlewares[-2], MagicMock) + def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch):