diff --git a/backend/packages/harness/deerflow/runtime/events/store/db.py b/backend/packages/harness/deerflow/runtime/events/store/db.py index 9374769f3..b7e54754f 100644 --- a/backend/packages/harness/deerflow/runtime/events/store/db.py +++ b/backend/packages/harness/deerflow/runtime/events/store/db.py @@ -11,7 +11,7 @@ import logging from datetime import UTC, datetime from typing import Any -from sqlalchemy import delete, func, select +from sqlalchemy import delete, func, select, text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from deerflow.persistence.models.run_event import RunEventRow @@ -86,6 +86,28 @@ class DbRunEventStore(RunEventStore): user = get_current_user() return str(user.id) if user is not None else None + @staticmethod + async def _max_seq_for_thread(session: AsyncSession, thread_id: str) -> int | None: + """Return the current max seq while serializing writers per thread. + + PostgreSQL rejects ``SELECT max(...) FOR UPDATE`` because aggregate + results are not lockable rows. As a release-safe workaround, take a + transaction-level advisory lock keyed by thread_id before reading the + aggregate. Other dialects keep the existing row-locking statement. + """ + stmt = select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id) + bind = session.get_bind() + dialect_name = bind.dialect.name if bind is not None else "" + + if dialect_name == "postgresql": + await session.execute( + text("SELECT pg_advisory_xact_lock(hashtext(CAST(:thread_id AS text))::bigint)"), + {"thread_id": thread_id}, + ) + return await session.scalar(stmt) + + return await session.scalar(stmt.with_for_update()) + async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401 """Write a single event — low-frequency path only. @@ -100,10 +122,7 @@ class DbRunEventStore(RunEventStore): user_id = self._user_id_from_context() async with self._sf() as session: async with session.begin(): - # Use FOR UPDATE to serialize seq assignment within a thread. - # NOTE: with_for_update() on aggregates is a no-op on SQLite; - # the UNIQUE(thread_id, seq) constraint catches races there. - max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update()) + max_seq = await self._max_seq_for_thread(session, thread_id) seq = (max_seq or 0) + 1 row = RunEventRow( thread_id=thread_id, @@ -126,10 +145,8 @@ class DbRunEventStore(RunEventStore): async with self._sf() as session: async with session.begin(): # Get max seq for the thread (assume all events in batch belong to same thread). - # NOTE: with_for_update() on aggregates is a no-op on SQLite; - # the UNIQUE(thread_id, seq) constraint catches races there. thread_id = events[0]["thread_id"] - max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update()) + max_seq = await self._max_seq_for_thread(session, thread_id) seq = max_seq or 0 rows = [] for e in events: diff --git a/backend/tests/test_run_event_store.py b/backend/tests/test_run_event_store.py index d2c78ccf0..17b796af7 100644 --- a/backend/tests/test_run_event_store.py +++ b/backend/tests/test_run_event_store.py @@ -268,6 +268,39 @@ class TestEdgeCases: class TestDbRunEventStore: """Tests for DbRunEventStore with temp SQLite.""" + @pytest.mark.anyio + async def test_postgres_max_seq_uses_advisory_lock_without_for_update(self): + from sqlalchemy.dialects import postgresql + + from deerflow.runtime.events.store.db import DbRunEventStore + + class FakeSession: + def __init__(self): + self.dialect = postgresql.dialect() + self.execute_calls = [] + self.scalar_stmt = None + + def get_bind(self): + return self + + async def execute(self, stmt, params=None): + self.execute_calls.append((stmt, params)) + + async def scalar(self, stmt): + self.scalar_stmt = stmt + return 41 + + session = FakeSession() + + max_seq = await DbRunEventStore._max_seq_for_thread(session, "thread-1") + + assert max_seq == 41 + assert session.execute_calls + assert session.execute_calls[0][1] == {"thread_id": "thread-1"} + assert "pg_advisory_xact_lock" in str(session.execute_calls[0][0]) + compiled = str(session.scalar_stmt.compile(dialect=postgresql.dialect())) + assert "FOR UPDATE" not in compiled + @pytest.mark.anyio async def test_basic_crud(self, tmp_path): from deerflow.persistence.engine import close_engine, get_session_factory, init_engine