"""Centralized accessors for singleton objects stored on ``app.state``. **Getters** (used by routers): raise 503 when a required dependency is missing, except ``get_store`` which returns ``None``. Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`. """ from __future__ import annotations from collections.abc import AsyncGenerator from contextlib import AsyncExitStack, asynccontextmanager from typing import TYPE_CHECKING from fastapi import FastAPI, HTTPException, Request from deerflow.runtime import RunContext, RunManager if TYPE_CHECKING: from app.gateway.auth.local_provider import LocalAuthProvider from app.gateway.auth.repositories.sqlite import SQLiteUserRepository from deerflow.persistence.thread_meta.base import ThreadMetaStore @asynccontextmanager async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]: """Bootstrap and tear down all LangGraph runtime singletons. Usage in ``app.py``:: async with langgraph_runtime(app): yield """ from deerflow.config import get_app_config from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config from deerflow.runtime import make_store, make_stream_bridge from deerflow.runtime.checkpointer.async_provider import make_checkpointer from deerflow.runtime.events.store import make_run_event_store async with AsyncExitStack() as stack: app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge()) # Initialize persistence engine BEFORE checkpointer so that # auto-create-database logic runs first (postgres backend). config = get_app_config() await init_engine_from_config(config.database) app.state.checkpointer = await stack.enter_async_context(make_checkpointer()) app.state.store = await stack.enter_async_context(make_store()) # Initialize repositories — one get_session_factory() call for all. sf = get_session_factory() if sf is not None: from deerflow.persistence.feedback import FeedbackRepository from deerflow.persistence.run import RunRepository app.state.run_store = RunRepository(sf) app.state.feedback_repo = FeedbackRepository(sf) else: from deerflow.runtime.runs.store.memory import MemoryRunStore app.state.run_store = MemoryRunStore() app.state.feedback_repo = None from deerflow.persistence.thread_meta import make_thread_store app.state.thread_store = make_thread_store(sf, app.state.store) # Run event store (has its own factory with config-driven backend selection) run_events_config = getattr(config, "run_events", None) app.state.run_event_store = make_run_event_store(run_events_config) # RunManager with store backing for persistence app.state.run_manager = RunManager(store=app.state.run_store) try: yield finally: await close_engine() # --------------------------------------------------------------------------- # Getters – called by routers per-request # --------------------------------------------------------------------------- def _require(attr: str, label: str): """Create a FastAPI dependency that returns ``app.state.`` or 503.""" def dep(request: Request): val = getattr(request.app.state, attr, None) if val is None: raise HTTPException(status_code=503, detail=f"{label} not available") return val dep.__name__ = dep.__qualname__ = f"get_{attr}" return dep get_stream_bridge = _require("stream_bridge", "Stream bridge") get_run_manager = _require("run_manager", "Run manager") get_checkpointer = _require("checkpointer", "Checkpointer") get_run_event_store = _require("run_event_store", "Run event store") get_feedback_repo = _require("feedback_repo", "Feedback") get_run_store = _require("run_store", "Run store") def get_store(request: Request): """Return the global store (may be ``None`` if not configured).""" return getattr(request.app.state, "store", None) def get_thread_store(request: Request) -> ThreadMetaStore: """Return the thread metadata store (SQL or memory-backed).""" val = getattr(request.app.state, "thread_store", None) if val is None: raise HTTPException(status_code=503, detail="Thread metadata store not available") return val def get_run_context(request: Request) -> RunContext: """Build a :class:`RunContext` from ``app.state`` singletons. Returns a *base* context with infrastructure dependencies. Callers that need per-run fields (e.g. ``follow_up_to_run_id``) should use ``dataclasses.replace(ctx, follow_up_to_run_id=...)`` before passing it to :func:`run_agent`. """ from deerflow.config import get_app_config return RunContext( checkpointer=get_checkpointer(request), store=get_store(request), event_store=get_run_event_store(request), run_events_config=getattr(get_app_config(), "run_events", None), thread_store=get_thread_store(request), ) # --------------------------------------------------------------------------- # Auth helpers (used by authz.py and auth middleware) # --------------------------------------------------------------------------- # Cached singletons to avoid repeated instantiation per request _cached_local_provider: LocalAuthProvider | None = None _cached_repo: SQLiteUserRepository | None = None def get_local_provider() -> LocalAuthProvider: """Get or create the cached LocalAuthProvider singleton. Must be called after ``init_engine_from_config()`` — the shared session factory is required to construct the user repository. """ global _cached_local_provider, _cached_repo if _cached_repo is None: from app.gateway.auth.repositories.sqlite import SQLiteUserRepository from deerflow.persistence.engine import get_session_factory sf = get_session_factory() if sf is None: raise RuntimeError("get_local_provider() called before init_engine_from_config(); cannot access users table") _cached_repo = SQLiteUserRepository(sf) if _cached_local_provider is None: from app.gateway.auth.local_provider import LocalAuthProvider _cached_local_provider = LocalAuthProvider(repository=_cached_repo) return _cached_local_provider async def get_current_user_from_request(request: Request): """Get the current authenticated user from the request cookie. Raises HTTPException 401 if not authenticated. """ from app.gateway.auth import decode_token from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code access_token = request.cookies.get("access_token") if not access_token: raise HTTPException( status_code=401, detail=AuthErrorResponse(code=AuthErrorCode.NOT_AUTHENTICATED, message="Not authenticated").model_dump(), ) payload = decode_token(access_token) if isinstance(payload, TokenError): raise HTTPException( status_code=401, detail=AuthErrorResponse(code=token_error_to_code(payload), message=f"Token error: {payload.value}").model_dump(), ) provider = get_local_provider() user = await provider.get_user(payload.sub) if user is None: raise HTTPException( status_code=401, detail=AuthErrorResponse(code=AuthErrorCode.USER_NOT_FOUND, message="User not found").model_dump(), ) # Token version mismatch → password was changed, token is stale if user.token_version != payload.ver: raise HTTPException( status_code=401, detail=AuthErrorResponse(code=AuthErrorCode.TOKEN_INVALID, message="Token revoked (password changed)").model_dump(), ) return user async def get_optional_user_from_request(request: Request): """Get optional authenticated user from request. Returns None if not authenticated. """ try: return await get_current_user_from_request(request) except HTTPException: return None async def get_current_user(request: Request) -> str | None: """Extract user_id from request cookie, or None if not authenticated. Thin adapter that returns the string id for callers that only need identification (e.g., ``feedback.py``). Full-user callers should use ``get_current_user_from_request`` or ``get_optional_user_from_request``. """ user = await get_optional_user_from_request(request) return str(user.id) if user else None