mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-30 12:28:10 +00:00
fix(runtime): harden JSONL async I/O and DB put_batch thread validation (#3084)
* fix(runtime): harden JSONL async I/O and DB put_batch thread validation (#2816) - JsonlRunEventStore: offload all file I/O to asyncio.to_thread() so the event loop is never blocked; add per-thread asyncio.Lock to serialise concurrent puts and prevent interleaved JSONL lines - Split _ensure_seq_loaded into a sync _compute_max_seq (runs in thread) and an async wrapper; seq counter is recovered from disk on fresh store init - DbRunEventStore.put_batch: raise ValueError when events span multiple thread_ids (previously silently assumed same thread) - Add test_jsonl_event_store_async_io.py: 12 tests covering lock reuse, concurrent seq monotonicity, disk recovery, and mixed-thread batch rejection Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix: address Copilot review comments - delete_by_thread: pop _write_locks after releasing the lock to prevent unbounded growth when threads are repeatedly created and deleted - tests: add regression guard asserting asyncio.to_thread is called for _write_record in put(); assert _write_locks entry removed on delete * fix(lint): move patch import to local scope to fix ruff I001 * fix(lint): apply ruff check+format fixes to test file * fix(runtime): address review feedback for JSONL async I/O hardening (#2816) Use setdefault for atomic lock init in _get_write_lock; pop _write_locks inside the held lock scope in delete_by_thread; update test docstring and assert lock entry also cleared on delete. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: rayhpeng <rayhpeng@gmail.com>
This commit is contained in:
parent
d46a5779bc
commit
cbf8b194e8
@ -144,10 +144,13 @@ class DbRunEventStore(RunEventStore):
|
||||
async def put_batch(self, events):
|
||||
if not events:
|
||||
return []
|
||||
thread_ids = {e["thread_id"] for e in events}
|
||||
if len(thread_ids) > 1:
|
||||
raise ValueError(f"put_batch requires all events to belong to the same thread; got {thread_ids!r}")
|
||||
user_id = self._user_id_from_context()
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
# Get max seq for the thread (assume all events in batch belong to same thread).
|
||||
# All events belong to the same thread (validated above).
|
||||
thread_id = events[0]["thread_id"]
|
||||
max_seq = await self._max_seq_for_thread(session, thread_id)
|
||||
seq = max_seq or 0
|
||||
|
||||
@ -6,6 +6,15 @@ Each run's events are stored in a single file:
|
||||
All categories (message, trace, lifecycle) are in the same file.
|
||||
This backend is suitable for lightweight single-node deployments.
|
||||
|
||||
**Single-process guarantee**: the in-memory seq counter is process-local.
|
||||
Multi-process deployments sharing the same directory will produce duplicate
|
||||
or non-monotonic seq values. Use ``DbRunEventStore`` for multi-process or
|
||||
high-concurrency deployments.
|
||||
|
||||
File I/O is offloaded to a thread pool via ``asyncio.to_thread`` so the
|
||||
event loop is never blocked. Per-thread ``asyncio.Lock`` objects serialise
|
||||
writes within a single process to prevent interleaved JSONL lines.
|
||||
|
||||
Known trade-off: ``list_messages()`` must scan all run files for a
|
||||
thread since messages from multiple runs need unified seq ordering.
|
||||
``list_events()`` reads only one file -- the fast path.
|
||||
@ -13,6 +22,7 @@ thread since messages from multiple runs need unified seq ordering.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
@ -30,6 +40,11 @@ class JsonlRunEventStore(RunEventStore):
|
||||
def __init__(self, base_dir: str | Path | None = None):
|
||||
self._base_dir = Path(base_dir) if base_dir else Path(".deer-flow")
|
||||
self._seq_counters: dict[str, int] = {} # thread_id -> current max seq
|
||||
# Per-thread asyncio.Lock — serialises concurrent writes within one process.
|
||||
self._write_locks: dict[str, asyncio.Lock] = {}
|
||||
|
||||
def _get_write_lock(self, thread_id: str) -> asyncio.Lock:
|
||||
return self._write_locks.setdefault(thread_id, asyncio.Lock())
|
||||
|
||||
@staticmethod
|
||||
def _validate_id(value: str, label: str) -> str:
|
||||
@ -50,10 +65,8 @@ class JsonlRunEventStore(RunEventStore):
|
||||
self._seq_counters[thread_id] = self._seq_counters.get(thread_id, 0) + 1
|
||||
return self._seq_counters[thread_id]
|
||||
|
||||
def _ensure_seq_loaded(self, thread_id: str) -> None:
|
||||
"""Load max seq from existing files if not yet cached."""
|
||||
if thread_id in self._seq_counters:
|
||||
return
|
||||
def _compute_max_seq(self, thread_id: str) -> int:
|
||||
"""Scan all run files for a thread and return the current max seq (blocking I/O)."""
|
||||
max_seq = 0
|
||||
thread_dir = self._thread_dir(thread_id)
|
||||
if thread_dir.exists():
|
||||
@ -64,7 +77,13 @@ class JsonlRunEventStore(RunEventStore):
|
||||
max_seq = max(max_seq, record.get("seq", 0))
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Skipping malformed JSONL line in %s", f)
|
||||
continue
|
||||
return max_seq
|
||||
|
||||
async def _ensure_seq_loaded(self, thread_id: str) -> None:
|
||||
"""Load max seq from existing files into the in-memory counter (non-blocking)."""
|
||||
if thread_id in self._seq_counters:
|
||||
return
|
||||
max_seq = await asyncio.to_thread(self._compute_max_seq, thread_id)
|
||||
self._seq_counters[thread_id] = max_seq
|
||||
|
||||
def _write_record(self, record: dict) -> None:
|
||||
@ -74,7 +93,7 @@ class JsonlRunEventStore(RunEventStore):
|
||||
f.write(json.dumps(record, default=str, ensure_ascii=False) + "\n")
|
||||
|
||||
def _read_thread_events(self, thread_id: str) -> list[dict]:
|
||||
"""Read all events for a thread, sorted by seq."""
|
||||
"""Read all events for a thread, sorted by seq (blocking I/O)."""
|
||||
events = []
|
||||
thread_dir = self._thread_dir(thread_id)
|
||||
if not thread_dir.exists():
|
||||
@ -87,12 +106,11 @@ class JsonlRunEventStore(RunEventStore):
|
||||
events.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Skipping malformed JSONL line in %s", f)
|
||||
continue
|
||||
events.sort(key=lambda e: e.get("seq", 0))
|
||||
return events
|
||||
|
||||
def _read_run_events(self, thread_id: str, run_id: str) -> list[dict]:
|
||||
"""Read events for a specific run file."""
|
||||
"""Read events for a specific run file (blocking I/O)."""
|
||||
path = self._run_file(thread_id, run_id)
|
||||
if not path.exists():
|
||||
return []
|
||||
@ -104,25 +122,36 @@ class JsonlRunEventStore(RunEventStore):
|
||||
events.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Skipping malformed JSONL line in %s", path)
|
||||
continue
|
||||
events.sort(key=lambda e: e.get("seq", 0))
|
||||
return events
|
||||
|
||||
def _delete_thread_files(self, thread_id: str) -> None:
|
||||
thread_dir = self._thread_dir(thread_id)
|
||||
if thread_dir.exists():
|
||||
for f in thread_dir.glob("*.jsonl"):
|
||||
f.unlink()
|
||||
|
||||
def _delete_run_file(self, thread_id: str, run_id: str) -> None:
|
||||
path = self._run_file(thread_id, run_id)
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
|
||||
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None):
|
||||
self._ensure_seq_loaded(thread_id)
|
||||
seq = self._next_seq(thread_id)
|
||||
record = {
|
||||
"thread_id": thread_id,
|
||||
"run_id": run_id,
|
||||
"event_type": event_type,
|
||||
"category": category,
|
||||
"content": content,
|
||||
"metadata": metadata or {},
|
||||
"seq": seq,
|
||||
"created_at": created_at or datetime.now(UTC).isoformat(),
|
||||
}
|
||||
self._write_record(record)
|
||||
return record
|
||||
async with self._get_write_lock(thread_id):
|
||||
await self._ensure_seq_loaded(thread_id)
|
||||
seq = self._next_seq(thread_id)
|
||||
record = {
|
||||
"thread_id": thread_id,
|
||||
"run_id": run_id,
|
||||
"event_type": event_type,
|
||||
"category": category,
|
||||
"content": content,
|
||||
"metadata": metadata or {},
|
||||
"seq": seq,
|
||||
"created_at": created_at or datetime.now(UTC).isoformat(),
|
||||
}
|
||||
await asyncio.to_thread(self._write_record, record)
|
||||
return record
|
||||
|
||||
async def put_batch(self, events):
|
||||
if not events:
|
||||
@ -134,7 +163,7 @@ class JsonlRunEventStore(RunEventStore):
|
||||
return results
|
||||
|
||||
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
|
||||
all_events = self._read_thread_events(thread_id)
|
||||
all_events = await asyncio.to_thread(self._read_thread_events, thread_id)
|
||||
messages = [e for e in all_events if e.get("category") == "message"]
|
||||
|
||||
if before_seq is not None:
|
||||
@ -147,13 +176,13 @@ class JsonlRunEventStore(RunEventStore):
|
||||
return messages[-limit:]
|
||||
|
||||
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
|
||||
events = self._read_run_events(thread_id, run_id)
|
||||
events = await asyncio.to_thread(self._read_run_events, thread_id, run_id)
|
||||
if event_types is not None:
|
||||
events = [e for e in events if e.get("event_type") in event_types]
|
||||
return events[:limit]
|
||||
|
||||
async def list_messages_by_run(self, thread_id, run_id, *, limit=50, before_seq=None, after_seq=None):
|
||||
events = self._read_run_events(thread_id, run_id)
|
||||
events = await asyncio.to_thread(self._read_run_events, thread_id, run_id)
|
||||
filtered = [e for e in events if e.get("category") == "message"]
|
||||
if before_seq is not None:
|
||||
filtered = [e for e in filtered if e.get("seq", 0) < before_seq]
|
||||
@ -165,23 +194,25 @@ class JsonlRunEventStore(RunEventStore):
|
||||
return filtered[-limit:] if len(filtered) > limit else filtered
|
||||
|
||||
async def count_messages(self, thread_id):
|
||||
all_events = self._read_thread_events(thread_id)
|
||||
all_events = await asyncio.to_thread(self._read_thread_events, thread_id)
|
||||
return sum(1 for e in all_events if e.get("category") == "message")
|
||||
|
||||
async def delete_by_thread(self, thread_id):
|
||||
all_events = self._read_thread_events(thread_id)
|
||||
count = len(all_events)
|
||||
thread_dir = self._thread_dir(thread_id)
|
||||
if thread_dir.exists():
|
||||
for f in thread_dir.glob("*.jsonl"):
|
||||
f.unlink()
|
||||
self._seq_counters.pop(thread_id, None)
|
||||
return count
|
||||
async with self._get_write_lock(thread_id):
|
||||
all_events = await asyncio.to_thread(self._read_thread_events, thread_id)
|
||||
count = len(all_events)
|
||||
await asyncio.to_thread(self._delete_thread_files, thread_id)
|
||||
self._seq_counters.pop(thread_id, None)
|
||||
# Pop the lock inside the held scope to minimise the window where a new caller
|
||||
# could obtain a fresh lock while a waiting coroutine still holds the old one.
|
||||
# Note: coroutines that already acquired a reference to this lock before the
|
||||
# delete will still proceed after we release — this is an accepted narrow race.
|
||||
self._write_locks.pop(thread_id, None)
|
||||
return count
|
||||
|
||||
async def delete_by_run(self, thread_id, run_id):
|
||||
events = self._read_run_events(thread_id, run_id)
|
||||
count = len(events)
|
||||
path = self._run_file(thread_id, run_id)
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
return count
|
||||
async with self._get_write_lock(thread_id):
|
||||
events = await asyncio.to_thread(self._read_run_events, thread_id, run_id)
|
||||
count = len(events)
|
||||
await asyncio.to_thread(self._delete_run_file, thread_id, run_id)
|
||||
return count
|
||||
|
||||
223
backend/tests/test_jsonl_event_store_async_io.py
Normal file
223
backend/tests/test_jsonl_event_store_async_io.py
Normal file
@ -0,0 +1,223 @@
|
||||
"""Concurrency-safety tests for JsonlRunEventStore async I/O hardening (#2816).
|
||||
|
||||
Verifies:
|
||||
- write-lock serialises concurrent puts within the same thread_id
|
||||
- put_batch keeps monotonic seq even under concurrent callers
|
||||
- seq recovery from disk on fresh store init
|
||||
- DB put_batch rejects mixed-thread batches
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_store(base_dir: Path) -> JsonlRunEventStore:
|
||||
return JsonlRunEventStore(base_dir=base_dir)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Write-lock: per-thread lock exists and is reused
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_write_lock_returns_asyncio_lock():
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
store = _make_store(Path(tmp))
|
||||
lock = store._get_write_lock("t1")
|
||||
assert isinstance(lock, asyncio.Lock)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_write_lock_same_thread_reuses_lock():
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
store = _make_store(Path(tmp))
|
||||
lock_a = store._get_write_lock("t1")
|
||||
lock_b = store._get_write_lock("t1")
|
||||
assert lock_a is lock_b
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_write_lock_different_threads_get_different_locks():
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
store = _make_store(Path(tmp))
|
||||
lock_a = store._get_write_lock("t1")
|
||||
lock_b = store._get_write_lock("t2")
|
||||
assert lock_a is not lock_b
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Seq monotonicity under concurrent puts
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_concurrent_puts_produce_unique_monotonic_seqs():
|
||||
"""10 concurrent puts on the same thread must yield distinct, monotonic seq values."""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
store = _make_store(Path(tmp))
|
||||
results = await asyncio.gather(*[store.put(thread_id="t1", run_id=f"r{i}", event_type="trace", category="trace", content=f"msg{i}") for i in range(10)])
|
||||
seqs = sorted(r["seq"] for r in results)
|
||||
assert seqs == list(range(1, 11)), f"Expected 1-10, got {seqs}"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_concurrent_puts_different_threads_independent_seqs():
|
||||
"""Concurrent puts on different threads keep independent seq counters."""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
store = _make_store(Path(tmp))
|
||||
t1_results, t2_results = await asyncio.gather(
|
||||
asyncio.gather(*[store.put(thread_id="t1", run_id="r1", event_type="trace", category="trace") for _ in range(5)]),
|
||||
asyncio.gather(*[store.put(thread_id="t2", run_id="r2", event_type="trace", category="trace") for _ in range(5)]),
|
||||
)
|
||||
t1_seqs = sorted(r["seq"] for r in t1_results)
|
||||
t2_seqs = sorted(r["seq"] for r in t2_results)
|
||||
assert t1_seqs == [1, 2, 3, 4, 5]
|
||||
assert t2_seqs == [1, 2, 3, 4, 5]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# put_batch: delegates to put() and preserves order
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_put_batch_seqs_are_monotonic():
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
store = _make_store(Path(tmp))
|
||||
events = [{"thread_id": "t1", "run_id": "r1", "event_type": "trace", "category": "trace", "content": str(i)} for i in range(5)]
|
||||
results = await store.put_batch(events)
|
||||
seqs = [r["seq"] for r in results]
|
||||
assert seqs == sorted(seqs)
|
||||
assert len(set(seqs)) == 5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _ensure_seq_loaded: recovers max_seq from disk after fresh store init
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_ensure_seq_loaded_recovers_from_disk():
|
||||
"""A fresh JsonlRunEventStore should pick up the max seq written by a previous instance."""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
base = Path(tmp)
|
||||
store1 = _make_store(base)
|
||||
for i in range(3):
|
||||
await store1.put(thread_id="t1", run_id="r1", event_type="trace", category="trace", content=str(i))
|
||||
|
||||
store2 = _make_store(base)
|
||||
record = await store2.put(thread_id="t1", run_id="r1", event_type="trace", category="trace", content="new")
|
||||
assert record["seq"] == 4, f"Expected seq=4 after recovery, got {record['seq']}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# asyncio.to_thread regression guard
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_put_offloads_write_via_to_thread():
|
||||
"""Regression guard: put() must call asyncio.to_thread for _write_record."""
|
||||
original = asyncio.to_thread
|
||||
calls: list[str] = []
|
||||
|
||||
async def spy(*args, **kwargs):
|
||||
calls.append(args[0].__name__ if callable(args[0]) else repr(args[0]))
|
||||
return await original(*args, **kwargs)
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
store = _make_store(Path(tmp))
|
||||
with patch("asyncio.to_thread", new=spy):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="trace", category="trace", content="x")
|
||||
|
||||
assert "_write_record" in calls, f"Expected asyncio.to_thread(_write_record, ...) — got: {calls}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Read methods are non-blocking (asyncio.to_thread path exercised)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_messages_reads_written_records():
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
store = _make_store(Path(tmp))
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="hello")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="world")
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["content"] == "hello"
|
||||
assert messages[1]["content"] == "world"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_count_messages_accurate_after_concurrent_writes():
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
store = _make_store(Path(tmp))
|
||||
await asyncio.gather(*[store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message") for _ in range(7)])
|
||||
count = await store.count_messages("t1")
|
||||
assert count == 7
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# delete_by_thread and delete_by_run use the write lock
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_by_thread_clears_seq_counter_and_lock():
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
store = _make_store(Path(tmp))
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="trace", category="trace")
|
||||
await store.delete_by_thread("t1")
|
||||
assert "t1" not in store._seq_counters
|
||||
assert "t1" not in store._write_locks
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_by_run_removes_run_events():
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
store = _make_store(Path(tmp))
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="trace", category="trace")
|
||||
await store.put(thread_id="t1", run_id="r2", event_type="trace", category="trace")
|
||||
await store.delete_by_run("t1", "r1")
|
||||
events = await store.list_events("t1", "r1")
|
||||
assert events == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DB put_batch: rejects mixed-thread batches
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_db_put_batch_rejects_mixed_thread_ids():
|
||||
"""DbRunEventStore.put_batch must raise ValueError for cross-thread batches."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
mock_sf = MagicMock()
|
||||
store = DbRunEventStore(session_factory=mock_sf)
|
||||
|
||||
events = [
|
||||
{"thread_id": "t1", "run_id": "r1", "event_type": "trace", "category": "trace"},
|
||||
{"thread_id": "t2", "run_id": "r2", "event_type": "trace", "category": "trace"},
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="same thread"):
|
||||
await store.put_batch(events)
|
||||
Loading…
x
Reference in New Issue
Block a user