diff --git a/backend/packages/harness/deerflow/agents/middlewares/view_image_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/view_image_middleware.py index 2870624c3..37432cd9a 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/view_image_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/view_image_middleware.py @@ -1,22 +1,19 @@ """Middleware for injecting image details into conversation before LLM call.""" import logging -from typing import NotRequired, override +from typing import override -from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langgraph.runtime import Runtime -from deerflow.agents.thread_state import ViewedImageData +from deerflow.agents.thread_state import ThreadState logger = logging.getLogger(__name__) -class ViewImageMiddlewareState(AgentState): - """Compatible with the `ThreadState` schema.""" - - viewed_images: NotRequired[dict[str, ViewedImageData] | None] +class ViewImageMiddlewareState(ThreadState): + """Reuse the thread state so reducer-backed keys keep their annotations.""" class ViewImageMiddleware(AgentMiddleware[ViewImageMiddlewareState]): diff --git a/backend/tests/test_create_deerflow_agent.py b/backend/tests/test_create_deerflow_agent.py index 2b6c74370..03fee2055 100644 --- a/backend/tests/test_create_deerflow_agent.py +++ b/backend/tests/test_create_deerflow_agent.py @@ -1,11 +1,14 @@ """Tests for create_deerflow_agent SDK entry point.""" +from typing import get_type_hints from unittest.mock import MagicMock, patch import pytest from deerflow.agents.factory import create_deerflow_agent from deerflow.agents.features import Next, Prev, RuntimeFeatures +from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware +from deerflow.agents.thread_state import ThreadState def _make_mock_model(): @@ -127,6 +130,13 @@ def test_vision_injects_view_image_tool(mock_create_agent): assert "view_image" in tool_names +def test_view_image_middleware_preserves_viewed_images_reducer(): + middleware_hints = get_type_hints(ViewImageMiddleware.state_schema, include_extras=True) + thread_hints = get_type_hints(ThreadState, include_extras=True) + + assert middleware_hints["viewed_images"] == thread_hints["viewed_images"] + + # --------------------------------------------------------------------------- # 8. Subagent feature auto-injects task_tool # ---------------------------------------------------------------------------