mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
refactor(persistence): rename owner_id to user_id and thread_meta_repo to thread_store
Rename owner_id to user_id across all persistence models, repositories, stores, routers, and tests for clearer semantics. Rename thread_meta_repo to thread_store for consistency with run_store/run_event_store naming. Add ThreadMetaStore return type annotation to get_thread_store(). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
03952eca53
commit
8da1903168
@ -42,6 +42,11 @@ logger = logging.getLogger(__name__)
|
||||
async def _ensure_admin_user(app: FastAPI) -> None:
|
||||
"""Startup hook: generate init token on first boot; migrate orphan threads otherwise.
|
||||
|
||||
After admin creation, migrate orphan threads from the LangGraph
|
||||
store (metadata.user_id unset) to the admin account. This is the
|
||||
"no-auth → with-auth" upgrade path: users who ran DeerFlow without
|
||||
authentication have existing LangGraph thread data that needs an
|
||||
owner assigned.
|
||||
First boot (no admin exists):
|
||||
- Generates a one-time ``init_token`` stored in ``app.state.init_token``
|
||||
- Logs the token to stdout so the operator can copy-paste it into the
|
||||
@ -52,7 +57,7 @@ async def _ensure_admin_user(app: FastAPI) -> None:
|
||||
- Runs the one-time "no-auth → with-auth" orphan thread migration for
|
||||
existing LangGraph thread metadata that has no owner_id.
|
||||
|
||||
No SQL persistence migration is needed: the four owner_id columns
|
||||
No SQL persistence migration is needed: the four user_id columns
|
||||
(threads_meta, runs, run_events, feedback) only come into existence
|
||||
alongside the auth module via create_all, so freshly created tables
|
||||
never contain NULL-owner rows.
|
||||
@ -96,6 +101,8 @@ async def _ensure_admin_user(app: FastAPI) -> None:
|
||||
admin_id = str(row.id)
|
||||
|
||||
# LangGraph store orphan migration — non-fatal.
|
||||
# This covers the "no-auth → with-auth" upgrade path for users
|
||||
# whose existing LangGraph thread metadata has no user_id set.
|
||||
store = getattr(app.state, "store", None)
|
||||
if store is not None:
|
||||
try:
|
||||
@ -127,7 +134,7 @@ async def _iter_store_items(store, namespace, *, page_size: int = 500):
|
||||
|
||||
|
||||
async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
|
||||
"""Migrate LangGraph store threads with no owner_id to the given admin.
|
||||
"""Migrate LangGraph store threads with no user_id to the given admin.
|
||||
|
||||
Uses cursor pagination so all orphans are migrated regardless of
|
||||
count. Returns the number of rows migrated.
|
||||
@ -135,8 +142,8 @@ async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
|
||||
migrated = 0
|
||||
async for item in _iter_store_items(store, ("threads",)):
|
||||
metadata = item.value.get("metadata", {})
|
||||
if not metadata.get("owner_id"):
|
||||
metadata["owner_id"] = admin_user_id
|
||||
if not metadata.get("user_id"):
|
||||
metadata["user_id"] = admin_user_id
|
||||
item.value["metadata"] = metadata
|
||||
await store.aput(("threads",), item.key, item.value)
|
||||
migrated += 1
|
||||
|
||||
@ -233,18 +233,18 @@ def require_permission(
|
||||
# (``threads_meta`` table). We verify ownership via
|
||||
# ``ThreadMetaStore.check_access``: it returns True for
|
||||
# missing rows (untracked legacy thread) and for rows whose
|
||||
# ``owner_id`` is NULL (shared / pre-auth data), so this is
|
||||
# ``user_id`` is NULL (shared / pre-auth data), so this is
|
||||
# strict-deny rather than strict-allow — only an *existing*
|
||||
# row with a *different* owner_id triggers 404.
|
||||
# row with a *different* user_id triggers 404.
|
||||
if owner_check:
|
||||
thread_id = kwargs.get("thread_id")
|
||||
if thread_id is None:
|
||||
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
||||
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
from app.gateway.deps import get_thread_store
|
||||
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
allowed = await thread_meta_repo.check_access(
|
||||
thread_store = get_thread_store(request)
|
||||
allowed = await thread_store.check_access(
|
||||
thread_id,
|
||||
str(auth.user.id),
|
||||
require_existing=require_existing,
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
"""Centralized accessors for singleton objects stored on ``app.state``.
|
||||
|
||||
**Getters** (used by routers): raise 503 when a required dependency is
|
||||
missing, except ``get_store`` and ``get_thread_meta_repo`` which return
|
||||
``None``.
|
||||
missing, except ``get_store`` which returns ``None``.
|
||||
|
||||
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
|
||||
"""
|
||||
@ -20,6 +19,7 @@ from deerflow.runtime import RunContext, RunManager
|
||||
if TYPE_CHECKING:
|
||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@ -31,10 +31,10 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
async with langgraph_runtime(app):
|
||||
yield
|
||||
"""
|
||||
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
|
||||
from deerflow.runtime import make_store, make_stream_bridge
|
||||
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
async with AsyncExitStack() as stack:
|
||||
@ -53,18 +53,18 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
if sf is not None:
|
||||
from deerflow.persistence.feedback import FeedbackRepository
|
||||
from deerflow.persistence.run import RunRepository
|
||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
||||
|
||||
app.state.run_store = RunRepository(sf)
|
||||
app.state.feedback_repo = FeedbackRepository(sf)
|
||||
app.state.thread_meta_repo = ThreadMetaRepository(sf)
|
||||
else:
|
||||
from deerflow.persistence.thread_meta import MemoryThreadMetaStore
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
app.state.run_store = MemoryRunStore()
|
||||
app.state.feedback_repo = None
|
||||
app.state.thread_meta_repo = MemoryThreadMetaStore(app.state.store)
|
||||
|
||||
from deerflow.persistence.thread_meta import make_thread_store
|
||||
|
||||
app.state.thread_store = make_thread_store(sf, app.state.store)
|
||||
|
||||
# Run event store (has its own factory with config-driven backend selection)
|
||||
run_events_config = getattr(config, "run_events", None)
|
||||
@ -110,7 +110,12 @@ def get_store(request: Request):
|
||||
return getattr(request.app.state, "store", None)
|
||||
|
||||
|
||||
get_thread_meta_repo = _require("thread_meta_repo", "Thread metadata store")
|
||||
def get_thread_store(request: Request) -> ThreadMetaStore:
|
||||
"""Return the thread metadata store (SQL or memory-backed)."""
|
||||
val = getattr(request.app.state, "thread_store", None)
|
||||
if val is None:
|
||||
raise HTTPException(status_code=503, detail="Thread metadata store not available")
|
||||
return val
|
||||
|
||||
|
||||
def get_run_context(request: Request) -> RunContext:
|
||||
@ -128,7 +133,7 @@ def get_run_context(request: Request) -> RunContext:
|
||||
store=get_store(request),
|
||||
event_store=get_run_event_store(request),
|
||||
run_events_config=getattr(get_app_config(), "run_events", None),
|
||||
thread_meta_repo=get_thread_meta_repo(request),
|
||||
thread_store=get_thread_store(request),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -93,14 +93,14 @@ async def authenticate(request):
|
||||
|
||||
@auth.on
|
||||
async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict):
|
||||
"""Inject owner_id metadata on writes; filter by owner_id on reads.
|
||||
"""Inject user_id metadata on writes; filter by user_id on reads.
|
||||
|
||||
Gateway stores thread ownership as ``metadata.owner_id``.
|
||||
Gateway stores thread ownership as ``metadata.user_id``.
|
||||
This handler ensures LangGraph Server enforces the same isolation.
|
||||
"""
|
||||
# On create/update: stamp owner_id into metadata
|
||||
# On create/update: stamp user_id into metadata
|
||||
metadata = value.setdefault("metadata", {})
|
||||
metadata["owner_id"] = ctx.user.identity
|
||||
metadata["user_id"] = ctx.user.identity
|
||||
|
||||
# Return filter dict — LangGraph applies it to search/read/delete
|
||||
return {"owner_id": ctx.user.identity}
|
||||
return {"user_id": ctx.user.identity}
|
||||
|
||||
@ -34,7 +34,7 @@ class FeedbackResponse(BaseModel):
|
||||
feedback_id: str
|
||||
run_id: str
|
||||
thread_id: str
|
||||
owner_id: str | None = None
|
||||
user_id: str | None = None
|
||||
message_id: str | None = None
|
||||
rating: int
|
||||
comment: str | None = None
|
||||
@ -80,7 +80,7 @@ async def create_feedback(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
rating=body.rating,
|
||||
owner_id=user_id,
|
||||
user_id=user_id,
|
||||
message_id=body.message_id,
|
||||
comment=body.comment,
|
||||
)
|
||||
|
||||
@ -34,7 +34,7 @@ router = APIRouter(prefix="/api/threads", tags=["threads"])
|
||||
# them. Pydantic ``@field_validator("metadata")`` strips them on every
|
||||
# inbound model below so a malicious client cannot reflect a forged
|
||||
# owner identity through the API surface. Defense-in-depth — the
|
||||
# row-level invariant is still ``threads_meta.owner_id`` populated from
|
||||
# row-level invariant is still ``threads_meta.user_id`` populated from
|
||||
# the auth contextvar; this list closes the metadata-blob echo gap.
|
||||
_SERVER_RESERVED_METADATA_KEYS: frozenset[str] = frozenset({"owner_id", "user_id"})
|
||||
|
||||
@ -194,7 +194,7 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
|
||||
and removes the thread_meta row from the configured ThreadMetaStore
|
||||
(sqlite or memory).
|
||||
"""
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
from app.gateway.deps import get_thread_store
|
||||
|
||||
# Clean local filesystem
|
||||
response = _delete_thread_data(thread_id)
|
||||
@ -211,8 +211,8 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
|
||||
# Remove thread_meta row (best-effort) — required for sqlite backend
|
||||
# so the deleted thread no longer appears in /threads/search.
|
||||
try:
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
await thread_meta_repo.delete(thread_id)
|
||||
thread_store = get_thread_store(request)
|
||||
await thread_store.delete(thread_id)
|
||||
except Exception:
|
||||
logger.debug("Could not delete thread_meta for %s (not critical)", sanitize_log_param(thread_id))
|
||||
|
||||
@ -227,17 +227,17 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
and an empty checkpoint (so state endpoints work immediately).
|
||||
Idempotent: returns the existing record when ``thread_id`` already exists.
|
||||
"""
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
from app.gateway.deps import get_thread_store
|
||||
|
||||
checkpointer = get_checkpointer(request)
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
thread_store = get_thread_store(request)
|
||||
thread_id = body.thread_id or str(uuid.uuid4())
|
||||
now = time.time()
|
||||
# ``body.metadata`` is already stripped of server-reserved keys by
|
||||
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
||||
|
||||
# Idempotency: return existing record when already present
|
||||
existing_record = await thread_meta_repo.get(thread_id)
|
||||
existing_record = await thread_store.get(thread_id)
|
||||
if existing_record is not None:
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
@ -249,7 +249,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
|
||||
# Write thread_meta so the thread appears in /threads/search immediately
|
||||
try:
|
||||
await thread_meta_repo.create(
|
||||
await thread_store.create(
|
||||
thread_id,
|
||||
assistant_id=getattr(body, "assistant_id", None),
|
||||
metadata=body.metadata,
|
||||
@ -293,9 +293,9 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
|
||||
Delegates to the configured ThreadMetaStore implementation
|
||||
(SQL-backed for sqlite/postgres, Store-backed for memory mode).
|
||||
"""
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
from app.gateway.deps import get_thread_store
|
||||
|
||||
repo = get_thread_meta_repo(request)
|
||||
repo = get_thread_store(request)
|
||||
rows = await repo.search(
|
||||
metadata=body.metadata or None,
|
||||
status=body.status,
|
||||
@ -320,22 +320,22 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
|
||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
||||
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
|
||||
"""Merge metadata into a thread record."""
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
from app.gateway.deps import get_thread_store
|
||||
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
record = await thread_meta_repo.get(thread_id)
|
||||
thread_store = get_thread_store(request)
|
||||
record = await thread_store.get(thread_id)
|
||||
if record is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
# ``body.metadata`` already stripped by ``ThreadPatchRequest._strip_reserved``.
|
||||
try:
|
||||
await thread_meta_repo.update_metadata(thread_id, body.metadata)
|
||||
await thread_store.update_metadata(thread_id, body.metadata)
|
||||
except Exception:
|
||||
logger.exception("Failed to patch thread %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to update thread")
|
||||
|
||||
# Re-read to get the merged metadata + refreshed updated_at
|
||||
record = await thread_meta_repo.get(thread_id) or record
|
||||
record = await thread_store.get(thread_id) or record
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=record.get("status", "idle"),
|
||||
@ -354,12 +354,12 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
||||
execution status from the checkpointer. Falls back to the checkpointer
|
||||
alone for threads that pre-date ThreadMetaStore adoption (backward compat).
|
||||
"""
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
from app.gateway.deps import get_thread_store
|
||||
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
thread_store = get_thread_store(request)
|
||||
checkpointer = get_checkpointer(request)
|
||||
|
||||
record: dict | None = await thread_meta_repo.get(thread_id)
|
||||
record: dict | None = await thread_store.get(thread_id)
|
||||
|
||||
# Derive accurate status from the checkpointer
|
||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
@ -462,10 +462,10 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
||||
ThreadMetaStore abstraction so that ``/threads/search`` reflects the
|
||||
change immediately in both sqlite and memory backends.
|
||||
"""
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
from app.gateway.deps import get_thread_store
|
||||
|
||||
checkpointer = get_checkpointer(request)
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
thread_store = get_thread_store(request)
|
||||
|
||||
# checkpoint_ns must be present in the config for aput — default to ""
|
||||
# (the root graph namespace). checkpoint_id is optional; omitting it
|
||||
@ -529,7 +529,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
||||
new_title = body.values["title"]
|
||||
if new_title: # Skip empty strings and None
|
||||
try:
|
||||
await thread_meta_repo.update_display_name(thread_id, new_title)
|
||||
await thread_store.update_display_name(thread_id, new_title)
|
||||
except Exception:
|
||||
logger.debug("Failed to sync title to thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||
|
||||
|
||||
@ -229,15 +229,15 @@ async def start_run(
|
||||
# even for threads that were never explicitly created via POST /threads
|
||||
# (e.g. stateless runs).
|
||||
try:
|
||||
existing = await run_ctx.thread_meta_repo.get(thread_id)
|
||||
existing = await run_ctx.thread_store.get(thread_id)
|
||||
if existing is None:
|
||||
await run_ctx.thread_meta_repo.create(
|
||||
await run_ctx.thread_store.create(
|
||||
thread_id,
|
||||
assistant_id=body.assistant_id,
|
||||
metadata=body.metadata,
|
||||
)
|
||||
else:
|
||||
await run_ctx.thread_meta_repo.update_status(thread_id, "running")
|
||||
await run_ctx.thread_store.update_status(thread_id, "running")
|
||||
except Exception:
|
||||
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||
|
||||
@ -285,7 +285,7 @@ async def start_run(
|
||||
record.task = task
|
||||
|
||||
# Title sync is handled by worker.py's finally block which reads the
|
||||
# title from the checkpoint and calls thread_meta_repo.update_display_name
|
||||
# title from the checkpoint and calls thread_store.update_display_name
|
||||
# after the run completes.
|
||||
|
||||
return record
|
||||
|
||||
@ -16,7 +16,7 @@ class FeedbackRow(Base):
|
||||
feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
run_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
message_id: Mapped[str | None] = mapped_column(String(64))
|
||||
# message_id is an optional RunEventStore event identifier —
|
||||
# allows feedback to target a specific message or the entire run
|
||||
|
||||
@ -12,7 +12,7 @@ from sqlalchemy import case, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.feedback.model import FeedbackRow
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
||||
|
||||
|
||||
class FeedbackRepository:
|
||||
@ -33,19 +33,19 @@ class FeedbackRepository:
|
||||
run_id: str,
|
||||
thread_id: str,
|
||||
rating: int,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
message_id: str | None = None,
|
||||
comment: str | None = None,
|
||||
) -> dict:
|
||||
"""Create a feedback record. rating must be +1 or -1."""
|
||||
if rating not in (1, -1):
|
||||
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.create")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.create")
|
||||
row = FeedbackRow(
|
||||
feedback_id=str(uuid.uuid4()),
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
owner_id=resolved_owner_id,
|
||||
user_id=resolved_user_id,
|
||||
message_id=message_id,
|
||||
rating=rating,
|
||||
comment=comment,
|
||||
@ -61,14 +61,14 @@ class FeedbackRepository:
|
||||
self,
|
||||
feedback_id: str,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> dict | None:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.get")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.get")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(FeedbackRow, feedback_id)
|
||||
if row is None:
|
||||
return None
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
return None
|
||||
return self._row_to_dict(row)
|
||||
|
||||
@ -78,12 +78,12 @@ class FeedbackRepository:
|
||||
run_id: str,
|
||||
*,
|
||||
limit: int = 100,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> list[dict]:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_run")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_run")
|
||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id)
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(FeedbackRow.user_id == resolved_user_id)
|
||||
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
@ -94,12 +94,12 @@ class FeedbackRepository:
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 100,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> list[dict]:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_thread")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_thread")
|
||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id)
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(FeedbackRow.user_id == resolved_user_id)
|
||||
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
@ -109,14 +109,14 @@ class FeedbackRepository:
|
||||
self,
|
||||
feedback_id: str,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> bool:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.delete")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.delete")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(FeedbackRow, feedback_id)
|
||||
if row is None:
|
||||
return False
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
return False
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
|
||||
@ -19,7 +19,7 @@ class RunEventRow(Base):
|
||||
# Owner of the conversation this event belongs to. Nullable for data
|
||||
# created before auth was introduced; populated by auth middleware on
|
||||
# new writes and by the boot-time orphan migration on existing rows.
|
||||
owner_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
|
||||
event_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
category: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
# "message" | "trace" | "lifecycle"
|
||||
|
||||
@ -16,7 +16,7 @@ class RunRow(Base):
|
||||
run_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
assistant_id: Mapped[str | None] = mapped_column(String(128))
|
||||
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default="pending")
|
||||
# "pending" | "running" | "success" | "error" | "timeout" | "interrupted"
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.run.model import RunRow
|
||||
from deerflow.runtime.runs.store.base import RunStore
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
||||
|
||||
|
||||
class RunRepository(RunStore):
|
||||
@ -69,7 +69,7 @@ class RunRepository(RunStore):
|
||||
*,
|
||||
thread_id,
|
||||
assistant_id=None,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
status="pending",
|
||||
multitask_strategy="reject",
|
||||
metadata=None,
|
||||
@ -78,13 +78,13 @@ class RunRepository(RunStore):
|
||||
created_at=None,
|
||||
follow_up_to_run_id=None,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.put")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.put")
|
||||
now = datetime.now(UTC)
|
||||
row = RunRow(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
owner_id=resolved_owner_id,
|
||||
user_id=resolved_user_id,
|
||||
status=status,
|
||||
multitask_strategy=multitask_strategy,
|
||||
metadata_json=self._safe_json(metadata) or {},
|
||||
@ -102,14 +102,14 @@ class RunRepository(RunStore):
|
||||
self,
|
||||
run_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.get")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.get")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(RunRow, run_id)
|
||||
if row is None:
|
||||
return None
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
return None
|
||||
return self._row_to_dict(row)
|
||||
|
||||
@ -117,13 +117,13 @@ class RunRepository(RunStore):
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
limit=100,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.list_by_thread")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.list_by_thread")
|
||||
stmt = select(RunRow).where(RunRow.thread_id == thread_id)
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunRow.owner_id == resolved_owner_id)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(RunRow.user_id == resolved_user_id)
|
||||
stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
@ -141,14 +141,14 @@ class RunRepository(RunStore):
|
||||
self,
|
||||
run_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.delete")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.delete")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(RunRow, run_id)
|
||||
if row is None:
|
||||
return
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
return
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
|
||||
@ -15,7 +15,7 @@ class ThreadMetaRow(Base):
|
||||
|
||||
thread_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
assistant_id: Mapped[str | None] = mapped_column(String(128), index=True)
|
||||
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
display_name: Mapped[str | None] = mapped_column(String(256))
|
||||
status: Mapped[str] = mapped_column(String(20), default="idle")
|
||||
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
|
||||
@ -10,7 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
||||
|
||||
|
||||
class ThreadMetaRepository(ThreadMetaStore):
|
||||
@ -32,18 +32,18 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
thread_id: str,
|
||||
*,
|
||||
assistant_id: str | None = None,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
display_name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> dict:
|
||||
# Auto-resolve owner_id from contextvar when AUTO; explicit None
|
||||
# Auto-resolve user_id from contextvar when AUTO; explicit None
|
||||
# creates an orphan row (used by migration scripts).
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.create")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.create")
|
||||
now = datetime.now(UTC)
|
||||
row = ThreadMetaRow(
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
owner_id=resolved_owner_id,
|
||||
user_id=resolved_user_id,
|
||||
display_name=display_name,
|
||||
metadata_json=metadata or {},
|
||||
created_at=now,
|
||||
@ -59,40 +59,34 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> dict | None:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.get")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.get")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is None:
|
||||
return None
|
||||
# Enforce owner filter unless explicitly bypassed (owner_id=None).
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
# Enforce owner filter unless explicitly bypassed (user_id=None).
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
return None
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def list_by_owner(self, owner_id: str, *, limit: int = 100, offset: int = 0) -> list[dict]:
|
||||
stmt = select(ThreadMetaRow).where(ThreadMetaRow.owner_id == owner_id).order_by(ThreadMetaRow.updated_at.desc()).limit(limit).offset(offset)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def check_access(self, thread_id: str, owner_id: str, *, require_existing: bool = False) -> bool:
|
||||
"""Check if ``owner_id`` has access to ``thread_id``.
|
||||
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
|
||||
"""Check if ``user_id`` has access to ``thread_id``.
|
||||
|
||||
Two modes — one row, two distinct semantics depending on what
|
||||
the caller is about to do:
|
||||
|
||||
- ``require_existing=False`` (default, permissive):
|
||||
Returns True for: row missing (untracked legacy thread),
|
||||
``row.owner_id`` is None (shared / pre-auth data),
|
||||
or ``row.owner_id == owner_id``. Use for **read-style**
|
||||
``row.user_id`` is None (shared / pre-auth data),
|
||||
or ``row.user_id == user_id``. Use for **read-style**
|
||||
decorators where treating an untracked thread as accessible
|
||||
preserves backward-compat.
|
||||
|
||||
- ``require_existing=True`` (strict):
|
||||
Returns True **only** when the row exists AND
|
||||
(``row.owner_id == owner_id`` OR ``row.owner_id is None``).
|
||||
(``row.user_id == user_id`` OR ``row.user_id is None``).
|
||||
Use for **destructive / mutating** decorators (DELETE, PATCH,
|
||||
state-update) so a thread that has *already been deleted*
|
||||
cannot be re-targeted by any caller — closing the
|
||||
@ -103,9 +97,9 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is None:
|
||||
return not require_existing
|
||||
if row.owner_id is None:
|
||||
if row.user_id is None:
|
||||
return True
|
||||
return row.owner_id == owner_id
|
||||
return row.user_id == user_id
|
||||
|
||||
async def search(
|
||||
self,
|
||||
@ -114,17 +108,17 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> list[dict]:
|
||||
"""Search threads with optional metadata and status filters.
|
||||
|
||||
Owner filter is enforced by default: caller must be in a user
|
||||
context. Pass ``owner_id=None`` to bypass (migration/CLI).
|
||||
context. Pass ``user_id=None`` to bypass (migration/CLI).
|
||||
"""
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.search")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.search")
|
||||
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(ThreadMetaRow.owner_id == resolved_owner_id)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(ThreadMetaRow.user_id == resolved_user_id)
|
||||
if status:
|
||||
stmt = stmt.where(ThreadMetaRow.status == status)
|
||||
|
||||
@ -144,24 +138,24 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_owner_id: str | None) -> bool:
|
||||
async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_user_id: str | None) -> bool:
|
||||
"""Return True if the row exists and is owned (or filter bypassed)."""
|
||||
if resolved_owner_id is None:
|
||||
if resolved_user_id is None:
|
||||
return True # explicit bypass
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
return row is not None and row.owner_id == resolved_owner_id
|
||||
return row is not None and row.user_id == resolved_user_id
|
||||
|
||||
async def update_display_name(
|
||||
self,
|
||||
thread_id: str,
|
||||
display_name: str,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> None:
|
||||
"""Update the display_name (title) for a thread."""
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_display_name")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_display_name")
|
||||
async with self._sf() as session:
|
||||
if not await self._check_ownership(session, thread_id, resolved_owner_id):
|
||||
if not await self._check_ownership(session, thread_id, resolved_user_id):
|
||||
return
|
||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(display_name=display_name, updated_at=datetime.now(UTC)))
|
||||
await session.commit()
|
||||
@ -171,11 +165,11 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
thread_id: str,
|
||||
status: str,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> None:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_status")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_status")
|
||||
async with self._sf() as session:
|
||||
if not await self._check_ownership(session, thread_id, resolved_owner_id):
|
||||
if not await self._check_ownership(session, thread_id, resolved_user_id):
|
||||
return
|
||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
|
||||
await session.commit()
|
||||
@ -185,20 +179,20 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
thread_id: str,
|
||||
metadata: dict,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> None:
|
||||
"""Merge ``metadata`` into ``metadata_json``.
|
||||
|
||||
Read-modify-write inside a single session/transaction so concurrent
|
||||
callers see consistent state. No-op if the row does not exist or
|
||||
the owner_id check fails.
|
||||
the user_id check fails.
|
||||
"""
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_metadata")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_metadata")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is None:
|
||||
return
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
return
|
||||
merged = dict(row.metadata_json or {})
|
||||
merged.update(metadata)
|
||||
@ -210,14 +204,14 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> None:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.delete")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.delete")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is None:
|
||||
return
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
return
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
|
||||
@ -24,12 +24,12 @@ from collections.abc import AsyncIterator
|
||||
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.runtime.checkpointer.provider import (
|
||||
POSTGRES_CONN_REQUIRED,
|
||||
POSTGRES_INSTALL,
|
||||
SQLITE_INSTALL,
|
||||
)
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -15,7 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.models.run_event import RunEventRow
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_owner_id
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -55,8 +55,8 @@ class DbRunEventStore(RunEventStore):
|
||||
return content, metadata or {}
|
||||
|
||||
@staticmethod
|
||||
def _owner_from_context() -> str | None:
|
||||
"""Soft read of owner_id from contextvar for write paths.
|
||||
def _user_id_from_context() -> str | None:
|
||||
"""Soft read of user_id from contextvar for write paths.
|
||||
|
||||
Returns ``None`` (no filter / no stamp) if contextvar is unset,
|
||||
which is the expected case for background worker writes. HTTP
|
||||
@ -81,7 +81,7 @@ class DbRunEventStore(RunEventStore):
|
||||
metadata = {**(metadata or {}), "content_is_dict": True}
|
||||
else:
|
||||
db_content = content
|
||||
owner_id = self._owner_from_context()
|
||||
user_id = self._user_id_from_context()
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
# Use FOR UPDATE to serialize seq assignment within a thread.
|
||||
@ -92,7 +92,7 @@ class DbRunEventStore(RunEventStore):
|
||||
row = RunEventRow(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
owner_id=owner_id,
|
||||
user_id=user_id,
|
||||
event_type=event_type,
|
||||
category=category,
|
||||
content=db_content,
|
||||
@ -106,7 +106,7 @@ class DbRunEventStore(RunEventStore):
|
||||
async def put_batch(self, events):
|
||||
if not events:
|
||||
return []
|
||||
owner_id = self._owner_from_context()
|
||||
user_id = self._user_id_from_context()
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
# Get max seq for the thread (assume all events in batch belong to same thread).
|
||||
@ -130,7 +130,7 @@ class DbRunEventStore(RunEventStore):
|
||||
row = RunEventRow(
|
||||
thread_id=e["thread_id"],
|
||||
run_id=e["run_id"],
|
||||
owner_id=e.get("owner_id", owner_id),
|
||||
user_id=e.get("user_id", user_id),
|
||||
event_type=e["event_type"],
|
||||
category=category,
|
||||
content=db_content,
|
||||
@ -149,12 +149,12 @@ class DbRunEventStore(RunEventStore):
|
||||
limit=50,
|
||||
before_seq=None,
|
||||
after_seq=None,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages")
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||
if before_seq is not None:
|
||||
stmt = stmt.where(RunEventRow.seq < before_seq)
|
||||
if after_seq is not None:
|
||||
@ -181,12 +181,12 @@ class DbRunEventStore(RunEventStore):
|
||||
*,
|
||||
event_types=None,
|
||||
limit=500,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_events")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_events")
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||
if event_types:
|
||||
stmt = stmt.where(RunEventRow.event_type.in_(event_types))
|
||||
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
||||
@ -199,12 +199,12 @@ class DbRunEventStore(RunEventStore):
|
||||
thread_id,
|
||||
run_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages_by_run")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages_by_run")
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message")
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||
stmt = stmt.order_by(RunEventRow.seq.asc())
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
@ -214,12 +214,12 @@ class DbRunEventStore(RunEventStore):
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.count_messages")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.count_messages")
|
||||
stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||
async with self._sf() as session:
|
||||
return await session.scalar(stmt) or 0
|
||||
|
||||
@ -227,13 +227,13 @@ class DbRunEventStore(RunEventStore):
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_thread")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_thread")
|
||||
async with self._sf() as session:
|
||||
count_conditions = [RunEventRow.thread_id == thread_id]
|
||||
if resolved_owner_id is not None:
|
||||
count_conditions.append(RunEventRow.owner_id == resolved_owner_id)
|
||||
if resolved_user_id is not None:
|
||||
count_conditions.append(RunEventRow.user_id == resolved_user_id)
|
||||
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
||||
count = await session.scalar(count_stmt) or 0
|
||||
if count > 0:
|
||||
@ -246,13 +246,13 @@ class DbRunEventStore(RunEventStore):
|
||||
thread_id,
|
||||
run_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_run")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_run")
|
||||
async with self._sf() as session:
|
||||
count_conditions = [RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id]
|
||||
if resolved_owner_id is not None:
|
||||
count_conditions.append(RunEventRow.owner_id == resolved_owner_id)
|
||||
if resolved_user_id is not None:
|
||||
count_conditions.append(RunEventRow.user_id == resolved_user_id)
|
||||
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
||||
count = await session.scalar(count_stmt) or 0
|
||||
if count > 0:
|
||||
|
||||
@ -4,8 +4,8 @@ RunManager depends on this interface. Implementations:
|
||||
- MemoryRunStore: in-memory dict (development, tests)
|
||||
- Future: RunRepository backed by SQLAlchemy ORM
|
||||
|
||||
All methods accept an optional owner_id for user isolation.
|
||||
When owner_id is None, no user filtering is applied (single-user mode).
|
||||
All methods accept an optional user_id for user isolation.
|
||||
When user_id is None, no user filtering is applied (single-user mode).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -22,7 +22,7 @@ class RunStore(abc.ABC):
|
||||
*,
|
||||
thread_id: str,
|
||||
assistant_id: str | None = None,
|
||||
owner_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
status: str = "pending",
|
||||
multitask_strategy: str = "reject",
|
||||
metadata: dict[str, Any] | None = None,
|
||||
@ -42,7 +42,7 @@ class RunStore(abc.ABC):
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
owner_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[dict[str, Any]]:
|
||||
pass
|
||||
|
||||
@ -21,7 +21,7 @@ class MemoryRunStore(RunStore):
|
||||
*,
|
||||
thread_id,
|
||||
assistant_id=None,
|
||||
owner_id=None,
|
||||
user_id=None,
|
||||
status="pending",
|
||||
multitask_strategy="reject",
|
||||
metadata=None,
|
||||
@ -35,7 +35,7 @@ class MemoryRunStore(RunStore):
|
||||
"run_id": run_id,
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant_id,
|
||||
"owner_id": owner_id,
|
||||
"user_id": user_id,
|
||||
"status": status,
|
||||
"multitask_strategy": multitask_strategy,
|
||||
"metadata": metadata or {},
|
||||
@ -49,8 +49,8 @@ class MemoryRunStore(RunStore):
|
||||
async def get(self, run_id):
|
||||
return self._runs.get(run_id)
|
||||
|
||||
async def list_by_thread(self, thread_id, *, owner_id=None, limit=100):
|
||||
results = [r for r in self._runs.values() if r["thread_id"] == thread_id and (owner_id is None or r.get("owner_id") == owner_id)]
|
||||
async def list_by_thread(self, thread_id, *, user_id=None, limit=100):
|
||||
results = [r for r in self._runs.values() if r["thread_id"] == thread_id and (user_id is None or r.get("user_id") == user_id)]
|
||||
results.sort(key=lambda r: r["created_at"], reverse=True)
|
||||
return results[:limit]
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@ class RunContext:
|
||||
store: Any | None = field(default=None)
|
||||
event_store: Any | None = field(default=None)
|
||||
run_events_config: Any | None = field(default=None)
|
||||
thread_meta_repo: Any | None = field(default=None)
|
||||
thread_store: Any | None = field(default=None)
|
||||
follow_up_to_run_id: str | None = field(default=None)
|
||||
|
||||
|
||||
@ -75,7 +75,7 @@ async def run_agent(
|
||||
store = ctx.store
|
||||
event_store = ctx.event_store
|
||||
run_events_config = ctx.run_events_config
|
||||
thread_meta_repo = ctx.thread_meta_repo
|
||||
thread_store = ctx.thread_store
|
||||
follow_up_to_run_id = ctx.follow_up_to_run_id
|
||||
|
||||
run_id = record.run_id
|
||||
@ -376,14 +376,14 @@ async def run_agent(
|
||||
ckpt = getattr(ckpt_tuple, "checkpoint", {}) or {}
|
||||
title = ckpt.get("channel_values", {}).get("title")
|
||||
if title:
|
||||
await thread_meta_repo.update_display_name(thread_id, title)
|
||||
await thread_store.update_display_name(thread_id, title)
|
||||
except Exception:
|
||||
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id)
|
||||
|
||||
# Update threads_meta status based on run outcome
|
||||
try:
|
||||
final_status = "idle" if record.status == RunStatus.success else record.status.value
|
||||
await thread_meta_repo.update_status(thread_id, final_status)
|
||||
await thread_store.update_status(thread_id, final_status)
|
||||
except Exception:
|
||||
logger.debug("Failed to update thread_meta status for %s (non-fatal)", thread_id)
|
||||
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
"""Request-scoped user context for owner-based authorization.
|
||||
"""Request-scoped user context for user-based authorization.
|
||||
|
||||
This module holds a :class:`~contextvars.ContextVar` that the gateway's
|
||||
auth middleware sets after a successful authentication. Repository
|
||||
methods read the contextvar via a sentinel default parameter, letting
|
||||
routers stay free of ``owner_id`` boilerplate.
|
||||
routers stay free of ``user_id`` boilerplate.
|
||||
|
||||
Three-state semantics for the repository ``owner_id`` parameter (the
|
||||
Three-state semantics for the repository ``user_id`` parameter (the
|
||||
consumer side of this module lives in ``deerflow.persistence.*``):
|
||||
|
||||
- ``_AUTO`` (module-private sentinel, default): read from contextvar;
|
||||
@ -91,16 +91,16 @@ def require_current_user() -> CurrentUser:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sentinel-based owner_id resolution
|
||||
# Sentinel-based user_id resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Repository methods accept an ``owner_id`` keyword-only argument that
|
||||
# Repository methods accept a ``user_id`` keyword-only argument that
|
||||
# defaults to ``AUTO``. The three possible values drive distinct
|
||||
# behaviours; see the docstring on :func:`resolve_owner_id`.
|
||||
# behaviours; see the docstring on :func:`resolve_user_id`.
|
||||
|
||||
|
||||
class _AutoSentinel:
|
||||
"""Singleton marker meaning 'resolve owner_id from contextvar'."""
|
||||
"""Singleton marker meaning 'resolve user_id from contextvar'."""
|
||||
|
||||
_instance: _AutoSentinel | None = None
|
||||
|
||||
@ -116,12 +116,12 @@ class _AutoSentinel:
|
||||
AUTO: Final[_AutoSentinel] = _AutoSentinel()
|
||||
|
||||
|
||||
def resolve_owner_id(
|
||||
def resolve_user_id(
|
||||
value: str | None | _AutoSentinel,
|
||||
*,
|
||||
method_name: str = "repository method",
|
||||
) -> str | None:
|
||||
"""Resolve the owner_id parameter passed to a repository method.
|
||||
"""Resolve the user_id parameter passed to a repository method.
|
||||
|
||||
Three-state semantics:
|
||||
|
||||
@ -131,16 +131,16 @@ def resolve_owner_id(
|
||||
- Explicit ``str``: use the provided id verbatim, overriding any
|
||||
contextvar value. Useful for tests and admin-override flows.
|
||||
- Explicit ``None``: no filter — the repository should skip the
|
||||
owner_id WHERE clause entirely. Reserved for migration scripts
|
||||
user_id WHERE clause entirely. Reserved for migration scripts
|
||||
and CLI tools that intentionally bypass isolation.
|
||||
"""
|
||||
if isinstance(value, _AutoSentinel):
|
||||
user = _current_user.get()
|
||||
if user is None:
|
||||
raise RuntimeError(f"{method_name} called with owner_id=AUTO but no user context is set; pass an explicit owner_id, set the contextvar via auth middleware, or opt out with owner_id=None for migration/CLI paths.")
|
||||
raise RuntimeError(f"{method_name} called with user_id=AUTO but no user context is set; pass an explicit user_id, set the contextvar via auth middleware, or opt out with user_id=None for migration/CLI paths.")
|
||||
# Coerce to ``str`` at the boundary: ``User.id`` is typed as
|
||||
# ``UUID`` for the API surface, but the persistence layer
|
||||
# stores ``owner_id`` as ``String(64)`` and aiosqlite cannot
|
||||
# stores ``user_id`` as ``String(64)`` and aiosqlite cannot
|
||||
# bind a raw UUID object to a VARCHAR column ("type 'UUID' is
|
||||
# not supported"). Honour the documented return type here
|
||||
# rather than ripple a type change through every caller.
|
||||
|
||||
@ -3,16 +3,16 @@
|
||||
The production gateway runs ``AuthMiddleware`` (validates the JWT cookie)
|
||||
ahead of every router, plus ``@require_permission(owner_check=True)``
|
||||
decorators that read ``request.state.auth`` and call
|
||||
``thread_meta_repo.check_access``. Router-level unit tests construct
|
||||
``thread_store.check_access``. Router-level unit tests construct
|
||||
**bare** FastAPI apps that include only one router — they have neither
|
||||
the auth middleware nor a real thread_meta_repo, so the decorators raise
|
||||
the auth middleware nor a real thread_store, so the decorators raise
|
||||
401 (TestClient path) or ValueError (direct-call path).
|
||||
|
||||
This module provides two surfaces:
|
||||
|
||||
1. :func:`make_authed_test_app` — wraps ``FastAPI()`` with a tiny
|
||||
``BaseHTTPMiddleware`` that stamps a fake user / AuthContext on every
|
||||
request, plus a permissive ``thread_meta_repo`` mock on
|
||||
request, plus a permissive ``thread_store`` mock on
|
||||
``app.state``. Use from TestClient-based router tests.
|
||||
|
||||
2. :func:`call_unwrapped` — invokes the underlying function bypassing
|
||||
@ -86,20 +86,20 @@ def make_authed_test_app(
|
||||
user_factory: Callable[[], User] | None = None,
|
||||
owner_check_passes: bool = True,
|
||||
) -> FastAPI:
|
||||
"""Build a FastAPI test app with stub auth + permissive thread_meta_repo.
|
||||
"""Build a FastAPI test app with stub auth + permissive thread_store.
|
||||
|
||||
Args:
|
||||
user_factory: Override the default test user. Must return a fully
|
||||
populated :class:`User`. Useful for cross-user isolation tests
|
||||
that need a stable id across requests.
|
||||
owner_check_passes: When True (default), ``thread_meta_repo.check_access``
|
||||
owner_check_passes: When True (default), ``thread_store.check_access``
|
||||
returns True for every call so ``@require_permission(owner_check=True)``
|
||||
never blocks the route under test. Pass False to verify that
|
||||
permission failures surface correctly.
|
||||
|
||||
Returns:
|
||||
A ``FastAPI`` app with the stub middleware installed and
|
||||
``app.state.thread_meta_repo`` set to a permissive mock. The
|
||||
``app.state.thread_store`` set to a permissive mock. The
|
||||
caller is still responsible for ``app.include_router(...)``.
|
||||
"""
|
||||
factory = user_factory or _make_stub_user
|
||||
@ -108,7 +108,7 @@ def make_authed_test_app(
|
||||
|
||||
repo = MagicMock()
|
||||
repo.check_access = AsyncMock(return_value=owner_check_passes)
|
||||
app.state.thread_meta_repo = repo
|
||||
app.state.thread_store = repo
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@ -60,7 +60,7 @@ def provisioner_module():
|
||||
# Auto-set user context for every test unless marked no_auto_user
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Repository methods read ``owner_id`` from a contextvar by default
|
||||
# Repository methods read ``user_id`` from a contextvar by default
|
||||
# (see ``deerflow.runtime.user_context``). Without this fixture, every
|
||||
# pre-existing persistence test would raise RuntimeError because the
|
||||
# contextvar is unset. The fixture sets a default test user on every
|
||||
|
||||
@ -6,13 +6,13 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
import deerflow.config.app_config as app_config_module
|
||||
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
|
||||
from deerflow.config.checkpointer_config import (
|
||||
CheckpointerConfig,
|
||||
get_checkpointer_config,
|
||||
load_checkpointer_config_from_dict,
|
||||
set_checkpointer_config,
|
||||
)
|
||||
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
||||
@ -199,12 +199,12 @@ def test_migration_failure_is_non_fatal():
|
||||
# ── Section 5.1-5.6 upgrade path: orphan thread migration ────────────────
|
||||
|
||||
|
||||
def test_migrate_orphaned_threads_stamps_owner_id_on_unowned_rows():
|
||||
def test_migrate_orphaned_threads_stamps_user_id_on_unowned_rows():
|
||||
"""First boot finds Store-only legacy threads → stamps admin's id.
|
||||
|
||||
Validates the **TC-UPG-02 upgrade story**: an operator running main
|
||||
(no auth) accumulates threads in the LangGraph Store namespace
|
||||
``("threads",)`` with no ``metadata.owner_id``. After upgrading to
|
||||
``("threads",)`` with no ``metadata.user_id``. After upgrading to
|
||||
feat/auth-on-2.0-rc, the first ``_ensure_admin_user`` boot should
|
||||
rewrite each unowned item with the freshly created admin's id.
|
||||
"""
|
||||
@ -215,7 +215,7 @@ def test_migrate_orphaned_threads_stamps_owner_id_on_unowned_rows():
|
||||
SimpleNamespace(key="t1", value={"metadata": {"title": "old-thread-1"}}),
|
||||
SimpleNamespace(key="t2", value={"metadata": {"title": "old-thread-2"}}),
|
||||
SimpleNamespace(key="t3", value={"metadata": {}}),
|
||||
SimpleNamespace(key="t4", value={"metadata": {"owner_id": "someone-else", "title": "preserved"}}),
|
||||
SimpleNamespace(key="t4", value={"metadata": {"user_id": "someone-else", "title": "preserved"}}),
|
||||
]
|
||||
store = AsyncMock()
|
||||
# asearch returns the entire batch on first call, then an empty page
|
||||
@ -235,11 +235,11 @@ def test_migrate_orphaned_threads_stamps_owner_id_on_unowned_rows():
|
||||
assert len(aput_calls) == 3
|
||||
rewritten_keys = {call[1] for call in aput_calls}
|
||||
assert rewritten_keys == {"t1", "t2", "t3"}
|
||||
# Each rewrite carries the new owner_id; titles preserved where present.
|
||||
# Each rewrite carries the new user_id; titles preserved where present.
|
||||
by_key = {call[1]: call[2] for call in aput_calls}
|
||||
assert by_key["t1"]["metadata"]["owner_id"] == "admin-id-42"
|
||||
assert by_key["t1"]["metadata"]["user_id"] == "admin-id-42"
|
||||
assert by_key["t1"]["metadata"]["title"] == "old-thread-1"
|
||||
assert by_key["t3"]["metadata"]["owner_id"] == "admin-id-42"
|
||||
assert by_key["t3"]["metadata"]["user_id"] == "admin-id-42"
|
||||
# The pre-owned item must NOT have been rewritten.
|
||||
assert "t4" not in rewritten_keys
|
||||
|
||||
|
||||
@ -60,8 +60,8 @@ class TestFeedbackRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_owner(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
record = await repo.create(run_id="r1", thread_id="t1", rating=1, owner_id="user-1")
|
||||
assert record["owner_id"] == "user-1"
|
||||
record = await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
|
||||
assert record["user_id"] == "user-1"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
|
||||
@ -175,46 +175,46 @@ def _make_ctx(user_id):
|
||||
def test_filter_injects_user_id():
|
||||
value = {}
|
||||
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
||||
assert value["metadata"]["owner_id"] == "user-a"
|
||||
assert value["metadata"]["user_id"] == "user-a"
|
||||
|
||||
|
||||
def test_filter_preserves_existing_metadata():
|
||||
value = {"metadata": {"title": "hello"}}
|
||||
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
||||
assert value["metadata"]["owner_id"] == "user-a"
|
||||
assert value["metadata"]["user_id"] == "user-a"
|
||||
assert value["metadata"]["title"] == "hello"
|
||||
|
||||
|
||||
def test_filter_returns_user_id_dict():
|
||||
result = asyncio.run(add_owner_filter(_make_ctx("user-x"), {}))
|
||||
assert result == {"owner_id": "user-x"}
|
||||
assert result == {"user_id": "user-x"}
|
||||
|
||||
|
||||
def test_filter_read_write_consistency():
|
||||
value = {}
|
||||
filter_dict = asyncio.run(add_owner_filter(_make_ctx("user-1"), value))
|
||||
assert value["metadata"]["owner_id"] == filter_dict["owner_id"]
|
||||
assert value["metadata"]["user_id"] == filter_dict["user_id"]
|
||||
|
||||
|
||||
def test_different_users_different_filters():
|
||||
f_a = asyncio.run(add_owner_filter(_make_ctx("a"), {}))
|
||||
f_b = asyncio.run(add_owner_filter(_make_ctx("b"), {}))
|
||||
assert f_a["owner_id"] != f_b["owner_id"]
|
||||
assert f_a["user_id"] != f_b["user_id"]
|
||||
|
||||
|
||||
def test_filter_overrides_conflicting_user_id():
|
||||
"""If value already has a different user_id in metadata, it gets overwritten."""
|
||||
value = {"metadata": {"owner_id": "attacker"}}
|
||||
value = {"metadata": {"user_id": "attacker"}}
|
||||
asyncio.run(add_owner_filter(_make_ctx("real-owner"), value))
|
||||
assert value["metadata"]["owner_id"] == "real-owner"
|
||||
assert value["metadata"]["user_id"] == "real-owner"
|
||||
|
||||
|
||||
def test_filter_with_empty_metadata():
|
||||
"""Explicit empty metadata dict is fine."""
|
||||
value = {"metadata": {}}
|
||||
result = asyncio.run(add_owner_filter(_make_ctx("user-z"), value))
|
||||
assert value["metadata"]["owner_id"] == "user-z"
|
||||
assert result == {"owner_id": "user-z"}
|
||||
assert value["metadata"]["user_id"] == "user-z"
|
||||
assert result == {"user_id": "user-z"}
|
||||
|
||||
|
||||
# ── Gateway parity ───────────────────────────────────────────────────────
|
||||
|
||||
@ -9,8 +9,8 @@ These tests bypass the HTTP layer and exercise the storage-layer
|
||||
owner filter directly by switching the ``user_context`` contextvar
|
||||
between two users. The safety property under test is:
|
||||
|
||||
After a repository write with owner_id=A, a subsequent read with
|
||||
owner_id=B must not return the row, and vice versa.
|
||||
After a repository write with user_id=A, a subsequent read with
|
||||
user_id=B must not return the row, and vice versa.
|
||||
|
||||
The HTTP layer is covered by test_auth_middleware.py, which proves
|
||||
that a request cookie reaches the ``set_current_user`` call. Together
|
||||
@ -431,13 +431,13 @@ async def test_repository_without_context_raises(tmp_path):
|
||||
await cleanup()
|
||||
|
||||
|
||||
# ── Escape hatch: explicit owner_id=None bypasses filter (for migration) ──
|
||||
# ── Escape hatch: explicit user_id=None bypasses filter (for migration) ──
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_explicit_none_bypasses_filter(tmp_path):
|
||||
"""Migration scripts pass owner_id=None to see all rows regardless of owner."""
|
||||
"""Migration scripts pass user_id=None to see all rows regardless of owner."""
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
||||
|
||||
@ -452,14 +452,14 @@ async def test_explicit_none_bypasses_filter(tmp_path):
|
||||
await repo.create("t-beta")
|
||||
|
||||
# Migration-style read: no contextvar, explicit None bypass.
|
||||
all_rows = await repo.search(owner_id=None)
|
||||
all_rows = await repo.search(user_id=None)
|
||||
thread_ids = {r["thread_id"] for r in all_rows}
|
||||
assert thread_ids == {"t-alpha", "t-beta"}
|
||||
|
||||
# Explicit get with None does not apply the filter either.
|
||||
row_a = await repo.get("t-alpha", owner_id=None)
|
||||
row_a = await repo.get("t-alpha", user_id=None)
|
||||
assert row_a is not None
|
||||
row_b = await repo.get("t-beta", owner_id=None)
|
||||
row_b = await repo.get("t-beta", user_id=None)
|
||||
assert row_b is not None
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
Tests:
|
||||
1. DatabaseConfig property derivation (paths, URLs)
|
||||
2. MemoryRunStore CRUD + owner_id filtering
|
||||
2. MemoryRunStore CRUD + user_id filtering
|
||||
3. Base.to_dict() via inspect mixin
|
||||
4. Engine init/close lifecycle (memory + SQLite)
|
||||
5. Postgres missing-dep error message
|
||||
@ -106,17 +106,17 @@ class TestMemoryRunStore:
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_owner_filter(self, store):
|
||||
await store.put("r1", thread_id="t1", owner_id="alice")
|
||||
await store.put("r2", thread_id="t1", owner_id="bob")
|
||||
rows = await store.list_by_thread("t1", owner_id="alice")
|
||||
await store.put("r1", thread_id="t1", user_id="alice")
|
||||
await store.put("r2", thread_id="t1", user_id="bob")
|
||||
rows = await store.list_by_thread("t1", user_id="alice")
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["owner_id"] == "alice"
|
||||
assert rows[0]["user_id"] == "alice"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_owner_none_returns_all(self, store):
|
||||
await store.put("r1", thread_id="t1", owner_id="alice")
|
||||
await store.put("r2", thread_id="t1", owner_id="bob")
|
||||
rows = await store.list_by_thread("t1", owner_id=None)
|
||||
await store.put("r1", thread_id="t1", user_id="alice")
|
||||
await store.put("r2", thread_id="t1", user_id="bob")
|
||||
rows = await store.list_by_thread("t1", user_id=None)
|
||||
assert len(rows) == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
|
||||
@ -73,11 +73,11 @@ class TestRunRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_owner_filter(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", owner_id="alice")
|
||||
await repo.put("r2", thread_id="t1", owner_id="bob")
|
||||
rows = await repo.list_by_thread("t1", owner_id="alice")
|
||||
await repo.put("r1", thread_id="t1", user_id="alice")
|
||||
await repo.put("r2", thread_id="t1", user_id="bob")
|
||||
rows = await repo.list_by_thread("t1", user_id="alice")
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["owner_id"] == "alice"
|
||||
assert rows[0]["user_id"] == "alice"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
@ -189,8 +189,8 @@ class TestRunRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_owner_none_returns_all(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", owner_id="alice")
|
||||
await repo.put("r2", thread_id="t1", owner_id="bob")
|
||||
rows = await repo.list_by_thread("t1", owner_id=None)
|
||||
await repo.put("r1", thread_id="t1", user_id="alice")
|
||||
await repo.put("r2", thread_id="t1", user_id="bob")
|
||||
rows = await repo.list_by_thread("t1", user_id=None)
|
||||
assert len(rows) == 2
|
||||
await _cleanup()
|
||||
|
||||
@ -47,7 +47,7 @@ def test_generate_suggestions_parses_and_limits(monkeypatch):
|
||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
# Bypass the require_permission decorator (which needs request +
|
||||
# thread_meta_repo) — these tests cover the parsing logic.
|
||||
# thread_store) — these tests cover the parsing logic.
|
||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||
|
||||
assert result.suggestions == ["Q1", "Q2", "Q3"]
|
||||
@ -67,7 +67,7 @@ def test_generate_suggestions_parses_list_block_content(monkeypatch):
|
||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
# Bypass the require_permission decorator (which needs request +
|
||||
# thread_meta_repo) — these tests cover the parsing logic.
|
||||
# thread_store) — these tests cover the parsing logic.
|
||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||
|
||||
assert result.suggestions == ["Q1", "Q2"]
|
||||
@ -87,7 +87,7 @@ def test_generate_suggestions_parses_output_text_block_content(monkeypatch):
|
||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
# Bypass the require_permission decorator (which needs request +
|
||||
# thread_meta_repo) — these tests cover the parsing logic.
|
||||
# thread_store) — these tests cover the parsing logic.
|
||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||
|
||||
assert result.suggestions == ["Q1", "Q2"]
|
||||
@ -104,7 +104,7 @@ def test_generate_suggestions_returns_empty_on_model_error(monkeypatch):
|
||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
# Bypass the require_permission decorator (which needs request +
|
||||
# thread_meta_repo) — these tests cover the parsing logic.
|
||||
# thread_store) — these tests cover the parsing logic.
|
||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||
|
||||
assert result.suggestions == []
|
||||
|
||||
@ -43,8 +43,8 @@ class TestThreadMetaRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_owner_and_display_name(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
record = await repo.create("t1", owner_id="user1", display_name="My Thread")
|
||||
assert record["owner_id"] == "user1"
|
||||
record = await repo.create("t1", user_id="user1", display_name="My Thread")
|
||||
assert record["user_id"] == "user1"
|
||||
assert record["display_name"] == "My Thread"
|
||||
await _cleanup()
|
||||
|
||||
@ -61,26 +61,6 @@ class TestThreadMetaRepository:
|
||||
assert await repo.get("nonexistent") is None
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_owner(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1", owner_id="user1")
|
||||
await repo.create("t2", owner_id="user1")
|
||||
await repo.create("t3", owner_id="user2")
|
||||
results = await repo.list_by_owner("user1")
|
||||
assert len(results) == 2
|
||||
assert all(r["owner_id"] == "user1" for r in results)
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_owner_with_limit_and_offset(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
for i in range(5):
|
||||
await repo.create(f"t{i}", owner_id="user1")
|
||||
results = await repo.list_by_owner("user1", limit=2, offset=1)
|
||||
assert len(results) == 2
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_no_record_allows(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
@ -90,23 +70,23 @@ class TestThreadMetaRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_owner_matches(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1", owner_id="user1")
|
||||
await repo.create("t1", user_id="user1")
|
||||
assert await repo.check_access("t1", "user1") is True
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_owner_mismatch(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1", owner_id="user1")
|
||||
await repo.create("t1", user_id="user1")
|
||||
assert await repo.check_access("t1", "user2") is False
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_no_owner_allows_all(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
# Explicit owner_id=None to bypass the new AUTO default that
|
||||
# Explicit user_id=None to bypass the new AUTO default that
|
||||
# would otherwise pick up the test user from the autouse fixture.
|
||||
await repo.create("t1", owner_id=None)
|
||||
await repo.create("t1", user_id=None)
|
||||
assert await repo.check_access("t1", "anyone") is True
|
||||
await _cleanup()
|
||||
|
||||
@ -125,27 +105,27 @@ class TestThreadMetaRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_strict_owner_match_allowed(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1", owner_id="user1")
|
||||
await repo.create("t1", user_id="user1")
|
||||
assert await repo.check_access("t1", "user1", require_existing=True) is True
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_strict_owner_mismatch_denied(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1", owner_id="user1")
|
||||
await repo.create("t1", user_id="user1")
|
||||
assert await repo.check_access("t1", "user2", require_existing=True) is False
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_strict_null_owner_still_allowed(self, tmp_path):
|
||||
"""Even in strict mode, a row with NULL owner_id stays shared.
|
||||
"""Even in strict mode, a row with NULL user_id stays shared.
|
||||
|
||||
The strict flag tightens the *missing row* case, not the *shared
|
||||
row* case — legacy pre-auth rows that survived a clean migration
|
||||
without an owner are still everyone's.
|
||||
"""
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1", owner_id=None)
|
||||
await repo.create("t1", user_id=None)
|
||||
assert await repo.check_access("t1", "anyone", require_existing=True) is True
|
||||
await _cleanup()
|
||||
|
||||
|
||||
@ -113,14 +113,8 @@ def test_delete_thread_data_returns_generic_500_error(tmp_path):
|
||||
# ── Server-reserved metadata key stripping ──────────────────────────────────
|
||||
|
||||
|
||||
def test_strip_reserved_metadata_removes_owner_id():
|
||||
"""Client-supplied owner_id is dropped to prevent reflection attacks."""
|
||||
out = threads._strip_reserved_metadata({"owner_id": "victim-id", "title": "ok"})
|
||||
assert out == {"title": "ok"}
|
||||
|
||||
|
||||
def test_strip_reserved_metadata_removes_user_id():
|
||||
"""user_id is also reserved (defense in depth for any future use)."""
|
||||
"""Client-supplied user_id is dropped to prevent reflection attacks."""
|
||||
out = threads._strip_reserved_metadata({"user_id": "victim-id", "title": "ok"})
|
||||
assert out == {"title": "ok"}
|
||||
|
||||
@ -136,6 +130,6 @@ def test_strip_reserved_metadata_empty_input():
|
||||
assert threads._strip_reserved_metadata({}) == {}
|
||||
|
||||
|
||||
def test_strip_reserved_metadata_strips_both_simultaneously():
|
||||
out = threads._strip_reserved_metadata({"owner_id": "x", "user_id": "y", "keep": "me"})
|
||||
def test_strip_reserved_metadata_strips_all_reserved_keys():
|
||||
out = threads._strip_reserved_metadata({"user_id": "x", "keep": "me"})
|
||||
assert out == {"keep": "me"}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user