mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-26 03:38:06 +00:00
refactor(gateway): simplify deps.py with getter factory + inline repos
- Replace 6 identical getter functions with _require() factory. - Inline 3 _make_*_repo() factories into langgraph_runtime(), call get_session_factory() once instead of 3 times. - Add thread_meta upsert in start_run (services.py). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
3b4622a26f
commit
e362aaefbd
@ -1,7 +1,8 @@
|
||||
"""Centralized accessors for singleton objects stored on ``app.state``.
|
||||
|
||||
**Getters** (used by routers): raise 503 when a required dependency is
|
||||
missing, except ``get_store`` which returns ``None``.
|
||||
missing, except ``get_store`` and ``get_thread_meta_repo`` which return
|
||||
``None``.
|
||||
|
||||
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
|
||||
"""
|
||||
@ -13,9 +14,7 @@ from contextlib import AsyncExitStack, asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
|
||||
from deerflow.runtime import RunManager, StreamBridge
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
from deerflow.runtime.runs.store.base import RunStore
|
||||
from deerflow.runtime import RunManager
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@ -29,25 +28,41 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""
|
||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.persistence.engine import close_engine, init_engine_from_config
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
|
||||
from deerflow.runtime import make_store, make_stream_bridge
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
async with AsyncExitStack() as stack:
|
||||
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge())
|
||||
app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
|
||||
app.state.store = await stack.enter_async_context(make_store())
|
||||
# Initialize persistence layer from unified database config
|
||||
|
||||
# Initialize persistence engine BEFORE checkpointer so that
|
||||
# auto-create-database logic runs first (postgres backend).
|
||||
config = get_app_config()
|
||||
await init_engine_from_config(config.database)
|
||||
|
||||
# Initialize run store (RunRepository if DB available, else MemoryRunStore)
|
||||
app.state.run_store = _make_run_store()
|
||||
app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
|
||||
app.state.store = await stack.enter_async_context(make_store())
|
||||
|
||||
# Initialize run event store based on config
|
||||
app.state.run_event_store = _make_run_event_store(config)
|
||||
# Initialize repositories — one get_session_factory() call for all.
|
||||
sf = get_session_factory()
|
||||
if sf is not None:
|
||||
from deerflow.persistence.repositories.feedback_repo import FeedbackRepository
|
||||
from deerflow.persistence.repositories.run_repo import RunRepository
|
||||
from deerflow.persistence.repositories.thread_meta_repo import ThreadMetaRepository
|
||||
|
||||
# Initialize feedback repository (None when no DB engine)
|
||||
app.state.feedback_repo = _make_feedback_repo()
|
||||
app.state.run_store = RunRepository(sf)
|
||||
app.state.feedback_repo = FeedbackRepository(sf)
|
||||
app.state.thread_meta_repo = ThreadMetaRepository(sf)
|
||||
else:
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
app.state.run_store = MemoryRunStore()
|
||||
app.state.feedback_repo = None
|
||||
app.state.thread_meta_repo = None
|
||||
|
||||
# Run event store (has its own factory with config-driven backend selection)
|
||||
run_events_config = getattr(config, "run_events", None)
|
||||
app.state.run_event_store = make_run_event_store(run_events_config)
|
||||
|
||||
# RunManager with store backing for persistence
|
||||
app.state.run_manager = RunManager(store=app.state.run_store)
|
||||
@ -58,71 +73,30 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
await close_engine()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Factories
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_run_store() -> RunStore:
|
||||
"""Create a RunStore: RunRepository if DB engine is available, else MemoryRunStore."""
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
|
||||
sf = get_session_factory()
|
||||
if sf is not None:
|
||||
from deerflow.persistence.repositories.run_repo import RunRepository
|
||||
|
||||
return RunRepository(sf)
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
return MemoryRunStore()
|
||||
|
||||
|
||||
def _make_feedback_repo():
|
||||
"""Create a FeedbackRepository if DB engine is available, else None."""
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
|
||||
sf = get_session_factory()
|
||||
if sf is not None:
|
||||
from deerflow.persistence.repositories.feedback_repo import FeedbackRepository
|
||||
|
||||
return FeedbackRepository(sf)
|
||||
return None
|
||||
|
||||
|
||||
def _make_run_event_store(config) -> RunEventStore:
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
run_events_config = getattr(config, "run_events", None)
|
||||
return make_run_event_store(run_events_config)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Getters -- called by routers per-request
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_stream_bridge(request: Request) -> StreamBridge:
|
||||
"""Return the global :class:`StreamBridge`, or 503."""
|
||||
bridge = getattr(request.app.state, "stream_bridge", None)
|
||||
if bridge is None:
|
||||
raise HTTPException(status_code=503, detail="Stream bridge not available")
|
||||
return bridge
|
||||
def _require(attr: str, label: str):
|
||||
"""Create a FastAPI dependency that returns ``app.state.<attr>`` or 503."""
|
||||
|
||||
def dep(request: Request):
|
||||
val = getattr(request.app.state, attr, None)
|
||||
if val is None:
|
||||
raise HTTPException(status_code=503, detail=f"{label} not available")
|
||||
return val
|
||||
|
||||
dep.__name__ = dep.__qualname__ = f"get_{attr}"
|
||||
return dep
|
||||
|
||||
|
||||
def get_run_manager(request: Request) -> RunManager:
|
||||
"""Return the global :class:`RunManager`, or 503."""
|
||||
mgr = getattr(request.app.state, "run_manager", None)
|
||||
if mgr is None:
|
||||
raise HTTPException(status_code=503, detail="Run manager not available")
|
||||
return mgr
|
||||
|
||||
|
||||
def get_checkpointer(request: Request):
|
||||
"""Return the global checkpointer, or 503."""
|
||||
cp = getattr(request.app.state, "checkpointer", None)
|
||||
if cp is None:
|
||||
raise HTTPException(status_code=503, detail="Checkpointer not available")
|
||||
return cp
|
||||
get_stream_bridge = _require("stream_bridge", "Stream bridge")
|
||||
get_run_manager = _require("run_manager", "Run manager")
|
||||
get_checkpointer = _require("checkpointer", "Checkpointer")
|
||||
get_run_event_store = _require("run_event_store", "Run event store")
|
||||
get_feedback_repo = _require("feedback_repo", "Feedback")
|
||||
get_run_store = _require("run_store", "Run store")
|
||||
|
||||
|
||||
def get_store(request: Request):
|
||||
@ -130,28 +104,9 @@ def get_store(request: Request):
|
||||
return getattr(request.app.state, "store", None)
|
||||
|
||||
|
||||
def get_run_event_store(request: Request) -> RunEventStore:
|
||||
"""Return the RunEventStore, or 503 if not available."""
|
||||
store = getattr(request.app.state, "run_event_store", None)
|
||||
if store is None:
|
||||
raise HTTPException(status_code=503, detail="Run event store not available")
|
||||
return store
|
||||
|
||||
|
||||
def get_feedback_repo(request: Request):
|
||||
"""Return the FeedbackRepository, or 503 if not available."""
|
||||
repo = getattr(request.app.state, "feedback_repo", None)
|
||||
if repo is None:
|
||||
raise HTTPException(status_code=503, detail="Feedback not available")
|
||||
return repo
|
||||
|
||||
|
||||
def get_run_store(request: Request) -> RunStore:
|
||||
"""Return the RunStore, or 503 if not available."""
|
||||
store = getattr(request.app.state, "run_store", None)
|
||||
if store is None:
|
||||
raise HTTPException(status_code=503, detail="Run store not available")
|
||||
return store
|
||||
def get_thread_meta_repo(request: Request):
|
||||
"""Return the ThreadMetaRepository, or None if not available."""
|
||||
return getattr(request.app.state, "thread_meta_repo", None)
|
||||
|
||||
|
||||
async def get_current_user(request: Request) -> str | None:
|
||||
|
||||
@ -17,7 +17,7 @@ from typing import Any
|
||||
from fastapi import HTTPException, Request
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_store, get_stream_bridge
|
||||
from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_store, get_stream_bridge, get_thread_meta_repo
|
||||
from deerflow.runtime import (
|
||||
END_SENTINEL,
|
||||
HEARTBEAT_SENTINEL,
|
||||
@ -274,6 +274,22 @@ async def start_run(
|
||||
if store is not None:
|
||||
await _upsert_thread_in_store(store, thread_id, body.metadata)
|
||||
|
||||
# Upsert thread metadata in the SQL-backed threads_meta table
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
if thread_meta_repo is not None:
|
||||
try:
|
||||
existing = await thread_meta_repo.get(thread_id)
|
||||
if existing is None:
|
||||
await thread_meta_repo.create(
|
||||
thread_id,
|
||||
assistant_id=body.assistant_id,
|
||||
metadata=body.metadata,
|
||||
)
|
||||
else:
|
||||
await thread_meta_repo.update_status(thread_id, "running")
|
||||
except Exception:
|
||||
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", thread_id)
|
||||
|
||||
# Resolve follow_up_to_run_id: explicit from request, or auto-detect from latest successful run
|
||||
follow_up_to_run_id = getattr(body, "follow_up_to_run_id", None)
|
||||
if follow_up_to_run_id is None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user