feat(persistence): unify ThreadMetaStore interface with user isolation and factory

Add user_id parameter to all ThreadMetaStore abstract methods. Implement
owner isolation in MemoryThreadMetaStore with _get_owned_record helper.
Add check_access to base class and memory implementation. Add
make_thread_store factory to simplify deps.py initialization. Add
memory-backend isolation tests.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
rayhpeng 2026-04-10 15:05:32 +08:00
parent 8da1903168
commit b2ec1f99b9
4 changed files with 251 additions and 25 deletions

View File

@ -1,13 +1,38 @@
"""Thread metadata persistence — ORM, abstract store, and concrete implementations."""
from __future__ import annotations
from typing import TYPE_CHECKING
from deerflow.persistence.thread_meta.base import ThreadMetaStore
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
from deerflow.persistence.thread_meta.model import ThreadMetaRow
from deerflow.persistence.thread_meta.sql import ThreadMetaRepository
if TYPE_CHECKING:
from langgraph.store.base import BaseStore
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
__all__ = [
"MemoryThreadMetaStore",
"ThreadMetaRepository",
"ThreadMetaRow",
"ThreadMetaStore",
"make_thread_store",
]
def make_thread_store(
session_factory: async_sessionmaker[AsyncSession] | None,
store: BaseStore | None = None,
) -> ThreadMetaStore:
"""Create the appropriate ThreadMetaStore based on available backends.
Returns a SQL-backed repository when a session factory is available,
otherwise falls back to the in-memory LangGraph Store implementation.
"""
if session_factory is not None:
return ThreadMetaRepository(session_factory)
if store is None:
raise ValueError("make_thread_store requires either a session_factory (SQL) or a store (memory)")
return MemoryThreadMetaStore(store)

View File

