mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
fix(channels): preserve clarification conversation history across follow-up turns
Pin channel-triggered runs to the root checkpoint namespace and ensure thread_id is always present in configurable run config so follow-up replies resume the same conversation state. Add regression coverage to channel tests: assert checkpoint_ns/thread_id are passed in wait and stream paths add an integration-style clarification flow test that verifies the second user reply continues prior context instead of starting a new session This addresses history loss after ask_clarification interruptions (issue #2425).
This commit is contained in:
parent
1ca2621285
commit
0f3c335d8a
@ -560,6 +560,15 @@ class ChannelManager:
|
||||
user_layer.get("config"),
|
||||
)
|
||||
|
||||
configurable = run_config.get("configurable")
|
||||
if not isinstance(configurable, Mapping):
|
||||
configurable = {}
|
||||
run_config["configurable"] = configurable
|
||||
# Pin channel-triggered runs to the root graph namespace so follow-up
|
||||
# turns continue from the same conversation checkpoint.
|
||||
configurable.setdefault("checkpoint_ns", "")
|
||||
configurable.setdefault("thread_id", thread_id)
|
||||
|
||||
run_context = _merge_dicts(
|
||||
DEFAULT_RUN_CONTEXT,
|
||||
self._default_session.get("context"),
|
||||
|
||||
@ -508,6 +508,8 @@ class TestChannelManager:
|
||||
assert call_args[0][0] == "test-thread-123" # thread_id
|
||||
assert call_args[0][1] == "lead_agent" # assistant_id
|
||||
assert call_args[1]["input"]["messages"][0]["content"] == "hi"
|
||||
assert call_args[1]["config"]["configurable"]["checkpoint_ns"] == ""
|
||||
assert call_args[1]["config"]["configurable"]["thread_id"] == "test-thread-123"
|
||||
|
||||
assert len(outbound_received) == 1
|
||||
assert outbound_received[0].text == "Hello from agent!"
|
||||
@ -556,12 +558,135 @@ class TestChannelManager:
|
||||
call_args = mock_client.runs.wait.call_args
|
||||
assert call_args[0][1] == "lead_agent"
|
||||
assert call_args[1]["config"]["recursion_limit"] == 55
|
||||
assert call_args[1]["config"]["configurable"]["checkpoint_ns"] == ""
|
||||
assert call_args[1]["config"]["configurable"]["thread_id"] == "test-thread-123"
|
||||
assert call_args[1]["context"]["thinking_enabled"] is False
|
||||
assert call_args[1]["context"]["subagent_enabled"] is True
|
||||
assert call_args[1]["context"]["agent_name"] == "mobile-agent"
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_clarification_follow_up_preserves_history(self):
|
||||
"""Conversation should continue after ask_clarification instead of resetting history."""
|
||||
from app.channels.manager import ChannelManager
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
||||
manager = ChannelManager(bus=bus, store=store)
|
||||
|
||||
outbound_received = []
|
||||
|
||||
async def capture_outbound(msg):
|
||||
outbound_received.append(msg)
|
||||
|
||||
bus.subscribe_outbound(capture_outbound)
|
||||
|
||||
history_by_checkpoint: dict[tuple[str, str], list[str]] = {}
|
||||
|
||||
async def _runs_wait(thread_id, assistant_id, *, input, config, context):
|
||||
del assistant_id, context # unused in this test, kept for signature parity
|
||||
|
||||
checkpoint_ns = config.get("configurable", {}).get("checkpoint_ns")
|
||||
key = (thread_id, str(checkpoint_ns))
|
||||
history = history_by_checkpoint.setdefault(key, [])
|
||||
|
||||
human_text = input["messages"][0]["content"]
|
||||
history.append(human_text)
|
||||
|
||||
if len(history) == 1:
|
||||
return {
|
||||
"messages": [
|
||||
{"type": "human", "content": history[0]},
|
||||
{
|
||||
"type": "ai",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"name": "ask_clarification",
|
||||
"args": {"question": "Which environment should I use?"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"type": "tool",
|
||||
"name": "ask_clarification",
|
||||
"content": "Which environment should I use?",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
if len(history) == 2 and history[0] == "Deploy my app" and history[1] == "prod":
|
||||
return {
|
||||
"messages": [
|
||||
{"type": "human", "content": history[0]},
|
||||
{
|
||||
"type": "ai",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"name": "ask_clarification",
|
||||
"args": {"question": "Which environment should I use?"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"type": "tool",
|
||||
"name": "ask_clarification",
|
||||
"content": "Which environment should I use?",
|
||||
},
|
||||
{"type": "human", "content": history[1]},
|
||||
{"type": "ai", "content": "Got it. I will deploy to prod."},
|
||||
]
|
||||
}
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
{"type": "human", "content": history[-1]},
|
||||
{"type": "ai", "content": "History missing; clarification repeated."},
|
||||
]
|
||||
}
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.threads.create = AsyncMock(return_value={"thread_id": "clarify-thread-1"})
|
||||
mock_client.threads.get = AsyncMock(return_value={"thread_id": "clarify-thread-1"})
|
||||
mock_client.runs.wait = AsyncMock(side_effect=_runs_wait)
|
||||
manager._client = mock_client
|
||||
|
||||
await manager.start()
|
||||
|
||||
await bus.publish_inbound(
|
||||
InboundMessage(
|
||||
channel_name="test",
|
||||
chat_id="chat1",
|
||||
user_id="user1",
|
||||
text="Deploy my app",
|
||||
)
|
||||
)
|
||||
await _wait_for(lambda: len(outbound_received) >= 1)
|
||||
|
||||
await bus.publish_inbound(
|
||||
InboundMessage(
|
||||
channel_name="test",
|
||||
chat_id="chat1",
|
||||
user_id="user1",
|
||||
text="prod",
|
||||
)
|
||||
)
|
||||
await _wait_for(lambda: len(outbound_received) >= 2)
|
||||
await manager.stop()
|
||||
|
||||
assert outbound_received[0].text == "Which environment should I use?"
|
||||
assert outbound_received[1].text == "Got it. I will deploy to prod."
|
||||
|
||||
assert mock_client.runs.wait.call_count == 2
|
||||
first_call = mock_client.runs.wait.call_args_list[0]
|
||||
second_call = mock_client.runs.wait.call_args_list[1]
|
||||
assert first_call.kwargs["config"]["configurable"]["checkpoint_ns"] == ""
|
||||
assert second_call.kwargs["config"]["configurable"]["checkpoint_ns"] == ""
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_handle_chat_uses_user_session_overrides(self):
|
||||
from app.channels.manager import ChannelManager
|
||||
|
||||
@ -1238,6 +1363,8 @@ class TestChannelManager:
|
||||
call_args = mock_client.runs.stream.call_args
|
||||
|
||||
assert call_args[1]["input"]["messages"][0]["content"] == "hello"
|
||||
assert call_args[1]["config"]["configurable"]["checkpoint_ns"] == ""
|
||||
assert call_args[1]["config"]["configurable"]["thread_id"] == "test-thread-123"
|
||||
assert call_args[1]["context"]["is_bootstrap"] is True
|
||||
|
||||
# Final message should be published
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user