diff --git a/backend/app/gateway/deps.py b/backend/app/gateway/deps.py index 111dceaf5..7e418c5ce 100644 --- a/backend/app/gateway/deps.py +++ b/backend/app/gateway/deps.py @@ -1,7 +1,8 @@ """Centralized accessors for singleton objects stored on ``app.state``. **Getters** (used by routers): raise 503 when a required dependency is -missing, except ``get_store`` which returns ``None``. +missing, except ``get_store`` and ``get_thread_meta_repo`` which return +``None``. Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`. """ @@ -13,9 +14,7 @@ from contextlib import AsyncExitStack, asynccontextmanager from fastapi import FastAPI, HTTPException, Request -from deerflow.runtime import RunManager, StreamBridge -from deerflow.runtime.events.store.base import RunEventStore -from deerflow.runtime.runs.store.base import RunStore +from deerflow.runtime import RunManager @asynccontextmanager @@ -29,25 +28,41 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]: """ from deerflow.agents.checkpointer.async_provider import make_checkpointer from deerflow.config import get_app_config - from deerflow.persistence.engine import close_engine, init_engine_from_config + from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config from deerflow.runtime import make_store, make_stream_bridge + from deerflow.runtime.events.store import make_run_event_store async with AsyncExitStack() as stack: app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge()) - app.state.checkpointer = await stack.enter_async_context(make_checkpointer()) - app.state.store = await stack.enter_async_context(make_store()) - # Initialize persistence layer from unified database config + + # Initialize persistence engine BEFORE checkpointer so that + # auto-create-database logic runs first (postgres backend). config = get_app_config() await init_engine_from_config(config.database) - # Initialize run store (RunRepository if DB available, else MemoryRunStore) - app.state.run_store = _make_run_store() + app.state.checkpointer = await stack.enter_async_context(make_checkpointer()) + app.state.store = await stack.enter_async_context(make_store()) - # Initialize run event store based on config - app.state.run_event_store = _make_run_event_store(config) + # Initialize repositories — one get_session_factory() call for all. + sf = get_session_factory() + if sf is not None: + from deerflow.persistence.repositories.feedback_repo import FeedbackRepository + from deerflow.persistence.repositories.run_repo import RunRepository + from deerflow.persistence.repositories.thread_meta_repo import ThreadMetaRepository - # Initialize feedback repository (None when no DB engine) - app.state.feedback_repo = _make_feedback_repo() + app.state.run_store = RunRepository(sf) + app.state.feedback_repo = FeedbackRepository(sf) + app.state.thread_meta_repo = ThreadMetaRepository(sf) + else: + from deerflow.runtime.runs.store.memory import MemoryRunStore + + app.state.run_store = MemoryRunStore() + app.state.feedback_repo = None + app.state.thread_meta_repo = None + + # Run event store (has its own factory with config-driven backend selection) + run_events_config = getattr(config, "run_events", None) + app.state.run_event_store = make_run_event_store(run_events_config) # RunManager with store backing for persistence app.state.run_manager = RunManager(store=app.state.run_store) @@ -58,71 +73,30 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]: await close_engine() -# --------------------------------------------------------------------------- -# Factories -# --------------------------------------------------------------------------- - - -def _make_run_store() -> RunStore: - """Create a RunStore: RunRepository if DB engine is available, else MemoryRunStore.""" - from deerflow.persistence.engine import get_session_factory - - sf = get_session_factory() - if sf is not None: - from deerflow.persistence.repositories.run_repo import RunRepository - - return RunRepository(sf) - from deerflow.runtime.runs.store.memory import MemoryRunStore - - return MemoryRunStore() - - -def _make_feedback_repo(): - """Create a FeedbackRepository if DB engine is available, else None.""" - from deerflow.persistence.engine import get_session_factory - - sf = get_session_factory() - if sf is not None: - from deerflow.persistence.repositories.feedback_repo import FeedbackRepository - - return FeedbackRepository(sf) - return None - - -def _make_run_event_store(config) -> RunEventStore: - from deerflow.runtime.events.store import make_run_event_store - - run_events_config = getattr(config, "run_events", None) - return make_run_event_store(run_events_config) - - # --------------------------------------------------------------------------- # Getters -- called by routers per-request # --------------------------------------------------------------------------- -def get_stream_bridge(request: Request) -> StreamBridge: - """Return the global :class:`StreamBridge`, or 503.""" - bridge = getattr(request.app.state, "stream_bridge", None) - if bridge is None: - raise HTTPException(status_code=503, detail="Stream bridge not available") - return bridge +def _require(attr: str, label: str): + """Create a FastAPI dependency that returns ``app.state.`` or 503.""" + + def dep(request: Request): + val = getattr(request.app.state, attr, None) + if val is None: + raise HTTPException(status_code=503, detail=f"{label} not available") + return val + + dep.__name__ = dep.__qualname__ = f"get_{attr}" + return dep -def get_run_manager(request: Request) -> RunManager: - """Return the global :class:`RunManager`, or 503.""" - mgr = getattr(request.app.state, "run_manager", None) - if mgr is None: - raise HTTPException(status_code=503, detail="Run manager not available") - return mgr - - -def get_checkpointer(request: Request): - """Return the global checkpointer, or 503.""" - cp = getattr(request.app.state, "checkpointer", None) - if cp is None: - raise HTTPException(status_code=503, detail="Checkpointer not available") - return cp +get_stream_bridge = _require("stream_bridge", "Stream bridge") +get_run_manager = _require("run_manager", "Run manager") +get_checkpointer = _require("checkpointer", "Checkpointer") +get_run_event_store = _require("run_event_store", "Run event store") +get_feedback_repo = _require("feedback_repo", "Feedback") +get_run_store = _require("run_store", "Run store") def get_store(request: Request): @@ -130,28 +104,9 @@ def get_store(request: Request): return getattr(request.app.state, "store", None) -def get_run_event_store(request: Request) -> RunEventStore: - """Return the RunEventStore, or 503 if not available.""" - store = getattr(request.app.state, "run_event_store", None) - if store is None: - raise HTTPException(status_code=503, detail="Run event store not available") - return store - - -def get_feedback_repo(request: Request): - """Return the FeedbackRepository, or 503 if not available.""" - repo = getattr(request.app.state, "feedback_repo", None) - if repo is None: - raise HTTPException(status_code=503, detail="Feedback not available") - return repo - - -def get_run_store(request: Request) -> RunStore: - """Return the RunStore, or 503 if not available.""" - store = getattr(request.app.state, "run_store", None) - if store is None: - raise HTTPException(status_code=503, detail="Run store not available") - return store +def get_thread_meta_repo(request: Request): + """Return the ThreadMetaRepository, or None if not available.""" + return getattr(request.app.state, "thread_meta_repo", None) async def get_current_user(request: Request) -> str | None: diff --git a/backend/app/gateway/services.py b/backend/app/gateway/services.py index a5509df33..d76bed1f3 100644 --- a/backend/app/gateway/services.py +++ b/backend/app/gateway/services.py @@ -17,7 +17,7 @@ from typing import Any from fastapi import HTTPException, Request from langchain_core.messages import HumanMessage -from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_store, get_stream_bridge +from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_store, get_stream_bridge, get_thread_meta_repo from deerflow.runtime import ( END_SENTINEL, HEARTBEAT_SENTINEL, @@ -274,6 +274,22 @@ async def start_run( if store is not None: await _upsert_thread_in_store(store, thread_id, body.metadata) + # Upsert thread metadata in the SQL-backed threads_meta table + thread_meta_repo = get_thread_meta_repo(request) + if thread_meta_repo is not None: + try: + existing = await thread_meta_repo.get(thread_id) + if existing is None: + await thread_meta_repo.create( + thread_id, + assistant_id=body.assistant_id, + metadata=body.metadata, + ) + else: + await thread_meta_repo.update_status(thread_id, "running") + except Exception: + logger.warning("Failed to upsert thread_meta for %s (non-fatal)", thread_id) + # Resolve follow_up_to_run_id: explicit from request, or auto-detect from latest successful run follow_up_to_run_id = getattr(body, "follow_up_to_run_id", None) if follow_up_to_run_id is None: