fix(persistence): address review feedback on PR #1851

- Fix naive datetime.now() → datetime.now(UTC) in all ORM models
- Fix seq race condition in DbRunEventStore.put() with FOR UPDATE
  and UNIQUE(thread_id, seq) constraint
- Encapsulate _store access in RunManager.update_run_completion()
- Deduplicate _store.put() logic in RunManager via _persist_to_store()
- Add update_run_completion to RunStore ABC + MemoryRunStore
- Wire follow_up_to_run_id through the full create path
- Add error recovery to RunJournal._flush_sync() lost-event scenario
- Add migration note for search_threads breaking change
- Fix test_checkpointer_none_fix mock to set database=None

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
rayhpeng 2026-04-05 22:02:50 +08:00
parent 7fdf9cad99
commit fc4e3a52d4
14 changed files with 172 additions and 110 deletions

View File

@ -317,7 +317,16 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
@router.post("/search", response_model=list[ThreadResponse])
async def search_threads(body: ThreadSearchRequest, request: Request) -> list[ThreadResponse]:
"""Search and list threads from the threads_meta table."""
"""Search and list threads from the threads_meta table.
NOTE: Migration from pre-persistence-layer deployments:
Threads created via LangGraph Server before this change are NOT
automatically indexed in threads_meta. They will not appear in
search results until a new run is created on them (which triggers
thread_meta upsert in services.py). For bulk migration, run:
python -m deerflow.persistence.migrate_threads_from_checkpointer
(migration script TBD in a follow-up PR)
"""
from app.gateway.deps import get_thread_meta_repo
repo = get_thread_meta_repo(request)

View File

@ -266,6 +266,17 @@ async def start_run(
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
# Resolve follow_up_to_run_id: explicit from request, or auto-detect from latest successful run
follow_up_to_run_id = getattr(body, "follow_up_to_run_id", None)
if follow_up_to_run_id is None:
run_store = get_run_store(request)
try:
recent_runs = await run_store.list_by_thread(thread_id, limit=1)
if recent_runs and recent_runs[0].get("status") == "success":
follow_up_to_run_id = recent_runs[0]["run_id"]
except Exception:
pass # Don't block run creation
try:
record = await run_mgr.create_or_reject(
thread_id,
@ -274,6 +285,7 @@ async def start_run(
metadata=body.metadata or {},
kwargs={"input": body.input, "config": body.config},
multitask_strategy=body.multitask_strategy,
follow_up_to_run_id=follow_up_to_run_id,
)
except ConflictError as exc:
raise HTTPException(status_code=409, detail=str(exc)) from exc
@ -302,17 +314,6 @@ async def start_run(
except Exception:
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", thread_id)
# Resolve follow_up_to_run_id: explicit from request, or auto-detect from latest successful run
follow_up_to_run_id = getattr(body, "follow_up_to_run_id", None)
if follow_up_to_run_id is None:
run_store = get_run_store(request)
try:
recent_runs = await run_store.list_by_thread(thread_id, limit=1)
if recent_runs and recent_runs[0].get("status") == "success":
follow_up_to_run_id = recent_runs[0]["run_id"]
except Exception:
pass # Don't block run creation
agent_factory = resolve_agent_factory(body.assistant_id)
graph_input = normalize_input(body.input)
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from datetime import datetime
from datetime import UTC, datetime
from sqlalchemy import DateTime, String, Text
from sqlalchemy.orm import Mapped, mapped_column
@ -27,4 +27,4 @@ class FeedbackRow(Base):
comment: Mapped[str | None] = mapped_column(Text)
# Optional text feedback from the user
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now())
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC))

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from datetime import datetime
from datetime import UTC, datetime
from sqlalchemy import JSON, DateTime, Index, String, Text
from sqlalchemy.orm import Mapped, mapped_column
@ -43,7 +43,7 @@ class RunRow(Base):
# Follow-up association
follow_up_to_run_id: Mapped[str | None] = mapped_column(String(64))
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now())
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(), onupdate=lambda: datetime.now())
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC))
__table_args__ = (Index("ix_runs_thread_status", "thread_id", "status"),)

