mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
parent
7c87dc5bca
commit
f514e35a36
@ -3,6 +3,7 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from hashlib import sha256
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
@ -36,6 +37,13 @@ class ClarificationMiddleware(AgentMiddleware[ClarificationMiddlewareState]):
|
||||
|
||||
state_schema = ClarificationMiddlewareState
|
||||
|
||||
def _stable_message_id(self, tool_call_id: str, formatted_message: str) -> str:
|
||||
"""Build a deterministic message ID so retried clarification calls replace, not append."""
|
||||
if tool_call_id:
|
||||
return f"clarification:{tool_call_id}"
|
||||
digest = sha256(formatted_message.encode("utf-8")).hexdigest()[:16]
|
||||
return f"clarification:{digest}"
|
||||
|
||||
def _is_chinese(self, text: str) -> bool:
|
||||
"""Check if text contains Chinese characters.
|
||||
|
||||
@ -131,6 +139,7 @@ class ClarificationMiddleware(AgentMiddleware[ClarificationMiddlewareState]):
|
||||
# Create a ToolMessage with the formatted question
|
||||
# This will be added to the message history
|
||||
tool_message = ToolMessage(
|
||||
id=self._stable_message_id(tool_call_id, formatted_message),
|
||||
content=formatted_message,
|
||||
tool_call_id=tool_call_id,
|
||||
name="ask_clarification",
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
"""Tests for ClarificationMiddleware, focusing on options type coercion."""
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
|
||||
|
||||
@ -118,3 +120,60 @@ class TestFormatClarificationMessage:
|
||||
assert "2. 2" in result
|
||||
assert "3. True" in result
|
||||
assert "4. None" in result
|
||||
|
||||
|
||||
class TestClarificationCommandIdempotency:
|
||||
"""Clarification tool-call retries should not duplicate messages in state."""
|
||||
|
||||
def test_repeated_tool_call_uses_stable_message_id(self, middleware):
|
||||
request = SimpleNamespace(
|
||||
tool_call={
|
||||
"name": "ask_clarification",
|
||||
"id": "call-clarify-1",
|
||||
"args": {
|
||||
"question": "Which environment should I use?",
|
||||
"clarification_type": "approach_choice",
|
||||
"options": ["dev", "prod"],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
first = middleware.wrap_tool_call(request, lambda _req: pytest.fail("handler should not be called"))
|
||||
second = middleware.wrap_tool_call(request, lambda _req: pytest.fail("handler should not be called"))
|
||||
|
||||
first_message = first.update["messages"][0]
|
||||
second_message = second.update["messages"][0]
|
||||
|
||||
assert first_message.id == "clarification:call-clarify-1"
|
||||
assert second_message.id == first_message.id
|
||||
assert second_message.tool_call_id == first_message.tool_call_id
|
||||
|
||||
merged = add_messages(add_messages([], [first_message]), [second_message])
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].id == "clarification:call-clarify-1"
|
||||
assert merged[0].content == first_message.content
|
||||
|
||||
def test_missing_tool_call_id_still_gets_stable_message_id(self, middleware):
|
||||
request = SimpleNamespace(
|
||||
tool_call={
|
||||
"name": "ask_clarification",
|
||||
"args": {
|
||||
"question": "Which environment should I use?",
|
||||
"clarification_type": "missing_info",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
first = middleware.wrap_tool_call(request, lambda _req: pytest.fail("handler should not be called"))
|
||||
second = middleware.wrap_tool_call(request, lambda _req: pytest.fail("handler should not be called"))
|
||||
|
||||
first_message = first.update["messages"][0]
|
||||
second_message = second.update["messages"][0]
|
||||
|
||||
assert first_message.id.startswith("clarification:")
|
||||
assert second_message.id == first_message.id
|
||||
|
||||
merged = add_messages(add_messages([], [first_message]), [second_message])
|
||||
|
||||
assert len(merged) == 1
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user