diff --git a/backend/packages/harness/deerflow/client.py b/backend/packages/harness/deerflow/client.py index 37d528c98..fdf5df24b 100644 --- a/backend/packages/harness/deerflow/client.py +++ b/backend/packages/harness/deerflow/client.py @@ -315,6 +315,108 @@ class DeerFlowClient: return "\n".join(pieces) if pieces else "" return str(content) + # ------------------------------------------------------------------ + # Public API — threads + # ------------------------------------------------------------------ + + def list_threads(self, limit: int = 10) -> dict: + """List the recent N threads. + + Args: + limit: Maximum number of threads to return. Default is 10. + + Returns: + Dict with "thread_list" key containing list of thread info dicts, + sorted by thread creation time descending. + """ + checkpointer = self._checkpointer + if checkpointer is None: + from deerflow.agents.checkpointer.provider import get_checkpointer + + checkpointer = get_checkpointer() + + thread_info_map = {} + + for cp in checkpointer.list(config=None, limit=limit): + cfg = cp.config.get("configurable", {}) + thread_id = cfg.get("thread_id") + if not thread_id: + continue + + ts = cp.checkpoint.get("ts") + checkpoint_id = cfg.get("checkpoint_id") + + if thread_id not in thread_info_map: + channel_values = cp.checkpoint.get("channel_values", {}) + thread_info_map[thread_id] = { + "thread_id": thread_id, + "created_at": ts, + "updated_at": ts, + "latest_checkpoint_id": checkpoint_id, + "title": channel_values.get("title"), + } + else: + # Explicitly compare timestamps to ensure accuracy when iterating over unordered namespaces. + # Treat None as "missing" and only compare when existing values are non-None. + if ts is not None: + current_created = thread_info_map[thread_id]["created_at"] + if current_created is None or ts < current_created: + thread_info_map[thread_id]["created_at"] = ts + + current_updated = thread_info_map[thread_id]["updated_at"] + if current_updated is None or ts > current_updated: + thread_info_map[thread_id]["updated_at"] = ts + thread_info_map[thread_id]["latest_checkpoint_id"] = checkpoint_id + channel_values = cp.checkpoint.get("channel_values", {}) + thread_info_map[thread_id]["title"] = channel_values.get("title") + + threads = list(thread_info_map.values()) + threads.sort(key=lambda x: x.get("created_at") or "", reverse=True) + + return {"thread_list": threads[:limit]} + + def get_thread(self, thread_id: str) -> dict: + """Get the complete thread record, including all node execution records. + + Args: + thread_id: Thread ID. + + Returns: + Dict containing the thread's full checkpoint history. + """ + checkpointer = self._checkpointer + if checkpointer is None: + from deerflow.agents.checkpointer.provider import get_checkpointer + + checkpointer = get_checkpointer() + + config = {"configurable": {"thread_id": thread_id}} + checkpoints = [] + + for cp in checkpointer.list(config): + channel_values = dict(cp.checkpoint.get("channel_values", {})) + if "messages" in channel_values: + channel_values["messages"] = [self._serialize_message(m) if hasattr(m, "content") else m for m in channel_values["messages"]] + + cfg = cp.config.get("configurable", {}) + parent_cfg = cp.parent_config.get("configurable", {}) if cp.parent_config else {} + + checkpoints.append( + { + "checkpoint_id": cfg.get("checkpoint_id"), + "parent_checkpoint_id": parent_cfg.get("checkpoint_id"), + "ts": cp.checkpoint.get("ts"), + "metadata": cp.metadata, + "values": channel_values, + "pending_writes": [{"task_id": w[0], "channel": w[1], "value": w[2]} for w in getattr(cp, "pending_writes", [])], + } + ) + + # Sort globally by timestamp to prevent partial ordering issues caused by different namespaces (e.g., subgraphs) + checkpoints.sort(key=lambda x: x["ts"] if x["ts"] else "") + + return {"thread_id": thread_id, "checkpoints": checkpoints} + # ------------------------------------------------------------------ # Public API — conversation # ------------------------------------------------------------------ diff --git a/backend/tests/test_client.py b/backend/tests/test_client.py index a88bb43c6..29574b085 100644 --- a/backend/tests/test_client.py +++ b/backend/tests/test_client.py @@ -570,6 +570,147 @@ class TestGetModel: assert client.get_model("nonexistent") is None +# --------------------------------------------------------------------------- +# Thread Queries (list_threads / get_thread) +# --------------------------------------------------------------------------- + + +class TestThreadQueries: + def _make_mock_checkpoint_tuple( + self, + thread_id: str, + checkpoint_id: str, + ts: str, + title: str | None = None, + parent_id: str | None = None, + messages: list = None, + pending_writes: list = None, + ): + cp = MagicMock() + cp.config = {"configurable": {"thread_id": thread_id, "checkpoint_id": checkpoint_id}} + + channel_values = {} + if title is not None: + channel_values["title"] = title + if messages is not None: + channel_values["messages"] = messages + + cp.checkpoint = {"ts": ts, "channel_values": channel_values} + cp.metadata = {"source": "test"} + + if parent_id: + cp.parent_config = {"configurable": {"thread_id": thread_id, "checkpoint_id": parent_id}} + else: + cp.parent_config = {} + + cp.pending_writes = pending_writes or [] + return cp + + def test_list_threads_empty(self, client): + mock_checkpointer = MagicMock() + mock_checkpointer.list.return_value = [] + client._checkpointer = mock_checkpointer + + result = client.list_threads() + assert result == {"thread_list": []} + mock_checkpointer.list.assert_called_once_with(config=None, limit=10) + + def test_list_threads_basic(self, client): + mock_checkpointer = MagicMock() + client._checkpointer = mock_checkpointer + + cp1 = self._make_mock_checkpoint_tuple("t1", "c1", "2023-01-01T10:00:00Z", title="Thread 1") + cp2 = self._make_mock_checkpoint_tuple("t1", "c2", "2023-01-01T10:05:00Z", title="Thread 1 Updated") + cp3 = self._make_mock_checkpoint_tuple("t2", "c3", "2023-01-02T10:00:00Z", title="Thread 2") + cp_empty = self._make_mock_checkpoint_tuple("", "c4", "2023-01-03T10:00:00Z", title="Thread Empty") + + # Mock list returns out of order to test the timestamp sorting/comparison + # Also includes a checkpoint with an empty thread_id which should be skipped + mock_checkpointer.list.return_value = [cp2, cp1, cp_empty, cp3] + + result = client.list_threads(limit=5) + mock_checkpointer.list.assert_called_once_with(config=None, limit=5) + + threads = result["thread_list"] + assert len(threads) == 2 + + # t2 should be first because its created_at (2023-01-02) is newer than t1 (2023-01-01) + assert threads[0]["thread_id"] == "t2" + assert threads[0]["created_at"] == "2023-01-02T10:00:00Z" + assert threads[0]["title"] == "Thread 2" + + assert threads[1]["thread_id"] == "t1" + assert threads[1]["created_at"] == "2023-01-01T10:00:00Z" + assert threads[1]["updated_at"] == "2023-01-01T10:05:00Z" + assert threads[1]["latest_checkpoint_id"] == "c2" + assert threads[1]["title"] == "Thread 1 Updated" + + def test_list_threads_fallback_checkpointer(self, client): + mock_checkpointer = MagicMock() + mock_checkpointer.list.return_value = [] + + with patch("deerflow.agents.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer): + # No internal checkpointer, should fetch from provider + result = client.list_threads() + + assert result == {"thread_list": []} + mock_checkpointer.list.assert_called_once() + + def test_get_thread(self, client): + mock_checkpointer = MagicMock() + client._checkpointer = mock_checkpointer + + msg1 = HumanMessage(content="Hello", id="m1") + msg2 = AIMessage(content="Hi there", id="m2") + + cp1 = self._make_mock_checkpoint_tuple("t1", "c1", "2023-01-01T10:00:00Z", messages=[msg1]) + cp2 = self._make_mock_checkpoint_tuple("t1", "c2", "2023-01-01T10:01:00Z", parent_id="c1", messages=[msg1, msg2], pending_writes=[("task_1", "messages", {"text": "pending"})]) + cp3_no_ts = self._make_mock_checkpoint_tuple("t1", "c3", None) + + # checkpointer.list yields in reverse time or random order, test sorting + mock_checkpointer.list.return_value = [cp2, cp1, cp3_no_ts] + + result = client.get_thread("t1") + + mock_checkpointer.list.assert_called_once_with({"configurable": {"thread_id": "t1"}}) + + assert result["thread_id"] == "t1" + checkpoints = result["checkpoints"] + assert len(checkpoints) == 3 + + # None timestamp remains None but is sorted first via a fallback key + assert checkpoints[0]["checkpoint_id"] == "c3" + assert checkpoints[0]["ts"] is None + + # Should be sorted by timestamp globally + assert checkpoints[1]["checkpoint_id"] == "c1" + assert checkpoints[1]["ts"] == "2023-01-01T10:00:00Z" + assert len(checkpoints[1]["values"]["messages"]) == 1 + + assert checkpoints[2]["checkpoint_id"] == "c2" + assert checkpoints[2]["parent_checkpoint_id"] == "c1" + assert checkpoints[2]["ts"] == "2023-01-01T10:01:00Z" + assert len(checkpoints[2]["values"]["messages"]) == 2 + # Verify message serialization + assert checkpoints[2]["values"]["messages"][1]["content"] == "Hi there" + + # Verify pending writes + assert len(checkpoints[2]["pending_writes"]) == 1 + assert checkpoints[2]["pending_writes"][0]["task_id"] == "task_1" + assert checkpoints[2]["pending_writes"][0]["channel"] == "messages" + + def test_get_thread_fallback_checkpointer(self, client): + mock_checkpointer = MagicMock() + mock_checkpointer.list.return_value = [] + + with patch("deerflow.agents.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer): + result = client.get_thread("t99") + + assert result["thread_id"] == "t99" + assert result["checkpoints"] == [] + mock_checkpointer.list.assert_called_once_with({"configurable": {"thread_id": "t99"}}) + + # --------------------------------------------------------------------------- # MCP config # ---------------------------------------------------------------------------