feat(events): add pagination to list_messages_by_run on all store backends

Replicates the existing before_seq/after_seq/limit cursor-pagination pattern
from list_messages onto list_messages_by_run across the abstract interface,
MemoryRunEventStore, JsonlRunEventStore, and DbRunEventStore.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
rayhpeng 2026-04-12 15:58:33 +08:00
parent a36186cf54
commit 82374eb18c
5 changed files with 162 additions and 10 deletions

View File

@ -83,8 +83,18 @@ class RunEventStore(abc.ABC):
self,
thread_id: str,
run_id: str,
*,
limit: int = 50,
before_seq: int | None = None,
after_seq: int | None = None,
) -> list[dict]:
"""Return displayable messages (category=message) for a specific run, ordered by seq ascending."""
"""Return displayable messages (category=message) for a specific run, ordered by seq ascending.
Supports bidirectional cursor pagination:
- after_seq: return the first ``limit`` records with seq > after_seq (ascending)
- before_seq: return the last ``limit`` records with seq < before_seq (ascending)
- neither: return the latest ``limit`` records (ascending)
"""
@abc.abstractmethod
async def count_messages(self, thread_id: str) -> int:

View File

@ -205,16 +205,35 @@ class DbRunEventStore(RunEventStore):
thread_id,
run_id,
*,
limit=50,
before_seq=None,
after_seq=None,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages_by_run")
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message")
stmt = select(RunEventRow).where(
RunEventRow.thread_id == thread_id,
RunEventRow.run_id == run_id,
RunEventRow.category == "message",
)
if resolved_user_id is not None:
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
stmt = stmt.order_by(RunEventRow.seq.asc())
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
if before_seq is not None:
stmt = stmt.where(RunEventRow.seq < before_seq)
if after_seq is not None:
stmt = stmt.where(RunEventRow.seq > after_seq)
if after_seq is not None:
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
else:
stmt = stmt.order_by(RunEventRow.seq.desc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
rows = list(result.scalars())
return [self._row_to_dict(r) for r in reversed(rows)]
async def count_messages(
self,

View File

@ -152,9 +152,17 @@ class JsonlRunEventStore(RunEventStore):
events = [e for e in events if e.get("event_type") in event_types]
return events[:limit]
async def list_messages_by_run(self, thread_id, run_id):
async def list_messages_by_run(self, thread_id, run_id, *, limit=50, before_seq=None, after_seq=None):
events = self._read_run_events(thread_id, run_id)
return [e for e in events if e.get("category") == "message"]
filtered = [e for e in events if e.get("category") == "message"]
if before_seq is not None:
filtered = [e for e in filtered if e.get("seq", 0) < before_seq]
if after_seq is not None:
filtered = [e for e in filtered if e.get("seq", 0) > after_seq]
if after_seq is not None:
return filtered[:limit]
else:
return filtered[-limit:] if len(filtered) > limit else filtered
async def count_messages(self, thread_id):
all_events = self._read_thread_events(thread_id)

View File

@ -97,9 +97,17 @@ class MemoryRunEventStore(RunEventStore):
filtered = [e for e in filtered if e["event_type"] in event_types]
return filtered[:limit]
async def list_messages_by_run(self, thread_id, run_id):
async def list_messages_by_run(self, thread_id, run_id, *, limit=50, before_seq=None, after_seq=None):
all_events = self._events.get(thread_id, [])
return [e for e in all_events if e["run_id"] == run_id and e["category"] == "message"]
filtered = [e for e in all_events if e["run_id"] == run_id and e["category"] == "message"]
if before_seq is not None:
filtered = [e for e in filtered if e["seq"] < before_seq]
if after_seq is not None:
filtered = [e for e in filtered if e["seq"] > after_seq]
if after_seq is not None:
return filtered[:limit]
else:
return filtered[-limit:] if len(filtered) > limit else filtered
async def count_messages(self, thread_id):
all_events = self._events.get(thread_id, [])

View File

@ -0,0 +1,107 @@
"""Tests for paginated list_messages_by_run across all RunEventStore backends."""
import pytest
from deerflow.runtime.events.store.memory import MemoryRunEventStore
@pytest.fixture
def base_store():
return MemoryRunEventStore()
@pytest.mark.anyio
async def test_list_messages_by_run_default_returns_all(base_store):
store = base_store
for i in range(7):
await store.put(
thread_id="t1", run_id="run-a",
event_type="human_message" if i % 2 == 0 else "ai_message",
category="message", content=f"msg-a-{i}",
)
for i in range(3):
await store.put(
thread_id="t1", run_id="run-b",
event_type="human_message", category="message", content=f"msg-b-{i}",
)
await store.put(thread_id="t1", run_id="run-a", event_type="tool_call", category="trace", content="trace")
msgs = await store.list_messages_by_run("t1", "run-a")
assert len(msgs) == 7
assert all(m["category"] == "message" for m in msgs)
assert all(m["run_id"] == "run-a" for m in msgs)
@pytest.mark.anyio
async def test_list_messages_by_run_with_limit(base_store):
store = base_store
for i in range(7):
await store.put(
thread_id="t1", run_id="run-a",
event_type="human_message" if i % 2 == 0 else "ai_message",
category="message", content=f"msg-a-{i}",
)
msgs = await store.list_messages_by_run("t1", "run-a", limit=3)
assert len(msgs) == 3
seqs = [m["seq"] for m in msgs]
assert seqs == sorted(seqs)
@pytest.mark.anyio
async def test_list_messages_by_run_after_seq(base_store):
store = base_store
for i in range(7):
await store.put(
thread_id="t1", run_id="run-a",
event_type="human_message" if i % 2 == 0 else "ai_message",
category="message", content=f"msg-a-{i}",
)
all_msgs = await store.list_messages_by_run("t1", "run-a")
cursor_seq = all_msgs[2]["seq"]
msgs = await store.list_messages_by_run("t1", "run-a", after_seq=cursor_seq, limit=50)
assert all(m["seq"] > cursor_seq for m in msgs)
assert len(msgs) == 4
@pytest.mark.anyio
async def test_list_messages_by_run_before_seq(base_store):
store = base_store
for i in range(7):
await store.put(
thread_id="t1", run_id="run-a",
event_type="human_message" if i % 2 == 0 else "ai_message",
category="message", content=f"msg-a-{i}",
)
all_msgs = await store.list_messages_by_run("t1", "run-a")
cursor_seq = all_msgs[4]["seq"]
msgs = await store.list_messages_by_run("t1", "run-a", before_seq=cursor_seq, limit=50)
assert all(m["seq"] < cursor_seq for m in msgs)
assert len(msgs) == 4
@pytest.mark.anyio
async def test_list_messages_by_run_does_not_include_other_run(base_store):
store = base_store
for i in range(7):
await store.put(
thread_id="t1", run_id="run-a",
event_type="human_message", category="message", content=f"msg-a-{i}",
)
for i in range(3):
await store.put(
thread_id="t1", run_id="run-b",
event_type="human_message", category="message", content=f"msg-b-{i}",
)
msgs = await store.list_messages_by_run("t1", "run-b")
assert len(msgs) == 3
assert all(m["run_id"] == "run-b" for m in msgs)
@pytest.mark.anyio
async def test_list_messages_by_run_empty_run(base_store):
store = base_store
msgs = await store.list_messages_by_run("t1", "nonexistent")
assert msgs == []