mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
feat(client): add thread query methods list_threads and get_thread (#1609)
* feat(client): add thread query methods `list_threads` and `get_thread` Implemented two public API methods in `DeerFlowClient` to query threads using the underlying `checkpointer`. * Update backend/packages/harness/deerflow/client.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update backend/packages/harness/deerflow/client.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update backend/tests/test_client.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update backend/packages/harness/deerflow/client.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix(deerflow): Fix possible KeyError issue when sorting threads * fix unit test --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
ad6d934a5f
commit
31a3c9a3de
@ -315,6 +315,108 @@ class DeerFlowClient:
|
||||
return "\n".join(pieces) if pieces else ""
|
||||
return str(content)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API — threads
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def list_threads(self, limit: int = 10) -> dict:
|
||||
"""List the recent N threads.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of threads to return. Default is 10.
|
||||
|
||||
Returns:
|
||||
Dict with "thread_list" key containing list of thread info dicts,
|
||||
sorted by thread creation time descending.
|
||||
"""
|
||||
checkpointer = self._checkpointer
|
||||
if checkpointer is None:
|
||||
from deerflow.agents.checkpointer.provider import get_checkpointer
|
||||
|
||||
checkpointer = get_checkpointer()
|
||||
|
||||
thread_info_map = {}
|
||||
|
||||
for cp in checkpointer.list(config=None, limit=limit):
|
||||
cfg = cp.config.get("configurable", {})
|
||||
thread_id = cfg.get("thread_id")
|
||||
if not thread_id:
|
||||
continue
|
||||
|
||||
ts = cp.checkpoint.get("ts")
|
||||
checkpoint_id = cfg.get("checkpoint_id")
|
||||
|
||||
if thread_id not in thread_info_map:
|
||||
channel_values = cp.checkpoint.get("channel_values", {})
|
||||
thread_info_map[thread_id] = {
|
||||
"thread_id": thread_id,
|
||||
"created_at": ts,
|
||||
"updated_at": ts,
|
||||
"latest_checkpoint_id": checkpoint_id,
|
||||
"title": channel_values.get("title"),
|
||||
}
|
||||
else:
|
||||
# Explicitly compare timestamps to ensure accuracy when iterating over unordered namespaces.
|
||||
# Treat None as "missing" and only compare when existing values are non-None.
|
||||
if ts is not None:
|
||||
current_created = thread_info_map[thread_id]["created_at"]
|
||||
if current_created is None or ts < current_created:
|
||||
thread_info_map[thread_id]["created_at"] = ts
|
||||
|
||||
current_updated = thread_info_map[thread_id]["updated_at"]
|
||||
if current_updated is None or ts > current_updated:
|
||||
thread_info_map[thread_id]["updated_at"] = ts
|
||||
thread_info_map[thread_id]["latest_checkpoint_id"] = checkpoint_id
|
||||
channel_values = cp.checkpoint.get("channel_values", {})
|
||||
thread_info_map[thread_id]["title"] = channel_values.get("title")
|
||||
|
||||
threads = list(thread_info_map.values())
|
||||
threads.sort(key=lambda x: x.get("created_at") or "", reverse=True)
|
||||
|
||||
return {"thread_list": threads[:limit]}
|
||||
|
||||
def get_thread(self, thread_id: str) -> dict:
|
||||
"""Get the complete thread record, including all node execution records.
|
||||
|
||||
Args:
|
||||
thread_id: Thread ID.
|
||||
|
||||
Returns:
|
||||
Dict containing the thread's full checkpoint history.
|
||||
"""
|
||||
checkpointer = self._checkpointer
|
||||
if checkpointer is None:
|
||||
from deerflow.agents.checkpointer.provider import get_checkpointer
|
||||
|
||||
checkpointer = get_checkpointer()
|
||||
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
checkpoints = []
|
||||
|
||||
for cp in checkpointer.list(config):
|
||||
channel_values = dict(cp.checkpoint.get("channel_values", {}))
|
||||
if "messages" in channel_values:
|
||||
channel_values["messages"] = [self._serialize_message(m) if hasattr(m, "content") else m for m in channel_values["messages"]]
|
||||
|
||||
cfg = cp.config.get("configurable", {})
|
||||
parent_cfg = cp.parent_config.get("configurable", {}) if cp.parent_config else {}
|
||||
|
||||
checkpoints.append(
|
||||
{
|
||||
"checkpoint_id": cfg.get("checkpoint_id"),
|
||||
"parent_checkpoint_id": parent_cfg.get("checkpoint_id"),
|
||||
"ts": cp.checkpoint.get("ts"),
|
||||
"metadata": cp.metadata,
|
||||
"values": channel_values,
|
||||
"pending_writes": [{"task_id": w[0], "channel": w[1], "value": w[2]} for w in getattr(cp, "pending_writes", [])],
|
||||
}
|
||||
)
|
||||
|
||||
# Sort globally by timestamp to prevent partial ordering issues caused by different namespaces (e.g., subgraphs)
|
||||
checkpoints.sort(key=lambda x: x["ts"] if x["ts"] else "")
|
||||
|
||||
return {"thread_id": thread_id, "checkpoints": checkpoints}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API — conversation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@ -570,6 +570,147 @@ class TestGetModel:
|
||||
assert client.get_model("nonexistent") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thread Queries (list_threads / get_thread)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestThreadQueries:
|
||||
def _make_mock_checkpoint_tuple(
|
||||
self,
|
||||
thread_id: str,
|
||||
checkpoint_id: str,
|
||||
ts: str,
|
||||
title: str | None = None,
|
||||
parent_id: str | None = None,
|
||||
messages: list = None,
|
||||
pending_writes: list = None,
|
||||
):
|
||||
cp = MagicMock()
|
||||
cp.config = {"configurable": {"thread_id": thread_id, "checkpoint_id": checkpoint_id}}
|
||||
|
||||
channel_values = {}
|
||||
if title is not None:
|
||||
channel_values["title"] = title
|
||||
if messages is not None:
|
||||
channel_values["messages"] = messages
|
||||
|
||||
cp.checkpoint = {"ts": ts, "channel_values": channel_values}
|
||||
cp.metadata = {"source": "test"}
|
||||
|
||||
if parent_id:
|
||||
cp.parent_config = {"configurable": {"thread_id": thread_id, "checkpoint_id": parent_id}}
|
||||
else:
|
||||
cp.parent_config = {}
|
||||
|
||||
cp.pending_writes = pending_writes or []
|
||||
return cp
|
||||
|
||||
def test_list_threads_empty(self, client):
|
||||
mock_checkpointer = MagicMock()
|
||||
mock_checkpointer.list.return_value = []
|
||||
client._checkpointer = mock_checkpointer
|
||||
|
||||
result = client.list_threads()
|
||||
assert result == {"thread_list": []}
|
||||
mock_checkpointer.list.assert_called_once_with(config=None, limit=10)
|
||||
|
||||
def test_list_threads_basic(self, client):
|
||||
mock_checkpointer = MagicMock()
|
||||
client._checkpointer = mock_checkpointer
|
||||
|
||||
cp1 = self._make_mock_checkpoint_tuple("t1", "c1", "2023-01-01T10:00:00Z", title="Thread 1")
|
||||
cp2 = self._make_mock_checkpoint_tuple("t1", "c2", "2023-01-01T10:05:00Z", title="Thread 1 Updated")
|
||||
cp3 = self._make_mock_checkpoint_tuple("t2", "c3", "2023-01-02T10:00:00Z", title="Thread 2")
|
||||
cp_empty = self._make_mock_checkpoint_tuple("", "c4", "2023-01-03T10:00:00Z", title="Thread Empty")
|
||||
|
||||
# Mock list returns out of order to test the timestamp sorting/comparison
|
||||
# Also includes a checkpoint with an empty thread_id which should be skipped
|
||||
mock_checkpointer.list.return_value = [cp2, cp1, cp_empty, cp3]
|
||||
|
||||
result = client.list_threads(limit=5)
|
||||
mock_checkpointer.list.assert_called_once_with(config=None, limit=5)
|
||||
|
||||
threads = result["thread_list"]
|
||||
assert len(threads) == 2
|
||||
|
||||
# t2 should be first because its created_at (2023-01-02) is newer than t1 (2023-01-01)
|
||||
assert threads[0]["thread_id"] == "t2"
|
||||
assert threads[0]["created_at"] == "2023-01-02T10:00:00Z"
|
||||
assert threads[0]["title"] == "Thread 2"
|
||||
|
||||
assert threads[1]["thread_id"] == "t1"
|
||||
assert threads[1]["created_at"] == "2023-01-01T10:00:00Z"
|
||||
assert threads[1]["updated_at"] == "2023-01-01T10:05:00Z"
|
||||
assert threads[1]["latest_checkpoint_id"] == "c2"
|
||||
assert threads[1]["title"] == "Thread 1 Updated"
|
||||
|
||||
def test_list_threads_fallback_checkpointer(self, client):
|
||||
mock_checkpointer = MagicMock()
|
||||
mock_checkpointer.list.return_value = []
|
||||
|
||||
with patch("deerflow.agents.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer):
|
||||
# No internal checkpointer, should fetch from provider
|
||||
result = client.list_threads()
|
||||
|
||||
assert result == {"thread_list": []}
|
||||
mock_checkpointer.list.assert_called_once()
|
||||
|
||||
def test_get_thread(self, client):
|
||||
mock_checkpointer = MagicMock()
|
||||
client._checkpointer = mock_checkpointer
|
||||
|
||||
msg1 = HumanMessage(content="Hello", id="m1")
|
||||
msg2 = AIMessage(content="Hi there", id="m2")
|
||||
|
||||
cp1 = self._make_mock_checkpoint_tuple("t1", "c1", "2023-01-01T10:00:00Z", messages=[msg1])
|
||||
cp2 = self._make_mock_checkpoint_tuple("t1", "c2", "2023-01-01T10:01:00Z", parent_id="c1", messages=[msg1, msg2], pending_writes=[("task_1", "messages", {"text": "pending"})])
|
||||
cp3_no_ts = self._make_mock_checkpoint_tuple("t1", "c3", None)
|
||||
|
||||
# checkpointer.list yields in reverse time or random order, test sorting
|
||||
mock_checkpointer.list.return_value = [cp2, cp1, cp3_no_ts]
|
||||
|
||||
result = client.get_thread("t1")
|
||||
|
||||
mock_checkpointer.list.assert_called_once_with({"configurable": {"thread_id": "t1"}})
|
||||
|
||||
assert result["thread_id"] == "t1"
|
||||
checkpoints = result["checkpoints"]
|
||||
assert len(checkpoints) == 3
|
||||
|
||||
# None timestamp remains None but is sorted first via a fallback key
|
||||
assert checkpoints[0]["checkpoint_id"] == "c3"
|
||||
assert checkpoints[0]["ts"] is None
|
||||
|
||||
# Should be sorted by timestamp globally
|
||||
assert checkpoints[1]["checkpoint_id"] == "c1"
|
||||
assert checkpoints[1]["ts"] == "2023-01-01T10:00:00Z"
|
||||
assert len(checkpoints[1]["values"]["messages"]) == 1
|
||||
|
||||
assert checkpoints[2]["checkpoint_id"] == "c2"
|
||||
assert checkpoints[2]["parent_checkpoint_id"] == "c1"
|
||||
assert checkpoints[2]["ts"] == "2023-01-01T10:01:00Z"
|
||||
assert len(checkpoints[2]["values"]["messages"]) == 2
|
||||
# Verify message serialization
|
||||
assert checkpoints[2]["values"]["messages"][1]["content"] == "Hi there"
|
||||
|
||||
# Verify pending writes
|
||||
assert len(checkpoints[2]["pending_writes"]) == 1
|
||||
assert checkpoints[2]["pending_writes"][0]["task_id"] == "task_1"
|
||||
assert checkpoints[2]["pending_writes"][0]["channel"] == "messages"
|
||||
|
||||
def test_get_thread_fallback_checkpointer(self, client):
|
||||
mock_checkpointer = MagicMock()
|
||||
mock_checkpointer.list.return_value = []
|
||||
|
||||
with patch("deerflow.agents.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer):
|
||||
result = client.get_thread("t99")
|
||||
|
||||
assert result["thread_id"] == "t99"
|
||||
assert result["checkpoints"] == []
|
||||
mock_checkpointer.list.assert_called_once_with({"configurable": {"thread_id": "t99"}})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCP config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user