diff --git a/backend/app/channels/manager.py b/backend/app/channels/manager.py index 286635d3f..e37078ba1 100644 --- a/backend/app/channels/manager.py +++ b/backend/app/channels/manager.py @@ -560,6 +560,15 @@ class ChannelManager: user_layer.get("config"), ) + configurable = run_config.get("configurable") + if not isinstance(configurable, Mapping): + configurable = {} + run_config["configurable"] = configurable + # Pin channel-triggered runs to the root graph namespace so follow-up + # turns continue from the same conversation checkpoint. + configurable.setdefault("checkpoint_ns", "") + configurable.setdefault("thread_id", thread_id) + run_context = _merge_dicts( DEFAULT_RUN_CONTEXT, self._default_session.get("context"), diff --git a/backend/tests/test_channels.py b/backend/tests/test_channels.py index 7fc412653..b3c1870e3 100644 --- a/backend/tests/test_channels.py +++ b/backend/tests/test_channels.py @@ -508,6 +508,8 @@ class TestChannelManager: assert call_args[0][0] == "test-thread-123" # thread_id assert call_args[0][1] == "lead_agent" # assistant_id assert call_args[1]["input"]["messages"][0]["content"] == "hi" + assert call_args[1]["config"]["configurable"]["checkpoint_ns"] == "" + assert call_args[1]["config"]["configurable"]["thread_id"] == "test-thread-123" assert len(outbound_received) == 1 assert outbound_received[0].text == "Hello from agent!" @@ -556,12 +558,135 @@ class TestChannelManager: call_args = mock_client.runs.wait.call_args assert call_args[0][1] == "lead_agent" assert call_args[1]["config"]["recursion_limit"] == 55 + assert call_args[1]["config"]["configurable"]["checkpoint_ns"] == "" + assert call_args[1]["config"]["configurable"]["thread_id"] == "test-thread-123" assert call_args[1]["context"]["thinking_enabled"] is False assert call_args[1]["context"]["subagent_enabled"] is True assert call_args[1]["context"]["agent_name"] == "mobile-agent" _run(go()) + def test_clarification_follow_up_preserves_history(self): + """Conversation should continue after ask_clarification instead of resetting history.""" + from app.channels.manager import ChannelManager + + async def go(): + bus = MessageBus() + store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") + manager = ChannelManager(bus=bus, store=store) + + outbound_received = [] + + async def capture_outbound(msg): + outbound_received.append(msg) + + bus.subscribe_outbound(capture_outbound) + + history_by_checkpoint: dict[tuple[str, str], list[str]] = {} + + async def _runs_wait(thread_id, assistant_id, *, input, config, context): + del assistant_id, context # unused in this test, kept for signature parity + + checkpoint_ns = config.get("configurable", {}).get("checkpoint_ns") + key = (thread_id, str(checkpoint_ns)) + history = history_by_checkpoint.setdefault(key, []) + + human_text = input["messages"][0]["content"] + history.append(human_text) + + if len(history) == 1: + return { + "messages": [ + {"type": "human", "content": history[0]}, + { + "type": "ai", + "content": "", + "tool_calls": [ + { + "name": "ask_clarification", + "args": {"question": "Which environment should I use?"}, + } + ], + }, + { + "type": "tool", + "name": "ask_clarification", + "content": "Which environment should I use?", + }, + ] + } + + if len(history) == 2 and history[0] == "Deploy my app" and history[1] == "prod": + return { + "messages": [ + {"type": "human", "content": history[0]}, + { + "type": "ai", + "content": "", + "tool_calls": [ + { + "name": "ask_clarification", + "args": {"question": "Which environment should I use?"}, + } + ], + }, + { + "type": "tool", + "name": "ask_clarification", + "content": "Which environment should I use?", + }, + {"type": "human", "content": history[1]}, + {"type": "ai", "content": "Got it. I will deploy to prod."}, + ] + } + + return { + "messages": [ + {"type": "human", "content": history[-1]}, + {"type": "ai", "content": "History missing; clarification repeated."}, + ] + } + + mock_client = MagicMock() + mock_client.threads.create = AsyncMock(return_value={"thread_id": "clarify-thread-1"}) + mock_client.threads.get = AsyncMock(return_value={"thread_id": "clarify-thread-1"}) + mock_client.runs.wait = AsyncMock(side_effect=_runs_wait) + manager._client = mock_client + + await manager.start() + + await bus.publish_inbound( + InboundMessage( + channel_name="test", + chat_id="chat1", + user_id="user1", + text="Deploy my app", + ) + ) + await _wait_for(lambda: len(outbound_received) >= 1) + + await bus.publish_inbound( + InboundMessage( + channel_name="test", + chat_id="chat1", + user_id="user1", + text="prod", + ) + ) + await _wait_for(lambda: len(outbound_received) >= 2) + await manager.stop() + + assert outbound_received[0].text == "Which environment should I use?" + assert outbound_received[1].text == "Got it. I will deploy to prod." + + assert mock_client.runs.wait.call_count == 2 + first_call = mock_client.runs.wait.call_args_list[0] + second_call = mock_client.runs.wait.call_args_list[1] + assert first_call.kwargs["config"]["configurable"]["checkpoint_ns"] == "" + assert second_call.kwargs["config"]["configurable"]["checkpoint_ns"] == "" + + _run(go()) + def test_handle_chat_uses_user_session_overrides(self): from app.channels.manager import ChannelManager @@ -1238,6 +1363,8 @@ class TestChannelManager: call_args = mock_client.runs.stream.call_args assert call_args[1]["input"]["messages"][0]["content"] == "hello" + assert call_args[1]["config"]["configurable"]["checkpoint_ns"] == "" + assert call_args[1]["config"]["configurable"]["thread_id"] == "test-thread-123" assert call_args[1]["context"]["is_bootstrap"] is True # Final message should be published