diff --git a/backend/packages/harness/deerflow/runtime/events/store/base.py b/backend/packages/harness/deerflow/runtime/events/store/base.py index e5da4ed82..df5136ba5 100644 --- a/backend/packages/harness/deerflow/runtime/events/store/base.py +++ b/backend/packages/harness/deerflow/runtime/events/store/base.py @@ -83,8 +83,18 @@ class RunEventStore(abc.ABC): self, thread_id: str, run_id: str, + *, + limit: int = 50, + before_seq: int | None = None, + after_seq: int | None = None, ) -> list[dict]: - """Return displayable messages (category=message) for a specific run, ordered by seq ascending.""" + """Return displayable messages (category=message) for a specific run, ordered by seq ascending. + + Supports bidirectional cursor pagination: + - after_seq: return the first ``limit`` records with seq > after_seq (ascending) + - before_seq: return the last ``limit`` records with seq < before_seq (ascending) + - neither: return the latest ``limit`` records (ascending) + """ @abc.abstractmethod async def count_messages(self, thread_id: str) -> int: diff --git a/backend/packages/harness/deerflow/runtime/events/store/db.py b/backend/packages/harness/deerflow/runtime/events/store/db.py index 63328db43..e4a21d006 100644 --- a/backend/packages/harness/deerflow/runtime/events/store/db.py +++ b/backend/packages/harness/deerflow/runtime/events/store/db.py @@ -205,16 +205,35 @@ class DbRunEventStore(RunEventStore): thread_id, run_id, *, + limit=50, + before_seq=None, + after_seq=None, user_id: str | None | _AutoSentinel = AUTO, ): resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages_by_run") - stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message") + stmt = select(RunEventRow).where( + RunEventRow.thread_id == thread_id, + RunEventRow.run_id == run_id, + RunEventRow.category == "message", + ) if resolved_user_id is not None: stmt = stmt.where(RunEventRow.user_id == resolved_user_id) - stmt = stmt.order_by(RunEventRow.seq.asc()) - async with self._sf() as session: - result = await session.execute(stmt) - return [self._row_to_dict(r) for r in result.scalars()] + if before_seq is not None: + stmt = stmt.where(RunEventRow.seq < before_seq) + if after_seq is not None: + stmt = stmt.where(RunEventRow.seq > after_seq) + + if after_seq is not None: + stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit) + async with self._sf() as session: + result = await session.execute(stmt) + return [self._row_to_dict(r) for r in result.scalars()] + else: + stmt = stmt.order_by(RunEventRow.seq.desc()).limit(limit) + async with self._sf() as session: + result = await session.execute(stmt) + rows = list(result.scalars()) + return [self._row_to_dict(r) for r in reversed(rows)] async def count_messages( self, diff --git a/backend/packages/harness/deerflow/runtime/events/store/jsonl.py b/backend/packages/harness/deerflow/runtime/events/store/jsonl.py index 1a4aac38c..378713afc 100644 --- a/backend/packages/harness/deerflow/runtime/events/store/jsonl.py +++ b/backend/packages/harness/deerflow/runtime/events/store/jsonl.py @@ -152,9 +152,17 @@ class JsonlRunEventStore(RunEventStore): events = [e for e in events if e.get("event_type") in event_types] return events[:limit] - async def list_messages_by_run(self, thread_id, run_id): + async def list_messages_by_run(self, thread_id, run_id, *, limit=50, before_seq=None, after_seq=None): events = self._read_run_events(thread_id, run_id) - return [e for e in events if e.get("category") == "message"] + filtered = [e for e in events if e.get("category") == "message"] + if before_seq is not None: + filtered = [e for e in filtered if e.get("seq", 0) < before_seq] + if after_seq is not None: + filtered = [e for e in filtered if e.get("seq", 0) > after_seq] + if after_seq is not None: + return filtered[:limit] + else: + return filtered[-limit:] if len(filtered) > limit else filtered async def count_messages(self, thread_id): all_events = self._read_thread_events(thread_id) diff --git a/backend/packages/harness/deerflow/runtime/events/store/memory.py b/backend/packages/harness/deerflow/runtime/events/store/memory.py index 889159086..cf70e1cdf 100644 --- a/backend/packages/harness/deerflow/runtime/events/store/memory.py +++ b/backend/packages/harness/deerflow/runtime/events/store/memory.py @@ -97,9 +97,17 @@ class MemoryRunEventStore(RunEventStore): filtered = [e for e in filtered if e["event_type"] in event_types] return filtered[:limit] - async def list_messages_by_run(self, thread_id, run_id): + async def list_messages_by_run(self, thread_id, run_id, *, limit=50, before_seq=None, after_seq=None): all_events = self._events.get(thread_id, []) - return [e for e in all_events if e["run_id"] == run_id and e["category"] == "message"] + filtered = [e for e in all_events if e["run_id"] == run_id and e["category"] == "message"] + if before_seq is not None: + filtered = [e for e in filtered if e["seq"] < before_seq] + if after_seq is not None: + filtered = [e for e in filtered if e["seq"] > after_seq] + if after_seq is not None: + return filtered[:limit] + else: + return filtered[-limit:] if len(filtered) > limit else filtered async def count_messages(self, thread_id): all_events = self._events.get(thread_id, []) diff --git a/backend/tests/test_run_event_store_pagination.py b/backend/tests/test_run_event_store_pagination.py new file mode 100644 index 000000000..ac5ba4c2d --- /dev/null +++ b/backend/tests/test_run_event_store_pagination.py @@ -0,0 +1,107 @@ +"""Tests for paginated list_messages_by_run across all RunEventStore backends.""" +import pytest + +from deerflow.runtime.events.store.memory import MemoryRunEventStore + + +@pytest.fixture +def base_store(): + return MemoryRunEventStore() + + +@pytest.mark.anyio +async def test_list_messages_by_run_default_returns_all(base_store): + store = base_store + for i in range(7): + await store.put( + thread_id="t1", run_id="run-a", + event_type="human_message" if i % 2 == 0 else "ai_message", + category="message", content=f"msg-a-{i}", + ) + for i in range(3): + await store.put( + thread_id="t1", run_id="run-b", + event_type="human_message", category="message", content=f"msg-b-{i}", + ) + await store.put(thread_id="t1", run_id="run-a", event_type="tool_call", category="trace", content="trace") + + msgs = await store.list_messages_by_run("t1", "run-a") + assert len(msgs) == 7 + assert all(m["category"] == "message" for m in msgs) + assert all(m["run_id"] == "run-a" for m in msgs) + + +@pytest.mark.anyio +async def test_list_messages_by_run_with_limit(base_store): + store = base_store + for i in range(7): + await store.put( + thread_id="t1", run_id="run-a", + event_type="human_message" if i % 2 == 0 else "ai_message", + category="message", content=f"msg-a-{i}", + ) + + msgs = await store.list_messages_by_run("t1", "run-a", limit=3) + assert len(msgs) == 3 + seqs = [m["seq"] for m in msgs] + assert seqs == sorted(seqs) + + +@pytest.mark.anyio +async def test_list_messages_by_run_after_seq(base_store): + store = base_store + for i in range(7): + await store.put( + thread_id="t1", run_id="run-a", + event_type="human_message" if i % 2 == 0 else "ai_message", + category="message", content=f"msg-a-{i}", + ) + + all_msgs = await store.list_messages_by_run("t1", "run-a") + cursor_seq = all_msgs[2]["seq"] + msgs = await store.list_messages_by_run("t1", "run-a", after_seq=cursor_seq, limit=50) + assert all(m["seq"] > cursor_seq for m in msgs) + assert len(msgs) == 4 + + +@pytest.mark.anyio +async def test_list_messages_by_run_before_seq(base_store): + store = base_store + for i in range(7): + await store.put( + thread_id="t1", run_id="run-a", + event_type="human_message" if i % 2 == 0 else "ai_message", + category="message", content=f"msg-a-{i}", + ) + + all_msgs = await store.list_messages_by_run("t1", "run-a") + cursor_seq = all_msgs[4]["seq"] + msgs = await store.list_messages_by_run("t1", "run-a", before_seq=cursor_seq, limit=50) + assert all(m["seq"] < cursor_seq for m in msgs) + assert len(msgs) == 4 + + +@pytest.mark.anyio +async def test_list_messages_by_run_does_not_include_other_run(base_store): + store = base_store + for i in range(7): + await store.put( + thread_id="t1", run_id="run-a", + event_type="human_message", category="message", content=f"msg-a-{i}", + ) + for i in range(3): + await store.put( + thread_id="t1", run_id="run-b", + event_type="human_message", category="message", content=f"msg-b-{i}", + ) + + msgs = await store.list_messages_by_run("t1", "run-b") + assert len(msgs) == 3 + assert all(m["run_id"] == "run-b" for m in msgs) + + +@pytest.mark.anyio +async def test_list_messages_by_run_empty_run(base_store): + store = base_store + msgs = await store.list_messages_by_run("t1", "nonexistent") + assert msgs == []