From 46ddc346ade831db5d1d9f4609f2f5b9962f5b8c Mon Sep 17 00:00:00 2001 From: kia <2486726988@qq.com> Date: Sun, 31 May 2026 22:43:07 +0800 Subject: [PATCH] fix(channels): preserve Feishu clarification thread continuity (#3285) * fix(channels): preserve Feishu clarification thread continuity * fix(channels): address Feishu clarification review feedback --------- Co-authored-by: zzp1221 Co-authored-by: Willem Jiang --- backend/app/channels/feishu.py | 197 +++++++++++++++++++++- backend/app/channels/manager.py | 65 +++++++- backend/app/channels/message_bus.py | 3 + backend/tests/test_channels.py | 190 ++++++++++++++++++++- backend/tests/test_feishu_parser.py | 246 +++++++++++++++++++++++++++- 5 files changed, 685 insertions(+), 16 deletions(-) diff --git a/backend/app/channels/feishu.py b/backend/app/channels/feishu.py index 75892d54d..eb6fb72ca 100644 --- a/backend/app/channels/feishu.py +++ b/backend/app/channels/feishu.py @@ -7,16 +7,26 @@ import json import logging import re import threading +import time from typing import Any, Literal from app.channels.base import Channel from app.channels.commands import KNOWN_CHANNEL_COMMANDS -from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment +from app.channels.message_bus import ( + PENDING_CLARIFICATION_METADATA_KEY, + RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY, + InboundMessage, + InboundMessageType, + MessageBus, + OutboundMessage, + ResolvedAttachment, +) from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths from deerflow.runtime.user_context import get_effective_user_id from deerflow.sandbox.sandbox_provider import get_sandbox_provider logger = logging.getLogger(__name__) +PENDING_CLARIFICATION_TTL_SECONDS = 30 * 60 def _is_feishu_command(text: str) -> bool: @@ -56,6 +66,7 @@ class FeishuChannel(Channel): self._background_tasks: set[asyncio.Task] = set() self._running_card_ids: dict[str, str] = {} self._running_card_tasks: dict[str, asyncio.Task] = {} + self._pending_clarifications: dict[tuple[str, str], list[dict[str, Any]]] = {} self._CreateFileRequest = None self._CreateFileRequestBody = None self._CreateImageRequest = None @@ -63,6 +74,16 @@ class FeishuChannel(Channel): self._GetMessageResourceRequest = None self._thread_lock = threading.Lock() + @staticmethod + def _non_empty_str(value: Any) -> str | None: + if isinstance(value, str) and value.strip(): + return value.strip() + return None + + @staticmethod + def _pending_key(chat_id: str, user_id: str) -> tuple[str, str]: + return (chat_id, user_id) + @property def supports_streaming(self) -> bool: return True @@ -531,18 +552,25 @@ class FeishuChannel(Channel): "[Feishu] failed to patch running card %s, falling back to final reply", running_card_id, ) - await self._reply_card(source_message_id, msg.text) + fallback_card_id = await self._reply_card(source_message_id, msg.text) + self._remember_thread_mapping(msg, source_message_id, fallback_card_id) + self._remember_pending_clarification(msg, fallback_card_id) else: + self._remember_thread_mapping(msg, source_message_id, running_card_id) + self._remember_pending_clarification(msg, running_card_id) logger.info("[Feishu] running card updated: source=%s card=%s", source_message_id, running_card_id) elif msg.is_final: - await self._reply_card(source_message_id, msg.text) + final_card_id = await self._reply_card(source_message_id, msg.text) + self._remember_thread_mapping(msg, source_message_id, final_card_id) + self._remember_pending_clarification(msg, final_card_id) elif awaited_running_card_task: logger.warning( "[Feishu] running card task finished without message_id for source=%s, skipping duplicate non-final creation", source_message_id, ) else: - await self._ensure_running_card(source_message_id, msg.text) + created_card_id = await self._ensure_running_card(source_message_id, msg.text) + self._remember_thread_mapping(msg, source_message_id, created_card_id) if msg.is_final: self._running_card_ids.pop(source_message_id, None) @@ -553,6 +581,129 @@ class FeishuChannel(Channel): # -- internal ---------------------------------------------------------- + def _remember_thread_mapping(self, msg: OutboundMessage, *topic_ids: str | None) -> None: + store = self.config.get("channel_store") + if store is None or not msg.thread_id: + return + + metadata_topic_ids = [ + msg.metadata.get("message_id"), + msg.metadata.get("root_id"), + msg.metadata.get("parent_id"), + msg.metadata.get("thread_id"), + msg.metadata.get("topic_id"), + ] + user_id = "" + raw_user_id = msg.metadata.get("user_id") + if isinstance(raw_user_id, str): + user_id = raw_user_id + + seen: set[str] = set() + for topic_id in [*topic_ids, *metadata_topic_ids]: + topic_id = self._non_empty_str(topic_id) + if not topic_id or topic_id in seen: + continue + seen.add(topic_id) + try: + store.set_thread_id( + self.name, + msg.chat_id, + msg.thread_id, + topic_id=topic_id, + user_id=user_id, + ) + except Exception: + logger.exception("[Feishu] failed to remember thread mapping for topic_id=%s", topic_id) + + def _remember_pending_clarification(self, msg: OutboundMessage, card_message_id: str | None) -> None: + if not msg.is_final or msg.metadata.get(PENDING_CLARIFICATION_METADATA_KEY) is not True: + return + + user_id = self._non_empty_str(msg.metadata.get("user_id")) + topic_id = self._non_empty_str(msg.metadata.get("topic_id")) + source_message_id = self._non_empty_str(msg.thread_ts) or self._non_empty_str(msg.metadata.get("message_id")) + if not (user_id and topic_id and msg.thread_id and source_message_id and card_message_id): + return + + key = self._pending_key(msg.chat_id, user_id) + pending = { + "thread_id": msg.thread_id, + "topic_id": topic_id, + "source_message_id": source_message_id, + "card_message_id": card_message_id, + "created_at": time.time(), + } + with self._thread_lock: + # Plain-message clarification continuity is a short-lived in-memory + # hint; explicit Feishu replies are still covered by persisted + # message-id mappings. + self._pending_clarifications.setdefault(key, []).append(pending) + logger.info( + "[Feishu] pending clarification remembered: chat_id=%s user_id=%s topic_id=%s thread_id=%s", + msg.chat_id, + user_id, + topic_id, + msg.thread_id, + ) + + def _consume_pending_clarification(self, chat_id: str, user_id: str) -> dict[str, Any] | None: + key = self._pending_key(chat_id, user_id) + with self._thread_lock: + pending_items = self._pending_clarifications.get(key) + if not pending_items: + return None + + now = time.time() + while pending_items: + pending = pending_items.pop(0) + created_at = pending.get("created_at") + if isinstance(created_at, (int, float)) and now - created_at <= PENDING_CLARIFICATION_TTL_SECONDS: + if pending_items: + self._pending_clarifications[key] = pending_items + else: + self._pending_clarifications.pop(key, None) + return pending + logger.info("[Feishu] pending clarification expired: chat_id=%s user_id=%s", chat_id, user_id) + + self._pending_clarifications.pop(key, None) + return None + + def _ensure_pending_thread_mapping(self, chat_id: str, user_id: str, pending: dict[str, Any]) -> None: + store = self.config.get("channel_store") + topic_id = self._non_empty_str(pending.get("topic_id")) + thread_id = self._non_empty_str(pending.get("thread_id")) + if store is None or not topic_id or not thread_id: + return + try: + store.set_thread_id(self.name, chat_id, thread_id, topic_id=topic_id, user_id=user_id) + except Exception: + logger.exception("[Feishu] failed to restore pending clarification mapping for topic_id=%s", topic_id) + + def _resolve_topic_id( + self, + chat_id: str, + msg_id: str, + *, + root_id: str | None, + parent_id: str | None, + thread_id: str | None, + ) -> tuple[str, bool]: + store = self.config.get("channel_store") + candidates = [root_id, parent_id, thread_id] + + if store is not None: + for candidate in candidates: + candidate = self._non_empty_str(candidate) + if not candidate: + continue + try: + if store.get_thread_id(self.name, chat_id, topic_id=candidate): + return candidate, True + except Exception: + logger.exception("[Feishu] failed to resolve stored topic mapping for topic_id=%s", candidate) + + return root_id or msg_id, False + @staticmethod def _log_future_error(fut, name: str, msg_id: str) -> None: """Callback for run_coroutine_threadsafe futures to surface errors.""" @@ -593,7 +744,9 @@ class FeishuChannel(Channel): # root_id is set when the message is a reply within a Feishu thread. # Use it as topic_id so all replies share the same DeerFlow thread. - root_id = getattr(message, "root_id", None) or None + root_id = self._non_empty_str(getattr(message, "root_id", None)) + parent_id = self._non_empty_str(getattr(message, "parent_id", None)) + feishu_thread_id = self._non_empty_str(getattr(message, "thread_id", None)) # Parse message content content = json.loads(message.content) @@ -654,10 +807,12 @@ class FeishuChannel(Channel): text = text.strip() logger.info( - "[Feishu] parsed message: chat_id=%s, msg_id=%s, root_id=%s, sender=%s, text=%r", + "[Feishu] parsed message: chat_id=%s, msg_id=%s, root_id=%s, parent_id=%s, thread_id=%s, sender=%s, text=%r", chat_id, msg_id, root_id, + parent_id, + feishu_thread_id, sender_id, text[:100] if text else "", ) @@ -673,8 +828,24 @@ class FeishuChannel(Channel): else: msg_type = InboundMessageType.CHAT - # topic_id: use root_id for replies (same topic), msg_id for new messages (new topic) - topic_id = root_id or msg_id + # Prefer any platform message id that already maps to a DeerFlow + # thread. This keeps replies to bot clarification cards in the + # original conversation even when Feishu reports the card as root. + topic_id, resolved_from_stored_mapping = self._resolve_topic_id( + chat_id, + msg_id, + root_id=root_id, + parent_id=parent_id, + thread_id=feishu_thread_id, + ) + resolved_from_pending = False + if msg_type == InboundMessageType.CHAT and not resolved_from_stored_mapping: + pending = self._consume_pending_clarification(chat_id, sender_id) + pending_topic_id = self._non_empty_str(pending.get("topic_id")) if pending else None + if pending_topic_id: + topic_id = pending_topic_id + self._ensure_pending_thread_mapping(chat_id, sender_id, pending) + resolved_from_pending = True inbound = self._make_inbound( chat_id=chat_id, @@ -683,7 +854,15 @@ class FeishuChannel(Channel): msg_type=msg_type, thread_ts=msg_id, files=files_list, - metadata={"message_id": msg_id, "root_id": root_id}, + metadata={ + "message_id": msg_id, + "root_id": root_id, + "parent_id": parent_id, + "thread_id": feishu_thread_id, + "topic_id": topic_id, + "user_id": sender_id, + RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY: resolved_from_pending, + }, ) inbound.topic_id = topic_id diff --git a/backend/app/channels/manager.py b/backend/app/channels/manager.py index c245d9448..982c87035 100644 --- a/backend/app/channels/manager.py +++ b/backend/app/channels/manager.py @@ -15,7 +15,14 @@ import httpx from langgraph_sdk.errors import ConflictError from app.channels.commands import KNOWN_CHANNEL_COMMANDS -from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment +from app.channels.message_bus import ( + PENDING_CLARIFICATION_METADATA_KEY, + InboundMessage, + InboundMessageType, + MessageBus, + OutboundMessage, + ResolvedAttachment, +) from app.channels.store import ChannelStore from app.gateway.csrf_middleware import CSRF_COOKIE_NAME, CSRF_HEADER_NAME, generate_csrf_token from app.gateway.internal_auth import create_internal_auth_headers @@ -202,6 +209,54 @@ def _extract_response_text(result: dict | list) -> str: return "" +def _messages_from_result(result: dict | list) -> list[Any]: + if isinstance(result, list): + return result + if isinstance(result, dict): + messages = result.get("messages", []) + if isinstance(messages, list): + return messages + return [] + + +def _current_turn_messages(result: dict | list) -> list[dict[str, Any]]: + messages = _messages_from_result(result) + current_turn: list[dict[str, Any]] = [] + for msg in reversed(messages): + if not isinstance(msg, dict): + continue + if msg.get("type") == "human": + break + current_turn.append(msg) + current_turn.reverse() + return current_turn + + +def _has_current_turn_clarification(result: dict | list) -> bool: + """Return True only when the current turn's final result is clarification.""" + for msg in reversed(_current_turn_messages(result)): + msg_type = msg.get("type") + if msg_type == "tool": + return msg.get("name") == "ask_clarification" + if msg_type == "ai": + content = msg.get("content") + if isinstance(content, str): + if content: + return False + elif content: + return False + if msg.get("tool_calls"): + return False + return False + + +def _response_metadata(base_metadata: dict[str, Any], *, pending_clarification: bool = False) -> dict[str, Any]: + metadata = _slim_metadata(base_metadata) + if pending_clarification: + metadata[PENDING_CLARIFICATION_METADATA_KEY] = True + return metadata + + def _extract_text_content(content: Any) -> str: """Extract text from a streaming payload content field.""" if isinstance(content, str): @@ -806,6 +861,7 @@ class ChannelManager: raise response_text = _extract_response_text(result) + pending_clarification = _has_current_turn_clarification(result) artifacts = _extract_artifacts(result) logger.info( @@ -831,7 +887,7 @@ class ChannelManager: artifacts=artifacts, attachments=attachments, thread_ts=msg.thread_ts, - metadata=_slim_metadata(msg.metadata), + metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification), ) logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id) await self.bus.publish_outbound(outbound) @@ -893,7 +949,7 @@ class ChannelManager: text=latest_text, is_final=False, thread_ts=msg.thread_ts, - metadata=_slim_metadata(msg.metadata), + metadata=_response_metadata(msg.metadata), ) ) last_published_text = latest_text @@ -907,6 +963,7 @@ class ChannelManager: finally: result = last_values if last_values is not None else {"messages": [{"type": "ai", "content": latest_text}]} response_text = _extract_response_text(result) + pending_clarification = _has_current_turn_clarification(result) artifacts = _extract_artifacts(result) response_text, attachments = _prepare_artifact_delivery(thread_id, response_text, artifacts) @@ -938,7 +995,7 @@ class ChannelManager: attachments=attachments, is_final=True, thread_ts=msg.thread_ts, - metadata=_slim_metadata(msg.metadata), + metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification), ) ) diff --git a/backend/app/channels/message_bus.py b/backend/app/channels/message_bus.py index 4d0818aca..4e847cca0 100644 --- a/backend/app/channels/message_bus.py +++ b/backend/app/channels/message_bus.py @@ -13,6 +13,9 @@ from typing import Any logger = logging.getLogger(__name__) +PENDING_CLARIFICATION_METADATA_KEY = "pending_clarification" +RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY = "resolved_from_pending_clarification" + # --------------------------------------------------------------------------- # Message types diff --git a/backend/tests/test_channels.py b/backend/tests/test_channels.py index 0e1c69f50..960e32a4b 100644 --- a/backend/tests/test_channels.py +++ b/backend/tests/test_channels.py @@ -12,7 +12,14 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from app.channels.base import Channel -from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment +from app.channels.message_bus import ( + PENDING_CLARIFICATION_METADATA_KEY, + InboundMessage, + InboundMessageType, + MessageBus, + OutboundMessage, + ResolvedAttachment, +) from app.channels.store import ChannelStore @@ -392,6 +399,47 @@ class TestExtractResponseText: assert _extract_response_text(result) == "Here is the plan." +class TestClarificationDetection: + def test_final_clarification_tool_message_is_pending(self): + from app.channels.manager import _has_current_turn_clarification + + result = { + "messages": [ + {"type": "human", "content": "deploy"}, + {"type": "ai", "content": "", "tool_calls": [{"name": "ask_clarification", "args": {}}]}, + {"type": "tool", "name": "ask_clarification", "content": "Which environment?"}, + ] + } + assert _has_current_turn_clarification(result) is True + + def test_clarification_followed_by_regular_ai_is_not_pending(self): + from app.channels.manager import _has_current_turn_clarification + + result = { + "messages": [ + {"type": "human", "content": "deploy"}, + {"type": "ai", "content": "", "tool_calls": [{"name": "ask_clarification", "args": {}}]}, + {"type": "tool", "name": "ask_clarification", "content": "Which environment?"}, + {"type": "ai", "content": "I will continue without pending clarification."}, + ] + } + assert _has_current_turn_clarification(result) is False + + def test_previous_turn_clarification_does_not_mark_current_turn(self): + from app.channels.manager import _has_current_turn_clarification + + result = { + "messages": [ + {"type": "human", "content": "deploy"}, + {"type": "ai", "content": "", "tool_calls": [{"name": "ask_clarification", "args": {}}]}, + {"type": "tool", "name": "ask_clarification", "content": "Which environment?"}, + {"type": "human", "content": "prod"}, + {"type": "ai", "content": "Deploying to prod."}, + ] + } + assert _has_current_turn_clarification(result) is False + + # --------------------------------------------------------------------------- # ChannelManager tests # --------------------------------------------------------------------------- @@ -637,6 +685,74 @@ class TestChannelManager: _run(go()) + def test_handle_chat_marks_clarification_outbound_metadata(self): + from app.channels.manager import ChannelManager + + async def go(): + bus = MessageBus() + store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") + manager = ChannelManager(bus=bus, store=store) + outbound_received: list[OutboundMessage] = [] + + async def capture_outbound(msg: OutboundMessage) -> None: + outbound_received.append(msg) + + bus.subscribe_outbound(capture_outbound) + mock_client = _make_mock_langgraph_client( + run_result={ + "messages": [ + {"type": "human", "content": "deploy"}, + {"type": "ai", "content": "", "tool_calls": [{"name": "ask_clarification", "args": {}}]}, + {"type": "tool", "name": "ask_clarification", "content": "Which environment?"}, + ] + } + ) + manager._client = mock_client + await manager.start() + + inbound = InboundMessage( + channel_name="test", + chat_id="chat1", + user_id="user1", + text="deploy", + metadata={"message_id": "msg-1"}, + ) + await bus.publish_inbound(inbound) + await _wait_for(lambda: len(outbound_received) >= 1) + await manager.stop() + + assert outbound_received[0].text == "Which environment?" + assert outbound_received[0].metadata["message_id"] == "msg-1" + assert outbound_received[0].metadata[PENDING_CLARIFICATION_METADATA_KEY] is True + + _run(go()) + + def test_handle_chat_does_not_mark_regular_outbound_as_clarification(self): + from app.channels.manager import ChannelManager + + async def go(): + bus = MessageBus() + store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") + manager = ChannelManager(bus=bus, store=store) + outbound_received: list[OutboundMessage] = [] + + async def capture_outbound(msg: OutboundMessage) -> None: + outbound_received.append(msg) + + bus.subscribe_outbound(capture_outbound) + mock_client = _make_mock_langgraph_client() + manager._client = mock_client + await manager.start() + + await bus.publish_inbound(InboundMessage(channel_name="test", chat_id="chat1", user_id="user1", text="hi")) + await _wait_for(lambda: len(outbound_received) >= 1) + await manager.stop() + + assert outbound_received[0].text == "Hello from agent!" + assert PENDING_CLARIFICATION_METADATA_KEY not in outbound_received[0].metadata + + _run(go()) + def test_handle_chat_outbound_drops_large_metadata_keys(self): """Large metadata keys like raw_message should be stripped from outbound messages.""" from app.channels.manager import ChannelManager @@ -1018,6 +1134,67 @@ class TestChannelManager: _run(go()) + def test_handle_feishu_streaming_marks_only_final_clarification_outbound(self, monkeypatch): + from app.channels.manager import ChannelManager + + monkeypatch.setattr("app.channels.manager.STREAM_UPDATE_MIN_INTERVAL_SECONDS", 0.0) + + async def go(): + bus = MessageBus() + store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") + manager = ChannelManager(bus=bus, store=store) + outbound_received: list[OutboundMessage] = [] + + async def capture_outbound(msg: OutboundMessage) -> None: + outbound_received.append(msg) + + bus.subscribe_outbound(capture_outbound) + stream_events = [ + _make_stream_part( + "messages-tuple", + [ + {"id": "ai-1", "content": "Thinking", "type": "AIMessageChunk"}, + {"langgraph_node": "agent"}, + ], + ), + _make_stream_part( + "values", + { + "messages": [ + {"type": "human", "content": "deploy"}, + {"type": "ai", "content": "", "tool_calls": [{"name": "ask_clarification", "args": {}}]}, + {"type": "tool", "name": "ask_clarification", "content": "Which environment?"}, + ], + "artifacts": [], + }, + ), + ] + mock_client = _make_mock_langgraph_client() + mock_client.runs.stream = MagicMock(return_value=_make_async_iterator(stream_events)) + manager._client = mock_client + await manager.start() + + await bus.publish_inbound( + InboundMessage( + channel_name="feishu", + chat_id="chat1", + user_id="user1", + text="deploy", + thread_ts="om-source-1", + ) + ) + await _wait_for(lambda: len(outbound_received) >= 2) + await manager.stop() + + assert [msg.is_final for msg in outbound_received] == [False, False, True] + assert outbound_received[0].text == "Thinking" + assert outbound_received[1].text == "Which environment?" + assert outbound_received[2].text == "Which environment?" + assert all(PENDING_CLARIFICATION_METADATA_KEY not in msg.metadata for msg in outbound_received[:-1]) + assert outbound_received[-1].metadata[PENDING_CLARIFICATION_METADATA_KEY] is True + + _run(go()) + def test_handle_feishu_stream_error_still_sends_final(self, monkeypatch): """When the stream raises mid-way, a final outbound with is_final=True must still be published.""" from app.channels.manager import ChannelManager @@ -2010,7 +2187,8 @@ class TestFeishuChannel: async def go(): bus = MessageBus() bus.publish_inbound = AsyncMock() - channel = FeishuChannel(bus, config={}) + store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") + channel = FeishuChannel(bus, config={"channel_store": store}) channel._api_client = MagicMock() reply_started = asyncio.Event() @@ -2046,6 +2224,11 @@ class TestFeishuChannel: text="Hello", is_final=False, thread_ts="om-source-msg", + metadata={ + "user_id": "user-1", + "root_id": "om-root-msg", + "topic_id": "om-root-msg", + }, ) ) ) @@ -2060,6 +2243,9 @@ class TestFeishuChannel: assert channel._reply_card.await_count == 1 channel._update_card.assert_awaited_once_with("om-running-card", "Hello") assert "om-source-msg" not in channel._running_card_tasks + assert store.get_thread_id("feishu", "chat-1", topic_id="om-source-msg") == "thread-1" + assert store.get_thread_id("feishu", "chat-1", topic_id="om-running-card") == "thread-1" + assert store.get_thread_id("feishu", "chat-1", topic_id="om-root-msg") == "thread-1" _run(go()) diff --git a/backend/tests/test_feishu_parser.py b/backend/tests/test_feishu_parser.py index 202862fb1..5ecfb9e0b 100644 --- a/backend/tests/test_feishu_parser.py +++ b/backend/tests/test_feishu_parser.py @@ -1,12 +1,38 @@ import asyncio import json +import tempfile +from pathlib import Path from unittest.mock import AsyncMock, MagicMock import pytest from app.channels.commands import KNOWN_CHANNEL_COMMANDS from app.channels.feishu import FeishuChannel -from app.channels.message_bus import InboundMessage, MessageBus +from app.channels.message_bus import ( + PENDING_CLARIFICATION_METADATA_KEY, + RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY, + InboundMessage, + MessageBus, + OutboundMessage, +) +from app.channels.store import ChannelStore + + +def _pending( + topic_id: str, + *, + thread_id: str | None = None, + source_message_id: str | None = None, + card_message_id: str | None = None, + created_at: float = 9999999999, +) -> dict: + return { + "thread_id": thread_id or f"deer-thread-{topic_id}", + "topic_id": topic_id, + "source_message_id": source_message_id or topic_id, + "card_message_id": card_message_id or f"card-{topic_id}", + "created_at": created_at, + } def _run(coro): @@ -138,6 +164,224 @@ def test_feishu_on_message_extracts_image_and_file_keys(): assert "[file]" in mock_make_inbound.call_args[1]["text"] +def test_feishu_on_message_reuses_stored_parent_topic_for_card_replies(): + bus = MessageBus() + store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") + store.set_thread_id( + "feishu", + "chat_1", + "deer-thread-1", + topic_id="om_clarification_card", + user_id="user_1", + ) + channel = FeishuChannel( + bus, + {"app_id": "test", "app_secret": "test", "channel_store": store}, + ) + + event = MagicMock() + event.event.message.chat_id = "chat_1" + event.event.message.message_id = "msg_reply" + event.event.message.root_id = "om_unknown_root" + event.event.message.parent_id = "om_clarification_card" + event.event.message.thread_id = None + event.event.sender.sender_id.open_id = "user_1" + event.event.message.content = json.dumps({"text": "prod"}) + + with pytest.MonkeyPatch.context() as m: + mock_make_inbound = MagicMock() + m.setattr(channel, "_make_inbound", mock_make_inbound) + channel._on_message(event) + + inbound = mock_make_inbound.return_value + assert inbound.topic_id == "om_clarification_card" + assert mock_make_inbound.call_args.kwargs["metadata"]["topic_id"] == "om_clarification_card" + + +def _make_text_event( + text: str, + *, + chat_id: str = "chat_1", + message_id: str = "msg_1", + user_id: str = "user_1", + root_id: str | None = None, + parent_id: str | None = None, + thread_id: str | None = None, +): + event = MagicMock() + event.event.message.chat_id = chat_id + event.event.message.message_id = message_id + event.event.message.root_id = root_id + event.event.message.parent_id = parent_id + event.event.message.thread_id = thread_id + event.event.sender.sender_id.open_id = user_id + event.event.message.content = json.dumps({"text": text}) + return event + + +def test_feishu_plain_reply_consumes_pending_clarification_topic(): + bus = MessageBus() + store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") + store.set_thread_id("feishu", "chat_1", "deer-thread-1", topic_id="om_original", user_id="user_1") + channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test", "channel_store": store}) + channel._pending_clarifications[channel._pending_key("chat_1", "user_1")] = [_pending("om_original", thread_id="deer-thread-1", card_message_id="om_card")] + + with pytest.MonkeyPatch.context() as m: + mock_make_inbound = MagicMock() + m.setattr(channel, "_make_inbound", mock_make_inbound) + channel._on_message(_make_text_event("2", message_id="msg_plain_2")) + + inbound = mock_make_inbound.return_value + metadata = mock_make_inbound.call_args.kwargs["metadata"] + assert inbound.topic_id == "om_original" + assert metadata["topic_id"] == "om_original" + assert metadata[RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY] is True + assert channel._pending_key("chat_1", "user_1") not in channel._pending_clarifications + + +def test_feishu_pending_clarification_is_consumed_once(): + bus = MessageBus() + channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test"}) + channel._pending_clarifications[channel._pending_key("chat_1", "user_1")] = [_pending("om_original", thread_id="deer-thread-1", card_message_id="om_card")] + + with pytest.MonkeyPatch.context() as m: + created = [] + + def fake_make_inbound(**kwargs): + inbound = InboundMessage(channel_name="feishu", **kwargs) + created.append(inbound) + return inbound + + mock_make_inbound = MagicMock(side_effect=fake_make_inbound) + m.setattr(channel, "_make_inbound", mock_make_inbound) + channel._on_message(_make_text_event("2", message_id="msg_first")) + channel._on_message(_make_text_event("next", message_id="msg_second")) + + first_inbound = created[0] + second_inbound = created[1] + first_metadata = mock_make_inbound.call_args_list[0].kwargs["metadata"] + second_metadata = mock_make_inbound.call_args_list[1].kwargs["metadata"] + assert first_inbound.topic_id == "om_original" + assert second_inbound.topic_id == "msg_second" + assert first_metadata["topic_id"] == "om_original" + assert first_metadata[RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY] is True + assert second_metadata["topic_id"] == "msg_second" + assert second_metadata[RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY] is False + + +def test_feishu_expired_pending_clarification_is_ignored(monkeypatch): + bus = MessageBus() + channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test"}) + monkeypatch.setattr("app.channels.feishu.time.time", lambda: 10_000.0) + channel._pending_clarifications[channel._pending_key("chat_1", "user_1")] = [_pending("om_original", thread_id="deer-thread-1", card_message_id="om_card", created_at=0.0)] + + with pytest.MonkeyPatch.context() as m: + mock_make_inbound = MagicMock() + m.setattr(channel, "_make_inbound", mock_make_inbound) + channel._on_message(_make_text_event("2", message_id="msg_plain_2")) + + metadata = mock_make_inbound.call_args.kwargs["metadata"] + assert metadata["topic_id"] == "msg_plain_2" + assert metadata[RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY] is False + assert channel._pending_key("chat_1", "user_1") not in channel._pending_clarifications + + +def test_feishu_command_does_not_consume_pending_clarification(): + bus = MessageBus() + channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test"}) + key = channel._pending_key("chat_1", "user_1") + channel._pending_clarifications[key] = [_pending("om_original", thread_id="deer-thread-1", card_message_id="om_card")] + + with pytest.MonkeyPatch.context() as m: + mock_make_inbound = MagicMock() + m.setattr(channel, "_make_inbound", mock_make_inbound) + channel._on_message(_make_text_event("/status", message_id="msg_command")) + + metadata = mock_make_inbound.call_args.kwargs["metadata"] + assert mock_make_inbound.call_args.kwargs["msg_type"].value == "command" + assert metadata["topic_id"] == "msg_command" + assert metadata[RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY] is False + assert key in channel._pending_clarifications + + +def test_feishu_remembers_pending_clarification_only_after_final_card_success(): + bus = MessageBus() + channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test"}) + outbound = OutboundMessage( + channel_name="feishu", + chat_id="chat_1", + thread_id="deer-thread-1", + text="clarify?", + thread_ts="om_original", + metadata={ + PENDING_CLARIFICATION_METADATA_KEY: True, + "user_id": "user_1", + "topic_id": "om_original", + "message_id": "om_original", + }, + ) + + channel._remember_pending_clarification(outbound, None) + assert channel._pending_clarifications == {} + + channel._remember_pending_clarification(outbound, "om_card") + pending = channel._pending_clarifications[channel._pending_key("chat_1", "user_1")][0] + assert pending["topic_id"] == "om_original" + assert pending["thread_id"] == "deer-thread-1" + assert pending["card_message_id"] == "om_card" + + +def test_feishu_multiple_pending_clarifications_are_consumed_in_order(): + bus = MessageBus() + channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test"}) + key = channel._pending_key("chat_1", "user_1") + channel._pending_clarifications[key] = [ + _pending("om_first", thread_id="deer-thread-1"), + _pending("om_second", thread_id="deer-thread-2"), + ] + + with pytest.MonkeyPatch.context() as m: + created = [] + + def fake_make_inbound(**kwargs): + inbound = InboundMessage(channel_name="feishu", **kwargs) + created.append(inbound) + return inbound + + m.setattr(channel, "_make_inbound", MagicMock(side_effect=fake_make_inbound)) + channel._on_message(_make_text_event("first answer", message_id="msg_first")) + channel._on_message(_make_text_event("second answer", message_id="msg_second")) + + assert [msg.topic_id for msg in created] == ["om_first", "om_second"] + assert key not in channel._pending_clarifications + + +def test_feishu_explicit_reply_prefers_stored_mapping_over_pending(): + bus = MessageBus() + store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") + store.set_thread_id("feishu", "chat_1", "deer-thread-card", topic_id="om_card", user_id="user_1") + channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test", "channel_store": store}) + key = channel._pending_key("chat_1", "user_1") + channel._pending_clarifications[key] = [_pending("om_pending", thread_id="deer-thread-pending")] + + with pytest.MonkeyPatch.context() as m: + mock_make_inbound = MagicMock() + m.setattr(channel, "_make_inbound", mock_make_inbound) + channel._on_message( + _make_text_event( + "answer", + message_id="msg_reply", + root_id="om_unknown", + parent_id="om_card", + ) + ) + + metadata = mock_make_inbound.call_args.kwargs["metadata"] + assert metadata["topic_id"] == "om_card" + assert metadata[RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY] is False + assert key in channel._pending_clarifications + + @pytest.mark.parametrize("command", sorted(KNOWN_CHANNEL_COMMANDS)) def test_feishu_recognizes_all_known_slash_commands(command): """Every entry in KNOWN_CHANNEL_COMMANDS must be classified as a command."""