View File

@ -2,9 +2,9 @@
from __future__ import annotations
from datetime import datetime
from datetime import UTC, datetime
from sqlalchemy import JSON, DateTime, Index, String, Text
from sqlalchemy import JSON, DateTime, Index, String, Text, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column
from deerflow.persistence.base import Base
@ -22,9 +22,10 @@ class RunEventRow(Base):
content: Mapped[str] = mapped_column(Text, default="")
event_metadata: Mapped[dict] = mapped_column(JSON, default=dict)
seq: Mapped[int] = mapped_column(nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now())
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
__table_args__ = (
UniqueConstraint("thread_id", "seq", name="uq_events_thread_seq"),
Index("ix_events_thread_cat_seq", "thread_id", "category", "seq"),
Index("ix_events_run", "thread_id", "run_id", "seq"),
)

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from datetime import datetime
from datetime import UTC, datetime
from sqlalchemy import JSON, DateTime, String
from sqlalchemy.orm import Mapped, mapped_column
@ -19,5 +19,5 @@ class ThreadMetaRow(Base):
display_name: Mapped[str | None] = mapped_column(String(256))
status: Mapped[str] = mapped_column(String(20), default="idle")
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now())
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(), onupdate=lambda: datetime.now())
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC))

View File

@ -78,6 +78,7 @@ class RunRepository(RunStore):
kwargs=None,
error=None,
created_at=None,
follow_up_to_run_id=None,
):
now = datetime.now(UTC)
row = RunRow(
@ -90,6 +91,7 @@ class RunRepository(RunStore):
metadata_json=self._safe_json(metadata) or {},
kwargs_json=self._safe_json(kwargs) or {},
error=error,
follow_up_to_run_id=follow_up_to_run_id,
created_at=datetime.fromisoformat(created_at) if created_at else now,
updated_at=now,
)

View File

@ -54,58 +54,68 @@ class DbRunEventStore(RunEventStore):
else:
db_content = content
async with self._sf() as session:
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id))
seq = (max_seq or 0) + 1
row = RunEventRow(
thread_id=thread_id,
run_id=run_id,
event_type=event_type,
category=category,
content=db_content,
event_metadata=metadata,
seq=seq,
created_at=datetime.fromisoformat(created_at) if created_at else datetime.now(UTC),
)
session.add(row)
await session.commit()
await session.refresh(row)
async with session.begin():
# Use FOR UPDATE to serialize seq assignment within a thread.
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
# the UNIQUE(thread_id, seq) constraint catches races there.
max_seq = await session.scalar(
select(func.max(RunEventRow.seq))
.where(RunEventRow.thread_id == thread_id)
.with_for_update()
)
seq = (max_seq or 0) + 1
row = RunEventRow(
thread_id=thread_id,
run_id=run_id,
event_type=event_type,
category=category,
content=db_content,
event_metadata=metadata,
seq=seq,
created_at=datetime.fromisoformat(created_at) if created_at else datetime.now(UTC),
)
session.add(row)
return self._row_to_dict(row)
async def put_batch(self, events):
if not events:
return []
async with self._sf() as session:
# Get max seq for the thread (assume all events in batch belong to same thread)
thread_id = events[0]["thread_id"]
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id))
seq = max_seq or 0
rows = []
for e in events:
seq += 1
content = e.get("content", "")
category = e.get("category", "trace")
metadata = e.get("metadata")
content, metadata = self._truncate_trace(category, content, metadata)
if isinstance(content, dict):
db_content = json.dumps(content, default=str, ensure_ascii=False)
metadata = {**(metadata or {}), "content_is_dict": True}
else:
db_content = content
row = RunEventRow(
thread_id=e["thread_id"],
run_id=e["run_id"],
event_type=e["event_type"],
category=category,
content=db_content,
event_metadata=metadata,
seq=seq,
created_at=datetime.fromisoformat(e["created_at"]) if e.get("created_at") else datetime.now(UTC),
async with session.begin():
# Get max seq for the thread (assume all events in batch belong to same thread).
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
# the UNIQUE(thread_id, seq) constraint catches races there.
thread_id = events[0]["thread_id"]
max_seq = await session.scalar(
select(func.max(RunEventRow.seq))
.where(RunEventRow.thread_id == thread_id)
.with_for_update()
)
session.add(row)
rows.append(row)
await session.commit()
for row in rows:
await session.refresh(row)
seq = max_seq or 0
rows = []
for e in events:
seq += 1
content = e.get("content", "")
category = e.get("category", "trace")
metadata = e.get("metadata")
content, metadata = self._truncate_trace(category, content, metadata)
if isinstance(content, dict):
db_content = json.dumps(content, default=str, ensure_ascii=False)
metadata = {**(metadata or {}), "content_is_dict": True}
else:
db_content = content
row = RunEventRow(
thread_id=e["thread_id"],
run_id=e["run_id"],
event_type=e["event_type"],
category=category,
content=db_content,
event_metadata=metadata,
seq=seq,
created_at=datetime.fromisoformat(e["created_at"]) if e.get("created_at") else datetime.now(UTC),
)
session.add(row)
rows.append(row)
return [self._row_to_dict(r) for r in rows]
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):

