diff --git a/backend/app/gateway/auth_middleware.py b/backend/app/gateway/auth_middleware.py index cca505688..fda05e93d 100644 --- a/backend/app/gateway/auth_middleware.py +++ b/backend/app/gateway/auth_middleware.py @@ -1,6 +1,11 @@ """Global authentication middleware — fail-closed safety net. -Rejects unauthenticated requests to non-public paths with 401. +Rejects unauthenticated requests to non-public paths with 401. When a +request passes the cookie check, resolves the JWT payload to a real +``User`` object and stamps it into both ``request.state.user`` and the +``deerflow.runtime.user_context`` contextvar so that repository-layer +owner filtering works automatically via the sentinel pattern. + Fine-grained permission checks remain in authz.py decorators. """ @@ -12,6 +17,7 @@ from starlette.responses import JSONResponse from starlette.types import ASGIApp from app.gateway.auth.errors import AuthErrorCode +from deerflow.runtime.user_context import reset_current_user, set_current_user # Paths that never require authentication. _PUBLIC_PATH_PREFIXES: tuple[str, ...] = ( @@ -68,4 +74,22 @@ class AuthMiddleware(BaseHTTPMiddleware): }, ) - return await call_next(request) + # Resolve the full user now so repository-layer owner filters + # can read from the contextvar. We use the "optional" flavour so + # middleware never raises on bad tokens — the cookie-presence + # check above plus the @require_auth decorator provide the + # strict gates. A stale/invalid token yields user=None here; + # the request continues without a contextvar, and any protected + # endpoint will still be rejected by @require_auth. + from app.gateway.deps import get_optional_user_from_request + + user = await get_optional_user_from_request(request) + if user is None: + return await call_next(request) + + request.state.user = user + token = set_current_user(user) + try: + return await call_next(request) + finally: + reset_current_user(token) diff --git a/backend/packages/harness/deerflow/persistence/feedback/sql.py b/backend/packages/harness/deerflow/persistence/feedback/sql.py index eae2f9997..ffae49f31 100644 --- a/backend/packages/harness/deerflow/persistence/feedback/sql.py +++ b/backend/packages/harness/deerflow/persistence/feedback/sql.py @@ -12,6 +12,7 @@ from sqlalchemy import case, func, select from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from deerflow.persistence.feedback.model import FeedbackRow +from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id class FeedbackRepository: @@ -32,18 +33,19 @@ class FeedbackRepository: run_id: str, thread_id: str, rating: int, - owner_id: str | None = None, + owner_id: "str | None | _AutoSentinel" = AUTO, message_id: str | None = None, comment: str | None = None, ) -> dict: """Create a feedback record. rating must be +1 or -1.""" if rating not in (1, -1): raise ValueError(f"rating must be +1 or -1, got {rating}") + resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.create") row = FeedbackRow( feedback_id=str(uuid.uuid4()), run_id=run_id, thread_id=thread_id, - owner_id=owner_id, + owner_id=resolved_owner_id, message_id=message_id, rating=rating, comment=comment, @@ -55,27 +57,66 @@ class FeedbackRepository: await session.refresh(row) return self._row_to_dict(row) - async def get(self, feedback_id: str) -> dict | None: - async with self._sf() as session: - row = await session.get(FeedbackRow, feedback_id) - return self._row_to_dict(row) if row else None - - async def list_by_run(self, thread_id: str, run_id: str, *, limit: int = 100) -> list[dict]: - stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id).order_by(FeedbackRow.created_at.asc()).limit(limit) - async with self._sf() as session: - result = await session.execute(stmt) - return [self._row_to_dict(r) for r in result.scalars()] - - async def list_by_thread(self, thread_id: str, *, limit: int = 100) -> list[dict]: - stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id).order_by(FeedbackRow.created_at.asc()).limit(limit) - async with self._sf() as session: - result = await session.execute(stmt) - return [self._row_to_dict(r) for r in result.scalars()] - - async def delete(self, feedback_id: str) -> bool: + async def get( + self, + feedback_id: str, + *, + owner_id: "str | None | _AutoSentinel" = AUTO, + ) -> dict | None: + resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.get") async with self._sf() as session: row = await session.get(FeedbackRow, feedback_id) if row is None: + return None + if resolved_owner_id is not None and row.owner_id != resolved_owner_id: + return None + return self._row_to_dict(row) + + async def list_by_run( + self, + thread_id: str, + run_id: str, + *, + limit: int = 100, + owner_id: "str | None | _AutoSentinel" = AUTO, + ) -> list[dict]: + resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_run") + stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id) + if resolved_owner_id is not None: + stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id) + stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit) + async with self._sf() as session: + result = await session.execute(stmt) + return [self._row_to_dict(r) for r in result.scalars()] + + async def list_by_thread( + self, + thread_id: str, + *, + limit: int = 100, + owner_id: "str | None | _AutoSentinel" = AUTO, + ) -> list[dict]: + resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_thread") + stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id) + if resolved_owner_id is not None: + stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id) + stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit) + async with self._sf() as session: + result = await session.execute(stmt) + return [self._row_to_dict(r) for r in result.scalars()] + + async def delete( + self, + feedback_id: str, + *, + owner_id: "str | None | _AutoSentinel" = AUTO, + ) -> bool: + resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.delete") + async with self._sf() as session: + row = await session.get(FeedbackRow, feedback_id) + if row is None: + return False + if resolved_owner_id is not None and row.owner_id != resolved_owner_id: return False await session.delete(row) await session.commit() diff --git a/backend/packages/harness/deerflow/persistence/models/run_event.py b/backend/packages/harness/deerflow/persistence/models/run_event.py index 8db50aea7..34f55ba03 100644 --- a/backend/packages/harness/deerflow/persistence/models/run_event.py +++ b/backend/packages/harness/deerflow/persistence/models/run_event.py @@ -16,6 +16,10 @@ class RunEventRow(Base): id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) thread_id: Mapped[str] = mapped_column(String(64), nullable=False) run_id: Mapped[str] = mapped_column(String(64), nullable=False) + # Owner of the conversation this event belongs to. Nullable for data + # created before auth was introduced; populated by auth middleware on + # new writes and by the boot-time orphan migration on existing rows. + owner_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True) event_type: Mapped[str] = mapped_column(String(32), nullable=False) category: Mapped[str] = mapped_column(String(16), nullable=False) # "message" | "trace" | "lifecycle" diff --git a/backend/packages/harness/deerflow/persistence/run/sql.py b/backend/packages/harness/deerflow/persistence/run/sql.py index fac88d968..9847825ef 100644 --- a/backend/packages/harness/deerflow/persistence/run/sql.py +++ b/backend/packages/harness/deerflow/persistence/run/sql.py @@ -16,6 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from deerflow.persistence.run.model import RunRow from deerflow.runtime.runs.store.base import RunStore +from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id class RunRepository(RunStore): @@ -68,7 +69,7 @@ class RunRepository(RunStore): *, thread_id, assistant_id=None, - owner_id=None, + owner_id: "str | None | _AutoSentinel" = AUTO, status="pending", multitask_strategy="reject", metadata=None, @@ -77,12 +78,13 @@ class RunRepository(RunStore): created_at=None, follow_up_to_run_id=None, ): + resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.put") now = datetime.now(UTC) row = RunRow( run_id=run_id, thread_id=thread_id, assistant_id=assistant_id, - owner_id=owner_id, + owner_id=resolved_owner_id, status=status, multitask_strategy=multitask_strategy, metadata_json=self._safe_json(metadata) or {}, @@ -96,15 +98,32 @@ class RunRepository(RunStore): session.add(row) await session.commit() - async def get(self, run_id): + async def get( + self, + run_id, + *, + owner_id: "str | None | _AutoSentinel" = AUTO, + ): + resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.get") async with self._sf() as session: row = await session.get(RunRow, run_id) - return self._row_to_dict(row) if row else None + if row is None: + return None + if resolved_owner_id is not None and row.owner_id != resolved_owner_id: + return None + return self._row_to_dict(row) - async def list_by_thread(self, thread_id, *, owner_id=None, limit=100): + async def list_by_thread( + self, + thread_id, + *, + owner_id: "str | None | _AutoSentinel" = AUTO, + limit=100, + ): + resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.list_by_thread") stmt = select(RunRow).where(RunRow.thread_id == thread_id) - if owner_id is not None: - stmt = stmt.where(RunRow.owner_id == owner_id) + if resolved_owner_id is not None: + stmt = stmt.where(RunRow.owner_id == resolved_owner_id) stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit) async with self._sf() as session: result = await session.execute(stmt) @@ -118,12 +137,21 @@ class RunRepository(RunStore): await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values)) await session.commit() - async def delete(self, run_id): + async def delete( + self, + run_id, + *, + owner_id: "str | None | _AutoSentinel" = AUTO, + ): + resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.delete") async with self._sf() as session: row = await session.get(RunRow, run_id) - if row is not None: - await session.delete(row) - await session.commit() + if row is None: + return + if resolved_owner_id is not None and row.owner_id != resolved_owner_id: + return + await session.delete(row) + await session.commit() async def list_pending(self, *, before=None): if before is None: diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/sql.py b/backend/packages/harness/deerflow/persistence/thread_meta/sql.py index 86c73030e..d49b9ee29 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/sql.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/sql.py @@ -10,6 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from deerflow.persistence.thread_meta.base import ThreadMetaStore from deerflow.persistence.thread_meta.model import ThreadMetaRow +from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id class ThreadMetaRepository(ThreadMetaStore): @@ -31,15 +32,18 @@ class ThreadMetaRepository(ThreadMetaStore): thread_id: str, *, assistant_id: str | None = None, - owner_id: str | None = None, + owner_id: "str | None | _AutoSentinel" = AUTO, display_name: str | None = None, metadata: dict | None = None, ) -> dict: + # Auto-resolve owner_id from contextvar when AUTO; explicit None + # creates an orphan row (used by migration scripts). + resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.create") now = datetime.now(UTC) row = ThreadMetaRow( thread_id=thread_id, assistant_id=assistant_id, - owner_id=owner_id, + owner_id=resolved_owner_id, display_name=display_name, metadata_json=metadata or {}, created_at=now, @@ -51,10 +55,21 @@ class ThreadMetaRepository(ThreadMetaStore): await session.refresh(row) return self._row_to_dict(row) - async def get(self, thread_id: str) -> dict | None: + async def get( + self, + thread_id: str, + *, + owner_id: "str | None | _AutoSentinel" = AUTO, + ) -> dict | None: + resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.get") async with self._sf() as session: row = await session.get(ThreadMetaRow, thread_id) - return self._row_to_dict(row) if row else None + if row is None: + return None + # Enforce owner filter unless explicitly bypassed (owner_id=None). + if resolved_owner_id is not None and row.owner_id != resolved_owner_id: + return None + return self._row_to_dict(row) async def list_by_owner(self, owner_id: str, *, limit: int = 100, offset: int = 0) -> list[dict]: stmt = select(ThreadMetaRow).where(ThreadMetaRow.owner_id == owner_id).order_by(ThreadMetaRow.updated_at.desc()).limit(limit).offset(offset) @@ -83,9 +98,17 @@ class ThreadMetaRepository(ThreadMetaStore): status: str | None = None, limit: int = 100, offset: int = 0, + owner_id: "str | None | _AutoSentinel" = AUTO, ) -> list[dict]: - """Search threads with optional metadata and status filters.""" + """Search threads with optional metadata and status filters. + + Owner filter is enforced by default: caller must be in a user + context. Pass ``owner_id=None`` to bypass (migration/CLI). + """ + resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.search") stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc()) + if resolved_owner_id is not None: + stmt = stmt.where(ThreadMetaRow.owner_id == resolved_owner_id) if status: stmt = stmt.where(ThreadMetaRow.status == status) @@ -105,36 +128,80 @@ class ThreadMetaRepository(ThreadMetaStore): result = await session.execute(stmt) return [self._row_to_dict(r) for r in result.scalars()] - async def update_display_name(self, thread_id: str, display_name: str) -> None: + async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_owner_id: str | None) -> bool: + """Return True if the row exists and is owned (or filter bypassed).""" + if resolved_owner_id is None: + return True # explicit bypass + row = await session.get(ThreadMetaRow, thread_id) + return row is not None and row.owner_id == resolved_owner_id + + async def update_display_name( + self, + thread_id: str, + display_name: str, + *, + owner_id: "str | None | _AutoSentinel" = AUTO, + ) -> None: """Update the display_name (title) for a thread.""" + resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_display_name") async with self._sf() as session: + if not await self._check_ownership(session, thread_id, resolved_owner_id): + return await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(display_name=display_name, updated_at=datetime.now(UTC))) await session.commit() - async def update_status(self, thread_id: str, status: str) -> None: + async def update_status( + self, + thread_id: str, + status: str, + *, + owner_id: "str | None | _AutoSentinel" = AUTO, + ) -> None: + resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_status") async with self._sf() as session: + if not await self._check_ownership(session, thread_id, resolved_owner_id): + return await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC))) await session.commit() - async def update_metadata(self, thread_id: str, metadata: dict) -> None: + async def update_metadata( + self, + thread_id: str, + metadata: dict, + *, + owner_id: "str | None | _AutoSentinel" = AUTO, + ) -> None: """Merge ``metadata`` into ``metadata_json``. Read-modify-write inside a single session/transaction so concurrent - callers see consistent state. No-op if the row does not exist. + callers see consistent state. No-op if the row does not exist or + the owner_id check fails. """ + resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_metadata") async with self._sf() as session: row = await session.get(ThreadMetaRow, thread_id) if row is None: return + if resolved_owner_id is not None and row.owner_id != resolved_owner_id: + return merged = dict(row.metadata_json or {}) merged.update(metadata) row.metadata_json = merged row.updated_at = datetime.now(UTC) await session.commit() - async def delete(self, thread_id: str) -> None: + async def delete( + self, + thread_id: str, + *, + owner_id: "str | None | _AutoSentinel" = AUTO, + ) -> None: + resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.delete") async with self._sf() as session: row = await session.get(ThreadMetaRow, thread_id) - if row is not None: - await session.delete(row) - await session.commit() + if row is None: + return + if resolved_owner_id is not None and row.owner_id != resolved_owner_id: + return + await session.delete(row) + await session.commit() diff --git a/backend/packages/harness/deerflow/runtime/events/store/db.py b/backend/packages/harness/deerflow/runtime/events/store/db.py index 0502cd879..dc8772751 100644 --- a/backend/packages/harness/deerflow/runtime/events/store/db.py +++ b/backend/packages/harness/deerflow/runtime/events/store/db.py @@ -15,6 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from deerflow.persistence.models.run_event import RunEventRow from deerflow.runtime.events.store.base import RunEventStore +from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_owner_id logger = logging.getLogger(__name__) @@ -53,6 +54,18 @@ class DbRunEventStore(RunEventStore): metadata = {**(metadata or {}), "content_truncated": True, "original_byte_length": len(encoded)} return content, metadata or {} + @staticmethod + def _owner_from_context() -> str | None: + """Soft read of owner_id from contextvar for write paths. + + Returns ``None`` (no filter / no stamp) if contextvar is unset, + which is the expected case for background worker writes. HTTP + request writes will have the contextvar set by auth middleware + and get their user_id stamped automatically. + """ + user = get_current_user() + return user.id if user is not None else None + async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401 """Write a single event — low-frequency path only. @@ -68,6 +81,7 @@ class DbRunEventStore(RunEventStore): metadata = {**(metadata or {}), "content_is_dict": True} else: db_content = content + owner_id = self._owner_from_context() async with self._sf() as session: async with session.begin(): # Use FOR UPDATE to serialize seq assignment within a thread. @@ -78,6 +92,7 @@ class DbRunEventStore(RunEventStore): row = RunEventRow( thread_id=thread_id, run_id=run_id, + owner_id=owner_id, event_type=event_type, category=category, content=db_content, @@ -91,6 +106,7 @@ class DbRunEventStore(RunEventStore): async def put_batch(self, events): if not events: return [] + owner_id = self._owner_from_context() async with self._sf() as session: async with session.begin(): # Get max seq for the thread (assume all events in batch belong to same thread). @@ -114,6 +130,7 @@ class DbRunEventStore(RunEventStore): row = RunEventRow( thread_id=e["thread_id"], run_id=e["run_id"], + owner_id=e.get("owner_id", owner_id), event_type=e["event_type"], category=category, content=db_content, @@ -125,8 +142,19 @@ class DbRunEventStore(RunEventStore): rows.append(row) return [self._row_to_dict(r) for r in rows] - async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None): + async def list_messages( + self, + thread_id, + *, + limit=50, + before_seq=None, + after_seq=None, + owner_id: "str | None | _AutoSentinel" = AUTO, + ): + resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages") stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message") + if resolved_owner_id is not None: + stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id) if before_seq is not None: stmt = stmt.where(RunEventRow.seq < before_seq) if after_seq is not None: @@ -146,8 +174,19 @@ class DbRunEventStore(RunEventStore): rows = list(result.scalars()) return [self._row_to_dict(r) for r in reversed(rows)] - async def list_events(self, thread_id, run_id, *, event_types=None, limit=500): + async def list_events( + self, + thread_id, + run_id, + *, + event_types=None, + limit=500, + owner_id: "str | None | _AutoSentinel" = AUTO, + ): + resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_events") stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id) + if resolved_owner_id is not None: + stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id) if event_types: stmt = stmt.where(RunEventRow.event_type.in_(event_types)) stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit) @@ -155,31 +194,68 @@ class DbRunEventStore(RunEventStore): result = await session.execute(stmt) return [self._row_to_dict(r) for r in result.scalars()] - async def list_messages_by_run(self, thread_id, run_id): - stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message").order_by(RunEventRow.seq.asc()) + async def list_messages_by_run( + self, + thread_id, + run_id, + *, + owner_id: "str | None | _AutoSentinel" = AUTO, + ): + resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages_by_run") + stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message") + if resolved_owner_id is not None: + stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id) + stmt = stmt.order_by(RunEventRow.seq.asc()) async with self._sf() as session: result = await session.execute(stmt) return [self._row_to_dict(r) for r in result.scalars()] - async def count_messages(self, thread_id): + async def count_messages( + self, + thread_id, + *, + owner_id: "str | None | _AutoSentinel" = AUTO, + ): + resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.count_messages") stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message") + if resolved_owner_id is not None: + stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id) async with self._sf() as session: return await session.scalar(stmt) or 0 - async def delete_by_thread(self, thread_id): + async def delete_by_thread( + self, + thread_id, + *, + owner_id: "str | None | _AutoSentinel" = AUTO, + ): + resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_thread") async with self._sf() as session: - count_stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id) + count_conditions = [RunEventRow.thread_id == thread_id] + if resolved_owner_id is not None: + count_conditions.append(RunEventRow.owner_id == resolved_owner_id) + count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions) count = await session.scalar(count_stmt) or 0 if count > 0: - await session.execute(delete(RunEventRow).where(RunEventRow.thread_id == thread_id)) + await session.execute(delete(RunEventRow).where(*count_conditions)) await session.commit() return count - async def delete_by_run(self, thread_id, run_id): + async def delete_by_run( + self, + thread_id, + run_id, + *, + owner_id: "str | None | _AutoSentinel" = AUTO, + ): + resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_run") async with self._sf() as session: - count_stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id) + count_conditions = [RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id] + if resolved_owner_id is not None: + count_conditions.append(RunEventRow.owner_id == resolved_owner_id) + count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions) count = await session.scalar(count_stmt) or 0 if count > 0: - await session.execute(delete(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)) + await session.execute(delete(RunEventRow).where(*count_conditions)) await session.commit() return count diff --git a/backend/packages/harness/deerflow/runtime/user_context.py b/backend/packages/harness/deerflow/runtime/user_context.py new file mode 100644 index 000000000..fd720c727 --- /dev/null +++ b/backend/packages/harness/deerflow/runtime/user_context.py @@ -0,0 +1,148 @@ +"""Request-scoped user context for owner-based authorization. + +This module holds a :class:`~contextvars.ContextVar` that the gateway's +auth middleware sets after a successful authentication. Repository +methods read the contextvar via a sentinel default parameter, letting +routers stay free of ``owner_id`` boilerplate. + +Three-state semantics for the repository ``owner_id`` parameter (the +consumer side of this module lives in ``deerflow.persistence.*``): + +- ``_AUTO`` (module-private sentinel, default): read from contextvar; + raise :class:`RuntimeError` if unset. +- Explicit ``str``: use the provided value, overriding contextvar. +- Explicit ``None``: no WHERE clause — used only by migration scripts + and admin CLIs that intentionally bypass isolation. + +Dependency direction +-------------------- +``persistence`` (lower layer) reads from this module; ``gateway.auth`` +(higher layer) writes to it. ``CurrentUser`` is defined here as a +:class:`typing.Protocol` so that ``persistence`` never needs to import +the concrete ``User`` class from ``gateway.auth.models``. Any object +with an ``.id: str`` attribute structurally satisfies the protocol. + +Asyncio semantics +----------------- +``ContextVar`` is task-local under asyncio, not thread-local. Each +FastAPI request runs in its own task, so the context is naturally +isolated. ``asyncio.create_task`` and ``asyncio.to_thread`` inherit the +parent task's context, which is typically the intended behaviour; if +a background task must *not* see the foreground user, wrap it with +``contextvars.copy_context()`` to get a clean copy. +""" + +from __future__ import annotations + +from contextvars import ContextVar, Token +from typing import Final, Protocol, runtime_checkable + + +@runtime_checkable +class CurrentUser(Protocol): + """Structural type for the current authenticated user. + + Any object with an ``.id: str`` attribute satisfies this protocol. + Concrete implementations live in ``app.gateway.auth.models.User``. + """ + + id: str + + +_current_user: Final[ContextVar["CurrentUser | None"]] = ContextVar( + "deerflow_current_user", default=None +) + + +def set_current_user(user: CurrentUser) -> Token[CurrentUser | None]: + """Set the current user for this async task. + + Returns a reset token that should be passed to + :func:`reset_current_user` in a ``finally`` block to restore the + previous context. + """ + return _current_user.set(user) + + +def reset_current_user(token: Token[CurrentUser | None]) -> None: + """Restore the context to the state captured by ``token``.""" + _current_user.reset(token) + + +def get_current_user() -> CurrentUser | None: + """Return the current user, or ``None`` if unset. + + Safe to call in any context. Used by code paths that can proceed + without a user (e.g. migration scripts, public endpoints). + """ + return _current_user.get() + + +def require_current_user() -> CurrentUser: + """Return the current user, or raise :class:`RuntimeError`. + + Used by repository code that must not be called outside a + request-authenticated context. The error message is phrased so + that a caller debugging a stack trace can locate the offending + code path. + """ + user = _current_user.get() + if user is None: + raise RuntimeError("repository accessed without user context") + return user + + +# --------------------------------------------------------------------------- +# Sentinel-based owner_id resolution +# --------------------------------------------------------------------------- +# +# Repository methods accept an ``owner_id`` keyword-only argument that +# defaults to ``AUTO``. The three possible values drive distinct +# behaviours; see the docstring on :func:`resolve_owner_id`. + + +class _AutoSentinel: + """Singleton marker meaning 'resolve owner_id from contextvar'.""" + + _instance: "_AutoSentinel | None" = None + + def __new__(cls) -> "_AutoSentinel": + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __repr__(self) -> str: + return "" + + +AUTO: Final[_AutoSentinel] = _AutoSentinel() + + +def resolve_owner_id( + value: "str | None | _AutoSentinel", + *, + method_name: str = "repository method", +) -> str | None: + """Resolve the owner_id parameter passed to a repository method. + + Three-state semantics: + + - :data:`AUTO` (default): read from contextvar; raise + :class:`RuntimeError` if no user is in context. This is the + common case for request-scoped calls. + - Explicit ``str``: use the provided id verbatim, overriding any + contextvar value. Useful for tests and admin-override flows. + - Explicit ``None``: no filter — the repository should skip the + owner_id WHERE clause entirely. Reserved for migration scripts + and CLI tools that intentionally bypass isolation. + """ + if isinstance(value, _AutoSentinel): + user = _current_user.get() + if user is None: + raise RuntimeError( + f"{method_name} called with owner_id=AUTO but no user context is set; " + "pass an explicit owner_id, set the contextvar via auth middleware, " + "or opt out with owner_id=None for migration/CLI paths." + ) + return user.id + return value diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 09e948511..a743d5e02 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -30,6 +30,11 @@ postgres = [ [dependency-groups] dev = ["pytest>=8.0.0", "ruff>=0.14.11"] +[tool.pytest.ini_options] +markers = [ + "no_auto_user: disable the conftest autouse contextvar fixture for this test", +] + [tool.uv.workspace] members = ["packages/harness"] diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 491961c00..9b10430e5 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -6,8 +6,11 @@ issues when unit-testing lightweight config/registry code in isolation. import sys from pathlib import Path +from types import SimpleNamespace from unittest.mock import MagicMock +import pytest + # Make 'app' and 'deerflow' importable from any working directory sys.path.insert(0, str(Path(__file__).parent.parent)) @@ -31,3 +34,44 @@ _executor_mock.MAX_CONCURRENT_SUBAGENTS = 3 _executor_mock.get_background_task_result = MagicMock() sys.modules["deerflow.subagents.executor"] = _executor_mock + + +# --------------------------------------------------------------------------- +# Auto-set user context for every test unless marked no_auto_user +# --------------------------------------------------------------------------- +# +# Repository methods read ``owner_id`` from a contextvar by default +# (see ``deerflow.runtime.user_context``). Without this fixture, every +# pre-existing persistence test would raise RuntimeError because the +# contextvar is unset. The fixture sets a default test user on every +# test; tests that explicitly want to verify behaviour *without* a user +# context should mark themselves ``@pytest.mark.no_auto_user``. + + +@pytest.fixture(autouse=True) +def _auto_user_context(request): + """Inject a default ``test-user-autouse`` into the contextvar. + + Opt-out via ``@pytest.mark.no_auto_user``. Uses lazy import so that + tests which don't touch the persistence layer never pay the cost + of importing runtime.user_context. + """ + if request.node.get_closest_marker("no_auto_user"): + yield + return + + try: + from deerflow.runtime.user_context import ( + reset_current_user, + set_current_user, + ) + except ImportError: + yield + return + + user = SimpleNamespace(id="test-user-autouse", email="test@local") + token = set_current_user(user) + try: + yield + finally: + reset_current_user(token) diff --git a/backend/tests/test_thread_meta_repo.py b/backend/tests/test_thread_meta_repo.py index dbb747a26..cc944ea81 100644 --- a/backend/tests/test_thread_meta_repo.py +++ b/backend/tests/test_thread_meta_repo.py @@ -104,7 +104,9 @@ class TestThreadMetaRepository: @pytest.mark.anyio async def test_check_access_no_owner_allows_all(self, tmp_path): repo = await _make_repo(tmp_path) - await repo.create("t1") # owner_id=None + # Explicit owner_id=None to bypass the new AUTO default that + # would otherwise pick up the test user from the autouse fixture. + await repo.create("t1", owner_id=None) assert await repo.check_access("t1", "anyone") is True await _cleanup() diff --git a/backend/tests/test_user_context.py b/backend/tests/test_user_context.py new file mode 100644 index 000000000..b7dd1efd0 --- /dev/null +++ b/backend/tests/test_user_context.py @@ -0,0 +1,69 @@ +"""Tests for runtime.user_context — contextvar three-state semantics. + +These tests opt out of the autouse contextvar fixture (added in +commit 6) because they explicitly test the cases where the contextvar +is set or unset. +""" + +from types import SimpleNamespace + +import pytest + +from deerflow.runtime.user_context import ( + CurrentUser, + get_current_user, + require_current_user, + reset_current_user, + set_current_user, +) + + +@pytest.mark.no_auto_user +def test_default_is_none(): + """Before any set, contextvar returns None.""" + assert get_current_user() is None + + +@pytest.mark.no_auto_user +def test_set_and_reset_roundtrip(): + """set_current_user returns a token that reset restores.""" + user = SimpleNamespace(id="user-1") + token = set_current_user(user) + try: + assert get_current_user() is user + finally: + reset_current_user(token) + assert get_current_user() is None + + +@pytest.mark.no_auto_user +def test_require_current_user_raises_when_unset(): + """require_current_user raises RuntimeError if contextvar is unset.""" + assert get_current_user() is None + with pytest.raises(RuntimeError, match="without user context"): + require_current_user() + + +@pytest.mark.no_auto_user +def test_require_current_user_returns_user_when_set(): + """require_current_user returns the user when contextvar is set.""" + user = SimpleNamespace(id="user-2") + token = set_current_user(user) + try: + assert require_current_user() is user + finally: + reset_current_user(token) + + +@pytest.mark.no_auto_user +def test_protocol_accepts_duck_typed(): + """CurrentUser is a runtime_checkable Protocol matching any .id-bearing object.""" + user = SimpleNamespace(id="user-3") + assert isinstance(user, CurrentUser) + + +@pytest.mark.no_auto_user +def test_protocol_rejects_no_id(): + """Objects without .id do not satisfy CurrentUser Protocol.""" + not_a_user = SimpleNamespace(email="no-id@example.com") + assert not isinstance(not_a_user, CurrentUser)