diff --git a/backend/packages/harness/deerflow/agents/middlewares/clarification_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/clarification_middleware.py index 9e0c2b259..385508f0f 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/clarification_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/clarification_middleware.py @@ -3,6 +3,7 @@ import json import logging from collections.abc import Callable +from hashlib import sha256 from typing import override from langchain.agents import AgentState @@ -36,6 +37,13 @@ class ClarificationMiddleware(AgentMiddleware[ClarificationMiddlewareState]): state_schema = ClarificationMiddlewareState + def _stable_message_id(self, tool_call_id: str, formatted_message: str) -> str: + """Build a deterministic message ID so retried clarification calls replace, not append.""" + if tool_call_id: + return f"clarification:{tool_call_id}" + digest = sha256(formatted_message.encode("utf-8")).hexdigest()[:16] + return f"clarification:{digest}" + def _is_chinese(self, text: str) -> bool: """Check if text contains Chinese characters. @@ -131,6 +139,7 @@ class ClarificationMiddleware(AgentMiddleware[ClarificationMiddlewareState]): # Create a ToolMessage with the formatted question # This will be added to the message history tool_message = ToolMessage( + id=self._stable_message_id(tool_call_id, formatted_message), content=formatted_message, tool_call_id=tool_call_id, name="ask_clarification", diff --git a/backend/tests/test_clarification_middleware.py b/backend/tests/test_clarification_middleware.py index 9a8118996..565b09beb 100644 --- a/backend/tests/test_clarification_middleware.py +++ b/backend/tests/test_clarification_middleware.py @@ -1,8 +1,10 @@ """Tests for ClarificationMiddleware, focusing on options type coercion.""" import json +from types import SimpleNamespace import pytest +from langgraph.graph.message import add_messages from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware @@ -118,3 +120,60 @@ class TestFormatClarificationMessage: assert "2. 2" in result assert "3. True" in result assert "4. None" in result + + +class TestClarificationCommandIdempotency: + """Clarification tool-call retries should not duplicate messages in state.""" + + def test_repeated_tool_call_uses_stable_message_id(self, middleware): + request = SimpleNamespace( + tool_call={ + "name": "ask_clarification", + "id": "call-clarify-1", + "args": { + "question": "Which environment should I use?", + "clarification_type": "approach_choice", + "options": ["dev", "prod"], + }, + } + ) + + first = middleware.wrap_tool_call(request, lambda _req: pytest.fail("handler should not be called")) + second = middleware.wrap_tool_call(request, lambda _req: pytest.fail("handler should not be called")) + + first_message = first.update["messages"][0] + second_message = second.update["messages"][0] + + assert first_message.id == "clarification:call-clarify-1" + assert second_message.id == first_message.id + assert second_message.tool_call_id == first_message.tool_call_id + + merged = add_messages(add_messages([], [first_message]), [second_message]) + + assert len(merged) == 1 + assert merged[0].id == "clarification:call-clarify-1" + assert merged[0].content == first_message.content + + def test_missing_tool_call_id_still_gets_stable_message_id(self, middleware): + request = SimpleNamespace( + tool_call={ + "name": "ask_clarification", + "args": { + "question": "Which environment should I use?", + "clarification_type": "missing_info", + }, + } + ) + + first = middleware.wrap_tool_call(request, lambda _req: pytest.fail("handler should not be called")) + second = middleware.wrap_tool_call(request, lambda _req: pytest.fail("handler should not be called")) + + first_message = first.update["messages"][0] + second_message = second.update["messages"][0] + + assert first_message.id.startswith("clarification:") + assert second_message.id == first_message.id + + merged = add_messages(add_messages([], [first_message]), [second_message]) + + assert len(merged) == 1