From fc4e3a52d4255d055910b30c88dd5346b41e152d Mon Sep 17 00:00:00 2001 From: rayhpeng Date: Sun, 5 Apr 2026 22:02:50 +0800 Subject: [PATCH] fix(persistence): address review feedback on PR #1851 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix naive datetime.now() → datetime.now(UTC) in all ORM models - Fix seq race condition in DbRunEventStore.put() with FOR UPDATE and UNIQUE(thread_id, seq) constraint - Encapsulate _store access in RunManager.update_run_completion() - Deduplicate _store.put() logic in RunManager via _persist_to_store() - Add update_run_completion to RunStore ABC + MemoryRunStore - Wire follow_up_to_run_id through the full create path - Add error recovery to RunJournal._flush_sync() lost-event scenario - Add migration note for search_threads breaking change - Fix test_checkpointer_none_fix mock to set database=None Co-Authored-By: Claude Opus 4.6 (1M context) --- backend/app/gateway/routers/threads.py | 11 +- backend/app/gateway/services.py | 23 ++-- .../deerflow/persistence/models/feedback.py | 4 +- .../deerflow/persistence/models/run.py | 6 +- .../deerflow/persistence/models/run_event.py | 7 +- .../persistence/models/thread_meta.py | 6 +- .../persistence/repositories/run_repo.py | 2 + .../deerflow/runtime/events/store/db.py | 100 ++++++++++-------- .../harness/deerflow/runtime/journal.py | 18 +++- .../harness/deerflow/runtime/runs/manager.py | 60 ++++++----- .../deerflow/runtime/runs/store/base.py | 20 ++++ .../deerflow/runtime/runs/store/memory.py | 10 ++ .../harness/deerflow/runtime/runs/worker.py | 12 +-- backend/tests/test_checkpointer_none_fix.py | 3 +- 14 files changed, 172 insertions(+), 110 deletions(-) diff --git a/backend/app/gateway/routers/threads.py b/backend/app/gateway/routers/threads.py index d11c53e37..bfc21b7c9 100644 --- a/backend/app/gateway/routers/threads.py +++ b/backend/app/gateway/routers/threads.py @@ -317,7 +317,16 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe @router.post("/search", response_model=list[ThreadResponse]) async def search_threads(body: ThreadSearchRequest, request: Request) -> list[ThreadResponse]: - """Search and list threads from the threads_meta table.""" + """Search and list threads from the threads_meta table. + + NOTE: Migration from pre-persistence-layer deployments: + Threads created via LangGraph Server before this change are NOT + automatically indexed in threads_meta. They will not appear in + search results until a new run is created on them (which triggers + thread_meta upsert in services.py). For bulk migration, run: + python -m deerflow.persistence.migrate_threads_from_checkpointer + (migration script TBD in a follow-up PR) + """ from app.gateway.deps import get_thread_meta_repo repo = get_thread_meta_repo(request) diff --git a/backend/app/gateway/services.py b/backend/app/gateway/services.py index 0c7683ed9..2bdb1170b 100644 --- a/backend/app/gateway/services.py +++ b/backend/app/gateway/services.py @@ -266,6 +266,17 @@ async def start_run( disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_ + # Resolve follow_up_to_run_id: explicit from request, or auto-detect from latest successful run + follow_up_to_run_id = getattr(body, "follow_up_to_run_id", None) + if follow_up_to_run_id is None: + run_store = get_run_store(request) + try: + recent_runs = await run_store.list_by_thread(thread_id, limit=1) + if recent_runs and recent_runs[0].get("status") == "success": + follow_up_to_run_id = recent_runs[0]["run_id"] + except Exception: + pass # Don't block run creation + try: record = await run_mgr.create_or_reject( thread_id, @@ -274,6 +285,7 @@ async def start_run( metadata=body.metadata or {}, kwargs={"input": body.input, "config": body.config}, multitask_strategy=body.multitask_strategy, + follow_up_to_run_id=follow_up_to_run_id, ) except ConflictError as exc: raise HTTPException(status_code=409, detail=str(exc)) from exc @@ -302,17 +314,6 @@ async def start_run( except Exception: logger.warning("Failed to upsert thread_meta for %s (non-fatal)", thread_id) - # Resolve follow_up_to_run_id: explicit from request, or auto-detect from latest successful run - follow_up_to_run_id = getattr(body, "follow_up_to_run_id", None) - if follow_up_to_run_id is None: - run_store = get_run_store(request) - try: - recent_runs = await run_store.list_by_thread(thread_id, limit=1) - if recent_runs and recent_runs[0].get("status") == "success": - follow_up_to_run_id = recent_runs[0]["run_id"] - except Exception: - pass # Don't block run creation - agent_factory = resolve_agent_factory(body.assistant_id) graph_input = normalize_input(body.input) config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id) diff --git a/backend/packages/harness/deerflow/persistence/models/feedback.py b/backend/packages/harness/deerflow/persistence/models/feedback.py index 6dbc89d31..221fb5fb1 100644 --- a/backend/packages/harness/deerflow/persistence/models/feedback.py +++ b/backend/packages/harness/deerflow/persistence/models/feedback.py @@ -2,7 +2,7 @@ from __future__ import annotations -from datetime import datetime +from datetime import UTC, datetime from sqlalchemy import DateTime, String, Text from sqlalchemy.orm import Mapped, mapped_column @@ -27,4 +27,4 @@ class FeedbackRow(Base): comment: Mapped[str | None] = mapped_column(Text) # Optional text feedback from the user - created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now()) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) diff --git a/backend/packages/harness/deerflow/persistence/models/run.py b/backend/packages/harness/deerflow/persistence/models/run.py index ff0a1dc6f..67396bc25 100644 --- a/backend/packages/harness/deerflow/persistence/models/run.py +++ b/backend/packages/harness/deerflow/persistence/models/run.py @@ -2,7 +2,7 @@ from __future__ import annotations -from datetime import datetime +from datetime import UTC, datetime from sqlalchemy import JSON, DateTime, Index, String, Text from sqlalchemy.orm import Mapped, mapped_column @@ -43,7 +43,7 @@ class RunRow(Base): # Follow-up association follow_up_to_run_id: Mapped[str | None] = mapped_column(String(64)) - created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now()) - updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(), onupdate=lambda: datetime.now()) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) + updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC)) __table_args__ = (Index("ix_runs_thread_status", "thread_id", "status"),) diff --git a/backend/packages/harness/deerflow/persistence/models/run_event.py b/backend/packages/harness/deerflow/persistence/models/run_event.py index ba6b011f7..8db50aea7 100644 --- a/backend/packages/harness/deerflow/persistence/models/run_event.py +++ b/backend/packages/harness/deerflow/persistence/models/run_event.py @@ -2,9 +2,9 @@ from __future__ import annotations -from datetime import datetime +from datetime import UTC, datetime -from sqlalchemy import JSON, DateTime, Index, String, Text +from sqlalchemy import JSON, DateTime, Index, String, Text, UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column from deerflow.persistence.base import Base @@ -22,9 +22,10 @@ class RunEventRow(Base): content: Mapped[str] = mapped_column(Text, default="") event_metadata: Mapped[dict] = mapped_column(JSON, default=dict) seq: Mapped[int] = mapped_column(nullable=False) - created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now()) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) __table_args__ = ( + UniqueConstraint("thread_id", "seq", name="uq_events_thread_seq"), Index("ix_events_thread_cat_seq", "thread_id", "category", "seq"), Index("ix_events_run", "thread_id", "run_id", "seq"), ) diff --git a/backend/packages/harness/deerflow/persistence/models/thread_meta.py b/backend/packages/harness/deerflow/persistence/models/thread_meta.py index dec206ae9..34a209277 100644 --- a/backend/packages/harness/deerflow/persistence/models/thread_meta.py +++ b/backend/packages/harness/deerflow/persistence/models/thread_meta.py @@ -2,7 +2,7 @@ from __future__ import annotations -from datetime import datetime +from datetime import UTC, datetime from sqlalchemy import JSON, DateTime, String from sqlalchemy.orm import Mapped, mapped_column @@ -19,5 +19,5 @@ class ThreadMetaRow(Base): display_name: Mapped[str | None] = mapped_column(String(256)) status: Mapped[str] = mapped_column(String(20), default="idle") metadata_json: Mapped[dict] = mapped_column(JSON, default=dict) - created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now()) - updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(), onupdate=lambda: datetime.now()) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) + updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC)) diff --git a/backend/packages/harness/deerflow/persistence/repositories/run_repo.py b/backend/packages/harness/deerflow/persistence/repositories/run_repo.py index f4727b155..3b209b402 100644 --- a/backend/packages/harness/deerflow/persistence/repositories/run_repo.py +++ b/backend/packages/harness/deerflow/persistence/repositories/run_repo.py @@ -78,6 +78,7 @@ class RunRepository(RunStore): kwargs=None, error=None, created_at=None, + follow_up_to_run_id=None, ): now = datetime.now(UTC) row = RunRow( @@ -90,6 +91,7 @@ class RunRepository(RunStore): metadata_json=self._safe_json(metadata) or {}, kwargs_json=self._safe_json(kwargs) or {}, error=error, + follow_up_to_run_id=follow_up_to_run_id, created_at=datetime.fromisoformat(created_at) if created_at else now, updated_at=now, ) diff --git a/backend/packages/harness/deerflow/runtime/events/store/db.py b/backend/packages/harness/deerflow/runtime/events/store/db.py index ce798813f..916533120 100644 --- a/backend/packages/harness/deerflow/runtime/events/store/db.py +++ b/backend/packages/harness/deerflow/runtime/events/store/db.py @@ -54,58 +54,68 @@ class DbRunEventStore(RunEventStore): else: db_content = content async with self._sf() as session: - max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id)) - seq = (max_seq or 0) + 1 - row = RunEventRow( - thread_id=thread_id, - run_id=run_id, - event_type=event_type, - category=category, - content=db_content, - event_metadata=metadata, - seq=seq, - created_at=datetime.fromisoformat(created_at) if created_at else datetime.now(UTC), - ) - session.add(row) - await session.commit() - await session.refresh(row) + async with session.begin(): + # Use FOR UPDATE to serialize seq assignment within a thread. + # NOTE: with_for_update() on aggregates is a no-op on SQLite; + # the UNIQUE(thread_id, seq) constraint catches races there. + max_seq = await session.scalar( + select(func.max(RunEventRow.seq)) + .where(RunEventRow.thread_id == thread_id) + .with_for_update() + ) + seq = (max_seq or 0) + 1 + row = RunEventRow( + thread_id=thread_id, + run_id=run_id, + event_type=event_type, + category=category, + content=db_content, + event_metadata=metadata, + seq=seq, + created_at=datetime.fromisoformat(created_at) if created_at else datetime.now(UTC), + ) + session.add(row) return self._row_to_dict(row) async def put_batch(self, events): if not events: return [] async with self._sf() as session: - # Get max seq for the thread (assume all events in batch belong to same thread) - thread_id = events[0]["thread_id"] - max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id)) - seq = max_seq or 0 - rows = [] - for e in events: - seq += 1 - content = e.get("content", "") - category = e.get("category", "trace") - metadata = e.get("metadata") - content, metadata = self._truncate_trace(category, content, metadata) - if isinstance(content, dict): - db_content = json.dumps(content, default=str, ensure_ascii=False) - metadata = {**(metadata or {}), "content_is_dict": True} - else: - db_content = content - row = RunEventRow( - thread_id=e["thread_id"], - run_id=e["run_id"], - event_type=e["event_type"], - category=category, - content=db_content, - event_metadata=metadata, - seq=seq, - created_at=datetime.fromisoformat(e["created_at"]) if e.get("created_at") else datetime.now(UTC), + async with session.begin(): + # Get max seq for the thread (assume all events in batch belong to same thread). + # NOTE: with_for_update() on aggregates is a no-op on SQLite; + # the UNIQUE(thread_id, seq) constraint catches races there. + thread_id = events[0]["thread_id"] + max_seq = await session.scalar( + select(func.max(RunEventRow.seq)) + .where(RunEventRow.thread_id == thread_id) + .with_for_update() ) - session.add(row) - rows.append(row) - await session.commit() - for row in rows: - await session.refresh(row) + seq = max_seq or 0 + rows = [] + for e in events: + seq += 1 + content = e.get("content", "") + category = e.get("category", "trace") + metadata = e.get("metadata") + content, metadata = self._truncate_trace(category, content, metadata) + if isinstance(content, dict): + db_content = json.dumps(content, default=str, ensure_ascii=False) + metadata = {**(metadata or {}), "content_is_dict": True} + else: + db_content = content + row = RunEventRow( + thread_id=e["thread_id"], + run_id=e["run_id"], + event_type=e["event_type"], + category=category, + content=db_content, + event_metadata=metadata, + seq=seq, + created_at=datetime.fromisoformat(e["created_at"]) if e.get("created_at") else datetime.now(UTC), + ) + session.add(row) + rows.append(row) return [self._row_to_dict(r) for r in rows] async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None): diff --git a/backend/packages/harness/deerflow/runtime/journal.py b/backend/packages/harness/deerflow/runtime/journal.py index 15e48bab0..55a230d4e 100644 --- a/backend/packages/harness/deerflow/runtime/journal.py +++ b/backend/packages/harness/deerflow/runtime/journal.py @@ -386,13 +386,27 @@ class RunJournal(BaseCallbackHandler): return batch = self._buffer.copy() self._buffer.clear() - loop.create_task(self._flush_async(batch)) + task = loop.create_task(self._flush_async(batch)) + task.add_done_callback(self._on_flush_done) async def _flush_async(self, batch: list[dict]) -> None: try: await self._store.put_batch(batch) except Exception: - logger.warning("RunJournal: failed to flush %d events", len(batch), exc_info=True) + logger.warning( + "Failed to flush %d events for run %s — returning to buffer", + len(batch), self.run_id, exc_info=True, + ) + # Return failed events to buffer for retry on next flush + self._buffer = batch + self._buffer + + @staticmethod + def _on_flush_done(task: asyncio.Task) -> None: + if task.cancelled(): + return + exc = task.exception() + if exc: + logger.warning("Journal flush task failed: %s", exc) def _identify_caller(self, kwargs: dict) -> str: for tag in kwargs.get("tags") or []: diff --git a/backend/packages/harness/deerflow/runtime/runs/manager.py b/backend/packages/harness/deerflow/runtime/runs/manager.py index 9e4d9089e..0a0794d87 100644 --- a/backend/packages/harness/deerflow/runtime/runs/manager.py +++ b/backend/packages/harness/deerflow/runtime/runs/manager.py @@ -54,6 +54,33 @@ class RunManager: self._lock = asyncio.Lock() self._store = store + async def _persist_to_store(self, record: RunRecord, *, follow_up_to_run_id: str | None = None) -> None: + """Best-effort persist run record to backing store.""" + if self._store is None: + return + try: + await self._store.put( + record.run_id, + thread_id=record.thread_id, + assistant_id=record.assistant_id, + status=record.status.value, + multitask_strategy=record.multitask_strategy, + metadata=record.metadata or {}, + kwargs=record.kwargs or {}, + created_at=record.created_at, + follow_up_to_run_id=follow_up_to_run_id, + ) + except Exception: + logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True) + + async def update_run_completion(self, run_id: str, **kwargs) -> None: + """Persist token usage and completion data to the backing store.""" + if self._store is not None: + try: + await self._store.update_run_completion(run_id, **kwargs) + except Exception: + logger.warning("Failed to persist run completion for %s", run_id, exc_info=True) + async def create( self, thread_id: str, @@ -63,6 +90,7 @@ class RunManager: metadata: dict | None = None, kwargs: dict | None = None, multitask_strategy: str = "reject", + follow_up_to_run_id: str | None = None, ) -> RunRecord: """Create a new pending run and register it.""" run_id = str(uuid.uuid4()) @@ -81,20 +109,7 @@ class RunManager: ) async with self._lock: self._runs[run_id] = record - if self._store is not None: - try: - await self._store.put( - run_id, - thread_id=thread_id, - assistant_id=assistant_id, - status=RunStatus.pending.value, - multitask_strategy=multitask_strategy, - metadata=metadata or {}, - kwargs=kwargs or {}, - created_at=now, - ) - except Exception: - logger.warning("Failed to persist run %s to store", run_id, exc_info=True) + await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id) logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) return record @@ -161,6 +176,7 @@ class RunManager: metadata: dict | None = None, kwargs: dict | None = None, multitask_strategy: str = "reject", + follow_up_to_run_id: str | None = None, ) -> RunRecord: """Atomically check for inflight runs and create a new one. @@ -214,21 +230,7 @@ class RunManager: ) self._runs[run_id] = record - if self._store is not None: - try: - await self._store.put( - run_id, - thread_id=thread_id, - assistant_id=assistant_id, - status=RunStatus.pending.value, - multitask_strategy=multitask_strategy, - metadata=metadata or {}, - kwargs=kwargs or {}, - created_at=now, - ) - except Exception: - logger.warning("Failed to persist run %s to store", run_id, exc_info=True) - + await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id) logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) return record diff --git a/backend/packages/harness/deerflow/runtime/runs/store/base.py b/backend/packages/harness/deerflow/runtime/runs/store/base.py index 993db21b4..921460108 100644 --- a/backend/packages/harness/deerflow/runtime/runs/store/base.py +++ b/backend/packages/harness/deerflow/runtime/runs/store/base.py @@ -29,6 +29,7 @@ class RunStore(abc.ABC): kwargs: dict[str, Any] | None = None, error: str | None = None, created_at: str | None = None, + follow_up_to_run_id: str | None = None, ) -> None: ... @abc.abstractmethod @@ -55,5 +56,24 @@ class RunStore(abc.ABC): @abc.abstractmethod async def delete(self, run_id: str) -> None: ... + @abc.abstractmethod + async def update_run_completion( + self, + run_id: str, + *, + status: str, + total_input_tokens: int = 0, + total_output_tokens: int = 0, + total_tokens: int = 0, + llm_call_count: int = 0, + lead_agent_tokens: int = 0, + subagent_tokens: int = 0, + middleware_tokens: int = 0, + message_count: int = 0, + last_ai_message: str | None = None, + first_human_message: str | None = None, + error: str | None = None, + ) -> None: ... + @abc.abstractmethod async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]: ... diff --git a/backend/packages/harness/deerflow/runtime/runs/store/memory.py b/backend/packages/harness/deerflow/runtime/runs/store/memory.py index a7a22e4c1..937f22e37 100644 --- a/backend/packages/harness/deerflow/runtime/runs/store/memory.py +++ b/backend/packages/harness/deerflow/runtime/runs/store/memory.py @@ -28,6 +28,7 @@ class MemoryRunStore(RunStore): kwargs=None, error=None, created_at=None, + follow_up_to_run_id=None, ): now = datetime.now(UTC).isoformat() self._runs[run_id] = { @@ -40,6 +41,7 @@ class MemoryRunStore(RunStore): "metadata": metadata or {}, "kwargs": kwargs or {}, "error": error, + "follow_up_to_run_id": follow_up_to_run_id, "created_at": created_at or now, "updated_at": now, } @@ -62,6 +64,14 @@ class MemoryRunStore(RunStore): async def delete(self, run_id): self._runs.pop(run_id, None) + async def update_run_completion(self, run_id, *, status, **kwargs): + if run_id in self._runs: + self._runs[run_id]["status"] = status + for key, value in kwargs.items(): + if value is not None: + self._runs[run_id][key] = value + self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat() + async def list_pending(self, *, before=None): now = before or datetime.now(UTC).isoformat() results = [r for r in self._runs.values() if r["status"] == "pending" and r["created_at"] <= now] diff --git a/backend/packages/harness/deerflow/runtime/runs/worker.py b/backend/packages/harness/deerflow/runtime/runs/worker.py index d6312b8ed..c4726f4a4 100644 --- a/backend/packages/harness/deerflow/runtime/runs/worker.py +++ b/backend/packages/harness/deerflow/runtime/runs/worker.py @@ -257,16 +257,8 @@ async def run_agent( logger.warning("Failed to flush journal for run %s", run_id, exc_info=True) # Persist token usage + convenience fields to RunStore - if run_manager._store is not None: - try: - completion = journal.get_completion_data() - await run_manager._store.update_run_completion( - run_id, - status=record.status.value, - **completion, - ) - except Exception: - logger.warning("Failed to persist run completion for %s", run_id, exc_info=True) + completion = journal.get_completion_data() + await run_manager.update_run_completion(run_id, status=record.status.value, **completion) # Sync title from checkpoint to threads_meta.display_name if thread_meta_repo is not None and checkpointer is not None: diff --git a/backend/tests/test_checkpointer_none_fix.py b/backend/tests/test_checkpointer_none_fix.py index 4e128adbc..1da435c85 100644 --- a/backend/tests/test_checkpointer_none_fix.py +++ b/backend/tests/test_checkpointer_none_fix.py @@ -14,9 +14,10 @@ class TestCheckpointerNoneFix: """make_checkpointer should return InMemorySaver when config.checkpointer is None.""" from deerflow.agents.checkpointer.async_provider import make_checkpointer - # Mock get_app_config to return a config with checkpointer=None + # Mock get_app_config to return a config with checkpointer=None and database=None mock_config = MagicMock() mock_config.checkpointer = None + mock_config.database = None with patch("deerflow.agents.checkpointer.async_provider.get_app_config", return_value=mock_config): async with make_checkpointer() as checkpointer: