feat(auth): enforce owner_id across 2.0-rc persistence layer

Add request-scoped contextvar-based owner filtering to threads_meta,
runs, run_events, and feedback repositories. Router code is unchanged
— isolation is enforced at the storage layer so that any caller that
forgets to pass owner_id still gets filtered results, and new routes
cannot accidentally leak data.

Core infrastructure
-------------------
- deerflow/runtime/user_context.py (new):
  - ContextVar[CurrentUser | None] with default None
  - runtime_checkable CurrentUser Protocol (structural subtype with .id)
  - set/reset/get/require helpers
  - AUTO sentinel + resolve_owner_id(value, method_name) for sentinel
    three-state resolution: AUTO reads contextvar, explicit str
    overrides, explicit None bypasses the filter (for migration/CLI)

Repository changes
------------------
- ThreadMetaRepository: create/get/search/update_*/delete gain
  owner_id=AUTO kwarg; read paths filter by owner, writes stamp it,
  mutations check ownership before applying
- RunRepository: put/get/list_by_thread/delete gain owner_id=AUTO kwarg
- FeedbackRepository: create/get/list_by_run/list_by_thread/delete
  gain owner_id=AUTO kwarg
- DbRunEventStore: list_messages/list_events/list_messages_by_run/
  count_messages/delete_by_thread/delete_by_run gain owner_id=AUTO
  kwarg. Write paths (put/put_batch) read contextvar softly: when a
  request-scoped user is available, owner_id is stamped; background
  worker writes without a user context pass None which is valid
  (orphan row to be bound by migration)

Schema
------
- persistence/models/run_event.py: RunEventRow.owner_id = Mapped[
  str | None] = mapped_column(String(64), nullable=True, index=True)
- No alembic migration needed: 2.0 ships fresh, Base.metadata.create_all
  picks up the new column automatically

Middleware
----------
- auth_middleware.py: after cookie check, call get_optional_user_from_
  request to load the real User, stamp it into request.state.user AND
  the contextvar via set_current_user, reset in a try/finally. Public
  paths and unauthenticated requests continue without contextvar, and
  @require_auth handles the strict 401 path

Test infrastructure
-------------------
- tests/conftest.py: @pytest.fixture(autouse=True) _auto_user_context
  sets a default SimpleNamespace(id="test-user-autouse") on every test
  unless marked @pytest.mark.no_auto_user. Keeps existing 20+
  persistence tests passing without modification
- pyproject.toml [tool.pytest.ini_options]: register no_auto_user
  marker so pytest does not emit warnings for opt-out tests
- tests/test_user_context.py: 6 tests covering three-state semantics,
  Protocol duck typing, and require/optional APIs
- tests/test_thread_meta_repo.py: one test updated to pass owner_id=
  None explicitly where it was previously relying on the old default

Test results
------------
- test_user_context.py: 6 passed
- test_auth*.py + test_langgraph_auth.py + test_ensure_admin.py: 127
- test_run_event_store / test_run_repository / test_thread_meta_repo
  / test_feedback: 92 passed
- Full backend suite: 1905 passed, 2 failed (both @requires_llm flaky
  integration tests unrelated to auth), 1 skipped
This commit is contained in:
greatmengqi 2026-04-08 11:08:48 +08:00
parent 2531cce0d1
commit 4b139fb689
11 changed files with 566 additions and 58 deletions

View File

@ -1,6 +1,11 @@
"""Global authentication middleware — fail-closed safety net.
Rejects unauthenticated requests to non-public paths with 401.
Rejects unauthenticated requests to non-public paths with 401. When a
request passes the cookie check, resolves the JWT payload to a real
``User`` object and stamps it into both ``request.state.user`` and the
``deerflow.runtime.user_context`` contextvar so that repository-layer
owner filtering works automatically via the sentinel pattern.
Fine-grained permission checks remain in authz.py decorators.
"""
@ -12,6 +17,7 @@ from starlette.responses import JSONResponse
from starlette.types import ASGIApp
from app.gateway.auth.errors import AuthErrorCode
from deerflow.runtime.user_context import reset_current_user, set_current_user
# Paths that never require authentication.
_PUBLIC_PATH_PREFIXES: tuple[str, ...] = (
@ -68,4 +74,22 @@ class AuthMiddleware(BaseHTTPMiddleware):
},
)
return await call_next(request)
# Resolve the full user now so repository-layer owner filters
# can read from the contextvar. We use the "optional" flavour so
# middleware never raises on bad tokens — the cookie-presence
# check above plus the @require_auth decorator provide the
# strict gates. A stale/invalid token yields user=None here;
# the request continues without a contextvar, and any protected
# endpoint will still be rejected by @require_auth.
from app.gateway.deps import get_optional_user_from_request
user = await get_optional_user_from_request(request)
if user is None:
return await call_next(request)
request.state.user = user
token = set_current_user(user)
try:
return await call_next(request)
finally:
reset_current_user(token)

View File

@ -12,6 +12,7 @@ from sqlalchemy import case, func, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.feedback.model import FeedbackRow
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
class FeedbackRepository:
@ -32,18 +33,19 @@ class FeedbackRepository:
run_id: str,
thread_id: str,
rating: int,
owner_id: str | None = None,
owner_id: "str | None | _AutoSentinel" = AUTO,
message_id: str | None = None,
comment: str | None = None,
) -> dict:
"""Create a feedback record. rating must be +1 or -1."""
if rating not in (1, -1):
raise ValueError(f"rating must be +1 or -1, got {rating}")
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.create")
row = FeedbackRow(
feedback_id=str(uuid.uuid4()),
run_id=run_id,
thread_id=thread_id,
owner_id=owner_id,
owner_id=resolved_owner_id,
message_id=message_id,
rating=rating,
comment=comment,
@ -55,27 +57,66 @@ class FeedbackRepository:
await session.refresh(row)
return self._row_to_dict(row)
async def get(self, feedback_id: str) -> dict | None:
async with self._sf() as session:
row = await session.get(FeedbackRow, feedback_id)
return self._row_to_dict(row) if row else None
async def list_by_run(self, thread_id: str, run_id: str, *, limit: int = 100) -> list[dict]:
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id).order_by(FeedbackRow.created_at.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def list_by_thread(self, thread_id: str, *, limit: int = 100) -> list[dict]:
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id).order_by(FeedbackRow.created_at.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def delete(self, feedback_id: str) -> bool:
async def get(
self,
feedback_id: str,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
) -> dict | None:
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.get")
async with self._sf() as session:
row = await session.get(FeedbackRow, feedback_id)
if row is None:
return None
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
return None
return self._row_to_dict(row)
async def list_by_run(
self,
thread_id: str,
run_id: str,
*,
limit: int = 100,
owner_id: "str | None | _AutoSentinel" = AUTO,
) -> list[dict]:
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_run")
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id)
if resolved_owner_id is not None:
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def list_by_thread(
self,
thread_id: str,
*,
limit: int = 100,
owner_id: "str | None | _AutoSentinel" = AUTO,
) -> list[dict]:
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_thread")
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id)
if resolved_owner_id is not None:
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def delete(
self,
feedback_id: str,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
) -> bool:
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.delete")
async with self._sf() as session:
row = await session.get(FeedbackRow, feedback_id)
if row is None:
return False
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
return False
await session.delete(row)
await session.commit()

View File

@ -16,6 +16,10 @@ class RunEventRow(Base):
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
thread_id: Mapped[str] = mapped_column(String(64), nullable=False)
run_id: Mapped[str] = mapped_column(String(64), nullable=False)
# Owner of the conversation this event belongs to. Nullable for data
# created before auth was introduced; populated by auth middleware on
# new writes and by the boot-time orphan migration on existing rows.
owner_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
event_type: Mapped[str] = mapped_column(String(32), nullable=False)
category: Mapped[str] = mapped_column(String(16), nullable=False)
# "message" | "trace" | "lifecycle"

View File