View File

@ -386,13 +386,27 @@ class RunJournal(BaseCallbackHandler):
return
batch = self._buffer.copy()
self._buffer.clear()
loop.create_task(self._flush_async(batch))
task = loop.create_task(self._flush_async(batch))
task.add_done_callback(self._on_flush_done)
async def _flush_async(self, batch: list[dict]) -> None:
try:
await self._store.put_batch(batch)
except Exception:
logger.warning("RunJournal: failed to flush %d events", len(batch), exc_info=True)
logger.warning(
"Failed to flush %d events for run %s — returning to buffer",
len(batch), self.run_id, exc_info=True,
)
# Return failed events to buffer for retry on next flush
self._buffer = batch + self._buffer
@staticmethod
def _on_flush_done(task: asyncio.Task) -> None:
if task.cancelled():
return
exc = task.exception()
if exc:
logger.warning("Journal flush task failed: %s", exc)
def _identify_caller(self, kwargs: dict) -> str:
for tag in kwargs.get("tags") or []:

View File

@ -54,6 +54,33 @@ class RunManager:
self._lock = asyncio.Lock()
self._store = store
async def _persist_to_store(self, record: RunRecord, *, follow_up_to_run_id: str | None = None) -> None:
"""Best-effort persist run record to backing store."""
if self._store is None:
return
try:
await self._store.put(
record.run_id,
thread_id=record.thread_id,
assistant_id=record.assistant_id,
status=record.status.value,
multitask_strategy=record.multitask_strategy,
metadata=record.metadata or {},
kwargs=record.kwargs or {},
created_at=record.created_at,
follow_up_to_run_id=follow_up_to_run_id,
)
except Exception:
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
async def update_run_completion(self, run_id: str, **kwargs) -> None:
"""Persist token usage and completion data to the backing store."""
if self._store is not None:
try:
await self._store.update_run_completion(run_id, **kwargs)
except Exception:
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
async def create(
self,
thread_id: str,
@ -63,6 +90,7 @@ class RunManager:
metadata: dict | None = None,
kwargs: dict | None = None,
multitask_strategy: str = "reject",
follow_up_to_run_id: str | None = None,
) -> RunRecord:
"""Create a new pending run and register it."""
run_id = str(uuid.uuid4())
@ -81,20 +109,7 @@ class RunManager:
)
async with self._lock:
self._runs[run_id] = record
if self._store is not None:
try:
await self._store.put(
run_id,
thread_id=thread_id,
assistant_id=assistant_id,
status=RunStatus.pending.value,
multitask_strategy=multitask_strategy,
metadata=metadata or {},
kwargs=kwargs or {},
created_at=now,
)
except Exception:
logger.warning("Failed to persist run %s to store", run_id, exc_info=True)
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record
@ -161,6 +176,7 @@ class RunManager:
metadata: dict | None = None,
kwargs: dict | None = None,
multitask_strategy: str = "reject",
follow_up_to_run_id: str | None = None,
) -> RunRecord:
"""Atomically check for inflight runs and create a new one.
@ -214,21 +230,7 @@ class RunManager:
)
self._runs[run_id] = record
if self._store is not None:
try:
await self._store.put(
run_id,
thread_id=thread_id,
assistant_id=assistant_id,
status=RunStatus.pending.value,
multitask_strategy=multitask_strategy,
metadata=metadata or {},
kwargs=kwargs or {},
created_at=now,
)
except Exception:
logger.warning("Failed to persist run %s to store", run_id, exc_info=True)
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record

View File

@ -29,6 +29,7 @@ class RunStore(abc.ABC):
kwargs: dict[str, Any] | None = None,
error: str | None = None,
created_at: str | None = None,
follow_up_to_run_id: str | None = None,
) -> None: ...
@abc.abstractmethod
@ -55,5 +56,24 @@ class RunStore(abc.ABC):
@abc.abstractmethod
async def delete(self, run_id: str) -> None: ...
@abc.abstractmethod
async def update_run_completion(
self,
run_id: str,
*,
status: str,
total_input_tokens: int = 0,
total_output_tokens: int = 0,
total_tokens: int = 0,
llm_call_count: int = 0,
lead_agent_tokens: int = 0,
subagent_tokens: int = 0,
middleware_tokens: int = 0,
message_count: int = 0,
last_ai_message: str | None = None,
first_human_message: str | None = None,
error: str | None = None,
) -> None: ...
@abc.abstractmethod
async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]: ...

View File

@ -28,6 +28,7 @@ class MemoryRunStore(RunStore):
kwargs=None,
error=None,
created_at=None,
follow_up_to_run_id=None,
):
now = datetime.now(UTC).isoformat()
self._runs[run_id] = {
@ -40,6 +41,7 @@ class MemoryRunStore(RunStore):
"metadata": metadata or {},
"kwargs": kwargs or {},
"error": error,
"follow_up_to_run_id": follow_up_to_run_id,
"created_at": created_at or now,
"updated_at": now,
}
@ -62,6 +64,14 @@ class MemoryRunStore(RunStore):
async def delete(self, run_id):
self._runs.pop(run_id, None)
async def update_run_completion(self, run_id, *, status, **kwargs):
if run_id in self._runs:
self._runs[run_id]["status"] = status
for key, value in kwargs.items():
if value is not None:
self._runs[run_id][key] = value
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
async def list_pending(self, *, before=None):
now = before or datetime.now(UTC).isoformat()
results = [r for r in self._runs.values() if r["status"] == "pending" and r["created_at"] <= now]

View File

@ -257,16 +257,8 @@ async def run_agent(
logger.warning("Failed to flush journal for run %s", run_id, exc_info=True)
# Persist token usage + convenience fields to RunStore
if run_manager._store is not None:
try:
completion = journal.get_completion_data()
await run_manager._store.update_run_completion(
run_id,
status=record.status.value,
**completion,
)
except Exception:
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
completion = journal.get_completion_data()
await run_manager.update_run_completion(run_id, status=record.status.value, **completion)
# Sync title from checkpoint to threads_meta.display_name
if thread_meta_repo is not None and checkpointer is not None:

View File

@ -14,9 +14,10 @@ class TestCheckpointerNoneFix:
"""make_checkpointer should return InMemorySaver when config.checkpointer is None."""
from deerflow.agents.checkpointer.async_provider import make_checkpointer
# Mock get_app_config to return a config with checkpointer=None
# Mock get_app_config to return a config with checkpointer=None and database=None
mock_config = MagicMock()
mock_config.checkpointer = None
mock_config.database = None
with patch("deerflow.agents.checkpointer.async_provider.get_app_config", return_value=mock_config):
async with make_checkpointer() as checkpointer: