mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-15 21:23:41 +00:00
fix(runtime): avoid postgres aggregate row lock (#2962)
This commit is contained in:
parent
722c690f4f
commit
45060a9ffc
@ -11,7 +11,7 @@ import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy import delete, func, select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.models.run_event import RunEventRow
|
||||
@ -86,6 +86,28 @@ class DbRunEventStore(RunEventStore):
|
||||
user = get_current_user()
|
||||
return str(user.id) if user is not None else None
|
||||
|
||||
@staticmethod
|
||||
async def _max_seq_for_thread(session: AsyncSession, thread_id: str) -> int | None:
|
||||
"""Return the current max seq while serializing writers per thread.
|
||||
|
||||
PostgreSQL rejects ``SELECT max(...) FOR UPDATE`` because aggregate
|
||||
results are not lockable rows. As a release-safe workaround, take a
|
||||
transaction-level advisory lock keyed by thread_id before reading the
|
||||
aggregate. Other dialects keep the existing row-locking statement.
|
||||
"""
|
||||
stmt = select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id)
|
||||
bind = session.get_bind()
|
||||
dialect_name = bind.dialect.name if bind is not None else ""
|
||||
|
||||
if dialect_name == "postgresql":
|
||||
await session.execute(
|
||||
text("SELECT pg_advisory_xact_lock(hashtext(CAST(:thread_id AS text))::bigint)"),
|
||||
{"thread_id": thread_id},
|
||||
)
|
||||
return await session.scalar(stmt)
|
||||
|
||||
return await session.scalar(stmt.with_for_update())
|
||||
|
||||
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401
|
||||
"""Write a single event — low-frequency path only.
|
||||
|
||||
@ -100,10 +122,7 @@ class DbRunEventStore(RunEventStore):
|
||||
user_id = self._user_id_from_context()
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
# Use FOR UPDATE to serialize seq assignment within a thread.
|
||||
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
|
||||
# the UNIQUE(thread_id, seq) constraint catches races there.
|
||||
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
|
||||
max_seq = await self._max_seq_for_thread(session, thread_id)
|
||||
seq = (max_seq or 0) + 1
|
||||
row = RunEventRow(
|
||||
thread_id=thread_id,
|
||||
@ -126,10 +145,8 @@ class DbRunEventStore(RunEventStore):
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
# Get max seq for the thread (assume all events in batch belong to same thread).
|
||||
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
|
||||
# the UNIQUE(thread_id, seq) constraint catches races there.
|
||||
thread_id = events[0]["thread_id"]
|
||||
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
|
||||
max_seq = await self._max_seq_for_thread(session, thread_id)
|
||||
seq = max_seq or 0
|
||||
rows = []
|
||||
for e in events:
|
||||
|
||||
@ -268,6 +268,39 @@ class TestEdgeCases:
|
||||
class TestDbRunEventStore:
|
||||
"""Tests for DbRunEventStore with temp SQLite."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_postgres_max_seq_uses_advisory_lock_without_for_update(self):
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
class FakeSession:
|
||||
def __init__(self):
|
||||
self.dialect = postgresql.dialect()
|
||||
self.execute_calls = []
|
||||
self.scalar_stmt = None
|
||||
|
||||
def get_bind(self):
|
||||
return self
|
||||
|
||||
async def execute(self, stmt, params=None):
|
||||
self.execute_calls.append((stmt, params))
|
||||
|
||||
async def scalar(self, stmt):
|
||||
self.scalar_stmt = stmt
|
||||
return 41
|
||||
|
||||
session = FakeSession()
|
||||
|
||||
max_seq = await DbRunEventStore._max_seq_for_thread(session, "thread-1")
|
||||
|
||||
assert max_seq == 41
|
||||
assert session.execute_calls
|
||||
assert session.execute_calls[0][1] == {"thread_id": "thread-1"}
|
||||
assert "pg_advisory_xact_lock" in str(session.execute_calls[0][0])
|
||||
compiled = str(session.scalar_stmt.compile(dialect=postgresql.dialect()))
|
||||
assert "FOR UPDATE" not in compiled
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_basic_crud(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user