mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
feat(auth): enforce owner_id across 2.0-rc persistence layer
Add request-scoped contextvar-based owner filtering to threads_meta,
runs, run_events, and feedback repositories. Router code is unchanged
— isolation is enforced at the storage layer so that any caller that
forgets to pass owner_id still gets filtered results, and new routes
cannot accidentally leak data.
Core infrastructure
-------------------
- deerflow/runtime/user_context.py (new):
- ContextVar[CurrentUser | None] with default None
- runtime_checkable CurrentUser Protocol (structural subtype with .id)
- set/reset/get/require helpers
- AUTO sentinel + resolve_owner_id(value, method_name) for sentinel
three-state resolution: AUTO reads contextvar, explicit str
overrides, explicit None bypasses the filter (for migration/CLI)
Repository changes
------------------
- ThreadMetaRepository: create/get/search/update_*/delete gain
owner_id=AUTO kwarg; read paths filter by owner, writes stamp it,
mutations check ownership before applying
- RunRepository: put/get/list_by_thread/delete gain owner_id=AUTO kwarg
- FeedbackRepository: create/get/list_by_run/list_by_thread/delete
gain owner_id=AUTO kwarg
- DbRunEventStore: list_messages/list_events/list_messages_by_run/
count_messages/delete_by_thread/delete_by_run gain owner_id=AUTO
kwarg. Write paths (put/put_batch) read contextvar softly: when a
request-scoped user is available, owner_id is stamped; background
worker writes without a user context pass None which is valid
(orphan row to be bound by migration)
Schema
------
- persistence/models/run_event.py: RunEventRow.owner_id = Mapped[
str | None] = mapped_column(String(64), nullable=True, index=True)
- No alembic migration needed: 2.0 ships fresh, Base.metadata.create_all
picks up the new column automatically
Middleware
----------
- auth_middleware.py: after cookie check, call get_optional_user_from_
request to load the real User, stamp it into request.state.user AND
the contextvar via set_current_user, reset in a try/finally. Public
paths and unauthenticated requests continue without contextvar, and
@require_auth handles the strict 401 path
Test infrastructure
-------------------
- tests/conftest.py: @pytest.fixture(autouse=True) _auto_user_context
sets a default SimpleNamespace(id="test-user-autouse") on every test
unless marked @pytest.mark.no_auto_user. Keeps existing 20+
persistence tests passing without modification
- pyproject.toml [tool.pytest.ini_options]: register no_auto_user
marker so pytest does not emit warnings for opt-out tests
- tests/test_user_context.py: 6 tests covering three-state semantics,
Protocol duck typing, and require/optional APIs
- tests/test_thread_meta_repo.py: one test updated to pass owner_id=
None explicitly where it was previously relying on the old default
Test results
------------
- test_user_context.py: 6 passed
- test_auth*.py + test_langgraph_auth.py + test_ensure_admin.py: 127
- test_run_event_store / test_run_repository / test_thread_meta_repo
/ test_feedback: 92 passed
- Full backend suite: 1905 passed, 2 failed (both @requires_llm flaky
integration tests unrelated to auth), 1 skipped
This commit is contained in:
parent
2531cce0d1
commit
4b139fb689
@ -1,6 +1,11 @@
|
||||
"""Global authentication middleware — fail-closed safety net.
|
||||
|
||||
Rejects unauthenticated requests to non-public paths with 401.
|
||||
Rejects unauthenticated requests to non-public paths with 401. When a
|
||||
request passes the cookie check, resolves the JWT payload to a real
|
||||
``User`` object and stamps it into both ``request.state.user`` and the
|
||||
``deerflow.runtime.user_context`` contextvar so that repository-layer
|
||||
owner filtering works automatically via the sentinel pattern.
|
||||
|
||||
Fine-grained permission checks remain in authz.py decorators.
|
||||
"""
|
||||
|
||||
@ -12,6 +17,7 @@ from starlette.responses import JSONResponse
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from app.gateway.auth.errors import AuthErrorCode
|
||||
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
||||
|
||||
# Paths that never require authentication.
|
||||
_PUBLIC_PATH_PREFIXES: tuple[str, ...] = (
|
||||
@ -68,4 +74,22 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
},
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
# Resolve the full user now so repository-layer owner filters
|
||||
# can read from the contextvar. We use the "optional" flavour so
|
||||
# middleware never raises on bad tokens — the cookie-presence
|
||||
# check above plus the @require_auth decorator provide the
|
||||
# strict gates. A stale/invalid token yields user=None here;
|
||||
# the request continues without a contextvar, and any protected
|
||||
# endpoint will still be rejected by @require_auth.
|
||||
from app.gateway.deps import get_optional_user_from_request
|
||||
|
||||
user = await get_optional_user_from_request(request)
|
||||
if user is None:
|
||||
return await call_next(request)
|
||||
|
||||
request.state.user = user
|
||||
token = set_current_user(user)
|
||||
try:
|
||||
return await call_next(request)
|
||||
finally:
|
||||
reset_current_user(token)
|
||||
|
||||
@ -12,6 +12,7 @@ from sqlalchemy import case, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.feedback.model import FeedbackRow
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
||||
|
||||
|
||||
class FeedbackRepository:
|
||||
@ -32,18 +33,19 @@ class FeedbackRepository:
|
||||
run_id: str,
|
||||
thread_id: str,
|
||||
rating: int,
|
||||
owner_id: str | None = None,
|
||||
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||
message_id: str | None = None,
|
||||
comment: str | None = None,
|
||||
) -> dict:
|
||||
"""Create a feedback record. rating must be +1 or -1."""
|
||||
if rating not in (1, -1):
|
||||
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.create")
|
||||
row = FeedbackRow(
|
||||
feedback_id=str(uuid.uuid4()),
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
owner_id=owner_id,
|
||||
owner_id=resolved_owner_id,
|
||||
message_id=message_id,
|
||||
rating=rating,
|
||||
comment=comment,
|
||||
@ -55,27 +57,66 @@ class FeedbackRepository:
|
||||
await session.refresh(row)
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def get(self, feedback_id: str) -> dict | None:
|
||||
async with self._sf() as session:
|
||||
row = await session.get(FeedbackRow, feedback_id)
|
||||
return self._row_to_dict(row) if row else None
|
||||
|
||||
async def list_by_run(self, thread_id: str, run_id: str, *, limit: int = 100) -> list[dict]:
|
||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id).order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def list_by_thread(self, thread_id: str, *, limit: int = 100) -> list[dict]:
|
||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id).order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def delete(self, feedback_id: str) -> bool:
|
||||
async def get(
|
||||
self,
|
||||
feedback_id: str,
|
||||
*,
|
||||
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||
) -> dict | None:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.get")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(FeedbackRow, feedback_id)
|
||||
if row is None:
|
||||
return None
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return None
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def list_by_run(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
limit: int = 100,
|
||||
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||
) -> list[dict]:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_run")
|
||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id)
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
|
||||
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def list_by_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 100,
|
||||
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||
) -> list[dict]:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_thread")
|
||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id)
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
|
||||
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
feedback_id: str,
|
||||
*,
|
||||
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||
) -> bool:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.delete")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(FeedbackRow, feedback_id)
|
||||
if row is None:
|
||||
return False
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return False
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
|
||||
@ -16,6 +16,10 @@ class RunEventRow(Base):
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
run_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
# Owner of the conversation this event belongs to. Nullable for data
|
||||
# created before auth was introduced; populated by auth middleware on
|
||||
# new writes and by the boot-time orphan migration on existing rows.
|
||||
owner_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
|
||||
event_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
category: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
# "message" | "trace" | "lifecycle"
|
||||
|
||||
@ -16,6 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.run.model import RunRow
|
||||
from deerflow.runtime.runs.store.base import RunStore
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
||||
|
||||
|
||||
class RunRepository(RunStore):
|
||||
@ -68,7 +69,7 @@ class RunRepository(RunStore):
|
||||
*,
|
||||
thread_id,
|
||||
assistant_id=None,
|
||||
owner_id=None,
|
||||
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||
status="pending",
|
||||
multitask_strategy="reject",
|
||||
metadata=None,
|
||||
@ -77,12 +78,13 @@ class RunRepository(RunStore):
|
||||
created_at=None,
|
||||
follow_up_to_run_id=None,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.put")
|
||||
now = datetime.now(UTC)
|
||||
row = RunRow(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
owner_id=owner_id,
|
||||
owner_id=resolved_owner_id,
|
||||
status=status,
|
||||
multitask_strategy=multitask_strategy,
|
||||
metadata_json=self._safe_json(metadata) or {},
|
||||
@ -96,15 +98,32 @@ class RunRepository(RunStore):
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
|
||||
async def get(self, run_id):
|
||||
async def get(
|
||||
self,
|
||||
run_id,
|
||||
*,
|
||||
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.get")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(RunRow, run_id)
|
||||
return self._row_to_dict(row) if row else None
|
||||
if row is None:
|
||||
return None
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return None
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def list_by_thread(self, thread_id, *, owner_id=None, limit=100):
|
||||
async def list_by_thread(
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||
limit=100,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.list_by_thread")
|
||||
stmt = select(RunRow).where(RunRow.thread_id == thread_id)
|
||||
if owner_id is not None:
|
||||
stmt = stmt.where(RunRow.owner_id == owner_id)
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunRow.owner_id == resolved_owner_id)
|
||||
stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
@ -118,12 +137,21 @@ class RunRepository(RunStore):
|
||||
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
|
||||
await session.commit()
|
||||
|
||||
async def delete(self, run_id):
|
||||
async def delete(
|
||||
self,
|
||||
run_id,
|
||||
*,
|
||||
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.delete")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(RunRow, run_id)
|
||||
if row is not None:
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
if row is None:
|
||||
return
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
|
||||
async def list_pending(self, *, before=None):
|
||||
if before is None:
|
||||
|
||||
@ -10,6 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
||||
|
||||
|
||||
class ThreadMetaRepository(ThreadMetaStore):
|
||||
@ -31,15 +32,18 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
thread_id: str,
|
||||
*,
|
||||
assistant_id: str | None = None,
|
||||
owner_id: str | None = None,
|
||||
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||
display_name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> dict:
|
||||
# Auto-resolve owner_id from contextvar when AUTO; explicit None
|
||||
# creates an orphan row (used by migration scripts).
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.create")
|
||||
now = datetime.now(UTC)
|
||||
row = ThreadMetaRow(
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
owner_id=owner_id,
|
||||
owner_id=resolved_owner_id,
|
||||
display_name=display_name,
|
||||
metadata_json=metadata or {},
|
||||
created_at=now,
|
||||
@ -51,10 +55,21 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
await session.refresh(row)
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def get(self, thread_id: str) -> dict | None:
|
||||
async def get(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||
) -> dict | None:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.get")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
return self._row_to_dict(row) if row else None
|
||||
if row is None:
|
||||
return None
|
||||
# Enforce owner filter unless explicitly bypassed (owner_id=None).
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return None
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def list_by_owner(self, owner_id: str, *, limit: int = 100, offset: int = 0) -> list[dict]:
|
||||
stmt = select(ThreadMetaRow).where(ThreadMetaRow.owner_id == owner_id).order_by(ThreadMetaRow.updated_at.desc()).limit(limit).offset(offset)
|
||||
@ -83,9 +98,17 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||
) -> list[dict]:
|
||||
"""Search threads with optional metadata and status filters."""
|
||||
"""Search threads with optional metadata and status filters.
|
||||
|
||||
Owner filter is enforced by default: caller must be in a user
|
||||
context. Pass ``owner_id=None`` to bypass (migration/CLI).
|
||||
"""
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.search")
|
||||
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(ThreadMetaRow.owner_id == resolved_owner_id)
|
||||
if status:
|
||||
stmt = stmt.where(ThreadMetaRow.status == status)
|
||||
|
||||
@ -105,36 +128,80 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def update_display_name(self, thread_id: str, display_name: str) -> None:
|
||||
async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_owner_id: str | None) -> bool:
|
||||
"""Return True if the row exists and is owned (or filter bypassed)."""
|
||||
if resolved_owner_id is None:
|
||||
return True # explicit bypass
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
return row is not None and row.owner_id == resolved_owner_id
|
||||
|
||||
async def update_display_name(
|
||||
self,
|
||||
thread_id: str,
|
||||
display_name: str,
|
||||
*,
|
||||
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||
) -> None:
|
||||
"""Update the display_name (title) for a thread."""
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_display_name")
|
||||
async with self._sf() as session:
|
||||
if not await self._check_ownership(session, thread_id, resolved_owner_id):
|
||||
return
|
||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(display_name=display_name, updated_at=datetime.now(UTC)))
|
||||
await session.commit()
|
||||
|
||||
async def update_status(self, thread_id: str, status: str) -> None:
|
||||
async def update_status(
|
||||
self,
|
||||
thread_id: str,
|
||||
status: str,
|
||||
*,
|
||||
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||
) -> None:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_status")
|
||||
async with self._sf() as session:
|
||||
if not await self._check_ownership(session, thread_id, resolved_owner_id):
|
||||
return
|
||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
|
||||
await session.commit()
|
||||
|
||||
async def update_metadata(self, thread_id: str, metadata: dict) -> None:
|
||||
async def update_metadata(
|
||||
self,
|
||||
thread_id: str,
|
||||
metadata: dict,
|
||||
*,
|
||||
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||
) -> None:
|
||||
"""Merge ``metadata`` into ``metadata_json``.
|
||||
|
||||
Read-modify-write inside a single session/transaction so concurrent
|
||||
callers see consistent state. No-op if the row does not exist.
|
||||
callers see consistent state. No-op if the row does not exist or
|
||||
the owner_id check fails.
|
||||
"""
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_metadata")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is None:
|
||||
return
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return
|
||||
merged = dict(row.metadata_json or {})
|
||||
merged.update(metadata)
|
||||
row.metadata_json = merged
|
||||
row.updated_at = datetime.now(UTC)
|
||||
await session.commit()
|
||||
|
||||
async def delete(self, thread_id: str) -> None:
|
||||
async def delete(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||
) -> None:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.delete")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is not None:
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
if row is None:
|
||||
return
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
|
||||
@ -15,6 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.models.run_event import RunEventRow
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_owner_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -53,6 +54,18 @@ class DbRunEventStore(RunEventStore):
|
||||
metadata = {**(metadata or {}), "content_truncated": True, "original_byte_length": len(encoded)}
|
||||
return content, metadata or {}
|
||||
|
||||
@staticmethod
|
||||
def _owner_from_context() -> str | None:
|
||||
"""Soft read of owner_id from contextvar for write paths.
|
||||
|
||||
Returns ``None`` (no filter / no stamp) if contextvar is unset,
|
||||
which is the expected case for background worker writes. HTTP
|
||||
request writes will have the contextvar set by auth middleware
|
||||
and get their user_id stamped automatically.
|
||||
"""
|
||||
user = get_current_user()
|
||||
return user.id if user is not None else None
|
||||
|
||||
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401
|
||||
"""Write a single event — low-frequency path only.
|
||||
|
||||
@ -68,6 +81,7 @@ class DbRunEventStore(RunEventStore):
|
||||
metadata = {**(metadata or {}), "content_is_dict": True}
|
||||
else:
|
||||
db_content = content
|
||||
owner_id = self._owner_from_context()
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
# Use FOR UPDATE to serialize seq assignment within a thread.
|
||||
@ -78,6 +92,7 @@ class DbRunEventStore(RunEventStore):
|
||||
row = RunEventRow(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
owner_id=owner_id,
|
||||
event_type=event_type,
|
||||
category=category,
|
||||
content=db_content,
|
||||
@ -91,6 +106,7 @@ class DbRunEventStore(RunEventStore):
|
||||
async def put_batch(self, events):
|
||||
if not events:
|
||||
return []
|
||||
owner_id = self._owner_from_context()
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
# Get max seq for the thread (assume all events in batch belong to same thread).
|
||||
@ -114,6 +130,7 @@ class DbRunEventStore(RunEventStore):
|
||||
row = RunEventRow(
|
||||
thread_id=e["thread_id"],
|
||||
run_id=e["run_id"],
|
||||
owner_id=e.get("owner_id", owner_id),
|
||||
event_type=e["event_type"],
|
||||
category=category,
|
||||
content=db_content,
|
||||
@ -125,8 +142,19 @@ class DbRunEventStore(RunEventStore):
|
||||
rows.append(row)
|
||||
return [self._row_to_dict(r) for r in rows]
|
||||
|
||||
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
|
||||
async def list_messages(
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
limit=50,
|
||||
before_seq=None,
|
||||
after_seq=None,
|
||||
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages")
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||
if before_seq is not None:
|
||||
stmt = stmt.where(RunEventRow.seq < before_seq)
|
||||
if after_seq is not None:
|
||||
@ -146,8 +174,19 @@ class DbRunEventStore(RunEventStore):
|
||||
rows = list(result.scalars())
|
||||
return [self._row_to_dict(r) for r in reversed(rows)]
|
||||
|
||||
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
|
||||
async def list_events(
|
||||
self,
|
||||
thread_id,
|
||||
run_id,
|
||||
*,
|
||||
event_types=None,
|
||||
limit=500,
|
||||
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_events")
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||
if event_types:
|
||||
stmt = stmt.where(RunEventRow.event_type.in_(event_types))
|
||||
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
||||
@ -155,31 +194,68 @@ class DbRunEventStore(RunEventStore):
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def list_messages_by_run(self, thread_id, run_id):
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message").order_by(RunEventRow.seq.asc())
|
||||
async def list_messages_by_run(
|
||||
self,
|
||||
thread_id,
|
||||
run_id,
|
||||
*,
|
||||
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages_by_run")
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message")
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||
stmt = stmt.order_by(RunEventRow.seq.asc())
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def count_messages(self, thread_id):
|
||||
async def count_messages(
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.count_messages")
|
||||
stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||
async with self._sf() as session:
|
||||
return await session.scalar(stmt) or 0
|
||||
|
||||
async def delete_by_thread(self, thread_id):
|
||||
async def delete_by_thread(
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_thread")
|
||||
async with self._sf() as session:
|
||||
count_stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id)
|
||||
count_conditions = [RunEventRow.thread_id == thread_id]
|
||||
if resolved_owner_id is not None:
|
||||
count_conditions.append(RunEventRow.owner_id == resolved_owner_id)
|
||||
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
||||
count = await session.scalar(count_stmt) or 0
|
||||
if count > 0:
|
||||
await session.execute(delete(RunEventRow).where(RunEventRow.thread_id == thread_id))
|
||||
await session.execute(delete(RunEventRow).where(*count_conditions))
|
||||
await session.commit()
|
||||
return count
|
||||
|
||||
async def delete_by_run(self, thread_id, run_id):
|
||||
async def delete_by_run(
|
||||
self,
|
||||
thread_id,
|
||||
run_id,
|
||||
*,
|
||||
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_run")
|
||||
async with self._sf() as session:
|
||||
count_stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
|
||||
count_conditions = [RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id]
|
||||
if resolved_owner_id is not None:
|
||||
count_conditions.append(RunEventRow.owner_id == resolved_owner_id)
|
||||
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
||||
count = await session.scalar(count_stmt) or 0
|
||||
if count > 0:
|
||||
await session.execute(delete(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id))
|
||||
await session.execute(delete(RunEventRow).where(*count_conditions))
|
||||
await session.commit()
|
||||
return count
|
||||
|
||||
148
backend/packages/harness/deerflow/runtime/user_context.py
Normal file
148
backend/packages/harness/deerflow/runtime/user_context.py
Normal file
@ -0,0 +1,148 @@
|
||||
"""Request-scoped user context for owner-based authorization.
|
||||
|
||||
This module holds a :class:`~contextvars.ContextVar` that the gateway's
|
||||
auth middleware sets after a successful authentication. Repository
|
||||
methods read the contextvar via a sentinel default parameter, letting
|
||||
routers stay free of ``owner_id`` boilerplate.
|
||||
|
||||
Three-state semantics for the repository ``owner_id`` parameter (the
|
||||
consumer side of this module lives in ``deerflow.persistence.*``):
|
||||
|
||||
- ``_AUTO`` (module-private sentinel, default): read from contextvar;
|
||||
raise :class:`RuntimeError` if unset.
|
||||
- Explicit ``str``: use the provided value, overriding contextvar.
|
||||
- Explicit ``None``: no WHERE clause — used only by migration scripts
|
||||
and admin CLIs that intentionally bypass isolation.
|
||||
|
||||
Dependency direction
|
||||
--------------------
|
||||
``persistence`` (lower layer) reads from this module; ``gateway.auth``
|
||||
(higher layer) writes to it. ``CurrentUser`` is defined here as a
|
||||
:class:`typing.Protocol` so that ``persistence`` never needs to import
|
||||
the concrete ``User`` class from ``gateway.auth.models``. Any object
|
||||
with an ``.id: str`` attribute structurally satisfies the protocol.
|
||||
|
||||
Asyncio semantics
|
||||
-----------------
|
||||
``ContextVar`` is task-local under asyncio, not thread-local. Each
|
||||
FastAPI request runs in its own task, so the context is naturally
|
||||
isolated. ``asyncio.create_task`` and ``asyncio.to_thread`` inherit the
|
||||
parent task's context, which is typically the intended behaviour; if
|
||||
a background task must *not* see the foreground user, wrap it with
|
||||
``contextvars.copy_context()`` to get a clean copy.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextvars import ContextVar, Token
|
||||
from typing import Final, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class CurrentUser(Protocol):
|
||||
"""Structural type for the current authenticated user.
|
||||
|
||||
Any object with an ``.id: str`` attribute satisfies this protocol.
|
||||
Concrete implementations live in ``app.gateway.auth.models.User``.
|
||||
"""
|
||||
|
||||
id: str
|
||||
|
||||
|
||||
_current_user: Final[ContextVar["CurrentUser | None"]] = ContextVar(
|
||||
"deerflow_current_user", default=None
|
||||
)
|
||||
|
||||
|
||||
def set_current_user(user: CurrentUser) -> Token[CurrentUser | None]:
|
||||
"""Set the current user for this async task.
|
||||
|
||||
Returns a reset token that should be passed to
|
||||
:func:`reset_current_user` in a ``finally`` block to restore the
|
||||
previous context.
|
||||
"""
|
||||
return _current_user.set(user)
|
||||
|
||||
|
||||
def reset_current_user(token: Token[CurrentUser | None]) -> None:
|
||||
"""Restore the context to the state captured by ``token``."""
|
||||
_current_user.reset(token)
|
||||
|
||||
|
||||
def get_current_user() -> CurrentUser | None:
|
||||
"""Return the current user, or ``None`` if unset.
|
||||
|
||||
Safe to call in any context. Used by code paths that can proceed
|
||||
without a user (e.g. migration scripts, public endpoints).
|
||||
"""
|
||||
return _current_user.get()
|
||||
|
||||
|
||||
def require_current_user() -> CurrentUser:
|
||||
"""Return the current user, or raise :class:`RuntimeError`.
|
||||
|
||||
Used by repository code that must not be called outside a
|
||||
request-authenticated context. The error message is phrased so
|
||||
that a caller debugging a stack trace can locate the offending
|
||||
code path.
|
||||
"""
|
||||
user = _current_user.get()
|
||||
if user is None:
|
||||
raise RuntimeError("repository accessed without user context")
|
||||
return user
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sentinel-based owner_id resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Repository methods accept an ``owner_id`` keyword-only argument that
|
||||
# defaults to ``AUTO``. The three possible values drive distinct
|
||||
# behaviours; see the docstring on :func:`resolve_owner_id`.
|
||||
|
||||
|
||||
class _AutoSentinel:
|
||||
"""Singleton marker meaning 'resolve owner_id from contextvar'."""
|
||||
|
||||
_instance: "_AutoSentinel | None" = None
|
||||
|
||||
def __new__(cls) -> "_AutoSentinel":
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<AUTO>"
|
||||
|
||||
|
||||
AUTO: Final[_AutoSentinel] = _AutoSentinel()
|
||||
|
||||
|
||||
def resolve_owner_id(
|
||||
value: "str | None | _AutoSentinel",
|
||||
*,
|
||||
method_name: str = "repository method",
|
||||
) -> str | None:
|
||||
"""Resolve the owner_id parameter passed to a repository method.
|
||||
|
||||
Three-state semantics:
|
||||
|
||||
- :data:`AUTO` (default): read from contextvar; raise
|
||||
:class:`RuntimeError` if no user is in context. This is the
|
||||
common case for request-scoped calls.
|
||||
- Explicit ``str``: use the provided id verbatim, overriding any
|
||||
contextvar value. Useful for tests and admin-override flows.
|
||||
- Explicit ``None``: no filter — the repository should skip the
|
||||
owner_id WHERE clause entirely. Reserved for migration scripts
|
||||
and CLI tools that intentionally bypass isolation.
|
||||
"""
|
||||
if isinstance(value, _AutoSentinel):
|
||||
user = _current_user.get()
|
||||
if user is None:
|
||||
raise RuntimeError(
|
||||
f"{method_name} called with owner_id=AUTO but no user context is set; "
|
||||
"pass an explicit owner_id, set the contextvar via auth middleware, "
|
||||
"or opt out with owner_id=None for migration/CLI paths."
|
||||
)
|
||||
return user.id
|
||||
return value
|
||||
@ -30,6 +30,11 @@ postgres = [
|
||||
[dependency-groups]
|
||||
dev = ["pytest>=8.0.0", "ruff>=0.14.11"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
markers = [
|
||||
"no_auto_user: disable the conftest autouse contextvar fixture for this test",
|
||||
]
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = ["packages/harness"]
|
||||
|
||||
|
||||
@ -6,8 +6,11 @@ issues when unit-testing lightweight config/registry code in isolation.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
# Make 'app' and 'deerflow' importable from any working directory
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
@ -31,3 +34,44 @@ _executor_mock.MAX_CONCURRENT_SUBAGENTS = 3
|
||||
_executor_mock.get_background_task_result = MagicMock()
|
||||
|
||||
sys.modules["deerflow.subagents.executor"] = _executor_mock
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auto-set user context for every test unless marked no_auto_user
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Repository methods read ``owner_id`` from a contextvar by default
|
||||
# (see ``deerflow.runtime.user_context``). Without this fixture, every
|
||||
# pre-existing persistence test would raise RuntimeError because the
|
||||
# contextvar is unset. The fixture sets a default test user on every
|
||||
# test; tests that explicitly want to verify behaviour *without* a user
|
||||
# context should mark themselves ``@pytest.mark.no_auto_user``.
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _auto_user_context(request):
|
||||
"""Inject a default ``test-user-autouse`` into the contextvar.
|
||||
|
||||
Opt-out via ``@pytest.mark.no_auto_user``. Uses lazy import so that
|
||||
tests which don't touch the persistence layer never pay the cost
|
||||
of importing runtime.user_context.
|
||||
"""
|
||||
if request.node.get_closest_marker("no_auto_user"):
|
||||
yield
|
||||
return
|
||||
|
||||
try:
|
||||
from deerflow.runtime.user_context import (
|
||||
reset_current_user,
|
||||
set_current_user,
|
||||
)
|
||||
except ImportError:
|
||||
yield
|
||||
return
|
||||
|
||||
user = SimpleNamespace(id="test-user-autouse", email="test@local")
|
||||
token = set_current_user(user)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
reset_current_user(token)
|
||||
|
||||
@ -104,7 +104,9 @@ class TestThreadMetaRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_no_owner_allows_all(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1") # owner_id=None
|
||||
# Explicit owner_id=None to bypass the new AUTO default that
|
||||
# would otherwise pick up the test user from the autouse fixture.
|
||||
await repo.create("t1", owner_id=None)
|
||||
assert await repo.check_access("t1", "anyone") is True
|
||||
await _cleanup()
|
||||
|
||||
|
||||
69
backend/tests/test_user_context.py
Normal file
69
backend/tests/test_user_context.py
Normal file
@ -0,0 +1,69 @@
|
||||
"""Tests for runtime.user_context — contextvar three-state semantics.
|
||||
|
||||
These tests opt out of the autouse contextvar fixture (added in
|
||||
commit 6) because they explicitly test the cases where the contextvar
|
||||
is set or unset.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.user_context import (
|
||||
CurrentUser,
|
||||
get_current_user,
|
||||
require_current_user,
|
||||
reset_current_user,
|
||||
set_current_user,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_default_is_none():
|
||||
"""Before any set, contextvar returns None."""
|
||||
assert get_current_user() is None
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_set_and_reset_roundtrip():
|
||||
"""set_current_user returns a token that reset restores."""
|
||||
user = SimpleNamespace(id="user-1")
|
||||
token = set_current_user(user)
|
||||
try:
|
||||
assert get_current_user() is user
|
||||
finally:
|
||||
reset_current_user(token)
|
||||
assert get_current_user() is None
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_require_current_user_raises_when_unset():
|
||||
"""require_current_user raises RuntimeError if contextvar is unset."""
|
||||
assert get_current_user() is None
|
||||
with pytest.raises(RuntimeError, match="without user context"):
|
||||
require_current_user()
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_require_current_user_returns_user_when_set():
|
||||
"""require_current_user returns the user when contextvar is set."""
|
||||
user = SimpleNamespace(id="user-2")
|
||||
token = set_current_user(user)
|
||||
try:
|
||||
assert require_current_user() is user
|
||||
finally:
|
||||
reset_current_user(token)
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_protocol_accepts_duck_typed():
|
||||
"""CurrentUser is a runtime_checkable Protocol matching any .id-bearing object."""
|
||||
user = SimpleNamespace(id="user-3")
|
||||
assert isinstance(user, CurrentUser)
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_protocol_rejects_no_id():
|
||||
"""Objects without .id do not satisfy CurrentUser Protocol."""
|
||||
not_a_user = SimpleNamespace(email="no-id@example.com")
|
||||
assert not isinstance(not_a_user, CurrentUser)
|
||||
Loading…
x
Reference in New Issue
Block a user