@ -3,12 +3,21 @@
Implementations:
- ThreadMetaRepository: SQL-backed (sqlite / postgres via SQLAlchemy)
- MemoryThreadMetaStore: wraps LangGraph BaseStore (memory mode)
All mutating and querying methods accept a ``user_id`` parameter with
three-state semantics (see :mod:`deerflow.runtime.user_context`):
- ``AUTO`` (default): resolve from the request-scoped contextvar.
- Explicit ``str``: use the provided value verbatim.
- Explicit ``None``: bypass owner filtering (migration/CLI only).
"""
from __future__ import annotations
import abc
from deerflow.runtime.user_context import AUTO, _AutoSentinel
class ThreadMetaStore(abc.ABC):
@abc.abstractmethod
@ -17,14 +26,14 @@ class ThreadMetaStore(abc.ABC):
thread_id: str,
*,
assistant_id: str | None = None,
owner_id: str | None = None,
user_id: str | None | _AutoSentinel = AUTO,
display_name: str | None = None,
metadata: dict | None = None,
) -> dict:
pass
@abc.abstractmethod
async def get(self, thread_id: str) -> dict | None:
async def get(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> dict | None:
pass
@abc.abstractmethod
@ -35,26 +44,33 @@ class ThreadMetaStore(abc.ABC):
status: str | None = None,
limit: int = 100,
offset: int = 0,
user_id: str | None | _AutoSentinel = AUTO,
) -> list[dict]:
pass
@abc.abstractmethod
async def update_display_name(self, thread_id: str, display_name: str) -> None:
async def update_display_name(self, thread_id: str, display_name: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
pass
@abc.abstractmethod
async def update_status(self, thread_id: str, status: str) -> None:
async def update_status(self, thread_id: str, status: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
pass
@abc.abstractmethod
async def update_metadata(self, thread_id: str, metadata: dict) -> None:
async def update_metadata(self, thread_id: str, metadata: dict, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
"""Merge ``metadata`` into the thread's metadata field.
Existing keys are overwritten by the new values; keys absent from
``metadata`` are preserved. No-op if the thread does not exist.
``metadata`` are preserved. No-op if the thread does not exist
or the owner check fails.
"""
pass
@abc.abstractmethod
async def delete(self, thread_id: str) -> None:
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
"""Check if ``user_id`` has access to ``thread_id``."""
pass
@abc.abstractmethod
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
pass

View File

@ -13,6 +13,7 @@ from typing import Any
from langgraph.store.base import BaseStore
from deerflow.persistence.thread_meta.base import ThreadMetaStore
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
THREADS_NS: tuple[str, ...] = ("threads",)
@ -21,20 +22,37 @@ class MemoryThreadMetaStore(ThreadMetaStore):
def __init__(self, store: BaseStore) -> None:
self._store = store
async def _get_owned_record(
self,
thread_id: str,
user_id: str | None | _AutoSentinel,
method_name: str,
) -> dict | None:
"""Fetch a record and verify ownership. Returns a mutable copy, or None."""
resolved = resolve_user_id(user_id, method_name=method_name)
item = await self._store.aget(THREADS_NS, thread_id)
if item is None:
return None
record = dict(item.value)
if resolved is not None and record.get("user_id") != resolved:
return None
return record
async def create(
self,
thread_id: str,
*,
assistant_id: str | None = None,
owner_id: str | None = None,
user_id: str | None | _AutoSentinel = AUTO,
display_name: str | None = None,
metadata: dict | None = None,
) -> dict:
resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.create")
now = time.time()
record: dict[str, Any] = {
"thread_id": thread_id,
"assistant_id": assistant_id,
"owner_id": owner_id,
"user_id": resolved_user_id,
"display_name": display_name,
"status": "idle",
"metadata": metadata or {},
@ -45,9 +63,8 @@ class MemoryThreadMetaStore(ThreadMetaStore):
await self._store.aput(THREADS_NS, thread_id, record)
return record
async def get(self, thread_id: str) -> dict | None:
item = await self._store.aget(THREADS_NS, thread_id)
return item.value if item is not None else None
async def get(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> dict | None:
return await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.get")
async def search(
self,
@ -56,12 +73,16 @@ class MemoryThreadMetaStore(ThreadMetaStore):
status: str | None = None,
limit: int = 100,
offset: int = 0,
user_id: str | None | _AutoSentinel = AUTO,
) -> list[dict]:
resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.search")
filter_dict: dict[str, Any] = {}
if metadata:
filter_dict.update(metadata)
if status:
filter_dict["status"] = status
if resolved_user_id is not None:
filter_dict["user_id"] = resolved_user_id
items = await self._store.asearch(
THREADS_NS,
@ -71,37 +92,45 @@ class MemoryThreadMetaStore(ThreadMetaStore):
)
return [self._item_to_dict(item) for item in items]
async def update_display_name(self, thread_id: str, display_name: str) -> None:
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
item = await self._store.aget(THREADS_NS, thread_id)
if item is None:
return not require_existing
record_user_id = item.value.get("user_id")
if record_user_id is None:
return True
return record_user_id == user_id
async def update_display_name(self, thread_id: str, display_name: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_display_name")
if record is None:
return
record = dict(item.value)
record["display_name"] = display_name
record["updated_at"] = time.time()
await self._store.aput(THREADS_NS, thread_id, record)
async def update_status(self, thread_id: str, status: str) -> None:
item = await self._store.aget(THREADS_NS, thread_id)
if item is None:
async def update_status(self, thread_id: str, status: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_status")
if record is None:
return
record = dict(item.value)
record["status"] = status
record["updated_at"] = time.time()
await self._store.aput(THREADS_NS, thread_id, record)
async def update_metadata(self, thread_id: str, metadata: dict) -> None:
"""Merge ``metadata`` into the in-memory record. No-op if absent."""
item = await self._store.aget(THREADS_NS, thread_id)
if item is None:
async def update_metadata(self, thread_id: str, metadata: dict, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_metadata")
if record is None:
return
record = dict(item.value)
merged = dict(record.get("metadata") or {})
merged.update(metadata)
record["metadata"] = merged
record["updated_at"] = time.time()
await self._store.aput(THREADS_NS, thread_id, record)
async def delete(self, thread_id: str) -> None:
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.delete")
if record is None:
return
await self._store.adelete(THREADS_NS, thread_id)
@staticmethod
@ -111,7 +140,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
return {
"thread_id": item.key,
"assistant_id": val.get("assistant_id"),
"owner_id": val.get("owner_id"),
"user_id": val.get("user_id"),
"display_name": val.get("display_name"),
"status": val.get("status", "idle"),
"metadata": val.get("metadata", {}),

View File

@ -0,0 +1,156 @@
"""Owner isolation tests for MemoryThreadMetaStore.
Mirrors the SQL-backed tests in test_owner_isolation.py but exercises
the in-memory LangGraph Store backend used when database.backend=memory.
"""
from __future__ import annotations
from types import SimpleNamespace
import pytest
from langgraph.store.memory import InMemoryStore
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
from deerflow.runtime.user_context import reset_current_user, set_current_user
USER_A = SimpleNamespace(id="user-a", email="a@test.local")
USER_B = SimpleNamespace(id="user-b", email="b@test.local")
def _as_user(user):
class _Ctx:
def __enter__(self):
self._token = set_current_user(user)
return user
def __exit__(self, *exc):
reset_current_user(self._token)
return _Ctx()
@pytest.fixture
def store():
return MemoryThreadMetaStore(InMemoryStore())
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_search_isolation(store):
"""search() returns only threads owned by the current user."""
with _as_user(USER_A):
await store.create("t-alpha", display_name="A's thread")
with _as_user(USER_B):
await store.create("t-beta", display_name="B's thread")
with _as_user(USER_A):
results = await store.search()
assert [r["thread_id"] for r in results] == ["t-alpha"]
with _as_user(USER_B):
results = await store.search()
assert [r["thread_id"] for r in results] == ["t-beta"]
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_get_isolation(store):
"""get() returns None for threads owned by another user."""
with _as_user(USER_A):
await store.create("t-alpha", display_name="A's thread")
with _as_user(USER_B):
assert await store.get("t-alpha") is None
with _as_user(USER_A):
result = await store.get("t-alpha")
assert result is not None
assert result["display_name"] == "A's thread"
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_update_display_name_denied(store):
"""User B cannot rename User A's thread."""
with _as_user(USER_A):
await store.create("t-alpha", display_name="original")
with _as_user(USER_B):
await store.update_display_name("t-alpha", "hacked")
with _as_user(USER_A):
row = await store.get("t-alpha")
assert row is not None
assert row["display_name"] == "original"
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_update_status_denied(store):
"""User B cannot change status of User A's thread."""
with _as_user(USER_A):
await store.create("t-alpha")
with _as_user(USER_B):
await store.update_status("t-alpha", "error")
with _as_user(USER_A):
row = await store.get("t-alpha")
assert row is not None
assert row["status"] == "idle"
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_update_metadata_denied(store):
"""User B cannot modify metadata of User A's thread."""
with _as_user(USER_A):
await store.create("t-alpha", metadata={"key": "original"})
with _as_user(USER_B):
await store.update_metadata("t-alpha", {"key": "hacked"})
with _as_user(USER_A):
row = await store.get("t-alpha")
assert row is not None
assert row["metadata"]["key"] == "original"
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_delete_denied(store):
"""User B cannot delete User A's thread."""
with _as_user(USER_A):
await store.create("t-alpha")
with _as_user(USER_B):
await store.delete("t-alpha")
with _as_user(USER_A):
row = await store.get("t-alpha")
assert row is not None
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_no_context_raises(store):
"""Calling methods without user context raises RuntimeError."""
with pytest.raises(RuntimeError, match="no user context is set"):
await store.search()
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_explicit_none_bypasses_filter(store):
"""user_id=None bypasses isolation (migration/CLI escape hatch)."""
with _as_user(USER_A):
await store.create("t-alpha")
with _as_user(USER_B):
await store.create("t-beta")
all_rows = await store.search(user_id=None)
assert {r["thread_id"] for r in all_rows} == {"t-alpha", "t-beta"}
row = await store.get("t-alpha", user_id=None)
assert row is not None