From b2ec1f99b9fc16a159e3684927dc689548132148 Mon Sep 17 00:00:00 2001 From: rayhpeng Date: Fri, 10 Apr 2026 15:05:32 +0800 Subject: [PATCH] feat(persistence): unify ThreadMetaStore interface with user isolation and factory Add user_id parameter to all ThreadMetaStore abstract methods. Implement owner isolation in MemoryThreadMetaStore with _get_owned_record helper. Add check_access to base class and memory implementation. Add make_thread_store factory to simplify deps.py initialization. Add memory-backend isolation tests. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../persistence/thread_meta/__init__.py | 25 +++ .../deerflow/persistence/thread_meta/base.py | 30 +++- .../persistence/thread_meta/memory.py | 65 ++++++-- .../test_memory_thread_meta_isolation.py | 156 ++++++++++++++++++ 4 files changed, 251 insertions(+), 25 deletions(-) create mode 100644 backend/tests/test_memory_thread_meta_isolation.py diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/__init__.py b/backend/packages/harness/deerflow/persistence/thread_meta/__init__.py index 8e497bb7e..080ce8093 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/__init__.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/__init__.py @@ -1,13 +1,38 @@ """Thread metadata persistence — ORM, abstract store, and concrete implementations.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + from deerflow.persistence.thread_meta.base import ThreadMetaStore from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore from deerflow.persistence.thread_meta.model import ThreadMetaRow from deerflow.persistence.thread_meta.sql import ThreadMetaRepository +if TYPE_CHECKING: + from langgraph.store.base import BaseStore + from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + __all__ = [ "MemoryThreadMetaStore", "ThreadMetaRepository", "ThreadMetaRow", "ThreadMetaStore", + "make_thread_store", ] + + +def make_thread_store( + session_factory: async_sessionmaker[AsyncSession] | None, + store: BaseStore | None = None, +) -> ThreadMetaStore: + """Create the appropriate ThreadMetaStore based on available backends. + + Returns a SQL-backed repository when a session factory is available, + otherwise falls back to the in-memory LangGraph Store implementation. + """ + if session_factory is not None: + return ThreadMetaRepository(session_factory) + if store is None: + raise ValueError("make_thread_store requires either a session_factory (SQL) or a store (memory)") + return MemoryThreadMetaStore(store) diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/base.py b/backend/packages/harness/deerflow/persistence/thread_meta/base.py index 466a82a21..c87c10a16 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/base.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/base.py @@ -3,12 +3,21 @@ Implementations: - ThreadMetaRepository: SQL-backed (sqlite / postgres via SQLAlchemy) - MemoryThreadMetaStore: wraps LangGraph BaseStore (memory mode) + +All mutating and querying methods accept a ``user_id`` parameter with +three-state semantics (see :mod:`deerflow.runtime.user_context`): + +- ``AUTO`` (default): resolve from the request-scoped contextvar. +- Explicit ``str``: use the provided value verbatim. +- Explicit ``None``: bypass owner filtering (migration/CLI only). """ from __future__ import annotations import abc +from deerflow.runtime.user_context import AUTO, _AutoSentinel + class ThreadMetaStore(abc.ABC): @abc.abstractmethod @@ -17,14 +26,14 @@ class ThreadMetaStore(abc.ABC): thread_id: str, *, assistant_id: str | None = None, - owner_id: str | None = None, + user_id: str | None | _AutoSentinel = AUTO, display_name: str | None = None, metadata: dict | None = None, ) -> dict: pass @abc.abstractmethod - async def get(self, thread_id: str) -> dict | None: + async def get(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> dict | None: pass @abc.abstractmethod @@ -35,26 +44,33 @@ class ThreadMetaStore(abc.ABC): status: str | None = None, limit: int = 100, offset: int = 0, + user_id: str | None | _AutoSentinel = AUTO, ) -> list[dict]: pass @abc.abstractmethod - async def update_display_name(self, thread_id: str, display_name: str) -> None: + async def update_display_name(self, thread_id: str, display_name: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: pass @abc.abstractmethod - async def update_status(self, thread_id: str, status: str) -> None: + async def update_status(self, thread_id: str, status: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: pass @abc.abstractmethod - async def update_metadata(self, thread_id: str, metadata: dict) -> None: + async def update_metadata(self, thread_id: str, metadata: dict, *, user_id: str | None | _AutoSentinel = AUTO) -> None: """Merge ``metadata`` into the thread's metadata field. Existing keys are overwritten by the new values; keys absent from - ``metadata`` are preserved. No-op if the thread does not exist. + ``metadata`` are preserved. No-op if the thread does not exist + or the owner check fails. """ pass @abc.abstractmethod - async def delete(self, thread_id: str) -> None: + async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool: + """Check if ``user_id`` has access to ``thread_id``.""" + pass + + @abc.abstractmethod + async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: pass diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/memory.py b/backend/packages/harness/deerflow/persistence/thread_meta/memory.py index ab921f229..ccf59ad42 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/memory.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/memory.py @@ -13,6 +13,7 @@ from typing import Any from langgraph.store.base import BaseStore from deerflow.persistence.thread_meta.base import ThreadMetaStore +from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id THREADS_NS: tuple[str, ...] = ("threads",) @@ -21,20 +22,37 @@ class MemoryThreadMetaStore(ThreadMetaStore): def __init__(self, store: BaseStore) -> None: self._store = store + async def _get_owned_record( + self, + thread_id: str, + user_id: str | None | _AutoSentinel, + method_name: str, + ) -> dict | None: + """Fetch a record and verify ownership. Returns a mutable copy, or None.""" + resolved = resolve_user_id(user_id, method_name=method_name) + item = await self._store.aget(THREADS_NS, thread_id) + if item is None: + return None + record = dict(item.value) + if resolved is not None and record.get("user_id") != resolved: + return None + return record + async def create( self, thread_id: str, *, assistant_id: str | None = None, - owner_id: str | None = None, + user_id: str | None | _AutoSentinel = AUTO, display_name: str | None = None, metadata: dict | None = None, ) -> dict: + resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.create") now = time.time() record: dict[str, Any] = { "thread_id": thread_id, "assistant_id": assistant_id, - "owner_id": owner_id, + "user_id": resolved_user_id, "display_name": display_name, "status": "idle", "metadata": metadata or {}, @@ -45,9 +63,8 @@ class MemoryThreadMetaStore(ThreadMetaStore): await self._store.aput(THREADS_NS, thread_id, record) return record - async def get(self, thread_id: str) -> dict | None: - item = await self._store.aget(THREADS_NS, thread_id) - return item.value if item is not None else None + async def get(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> dict | None: + return await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.get") async def search( self, @@ -56,12 +73,16 @@ class MemoryThreadMetaStore(ThreadMetaStore): status: str | None = None, limit: int = 100, offset: int = 0, + user_id: str | None | _AutoSentinel = AUTO, ) -> list[dict]: + resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.search") filter_dict: dict[str, Any] = {} if metadata: filter_dict.update(metadata) if status: filter_dict["status"] = status + if resolved_user_id is not None: + filter_dict["user_id"] = resolved_user_id items = await self._store.asearch( THREADS_NS, @@ -71,37 +92,45 @@ class MemoryThreadMetaStore(ThreadMetaStore): ) return [self._item_to_dict(item) for item in items] - async def update_display_name(self, thread_id: str, display_name: str) -> None: + async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool: item = await self._store.aget(THREADS_NS, thread_id) if item is None: + return not require_existing + record_user_id = item.value.get("user_id") + if record_user_id is None: + return True + return record_user_id == user_id + + async def update_display_name(self, thread_id: str, display_name: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: + record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_display_name") + if record is None: return - record = dict(item.value) record["display_name"] = display_name record["updated_at"] = time.time() await self._store.aput(THREADS_NS, thread_id, record) - async def update_status(self, thread_id: str, status: str) -> None: - item = await self._store.aget(THREADS_NS, thread_id) - if item is None: + async def update_status(self, thread_id: str, status: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: + record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_status") + if record is None: return - record = dict(item.value) record["status"] = status record["updated_at"] = time.time() await self._store.aput(THREADS_NS, thread_id, record) - async def update_metadata(self, thread_id: str, metadata: dict) -> None: - """Merge ``metadata`` into the in-memory record. No-op if absent.""" - item = await self._store.aget(THREADS_NS, thread_id) - if item is None: + async def update_metadata(self, thread_id: str, metadata: dict, *, user_id: str | None | _AutoSentinel = AUTO) -> None: + record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_metadata") + if record is None: return - record = dict(item.value) merged = dict(record.get("metadata") or {}) merged.update(metadata) record["metadata"] = merged record["updated_at"] = time.time() await self._store.aput(THREADS_NS, thread_id, record) - async def delete(self, thread_id: str) -> None: + async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: + record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.delete") + if record is None: + return await self._store.adelete(THREADS_NS, thread_id) @staticmethod @@ -111,7 +140,7 @@ class MemoryThreadMetaStore(ThreadMetaStore): return { "thread_id": item.key, "assistant_id": val.get("assistant_id"), - "owner_id": val.get("owner_id"), + "user_id": val.get("user_id"), "display_name": val.get("display_name"), "status": val.get("status", "idle"), "metadata": val.get("metadata", {}), diff --git a/backend/tests/test_memory_thread_meta_isolation.py b/backend/tests/test_memory_thread_meta_isolation.py new file mode 100644 index 000000000..25c9298f0 --- /dev/null +++ b/backend/tests/test_memory_thread_meta_isolation.py @@ -0,0 +1,156 @@ +"""Owner isolation tests for MemoryThreadMetaStore. + +Mirrors the SQL-backed tests in test_owner_isolation.py but exercises +the in-memory LangGraph Store backend used when database.backend=memory. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +from langgraph.store.memory import InMemoryStore + +from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore +from deerflow.runtime.user_context import reset_current_user, set_current_user + +USER_A = SimpleNamespace(id="user-a", email="a@test.local") +USER_B = SimpleNamespace(id="user-b", email="b@test.local") + + +def _as_user(user): + class _Ctx: + def __enter__(self): + self._token = set_current_user(user) + return user + + def __exit__(self, *exc): + reset_current_user(self._token) + + return _Ctx() + + +@pytest.fixture +def store(): + return MemoryThreadMetaStore(InMemoryStore()) + + +@pytest.mark.anyio +@pytest.mark.no_auto_user +async def test_search_isolation(store): + """search() returns only threads owned by the current user.""" + with _as_user(USER_A): + await store.create("t-alpha", display_name="A's thread") + with _as_user(USER_B): + await store.create("t-beta", display_name="B's thread") + + with _as_user(USER_A): + results = await store.search() + assert [r["thread_id"] for r in results] == ["t-alpha"] + + with _as_user(USER_B): + results = await store.search() + assert [r["thread_id"] for r in results] == ["t-beta"] + + +@pytest.mark.anyio +@pytest.mark.no_auto_user +async def test_get_isolation(store): + """get() returns None for threads owned by another user.""" + with _as_user(USER_A): + await store.create("t-alpha", display_name="A's thread") + + with _as_user(USER_B): + assert await store.get("t-alpha") is None + + with _as_user(USER_A): + result = await store.get("t-alpha") + assert result is not None + assert result["display_name"] == "A's thread" + + +@pytest.mark.anyio +@pytest.mark.no_auto_user +async def test_update_display_name_denied(store): + """User B cannot rename User A's thread.""" + with _as_user(USER_A): + await store.create("t-alpha", display_name="original") + + with _as_user(USER_B): + await store.update_display_name("t-alpha", "hacked") + + with _as_user(USER_A): + row = await store.get("t-alpha") + assert row is not None + assert row["display_name"] == "original" + + +@pytest.mark.anyio +@pytest.mark.no_auto_user +async def test_update_status_denied(store): + """User B cannot change status of User A's thread.""" + with _as_user(USER_A): + await store.create("t-alpha") + + with _as_user(USER_B): + await store.update_status("t-alpha", "error") + + with _as_user(USER_A): + row = await store.get("t-alpha") + assert row is not None + assert row["status"] == "idle" + + +@pytest.mark.anyio +@pytest.mark.no_auto_user +async def test_update_metadata_denied(store): + """User B cannot modify metadata of User A's thread.""" + with _as_user(USER_A): + await store.create("t-alpha", metadata={"key": "original"}) + + with _as_user(USER_B): + await store.update_metadata("t-alpha", {"key": "hacked"}) + + with _as_user(USER_A): + row = await store.get("t-alpha") + assert row is not None + assert row["metadata"]["key"] == "original" + + +@pytest.mark.anyio +@pytest.mark.no_auto_user +async def test_delete_denied(store): + """User B cannot delete User A's thread.""" + with _as_user(USER_A): + await store.create("t-alpha") + + with _as_user(USER_B): + await store.delete("t-alpha") + + with _as_user(USER_A): + row = await store.get("t-alpha") + assert row is not None + + +@pytest.mark.anyio +@pytest.mark.no_auto_user +async def test_no_context_raises(store): + """Calling methods without user context raises RuntimeError.""" + with pytest.raises(RuntimeError, match="no user context is set"): + await store.search() + + +@pytest.mark.anyio +@pytest.mark.no_auto_user +async def test_explicit_none_bypasses_filter(store): + """user_id=None bypasses isolation (migration/CLI escape hatch).""" + with _as_user(USER_A): + await store.create("t-alpha") + with _as_user(USER_B): + await store.create("t-beta") + + all_rows = await store.search(user_id=None) + assert {r["thread_id"] for r in all_rows} == {"t-alpha", "t-beta"} + + row = await store.get("t-alpha", user_id=None) + assert row is not None