@ -16,6 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.run.model import RunRow
from deerflow.runtime.runs.store.base import RunStore
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
class RunRepository(RunStore):
@ -68,7 +69,7 @@ class RunRepository(RunStore):
*,
thread_id,
assistant_id=None,
owner_id=None,
owner_id: "str | None | _AutoSentinel" = AUTO,
status="pending",
multitask_strategy="reject",
metadata=None,
@ -77,12 +78,13 @@ class RunRepository(RunStore):
created_at=None,
follow_up_to_run_id=None,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.put")
now = datetime.now(UTC)
row = RunRow(
run_id=run_id,
thread_id=thread_id,
assistant_id=assistant_id,
owner_id=owner_id,
owner_id=resolved_owner_id,
status=status,
multitask_strategy=multitask_strategy,
metadata_json=self._safe_json(metadata) or {},
@ -96,15 +98,32 @@ class RunRepository(RunStore):
session.add(row)
await session.commit()
async def get(self, run_id):
async def get(
self,
run_id,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.get")
async with self._sf() as session:
row = await session.get(RunRow, run_id)
return self._row_to_dict(row) if row else None
if row is None:
return None
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
return None
return self._row_to_dict(row)
async def list_by_thread(self, thread_id, *, owner_id=None, limit=100):
async def list_by_thread(
self,
thread_id,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
limit=100,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.list_by_thread")
stmt = select(RunRow).where(RunRow.thread_id == thread_id)
if owner_id is not None:
stmt = stmt.where(RunRow.owner_id == owner_id)
if resolved_owner_id is not None:
stmt = stmt.where(RunRow.owner_id == resolved_owner_id)
stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
@ -118,12 +137,21 @@ class RunRepository(RunStore):
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
await session.commit()
async def delete(self, run_id):
async def delete(
self,
run_id,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.delete")
async with self._sf() as session:
row = await session.get(RunRow, run_id)
if row is not None:
await session.delete(row)
await session.commit()
if row is None:
return
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
return
await session.delete(row)
await session.commit()
async def list_pending(self, *, before=None):
if before is None:

View File

@ -10,6 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.thread_meta.base import ThreadMetaStore
from deerflow.persistence.thread_meta.model import ThreadMetaRow
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
class ThreadMetaRepository(ThreadMetaStore):
@ -31,15 +32,18 @@ class ThreadMetaRepository(ThreadMetaStore):
thread_id: str,
*,
assistant_id: str | None = None,
owner_id: str | None = None,
owner_id: "str | None | _AutoSentinel" = AUTO,
display_name: str | None = None,
metadata: dict | None = None,
) -> dict:
# Auto-resolve owner_id from contextvar when AUTO; explicit None
# creates an orphan row (used by migration scripts).
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.create")
now = datetime.now(UTC)
row = ThreadMetaRow(
thread_id=thread_id,
assistant_id=assistant_id,
owner_id=owner_id,
owner_id=resolved_owner_id,
display_name=display_name,
metadata_json=metadata or {},
created_at=now,
@ -51,10 +55,21 @@ class ThreadMetaRepository(ThreadMetaStore):
await session.refresh(row)
return self._row_to_dict(row)
async def get(self, thread_id: str) -> dict | None:
async def get(
self,
thread_id: str,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
) -> dict | None:
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.get")
async with self._sf() as session:
row = await session.get(ThreadMetaRow, thread_id)
return self._row_to_dict(row) if row else None
if row is None:
return None
# Enforce owner filter unless explicitly bypassed (owner_id=None).
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
return None
return self._row_to_dict(row)
async def list_by_owner(self, owner_id: str, *, limit: int = 100, offset: int = 0) -> list[dict]:
stmt = select(ThreadMetaRow).where(ThreadMetaRow.owner_id == owner_id).order_by(ThreadMetaRow.updated_at.desc()).limit(limit).offset(offset)
@ -83,9 +98,17 @@ class ThreadMetaRepository(ThreadMetaStore):
status: str | None = None,
limit: int = 100,
offset: int = 0,
owner_id: "str | None | _AutoSentinel" = AUTO,
) -> list[dict]:
"""Search threads with optional metadata and status filters."""
"""Search threads with optional metadata and status filters.
Owner filter is enforced by default: caller must be in a user
context. Pass ``owner_id=None`` to bypass (migration/CLI).
"""
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.search")
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
if resolved_owner_id is not None:
stmt = stmt.where(ThreadMetaRow.owner_id == resolved_owner_id)
if status:
stmt = stmt.where(ThreadMetaRow.status == status)
@ -105,36 +128,80 @@ class ThreadMetaRepository(ThreadMetaStore):
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def update_display_name(self, thread_id: str, display_name: str) -> None:
async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_owner_id: str | None) -> bool:
"""Return True if the row exists and is owned (or filter bypassed)."""
if resolved_owner_id is None:
return True # explicit bypass
row = await session.get(ThreadMetaRow, thread_id)
return row is not None and row.owner_id == resolved_owner_id
async def update_display_name(
self,
thread_id: str,
display_name: str,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
) -> None:
"""Update the display_name (title) for a thread."""
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_display_name")
async with self._sf() as session:
if not await self._check_ownership(session, thread_id, resolved_owner_id):
return
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(display_name=display_name, updated_at=datetime.now(UTC)))
await session.commit()
async def update_status(self, thread_id: str, status: str) -> None:
async def update_status(
self,
thread_id: str,
status: str,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
) -> None:
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_status")
async with self._sf() as session:
if not await self._check_ownership(session, thread_id, resolved_owner_id):
return
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
await session.commit()
async def update_metadata(self, thread_id: str, metadata: dict) -> None:
async def update_metadata(
self,
thread_id: str,
metadata: dict,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
) -> None:
"""Merge ``metadata`` into ``metadata_json``.
Read-modify-write inside a single session/transaction so concurrent
callers see consistent state. No-op if the row does not exist.
callers see consistent state. No-op if the row does not exist or
the owner_id check fails.
"""
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_metadata")
async with self._sf() as session:
row = await session.get(ThreadMetaRow, thread_id)
if row is None:
return
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
return
merged = dict(row.metadata_json or {})
merged.update(metadata)
row.metadata_json = merged
row.updated_at = datetime.now(UTC)
await session.commit()
async def delete(self, thread_id: str) -> None:
async def delete(
self,
thread_id: str,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
) -> None:
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.delete")
async with self._sf() as session:
row = await session.get(ThreadMetaRow, thread_id)
if row is not None:
await session.delete(row)
await session.commit()
if row is None:
return
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
return
await session.delete(row)
await session.commit()

View File

@ -15,6 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.models.run_event import RunEventRow
from deerflow.runtime.events.store.base import RunEventStore
from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_owner_id
logger = logging.getLogger(__name__)
@ -53,6 +54,18 @@ class DbRunEventStore(RunEventStore):
metadata = {**(metadata or {}), "content_truncated": True, "original_byte_length": len(encoded)}
return content, metadata or {}
@staticmethod
def _owner_from_context() -> str | None:
"""Soft read of owner_id from contextvar for write paths.
Returns ``None`` (no filter / no stamp) if contextvar is unset,
which is the expected case for background worker writes. HTTP
request writes will have the contextvar set by auth middleware
and get their user_id stamped automatically.
"""
user = get_current_user()
return user.id if user is not None else None
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401
"""Write a single event — low-frequency path only.
@ -68,6 +81,7 @@ class DbRunEventStore(RunEventStore):
metadata = {**(metadata or {}), "content_is_dict": True}
else:
db_content = content
owner_id = self._owner_from_context()
async with self._sf() as session:
async with session.begin():
# Use FOR UPDATE to serialize seq assignment within a thread.
@ -78,6 +92,7 @@ class DbRunEventStore(RunEventStore):
row = RunEventRow(
thread_id=thread_id,
run_id=run_id,
owner_id=owner_id,
event_type=event_type,
category=category,
content=db_content,
@ -91,6 +106,7 @@ class DbRunEventStore(RunEventStore):
async def put_batch(self, events):
if not events:
return []
owner_id = self._owner_from_context()
async with self._sf() as session:
async with session.begin():
# Get max seq for the thread (assume all events in batch belong to same thread).
@ -114,6 +130,7 @@ class DbRunEventStore(RunEventStore):
row = RunEventRow(
thread_id=e["thread_id"],
run_id=e["run_id"],
owner_id=e.get("owner_id", owner_id),
event_type=e["event_type"],
category=category,
content=db_content,
@ -125,8 +142,19 @@ class DbRunEventStore(RunEventStore):
rows.append(row)
return [self._row_to_dict(r) for r in rows]
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
async def list_messages(
self,
thread_id,
*,
limit=50,
before_seq=None,
after_seq=None,
owner_id: "str | None | _AutoSentinel" = AUTO,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages")
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
if resolved_owner_id is not None:
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
if before_seq is not None:
stmt = stmt.where(RunEventRow.seq < before_seq)
if after_seq is not None:
@ -146,8 +174,19 @@ class DbRunEventStore(RunEventStore):
rows = list(result.scalars())
return [self._row_to_dict(r) for r in reversed(rows)]
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
async def list_events(
self,
thread_id,
run_id,
*,
event_types=None,
limit=500,
owner_id: "str | None | _AutoSentinel" = AUTO,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_events")
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
if resolved_owner_id is not None:
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
if event_types:
stmt = stmt.where(RunEventRow.event_type.in_(event_types))
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
@ -155,31 +194,68 @@ class DbRunEventStore(RunEventStore):
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def list_messages_by_run(self, thread_id, run_id):
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message").order_by(RunEventRow.seq.asc())
async def list_messages_by_run(
self,
thread_id,
run_id,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages_by_run")
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message")
if resolved_owner_id is not None:
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
stmt = stmt.order_by(RunEventRow.seq.asc())
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def count_messages(self, thread_id):
async def count_messages(
self,
thread_id,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.count_messages")
stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
if resolved_owner_id is not None:
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
async with self._sf() as session:
return await session.scalar(stmt) or 0
async def delete_by_thread(self, thread_id):
async def delete_by_thread(
self,
thread_id,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_thread")
async with self._sf() as session:
count_stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id)
count_conditions = [RunEventRow.thread_id == thread_id]
if resolved_owner_id is not None:
count_conditions.append(RunEventRow.owner_id == resolved_owner_id)
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
count = await session.scalar(count_stmt) or 0
if count > 0:
await session.execute(delete(RunEventRow).where(RunEventRow.thread_id == thread_id))
await session.execute(delete(RunEventRow).where(*count_conditions))
await session.commit()
return count
async def delete_by_run(self, thread_id, run_id):
async def delete_by_run(
self,
thread_id,
run_id,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_run")
async with self._sf() as session:
count_stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
count_conditions = [RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id]
if resolved_owner_id is not None:
count_conditions.append(RunEventRow.owner_id == resolved_owner_id)
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
count = await session.scalar(count_stmt) or 0
if count > 0:
await session.execute(delete(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id))
await session.execute(delete(RunEventRow).where(*count_conditions))
await session.commit()
return count

View File

@ -0,0 +1,148 @@
"""Request-scoped user context for owner-based authorization.
This module holds a :class:`~contextvars.ContextVar` that the gateway's
auth middleware sets after a successful authentication. Repository
methods read the contextvar via a sentinel default parameter, letting
routers stay free of ``owner_id`` boilerplate.
Three-state semantics for the repository ``owner_id`` parameter (the
consumer side of this module lives in ``deerflow.persistence.*``):
- ``_AUTO`` (module-private sentinel, default): read from contextvar;
raise :class:`RuntimeError` if unset.
- Explicit ``str``: use the provided value, overriding contextvar.
- Explicit ``None``: no WHERE clause used only by migration scripts
and admin CLIs that intentionally bypass isolation.
Dependency direction
--------------------
``persistence`` (lower layer) reads from this module; ``gateway.auth``
(higher layer) writes to it. ``CurrentUser`` is defined here as a
:class:`typing.Protocol` so that ``persistence`` never needs to import
the concrete ``User`` class from ``gateway.auth.models``. Any object
with an ``.id: str`` attribute structurally satisfies the protocol.
Asyncio semantics
-----------------
``ContextVar`` is task-local under asyncio, not thread-local. Each
FastAPI request runs in its own task, so the context is naturally
isolated. ``asyncio.create_task`` and ``asyncio.to_thread`` inherit the
parent task's context, which is typically the intended behaviour; if
a background task must *not* see the foreground user, wrap it with
``contextvars.copy_context()`` to get a clean copy.
"""
from __future__ import annotations
from contextvars import ContextVar, Token
from typing import Final, Protocol, runtime_checkable
@runtime_checkable
class CurrentUser(Protocol):
"""Structural type for the current authenticated user.
Any object with an ``.id: str`` attribute satisfies this protocol.
Concrete implementations live in ``app.gateway.auth.models.User``.
"""
id: str
_current_user: Final[ContextVar["CurrentUser | None"]] = ContextVar(
"deerflow_current_user", default=None
)
def set_current_user(user: CurrentUser) -> Token[CurrentUser | None]:
"""Set the current user for this async task.
Returns a reset token that should be passed to
:func:`reset_current_user` in a ``finally`` block to restore the
previous context.
"""
return _current_user.set(user)
def reset_current_user(token: Token[CurrentUser | None]) -> None:
"""Restore the context to the state captured by ``token``."""
_current_user.reset(token)
def get_current_user() -> CurrentUser | None:
"""Return the current user, or ``None`` if unset.
Safe to call in any context. Used by code paths that can proceed
without a user (e.g. migration scripts, public endpoints).
"""
return _current_user.get()
def require_current_user() -> CurrentUser:
"""Return the current user, or raise :class:`RuntimeError`.
Used by repository code that must not be called outside a
request-authenticated context. The error message is phrased so
that a caller debugging a stack trace can locate the offending
code path.
"""
user = _current_user.get()
if user is None:
raise RuntimeError("repository accessed without user context")
return user
# ---------------------------------------------------------------------------
# Sentinel-based owner_id resolution
# ---------------------------------------------------------------------------
#
# Repository methods accept an ``owner_id`` keyword-only argument that
# defaults to ``AUTO``. The three possible values drive distinct
# behaviours; see the docstring on :func:`resolve_owner_id`.
class _AutoSentinel:
"""Singleton marker meaning 'resolve owner_id from contextvar'."""
_instance: "_AutoSentinel | None" = None
def __new__(cls) -> "_AutoSentinel":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __repr__(self) -> str:
return "<AUTO>"
AUTO: Final[_AutoSentinel] = _AutoSentinel()
def resolve_owner_id(
value: "str | None | _AutoSentinel",
*,
method_name: str = "repository method",
) -> str | None:
"""Resolve the owner_id parameter passed to a repository method.
Three-state semantics:
- :data:`AUTO` (default): read from contextvar; raise
:class:`RuntimeError` if no user is in context. This is the
common case for request-scoped calls.
- Explicit ``str``: use the provided id verbatim, overriding any
contextvar value. Useful for tests and admin-override flows.
- Explicit ``None``: no filter the repository should skip the
owner_id WHERE clause entirely. Reserved for migration scripts
and CLI tools that intentionally bypass isolation.
"""
if isinstance(value, _AutoSentinel):
user = _current_user.get()
if user is None:
raise RuntimeError(
f"{method_name} called with owner_id=AUTO but no user context is set; "
"pass an explicit owner_id, set the contextvar via auth middleware, "
"or opt out with owner_id=None for migration/CLI paths."
)
return user.id
return value

View File

@ -30,6 +30,11 @@ postgres = [
[dependency-groups]
dev = ["pytest>=8.0.0", "ruff>=0.14.11"]
[tool.pytest.ini_options]
markers = [
"no_auto_user: disable the conftest autouse contextvar fixture for this test",
]
[tool.uv.workspace]
members = ["packages/harness"]

View File

@ -6,8 +6,11 @@ issues when unit-testing lightweight config/registry code in isolation.
import sys
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
# Make 'app' and 'deerflow' importable from any working directory
sys.path.insert(0, str(Path(__file__).parent.parent))
@ -31,3 +34,44 @@ _executor_mock.MAX_CONCURRENT_SUBAGENTS = 3
_executor_mock.get_background_task_result = MagicMock()
sys.modules["deerflow.subagents.executor"] = _executor_mock
# ---------------------------------------------------------------------------
# Auto-set user context for every test unless marked no_auto_user
# ---------------------------------------------------------------------------
#
# Repository methods read ``owner_id`` from a contextvar by default
# (see ``deerflow.runtime.user_context``). Without this fixture, every
# pre-existing persistence test would raise RuntimeError because the
# contextvar is unset. The fixture sets a default test user on every
# test; tests that explicitly want to verify behaviour *without* a user
# context should mark themselves ``@pytest.mark.no_auto_user``.
@pytest.fixture(autouse=True)
def _auto_user_context(request):
"""Inject a default ``test-user-autouse`` into the contextvar.
Opt-out via ``@pytest.mark.no_auto_user``. Uses lazy import so that
tests which don't touch the persistence layer never pay the cost
of importing runtime.user_context.
"""
if request.node.get_closest_marker("no_auto_user"):
yield
return
try:
from deerflow.runtime.user_context import (
reset_current_user,
set_current_user,
)
except ImportError:
yield
return
user = SimpleNamespace(id="test-user-autouse", email="test@local")
token = set_current_user(user)
try:
yield
finally:
reset_current_user(token)

View File

@ -104,7 +104,9 @@ class TestThreadMetaRepository:
@pytest.mark.anyio
async def test_check_access_no_owner_allows_all(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.create("t1") # owner_id=None
# Explicit owner_id=None to bypass the new AUTO default that
# would otherwise pick up the test user from the autouse fixture.
await repo.create("t1", owner_id=None)
assert await repo.check_access("t1", "anyone") is True
await _cleanup()

View File

@ -0,0 +1,69 @@
"""Tests for runtime.user_context — contextvar three-state semantics.
These tests opt out of the autouse contextvar fixture (added in
commit 6) because they explicitly test the cases where the contextvar
is set or unset.
"""
from types import SimpleNamespace
import pytest
from deerflow.runtime.user_context import (
CurrentUser,
get_current_user,
require_current_user,
reset_current_user,
set_current_user,
)
@pytest.mark.no_auto_user
def test_default_is_none():
"""Before any set, contextvar returns None."""
assert get_current_user() is None
@pytest.mark.no_auto_user
def test_set_and_reset_roundtrip():
"""set_current_user returns a token that reset restores."""
user = SimpleNamespace(id="user-1")
token = set_current_user(user)
try:
assert get_current_user() is user
finally:
reset_current_user(token)
assert get_current_user() is None
@pytest.mark.no_auto_user
def test_require_current_user_raises_when_unset():
"""require_current_user raises RuntimeError if contextvar is unset."""
assert get_current_user() is None
with pytest.raises(RuntimeError, match="without user context"):
require_current_user()
@pytest.mark.no_auto_user
def test_require_current_user_returns_user_when_set():
"""require_current_user returns the user when contextvar is set."""
user = SimpleNamespace(id="user-2")
token = set_current_user(user)
try:
assert require_current_user() is user
finally:
reset_current_user(token)
@pytest.mark.no_auto_user
def test_protocol_accepts_duck_typed():
"""CurrentUser is a runtime_checkable Protocol matching any .id-bearing object."""
user = SimpleNamespace(id="user-3")
assert isinstance(user, CurrentUser)
@pytest.mark.no_auto_user
def test_protocol_rejects_no_id():
"""Objects without .id do not satisfy CurrentUser Protocol."""
not_a_user = SimpleNamespace(email="no-id@example.com")
assert not isinstance(not_a_user, CurrentUser)