mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-27 12:18:14 +00:00
refactor(gateway): restructure gateway with new dependency injection and services
Reorganize app/gateway/ with: - common/ - lifespan management - dependencies/ - FastAPI dependency injection (db, checkpointer, repositories, stream_bridge) - services/runs/ - run execution services (facade_factory, input adapters, store operations) - registrar.py - router registration - router.py - main router setup Simplify app.py to use the new modular structure. Remove deprecated utils.py. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
274255b1a5
commit
39a575617b
@ -1,4 +1,23 @@
|
||||
from .app import app, create_app
|
||||
from .config import GatewayConfig, get_gateway_config
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ["app", "create_app", "GatewayConfig", "get_gateway_config"]
|
||||
__all__ = ["GatewayConfig", "app", "get_gateway_config", "register_app"]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name == "app":
|
||||
from .app import app
|
||||
|
||||
return app
|
||||
if name == "GatewayConfig":
|
||||
from .config import GatewayConfig
|
||||
|
||||
return GatewayConfig
|
||||
if name == "get_gateway_config":
|
||||
from .config import get_gateway_config
|
||||
|
||||
return get_gateway_config
|
||||
if name == "register_app":
|
||||
from .registrar import register_app
|
||||
|
||||
return register_app
|
||||
raise AttributeError(name)
|
||||
|
||||
@ -1,363 +1,8 @@
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from app.gateway.registrar import register_app
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.gateway.auth_middleware import AuthMiddleware
|
||||
from app.gateway.config import get_gateway_config
|
||||
from app.gateway.csrf_middleware import CSRFMiddleware
|
||||
from app.gateway.deps import langgraph_runtime
|
||||
from app.gateway.routers import (
|
||||
agents,
|
||||
artifacts,
|
||||
assistants_compat,
|
||||
auth,
|
||||
channels,
|
||||
feedback,
|
||||
mcp,
|
||||
memory,
|
||||
models,
|
||||
runs,
|
||||
skills,
|
||||
suggestions,
|
||||
thread_runs,
|
||||
threads,
|
||||
uploads,
|
||||
)
|
||||
from deerflow.config.app_config import get_app_config
|
||||
def create_app():
|
||||
return register_app()
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _ensure_admin_user(app: FastAPI) -> None:
|
||||
"""Startup hook: handle first boot and migrate orphan threads otherwise.
|
||||
|
||||
After admin creation, migrate orphan threads from the LangGraph
|
||||
store (metadata.user_id unset) to the admin account. This is the
|
||||
"no-auth → with-auth" upgrade path: users who ran DeerFlow without
|
||||
authentication have existing LangGraph thread data that needs an
|
||||
owner assigned.
|
||||
First boot (no admin exists):
|
||||
- Does NOT create any user accounts automatically.
|
||||
- The operator must visit ``/setup`` to create the first admin.
|
||||
|
||||
Subsequent boots (admin already exists):
|
||||
- Runs the one-time "no-auth → with-auth" orphan thread migration for
|
||||
existing LangGraph thread metadata that has no owner_id.
|
||||
|
||||
No SQL persistence migration is needed: the four user_id columns
|
||||
(threads_meta, runs, run_events, feedback) only come into existence
|
||||
alongside the auth module via create_all, so freshly created tables
|
||||
never contain NULL-owner rows.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.gateway.deps import get_local_provider
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.user.model import UserRow
|
||||
|
||||
provider = get_local_provider()
|
||||
admin_count = await provider.count_admin_users()
|
||||
|
||||
if admin_count == 0:
|
||||
logger.info("=" * 60)
|
||||
logger.info(" First boot detected — no admin account exists.")
|
||||
logger.info(" Visit /setup to complete admin account creation.")
|
||||
logger.info("=" * 60)
|
||||
return
|
||||
|
||||
# Admin already exists — run orphan thread migration for any
|
||||
# LangGraph thread metadata that pre-dates the auth module.
|
||||
sf = get_session_factory()
|
||||
if sf is None:
|
||||
return
|
||||
|
||||
async with sf() as session:
|
||||
stmt = select(UserRow).where(UserRow.system_role == "admin").limit(1)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
|
||||
if row is None:
|
||||
return # Should not happen (admin_count > 0 above), but be safe.
|
||||
|
||||
admin_id = str(row.id)
|
||||
|
||||
# LangGraph store orphan migration — non-fatal.
|
||||
# This covers the "no-auth → with-auth" upgrade path for users
|
||||
# whose existing LangGraph thread metadata has no user_id set.
|
||||
store = getattr(app.state, "store", None)
|
||||
if store is not None:
|
||||
try:
|
||||
migrated = await _migrate_orphaned_threads(store, admin_id)
|
||||
if migrated:
|
||||
logger.info("Migrated %d orphan LangGraph thread(s) to admin", migrated)
|
||||
except Exception:
|
||||
logger.exception("LangGraph thread migration failed (non-fatal)")
|
||||
|
||||
|
||||
async def _iter_store_items(store, namespace, *, page_size: int = 500):
|
||||
"""Paginated async iterator over a LangGraph store namespace.
|
||||
|
||||
Replaces the old hardcoded ``limit=1000`` call with a cursor-style
|
||||
loop so that environments with more than one page of orphans do
|
||||
not silently lose data. Terminates when a page is empty OR when a
|
||||
short page arrives (indicating the last page).
|
||||
"""
|
||||
offset = 0
|
||||
while True:
|
||||
batch = await store.asearch(namespace, limit=page_size, offset=offset)
|
||||
if not batch:
|
||||
return
|
||||
for item in batch:
|
||||
yield item
|
||||
if len(batch) < page_size:
|
||||
return
|
||||
offset += page_size
|
||||
|
||||
|
||||
async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
|
||||
"""Migrate LangGraph store threads with no user_id to the given admin.
|
||||
|
||||
Uses cursor pagination so all orphans are migrated regardless of
|
||||
count. Returns the number of rows migrated.
|
||||
"""
|
||||
migrated = 0
|
||||
async for item in _iter_store_items(store, ("threads",)):
|
||||
metadata = item.value.get("metadata", {})
|
||||
if not metadata.get("user_id"):
|
||||
metadata["user_id"] = admin_user_id
|
||||
item.value["metadata"] = metadata
|
||||
await store.aput(("threads",), item.key, item.value)
|
||||
migrated += 1
|
||||
return migrated
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Application lifespan handler."""
|
||||
|
||||
# Load config and check necessary environment variables at startup
|
||||
try:
|
||||
get_app_config()
|
||||
logger.info("Configuration loaded successfully")
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to load configuration during gateway startup: {e}"
|
||||
logger.exception(error_msg)
|
||||
raise RuntimeError(error_msg) from e
|
||||
config = get_gateway_config()
|
||||
logger.info(f"Starting API Gateway on {config.host}:{config.port}")
|
||||
|
||||
# Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store)
|
||||
async with langgraph_runtime(app):
|
||||
logger.info("LangGraph runtime initialised")
|
||||
|
||||
# Ensure admin user exists (auto-create on first boot)
|
||||
# Must run AFTER langgraph_runtime so app.state.store is available for thread migration
|
||||
await _ensure_admin_user(app)
|
||||
|
||||
# Start IM channel service if any channels are configured
|
||||
try:
|
||||
from app.channels.service import start_channel_service
|
||||
|
||||
channel_service = await start_channel_service()
|
||||
logger.info("Channel service started: %s", channel_service.get_status())
|
||||
except Exception:
|
||||
logger.exception("No IM channels configured or channel service failed to start")
|
||||
|
||||
yield
|
||||
|
||||
# Stop channel service on shutdown
|
||||
try:
|
||||
from app.channels.service import stop_channel_service
|
||||
|
||||
await stop_channel_service()
|
||||
except Exception:
|
||||
logger.exception("Failed to stop channel service")
|
||||
|
||||
logger.info("Shutting down API Gateway")
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Create and configure the FastAPI application.
|
||||
|
||||
Returns:
|
||||
Configured FastAPI application instance.
|
||||
"""
|
||||
|
||||
app = FastAPI(
|
||||
title="DeerFlow API Gateway",
|
||||
description="""
|
||||
## DeerFlow API Gateway
|
||||
|
||||
API Gateway for DeerFlow - A LangGraph-based AI agent backend with sandbox execution capabilities.
|
||||
|
||||
### Features
|
||||
|
||||
- **Models Management**: Query and retrieve available AI models
|
||||
- **MCP Configuration**: Manage Model Context Protocol (MCP) server configurations
|
||||
- **Memory Management**: Access and manage global memory data for personalized conversations
|
||||
- **Skills Management**: Query and manage skills and their enabled status
|
||||
- **Artifacts**: Access thread artifacts and generated files
|
||||
- **Health Monitoring**: System health check endpoints
|
||||
|
||||
### Architecture
|
||||
|
||||
LangGraph requests are handled by nginx reverse proxy.
|
||||
This gateway provides custom endpoints for models, MCP configuration, skills, and artifacts.
|
||||
""",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
openapi_url="/openapi.json",
|
||||
openapi_tags=[
|
||||
{
|
||||
"name": "models",
|
||||
"description": "Operations for querying available AI models and their configurations",
|
||||
},
|
||||
{
|
||||
"name": "mcp",
|
||||
"description": "Manage Model Context Protocol (MCP) server configurations",
|
||||
},
|
||||
{
|
||||
"name": "memory",
|
||||
"description": "Access and manage global memory data for personalized conversations",
|
||||
},
|
||||
{
|
||||
"name": "skills",
|
||||
"description": "Manage skills and their configurations",
|
||||
},
|
||||
{
|
||||
"name": "artifacts",
|
||||
"description": "Access and download thread artifacts and generated files",
|
||||
},
|
||||
{
|
||||
"name": "uploads",
|
||||
"description": "Upload and manage user files for threads",
|
||||
},
|
||||
{
|
||||
"name": "threads",
|
||||
"description": "Manage DeerFlow thread-local filesystem data",
|
||||
},
|
||||
{
|
||||
"name": "agents",
|
||||
"description": "Create and manage custom agents with per-agent config and prompts",
|
||||
},
|
||||
{
|
||||
"name": "suggestions",
|
||||
"description": "Generate follow-up question suggestions for conversations",
|
||||
},
|
||||
{
|
||||
"name": "channels",
|
||||
"description": "Manage IM channel integrations (Feishu, Slack, Telegram)",
|
||||
},
|
||||
{
|
||||
"name": "assistants-compat",
|
||||
"description": "LangGraph Platform-compatible assistants API (stub)",
|
||||
},
|
||||
{
|
||||
"name": "runs",
|
||||
"description": "LangGraph Platform-compatible runs lifecycle (create, stream, cancel)",
|
||||
},
|
||||
{
|
||||
"name": "health",
|
||||
"description": "Health check and system status endpoints",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# Auth: reject unauthenticated requests to non-public paths (fail-closed safety net)
|
||||
app.add_middleware(AuthMiddleware)
|
||||
|
||||
# CSRF: Double Submit Cookie pattern for state-changing requests
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
# CORS: when GATEWAY_CORS_ORIGINS is set (dev without nginx), add CORS middleware.
|
||||
# In production, nginx handles CORS and no middleware is needed.
|
||||
cors_origins_env = os.environ.get("GATEWAY_CORS_ORIGINS", "")
|
||||
if cors_origins_env:
|
||||
cors_origins = [o.strip() for o in cors_origins_env.split(",") if o.strip()]
|
||||
# Validate: wildcard origin with credentials is a security misconfiguration
|
||||
for origin in cors_origins:
|
||||
if origin == "*":
|
||||
logger.error("GATEWAY_CORS_ORIGINS contains wildcard '*' with allow_credentials=True. This is a security misconfiguration — browsers will reject the response. Use explicit scheme://host:port origins instead.")
|
||||
cors_origins = [o for o in cors_origins if o != "*"]
|
||||
break
|
||||
if cors_origins:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include routers
|
||||
# Models API is mounted at /api/models
|
||||
app.include_router(models.router)
|
||||
|
||||
# MCP API is mounted at /api/mcp
|
||||
app.include_router(mcp.router)
|
||||
|
||||
# Memory API is mounted at /api/memory
|
||||
app.include_router(memory.router)
|
||||
|
||||
# Skills API is mounted at /api/skills
|
||||
app.include_router(skills.router)
|
||||
|
||||
# Artifacts API is mounted at /api/threads/{thread_id}/artifacts
|
||||
app.include_router(artifacts.router)
|
||||
|
||||
# Uploads API is mounted at /api/threads/{thread_id}/uploads
|
||||
app.include_router(uploads.router)
|
||||
|
||||
# Thread cleanup API is mounted at /api/threads/{thread_id}
|
||||
app.include_router(threads.router)
|
||||
|
||||
# Agents API is mounted at /api/agents
|
||||
app.include_router(agents.router)
|
||||
|
||||
# Suggestions API is mounted at /api/threads/{thread_id}/suggestions
|
||||
app.include_router(suggestions.router)
|
||||
|
||||
# Channels API is mounted at /api/channels
|
||||
app.include_router(channels.router)
|
||||
|
||||
# Assistants compatibility API (LangGraph Platform stub)
|
||||
app.include_router(assistants_compat.router)
|
||||
|
||||
# Auth API is mounted at /api/v1/auth
|
||||
app.include_router(auth.router)
|
||||
|
||||
# Feedback API is mounted at /api/threads/{thread_id}/runs/{run_id}/feedback
|
||||
app.include_router(feedback.router)
|
||||
|
||||
# Thread Runs API (LangGraph Platform-compatible runs lifecycle)
|
||||
app.include_router(thread_runs.router)
|
||||
|
||||
# Stateless Runs API (stream/wait without a pre-existing thread)
|
||||
app.include_router(runs.router)
|
||||
|
||||
@app.get("/health", tags=["health"])
|
||||
async def health_check() -> dict:
|
||||
"""Health check endpoint.
|
||||
|
||||
Returns:
|
||||
Service health status information.
|
||||
"""
|
||||
return {"status": "healthy", "service": "deer-flow-gateway"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# Create app instance for uvicorn
|
||||
app = create_app()
|
||||
app = register_app()
|
||||
|
||||
3
backend/app/gateway/common/__init__.py
Normal file
3
backend/app/gateway/common/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .lifespan import lifespan_manager
|
||||
|
||||
__all__ = ["lifespan_manager"]
|
||||
52
backend/app/gateway/common/lifespan.py
Normal file
52
backend/app/gateway/common/lifespan.py
Normal file
@ -0,0 +1,52 @@
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
LifespanFunc = Callable[[FastAPI], AbstractAsyncContextManager[dict[str, Any] | None]]
|
||||
|
||||
|
||||
class LifespanManager:
|
||||
"""FastAPI lifespan manager"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lifespans: list[LifespanFunc] = []
|
||||
|
||||
def register(self, func: LifespanFunc) -> LifespanFunc:
|
||||
"""
|
||||
Register a lifespan hook.
|
||||
|
||||
:param func: lifespan hook
|
||||
:return:
|
||||
"""
|
||||
if func not in self._lifespans:
|
||||
self._lifespans.append(func)
|
||||
return func
|
||||
|
||||
def build(self) -> LifespanFunc:
|
||||
"""
|
||||
Build the combined lifespan hook.
|
||||
|
||||
:return:
|
||||
"""
|
||||
|
||||
@asynccontextmanager
|
||||
async def combined_lifespan(app: FastAPI): # noqa: ANN202
|
||||
state: dict[str, Any] = {}
|
||||
async with AsyncExitStack() as exit_stack:
|
||||
for lifespan_fn in self._lifespans:
|
||||
result = await exit_stack.enter_async_context(lifespan_fn(app))
|
||||
if isinstance(result, dict):
|
||||
state.update(result)
|
||||
|
||||
for key, value in state.items():
|
||||
setattr(app.state, key, value)
|
||||
|
||||
yield state or None
|
||||
|
||||
return combined_lifespan
|
||||
|
||||
|
||||
# Singleton lifespan_manager instance
|
||||
lifespan_manager = LifespanManager()
|
||||
59
backend/app/gateway/dependencies/__init__.py
Normal file
59
backend/app/gateway/dependencies/__init__.py
Normal file
@ -0,0 +1,59 @@
|
||||
from app.gateway.dependencies.checkpointer import (
|
||||
CurrentCheckpointer,
|
||||
get_checkpointer,
|
||||
)
|
||||
from app.plugins.auth.security.dependencies import (
|
||||
CurrentAuthService,
|
||||
CurrentUserRepository,
|
||||
get_auth_service,
|
||||
get_current_user_from_request,
|
||||
get_current_user_id,
|
||||
get_optional_user_from_request,
|
||||
get_user_repository,
|
||||
)
|
||||
from app.gateway.dependencies.db import (
|
||||
CurrentSession,
|
||||
CurrentSessionTransaction,
|
||||
get_db_session,
|
||||
get_db_session_transaction,
|
||||
)
|
||||
from app.gateway.dependencies.repositories import (
|
||||
CurrentFeedbackRepository,
|
||||
CurrentRunRepository,
|
||||
CurrentThreadMetaRepository,
|
||||
CurrentThreadMetaStorage,
|
||||
get_feedback_repository,
|
||||
get_run_repository,
|
||||
get_thread_meta_repository,
|
||||
get_thread_meta_storage,
|
||||
)
|
||||
from app.gateway.dependencies.stream_bridge import (
|
||||
CurrentStreamBridge,
|
||||
get_stream_bridge,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CurrentCheckpointer",
|
||||
"CurrentAuthService",
|
||||
"CurrentFeedbackRepository",
|
||||
"CurrentRunRepository",
|
||||
"CurrentSession",
|
||||
"CurrentSessionTransaction",
|
||||
"CurrentStreamBridge",
|
||||
"CurrentThreadMetaRepository",
|
||||
"CurrentThreadMetaStorage",
|
||||
"CurrentUserRepository",
|
||||
"get_auth_service",
|
||||
"get_checkpointer",
|
||||
"get_current_user_from_request",
|
||||
"get_current_user_id",
|
||||
"get_db_session",
|
||||
"get_db_session_transaction",
|
||||
"get_feedback_repository",
|
||||
"get_optional_user_from_request",
|
||||
"get_run_repository",
|
||||
"get_stream_bridge",
|
||||
"get_thread_meta_repository",
|
||||
"get_thread_meta_storage",
|
||||
"get_user_repository",
|
||||
]
|
||||
20
backend/app/gateway/dependencies/checkpointer.py
Normal file
20
backend/app/gateway/dependencies/checkpointer.py
Normal file
@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
|
||||
def get_checkpointer(request: Request) -> Checkpointer:
|
||||
"""Get checkpointer from app.state.persistence."""
|
||||
persistence = getattr(request.app.state, "persistence", None)
|
||||
if persistence is None:
|
||||
raise HTTPException(status_code=503, detail="Persistence not available")
|
||||
checkpointer = getattr(persistence, "checkpointer", None)
|
||||
if checkpointer is None:
|
||||
raise HTTPException(status_code=503, detail="Checkpointer not available")
|
||||
return checkpointer
|
||||
|
||||
|
||||
CurrentCheckpointer = Annotated[Checkpointer, Depends(get_checkpointer)]
|
||||
37
backend/app/gateway/dependencies/db.py
Normal file
37
backend/app/gateway/dependencies/db.py
Normal file
@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
|
||||
def _get_session_factory(request: Request) -> async_sessionmaker[AsyncSession]:
|
||||
factory = getattr(request.app.state.persistence, "session_factory", None)
|
||||
if factory is None:
|
||||
raise HTTPException(status_code=503, detail="Database session factory not available")
|
||||
return factory
|
||||
|
||||
|
||||
async def get_db_session(request: Request) -> AsyncIterator[AsyncSession]:
|
||||
"""Open a session without auto-commit. Use for read-only endpoints."""
|
||||
session_factory = _get_session_factory(request)
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def get_db_session_transaction(request: Request) -> AsyncIterator[AsyncSession]:
|
||||
"""Open a session and commit on success, rollback on error."""
|
||||
session_factory = _get_session_factory(request)
|
||||
async with session_factory() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
CurrentSession = Annotated[AsyncSession, Depends(get_db_session)]
|
||||
CurrentSessionTransaction = Annotated[AsyncSession, Depends(get_db_session_transaction)]
|
||||
41
backend/app/gateway/dependencies/repositories.py
Normal file
41
backend/app/gateway/dependencies/repositories.py
Normal file
@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
|
||||
from app.infra.storage import ThreadMetaStorage
|
||||
from store.repositories.contracts import (
|
||||
FeedbackRepositoryProtocol,
|
||||
RunRepositoryProtocol,
|
||||
ThreadMetaRepositoryProtocol,
|
||||
)
|
||||
|
||||
|
||||
def _require_state(request: Request, attr: str, label: str):
|
||||
value = getattr(request.app.state, attr, None)
|
||||
if value is None:
|
||||
raise HTTPException(status_code=503, detail=f"{label} not available")
|
||||
return value
|
||||
|
||||
|
||||
def get_run_repository(request: Request) -> RunRepositoryProtocol:
|
||||
return _require_state(request, "run_store", "Run store")
|
||||
|
||||
|
||||
def get_thread_meta_repository(request: Request) -> ThreadMetaRepositoryProtocol:
|
||||
return _require_state(request, "thread_meta_repo", "Thread metadata store")
|
||||
|
||||
|
||||
def get_thread_meta_storage(request: Request) -> ThreadMetaStorage:
|
||||
return _require_state(request, "thread_meta_storage", "Thread metadata storage")
|
||||
|
||||
|
||||
def get_feedback_repository(request: Request) -> FeedbackRepositoryProtocol:
|
||||
return _require_state(request, "feedback_repo", "Feedback")
|
||||
|
||||
|
||||
CurrentRunRepository = Annotated[RunRepositoryProtocol, Depends(get_run_repository)]
|
||||
CurrentThreadMetaRepository = Annotated[ThreadMetaRepositoryProtocol, Depends(get_thread_meta_repository)]
|
||||
CurrentThreadMetaStorage = Annotated[ThreadMetaStorage, Depends(get_thread_meta_storage)]
|
||||
CurrentFeedbackRepository = Annotated[FeedbackRepositoryProtocol, Depends(get_feedback_repository)]
|
||||
18
backend/app/gateway/dependencies/stream_bridge.py
Normal file
18
backend/app/gateway/dependencies/stream_bridge.py
Normal file
@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
|
||||
from deerflow.runtime import StreamBridge
|
||||
|
||||
|
||||
def get_stream_bridge(request: Request) -> StreamBridge:
|
||||
"""Get stream bridge from app.state."""
|
||||
bridge = getattr(request.app.state, "stream_bridge", None)
|
||||
if bridge is None:
|
||||
raise HTTPException(status_code=503, detail="Stream bridge not available")
|
||||
return bridge
|
||||
|
||||
|
||||
CurrentStreamBridge = Annotated[StreamBridge, Depends(get_stream_bridge)]
|
||||
@ -5,16 +5,17 @@ from pathlib import Path
|
||||
from fastapi import HTTPException
|
||||
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.runtime.actor_context import get_effective_user_id
|
||||
|
||||
|
||||
def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path:
|
||||
def resolve_thread_virtual_path(thread_id: str, virtual_path: str, *, user_id: str | None = None) -> Path:
|
||||
"""Resolve a virtual path to the actual filesystem path under thread user-data.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
virtual_path: The virtual path as seen inside the sandbox
|
||||
(e.g., /mnt/user-data/outputs/file.txt).
|
||||
user_id: Explicit user id override. Falls back to the current actor context.
|
||||
|
||||
Returns:
|
||||
The resolved filesystem path.
|
||||
@ -23,7 +24,8 @@ def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path:
|
||||
HTTPException: If the path is invalid or outside allowed directories.
|
||||
"""
|
||||
try:
|
||||
return get_paths().resolve_virtual_path(thread_id, virtual_path, user_id=get_effective_user_id())
|
||||
resolved_user_id = get_effective_user_id() if user_id is None else user_id
|
||||
return get_paths().resolve_virtual_path(thread_id, virtual_path, user_id=resolved_user_id)
|
||||
except ValueError as e:
|
||||
status = 403 if "traversal" in str(e) else 400
|
||||
raise HTTPException(status_code=status, detail=str(e))
|
||||
|
||||
132
backend/app/gateway/registrar.py
Normal file
132
backend/app/gateway/registrar.py
Normal file
@ -0,0 +1,132 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from scalar_fastapi import AgentScalarConfig, get_scalar_api_reference
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from store.persistence import create_persistence
|
||||
|
||||
from app.gateway.common import lifespan_manager
|
||||
from app.gateway.router import router as gateway_router
|
||||
from app.infra.run_events import build_run_event_store
|
||||
from app.infra.storage import FeedbackStoreAdapter, RunStoreAdapter, ThreadMetaStorage, ThreadMetaStoreAdapter
|
||||
from app.plugins.auth.authorization.hooks import build_authz_hooks
|
||||
from app.plugins.auth.injection import install_route_guards, load_route_policy_registry, validate_route_policy_registry
|
||||
from app.plugins.auth.security import AuthMiddleware, CSRFMiddleware
|
||||
|
||||
STATIC_DIR = Path(__file__).resolve().parents[1] / "static"
|
||||
STATIC_MOUNT = "/api/static"
|
||||
SCALAR_JS_URL = f"{STATIC_MOUNT}/scalar.js"
|
||||
|
||||
|
||||
@lifespan_manager.register
|
||||
@asynccontextmanager
|
||||
async def init_persistence(app: FastAPI) -> AsyncGenerator[dict[str, Any], None]:
|
||||
"""Initialize persistence layer (DB, checkpointer, store)."""
|
||||
app_persistence = await create_persistence()
|
||||
|
||||
await app_persistence.setup()
|
||||
run_store = RunStoreAdapter(app_persistence.session_factory)
|
||||
thread_meta_store = ThreadMetaStoreAdapter(app_persistence.session_factory)
|
||||
feedback_store = FeedbackStoreAdapter(app_persistence.session_factory)
|
||||
|
||||
try:
|
||||
yield {
|
||||
"persistence": app_persistence,
|
||||
"checkpointer": app_persistence.checkpointer,
|
||||
"store": None,
|
||||
"session_factory": app_persistence.session_factory,
|
||||
"run_store": run_store,
|
||||
"run_read_repo": run_store,
|
||||
"run_write_repo": run_store,
|
||||
"run_delete_repo": run_store,
|
||||
"feedback_repo": feedback_store,
|
||||
"thread_meta_repo": thread_meta_store,
|
||||
"thread_meta_storage": ThreadMetaStorage(thread_meta_store),
|
||||
"run_event_store": build_run_event_store(app_persistence.session_factory),
|
||||
}
|
||||
finally:
|
||||
await app_persistence.aclose()
|
||||
|
||||
|
||||
@lifespan_manager.register
|
||||
@asynccontextmanager
|
||||
async def init_runtime(app: FastAPI) -> AsyncGenerator[dict[str, Any], None]:
|
||||
"""Initialize StreamBridge for LangGraph-compatible runtime endpoints."""
|
||||
from app.infra.stream_bridge import build_stream_bridge
|
||||
|
||||
async with build_stream_bridge() as stream_bridge:
|
||||
yield {
|
||||
"stream_bridge": stream_bridge,
|
||||
}
|
||||
|
||||
|
||||
def register_app() -> FastAPI:
|
||||
app = FastAPI(
|
||||
title="DeerFlow API Gateway",
|
||||
version="0.1.0",
|
||||
docs_url=None,
|
||||
redoc_url=None,
|
||||
lifespan=lifespan_manager.build(),
|
||||
openapi_tags=[
|
||||
{
|
||||
"name": "threads",
|
||||
"description": "Endpoints for managing threads, which are conversations between a human and an assistant. A thread can have multiple runs as the conversation progresses."
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
app.state.authz_hooks = build_authz_hooks()
|
||||
|
||||
_register_static(app)
|
||||
_register_routes(app)
|
||||
_register_scalar(app)
|
||||
_register_auth_route_policies(app)
|
||||
_register_middlewares(app)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def _register_static(app: FastAPI) -> None:
|
||||
app.mount(STATIC_MOUNT, StaticFiles(directory=STATIC_DIR), name="static")
|
||||
|
||||
|
||||
def _register_routes(app: FastAPI) -> None:
|
||||
app.include_router(gateway_router)
|
||||
|
||||
|
||||
def _register_auth_route_policies(app: FastAPI) -> None:
|
||||
registry = load_route_policy_registry()
|
||||
validate_route_policy_registry(app, registry)
|
||||
app.state.auth_route_policy_registry = registry
|
||||
install_route_guards(app)
|
||||
|
||||
|
||||
def _register_middlewares(app: FastAPI) -> None:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
expose_headers=["*"],
|
||||
)
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
app.add_middleware(AuthMiddleware)
|
||||
|
||||
|
||||
def _register_scalar(app: FastAPI) -> None:
|
||||
@app.get("/docs", include_in_schema=False)
|
||||
def scalar_docs() -> HTMLResponse:
|
||||
return get_scalar_api_reference(
|
||||
openapi_url=app.openapi_url,
|
||||
title=app.title,
|
||||
scalar_js_url=SCALAR_JS_URL,
|
||||
agent=AgentScalarConfig(disabled=True),
|
||||
hide_client_button=True,
|
||||
overrides={"mcp": {"disabled": True}},
|
||||
)
|
||||
22
backend/app/gateway/router.py
Normal file
22
backend/app/gateway/router.py
Normal file
@ -0,0 +1,22 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.plugins.auth.api.router import router as auth_router
|
||||
|
||||
from .routers import artifacts, channels, mcp, models, skills, uploads
|
||||
from .routers.agents import router as agents_router
|
||||
from .routers.langgraph import feedback_router, runs_router, suggestion_router, threads_router
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
router.include_router(auth_router)
|
||||
router.include_router(threads_router, prefix="/api/threads")
|
||||
router.include_router(runs_router, prefix="/api/threads")
|
||||
router.include_router(feedback_router, prefix="/api/threads")
|
||||
router.include_router(suggestion_router)
|
||||
router.include_router(agents_router)
|
||||
router.include_router(channels.router)
|
||||
router.include_router(artifacts.router)
|
||||
router.include_router(mcp.router)
|
||||
router.include_router(models.router)
|
||||
router.include_router(skills.router)
|
||||
router.include_router(uploads.router)
|
||||
5
backend/app/gateway/services/__init__.py
Normal file
5
backend/app/gateway/services/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
"""Gateway service layer."""
|
||||
|
||||
"""Compatibility package for app service submodules."""
|
||||
|
||||
__all__: list[str] = []
|
||||
29
backend/app/gateway/services/runs/__init__.py
Normal file
29
backend/app/gateway/services/runs/__init__.py
Normal file
@ -0,0 +1,29 @@
|
||||
"""Runs app layer services."""
|
||||
|
||||
from app.infra.storage import StorageRunObserver
|
||||
from .input import (
|
||||
AdaptedRunRequest,
|
||||
RunSpecBuilder,
|
||||
UnsupportedRunFeatureError,
|
||||
adapt_create_run_request,
|
||||
adapt_create_stream_request,
|
||||
adapt_create_wait_request,
|
||||
adapt_join_stream_request,
|
||||
adapt_join_wait_request,
|
||||
)
|
||||
from .store import AppRunCreateStore, AppRunDeleteStore, AppRunQueryStore
|
||||
|
||||
__all__ = [
|
||||
"AdaptedRunRequest",
|
||||
"AppRunCreateStore",
|
||||
"AppRunDeleteStore",
|
||||
"AppRunQueryStore",
|
||||
"RunSpecBuilder",
|
||||
"StorageRunObserver",
|
||||
"UnsupportedRunFeatureError",
|
||||
"adapt_create_run_request",
|
||||
"adapt_create_stream_request",
|
||||
"adapt_create_wait_request",
|
||||
"adapt_join_stream_request",
|
||||
"adapt_join_wait_request",
|
||||
]
|
||||
150
backend/app/gateway/services/runs/facade_factory.py
Normal file
150
backend/app/gateway/services/runs/facade_factory.py
Normal file
@ -0,0 +1,150 @@
|
||||
"""Facade factory - assembles RunsFacade with dependencies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
from app.gateway.dependencies import get_checkpointer, get_stream_bridge
|
||||
from deerflow.runtime.runs.facade import RunsFacade
|
||||
from deerflow.runtime.runs.facade import RunsRuntime
|
||||
from deerflow.runtime.runs.internal.execution.supervisor import RunSupervisor
|
||||
from deerflow.runtime.runs.internal.planner import ExecutionPlanner
|
||||
from deerflow.runtime.runs.internal.registry import RunRegistry
|
||||
from deerflow.runtime.runs.internal.streams import RunStreamService
|
||||
from deerflow.runtime.runs.internal.wait import RunWaitService
|
||||
|
||||
from app.infra.storage import StorageRunObserver, ThreadMetaStorage
|
||||
from app.infra.storage.runs import RunDeleteRepository, RunReadRepository, RunWriteRepository
|
||||
from .store import AppRunCreateStore, AppRunDeleteStore, AppRunQueryStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.runtime.stream_bridge import StreamBridge
|
||||
|
||||
|
||||
type AgentFactory = Callable[..., object]
|
||||
|
||||
|
||||
# Module-level singleton registry (shared across requests)
|
||||
_registry: RunRegistry | None = None
|
||||
_supervisor: RunSupervisor | None = None
|
||||
|
||||
|
||||
def _get_state(request: Request, attr: str, label: str):
|
||||
value = getattr(request.app.state, attr, None)
|
||||
if value is None:
|
||||
raise HTTPException(status_code=503, detail=f"{label} not available")
|
||||
return value
|
||||
|
||||
|
||||
def get_registry() -> RunRegistry:
|
||||
"""Get or create singleton registry."""
|
||||
global _registry
|
||||
if _registry is None:
|
||||
_registry = RunRegistry()
|
||||
return _registry
|
||||
|
||||
|
||||
def get_supervisor() -> RunSupervisor:
|
||||
"""Get or create singleton run supervisor."""
|
||||
global _supervisor
|
||||
if _supervisor is None:
|
||||
_supervisor = RunSupervisor()
|
||||
return _supervisor
|
||||
|
||||
|
||||
def resolve_agent_factory(assistant_id: str | None) -> AgentFactory:
|
||||
"""Resolve the agent factory callable from config."""
|
||||
from deerflow.agents.lead_agent.agent import make_lead_agent
|
||||
|
||||
return make_lead_agent
|
||||
|
||||
|
||||
def build_runs_facade(
|
||||
*,
|
||||
stream_bridge: "StreamBridge",
|
||||
checkpointer: object,
|
||||
store: object | None = None,
|
||||
run_read_repo: RunReadRepository | None = None,
|
||||
run_write_repo: RunWriteRepository | None = None,
|
||||
run_delete_repo: RunDeleteRepository | None = None,
|
||||
thread_meta_storage: ThreadMetaStorage | None = None,
|
||||
run_event_store: object | None = None,
|
||||
) -> RunsFacade:
|
||||
"""
|
||||
Build RunsFacade with all dependencies.
|
||||
|
||||
Args:
|
||||
stream_bridge: StreamBridge instance
|
||||
checkpointer: LangGraph checkpointer
|
||||
store: Optional LangGraph runtime store
|
||||
run_read_repo: Optional run repository for durable reads
|
||||
run_write_repo: Optional run repository for durable writes
|
||||
run_delete_repo: Optional run repository for durable deletes
|
||||
thread_meta_storage: Optional thread metadata storage adapter
|
||||
|
||||
Returns:
|
||||
Configured RunsFacade instance
|
||||
"""
|
||||
registry = get_registry()
|
||||
planner = ExecutionPlanner()
|
||||
supervisor = get_supervisor()
|
||||
|
||||
stream_service = RunStreamService(stream_bridge)
|
||||
wait_service = RunWaitService(stream_service)
|
||||
query_store = AppRunQueryStore(run_read_repo) if run_read_repo else None
|
||||
create_store = (
|
||||
AppRunCreateStore(run_write_repo, thread_meta_storage=thread_meta_storage)
|
||||
if run_write_repo
|
||||
else None
|
||||
)
|
||||
delete_store = AppRunDeleteStore(run_delete_repo) if run_delete_repo else None
|
||||
|
||||
# Build storage observer if repositories provided
|
||||
storage_observer = None
|
||||
if run_write_repo or thread_meta_storage:
|
||||
storage_observer = StorageRunObserver(
|
||||
run_write_repo=run_write_repo,
|
||||
thread_meta_storage=thread_meta_storage,
|
||||
)
|
||||
|
||||
return RunsFacade(
|
||||
registry=registry,
|
||||
planner=planner,
|
||||
supervisor=supervisor,
|
||||
stream_service=stream_service,
|
||||
wait_service=wait_service,
|
||||
runtime=RunsRuntime(
|
||||
bridge=stream_bridge,
|
||||
checkpointer=checkpointer,
|
||||
store=store,
|
||||
event_store=run_event_store,
|
||||
agent_factory_resolver=resolve_agent_factory,
|
||||
),
|
||||
observer=storage_observer,
|
||||
query_store=query_store,
|
||||
create_store=create_store,
|
||||
delete_store=delete_store,
|
||||
)
|
||||
|
||||
|
||||
def build_runs_facade_from_request(request: "Request") -> RunsFacade:
|
||||
"""
|
||||
Build RunsFacade from FastAPI request context.
|
||||
|
||||
Extracts dependencies from request.app.state.
|
||||
"""
|
||||
app_state = request.app.state
|
||||
|
||||
return build_runs_facade(
|
||||
stream_bridge=get_stream_bridge(request),
|
||||
checkpointer=get_checkpointer(request),
|
||||
store=getattr(request.app.state, "store", None),
|
||||
run_read_repo=getattr(app_state, "run_read_repo", None),
|
||||
run_write_repo=getattr(app_state, "run_write_repo", None),
|
||||
run_delete_repo=getattr(app_state, "run_delete_repo", None),
|
||||
thread_meta_storage=getattr(app_state, "thread_meta_storage", None),
|
||||
run_event_store=getattr(app_state, "run_event_store", None),
|
||||
)
|
||||
22
backend/app/gateway/services/runs/input/__init__.py
Normal file
22
backend/app/gateway/services/runs/input/__init__.py
Normal file
@ -0,0 +1,22 @@
|
||||
"""Input adapters for app-owned runs entrypoints."""
|
||||
|
||||
from .request_adapter import (
|
||||
AdaptedRunRequest,
|
||||
adapt_create_run_request,
|
||||
adapt_create_stream_request,
|
||||
adapt_create_wait_request,
|
||||
adapt_join_stream_request,
|
||||
adapt_join_wait_request,
|
||||
)
|
||||
from .spec_builder import RunSpecBuilder, UnsupportedRunFeatureError
|
||||
|
||||
__all__ = [
|
||||
"AdaptedRunRequest",
|
||||
"RunSpecBuilder",
|
||||
"UnsupportedRunFeatureError",
|
||||
"adapt_create_run_request",
|
||||
"adapt_create_stream_request",
|
||||
"adapt_create_wait_request",
|
||||
"adapt_join_stream_request",
|
||||
"adapt_join_wait_request",
|
||||
]
|
||||
127
backend/app/gateway/services/runs/input/request_adapter.py
Normal file
127
backend/app/gateway/services/runs/input/request_adapter.py
Normal file
@ -0,0 +1,127 @@
|
||||
"""App-owned request adapter for runs entrypoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from deerflow.runtime.stream_bridge import JSONValue
|
||||
from deerflow.runtime.runs.types import RunIntent
|
||||
|
||||
type RequestBody = dict[str, JSONValue]
|
||||
type RequestQuery = dict[str, str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AdaptedRunRequest:
|
||||
"""
|
||||
统一的内部请求 DTO.
|
||||
|
||||
路由层只负责提取 path/query/body,适配器负责转成稳定内部结构。
|
||||
"""
|
||||
|
||||
intent: RunIntent
|
||||
thread_id: str | None
|
||||
run_id: str | None
|
||||
body: RequestBody
|
||||
headers: dict[str, str]
|
||||
query: RequestQuery
|
||||
|
||||
@property
|
||||
def last_event_id(self) -> str | None:
|
||||
"""Extract Last-Event-ID from headers."""
|
||||
return self.headers.get("last-event-id") or self.headers.get("Last-Event-ID")
|
||||
|
||||
@property
|
||||
def is_stateless(self) -> bool:
|
||||
"""Check if this is a stateless request."""
|
||||
return self.thread_id is None
|
||||
|
||||
|
||||
def adapt_create_run_request(
|
||||
*,
|
||||
thread_id: str | None,
|
||||
body: RequestBody,
|
||||
headers: dict[str, str] | None = None,
|
||||
query: RequestQuery | None = None,
|
||||
) -> AdaptedRunRequest:
|
||||
"""Adapt POST /threads/{thread_id}/runs or POST /runs."""
|
||||
return AdaptedRunRequest(
|
||||
intent="create_background",
|
||||
thread_id=thread_id,
|
||||
run_id=None,
|
||||
body=body,
|
||||
headers=headers or {},
|
||||
query=query or {},
|
||||
)
|
||||
|
||||
|
||||
def adapt_create_stream_request(
|
||||
*,
|
||||
thread_id: str | None,
|
||||
body: RequestBody,
|
||||
headers: dict[str, str] | None = None,
|
||||
query: RequestQuery | None = None,
|
||||
) -> AdaptedRunRequest:
|
||||
"""Adapt POST /threads/{thread_id}/runs/stream or POST /runs/stream."""
|
||||
return AdaptedRunRequest(
|
||||
intent="create_and_stream",
|
||||
thread_id=thread_id,
|
||||
run_id=None,
|
||||
body=body,
|
||||
headers=headers or {},
|
||||
query=query or {},
|
||||
)
|
||||
|
||||
|
||||
def adapt_create_wait_request(
|
||||
*,
|
||||
thread_id: str | None,
|
||||
body: RequestBody,
|
||||
headers: dict[str, str] | None = None,
|
||||
query: RequestQuery | None = None,
|
||||
) -> AdaptedRunRequest:
|
||||
"""Adapt POST /threads/{thread_id}/runs/wait or POST /runs/wait."""
|
||||
return AdaptedRunRequest(
|
||||
intent="create_and_wait",
|
||||
thread_id=thread_id,
|
||||
run_id=None,
|
||||
body=body,
|
||||
headers=headers or {},
|
||||
query=query or {},
|
||||
)
|
||||
|
||||
|
||||
def adapt_join_stream_request(
|
||||
*,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
query: RequestQuery | None = None,
|
||||
) -> AdaptedRunRequest:
|
||||
"""Adapt GET /threads/{thread_id}/runs/{run_id}/stream."""
|
||||
return AdaptedRunRequest(
|
||||
intent="join_stream",
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
body={},
|
||||
headers=headers or {},
|
||||
query=query or {},
|
||||
)
|
||||
|
||||
|
||||
def adapt_join_wait_request(
|
||||
*,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
query: RequestQuery | None = None,
|
||||
) -> AdaptedRunRequest:
|
||||
"""Adapt GET /threads/{thread_id}/runs/{run_id}/join."""
|
||||
return AdaptedRunRequest(
|
||||
intent="join_wait",
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
body={},
|
||||
headers=headers or {},
|
||||
query=query or {},
|
||||
)
|
||||
254
backend/app/gateway/services/runs/input/spec_builder.py
Normal file
254
backend/app/gateway/services/runs/input/spec_builder.py
Normal file
@ -0,0 +1,254 @@
|
||||
"""App-owned RunSpec builder."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import uuid
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from deerflow.runtime.runs.types import CheckpointRequest, RunScope, RunSpec
|
||||
from deerflow.runtime.stream_bridge import JSONValue
|
||||
|
||||
from .request_adapter import AdaptedRunRequest
|
||||
|
||||
type JSONMapping = dict[str, JSONValue]
|
||||
type GraphInput = dict[str, object]
|
||||
type RunnableConfigDict = dict[str, object]
|
||||
|
||||
|
||||
class UnsupportedRunFeatureError(ValueError):
|
||||
"""Raised when a phase1-unsupported feature is requested."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RunSpecBuilder:
|
||||
"""
|
||||
Build RunSpec from AdaptedRunRequest.
|
||||
|
||||
Phase 1 rules:
|
||||
1. messages-tuple normalized to messages
|
||||
2. enqueue not supported
|
||||
3. rollback not supported
|
||||
4. after_seconds not supported
|
||||
5. stream_resumable accepted
|
||||
6. stateless auto-generates temporary thread
|
||||
"""
|
||||
|
||||
# Phase 1 unsupported features
|
||||
UNSUPPORTED_MULTITASK_STRATEGIES = {"enqueue"}
|
||||
UNSUPPORTED_ACTIONS = {"rollback"}
|
||||
|
||||
# Default stream modes
|
||||
DEFAULT_STREAM_MODES = ["values", "messages"]
|
||||
CONTEXT_CONFIGURABLE_KEYS = frozenset({
|
||||
"model_name",
|
||||
"mode",
|
||||
"thinking_enabled",
|
||||
"reasoning_effort",
|
||||
"is_plan_mode",
|
||||
"subagent_enabled",
|
||||
"max_concurrent_subagents",
|
||||
})
|
||||
DEFAULT_ASSISTANT_ID = "lead_agent"
|
||||
|
||||
@staticmethod
|
||||
def _as_json_mapping(value: JSONValue | None) -> JSONMapping | None:
|
||||
return value if isinstance(value, dict) else None
|
||||
|
||||
@staticmethod
|
||||
def _as_string_list(value: JSONValue | None) -> list[str] | None:
|
||||
if not isinstance(value, list):
|
||||
return None
|
||||
return [item for item in value if isinstance(item, str)]
|
||||
|
||||
def build(self, request: AdaptedRunRequest) -> RunSpec:
|
||||
"""Build RunSpec from adapted request."""
|
||||
body = request.body
|
||||
|
||||
# Validate phase1 constraints
|
||||
self._validate_constraints(body)
|
||||
|
||||
# Build scope
|
||||
scope = self._build_scope(request)
|
||||
|
||||
# Normalize stream modes
|
||||
stream_modes = self._normalize_stream_modes(body.get("stream_mode"))
|
||||
|
||||
# Build checkpoint request
|
||||
checkpoint_request = self._build_checkpoint_request(body)
|
||||
|
||||
config = self._build_runnable_config(
|
||||
thread_id=scope.thread_id,
|
||||
request_config=self._as_json_mapping(body.get("config")),
|
||||
metadata=self._as_json_mapping(body.get("metadata")),
|
||||
assistant_id=body.get("assistant_id"),
|
||||
context=self._as_json_mapping(body.get("context")),
|
||||
)
|
||||
|
||||
return RunSpec(
|
||||
intent=request.intent,
|
||||
scope=scope,
|
||||
assistant_id=body.get("assistant_id") if isinstance(body.get("assistant_id"), str) else None,
|
||||
input=self._normalize_input(self._as_json_mapping(body.get("input"))),
|
||||
command=self._as_json_mapping(body.get("command")),
|
||||
runnable_config=config,
|
||||
context=self._as_json_mapping(body.get("context")),
|
||||
metadata=self._as_json_mapping(body.get("metadata")) or {},
|
||||
stream_modes=stream_modes,
|
||||
stream_subgraphs=bool(body.get("stream_subgraphs", False)),
|
||||
stream_resumable=bool(body.get("stream_resumable", False)),
|
||||
on_disconnect=body.get("on_disconnect", "cancel") if body.get("on_disconnect") in {"cancel", "continue"} else "cancel",
|
||||
on_completion=body.get("on_completion", "keep") if body.get("on_completion") in {"delete", "keep"} else "keep",
|
||||
multitask_strategy=body.get("multitask_strategy", "reject") if body.get("multitask_strategy") in {"reject", "interrupt"} else "reject",
|
||||
interrupt_before="*" if body.get("interrupt_before") == "*" else self._as_string_list(body.get("interrupt_before")),
|
||||
interrupt_after="*" if body.get("interrupt_after") == "*" else self._as_string_list(body.get("interrupt_after")),
|
||||
checkpoint_request=checkpoint_request,
|
||||
follow_up_to_run_id=body.get("follow_up_to_run_id") if isinstance(body.get("follow_up_to_run_id"), str) else None,
|
||||
webhook=body.get("webhook") if isinstance(body.get("webhook"), str) else None,
|
||||
feedback_keys=self._as_string_list(body.get("feedback_keys")),
|
||||
)
|
||||
|
||||
def _validate_constraints(self, body: JSONMapping) -> None:
|
||||
"""Validate phase1 constraints, raise UnsupportedRunFeatureError if violated."""
|
||||
# Check multitask_strategy
|
||||
strategy = body.get("multitask_strategy", "reject")
|
||||
if strategy in self.UNSUPPORTED_MULTITASK_STRATEGIES:
|
||||
raise UnsupportedRunFeatureError(
|
||||
f"multitask_strategy '{strategy}' is not supported in phase1. "
|
||||
f"Supported: reject, interrupt"
|
||||
)
|
||||
|
||||
# Check for rollback action
|
||||
command = self._as_json_mapping(body.get("command")) or {}
|
||||
if command.get("action") in self.UNSUPPORTED_ACTIONS:
|
||||
raise UnsupportedRunFeatureError(
|
||||
f"action '{command.get('action')}' is not supported in phase1"
|
||||
)
|
||||
|
||||
# Check for after_seconds
|
||||
if body.get("after_seconds") is not None:
|
||||
raise UnsupportedRunFeatureError("after_seconds is not supported in phase1")
|
||||
|
||||
def _build_scope(self, request: AdaptedRunRequest) -> RunScope:
|
||||
"""Build RunScope from request."""
|
||||
if request.is_stateless:
|
||||
# Stateless: generate temporary thread
|
||||
return RunScope(
|
||||
kind="stateless",
|
||||
thread_id=str(uuid.uuid4()),
|
||||
temporary=True,
|
||||
)
|
||||
else:
|
||||
assert request.thread_id is not None
|
||||
return RunScope(
|
||||
kind="stateful",
|
||||
thread_id=request.thread_id,
|
||||
temporary=False,
|
||||
)
|
||||
|
||||
def _normalize_stream_modes(self, stream_mode: JSONValue | None) -> list[str]:
|
||||
"""Normalize stream_mode to list, convert messages-tuple to messages."""
|
||||
if stream_mode is None:
|
||||
return self.DEFAULT_STREAM_MODES.copy()
|
||||
|
||||
if isinstance(stream_mode, str):
|
||||
modes = [stream_mode]
|
||||
elif isinstance(stream_mode, list):
|
||||
modes = [mode for mode in stream_mode if isinstance(mode, str)]
|
||||
else:
|
||||
return self.DEFAULT_STREAM_MODES.copy()
|
||||
|
||||
return ["messages" if m == "messages-tuple" else m for m in modes]
|
||||
|
||||
def _build_checkpoint_request(self, body: JSONMapping) -> CheckpointRequest | None:
|
||||
"""Build CheckpointRequest if checkpoint data is provided."""
|
||||
checkpoint_id = body.get("checkpoint_id")
|
||||
checkpoint = self._as_json_mapping(body.get("checkpoint"))
|
||||
|
||||
if not isinstance(checkpoint_id, str) and checkpoint is None:
|
||||
return None
|
||||
|
||||
return CheckpointRequest(
|
||||
checkpoint_id=checkpoint_id if isinstance(checkpoint_id, str) else None,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
def _normalize_input(self, raw_input: JSONMapping | None) -> GraphInput | None:
|
||||
"""Convert HTTP-friendly message dicts into LangChain message objects."""
|
||||
if raw_input is None:
|
||||
return None
|
||||
|
||||
messages = raw_input.get("messages")
|
||||
if not messages or not isinstance(messages, list):
|
||||
return raw_input
|
||||
|
||||
converted: list[object] = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, dict):
|
||||
role = msg.get("role", msg.get("type", "user"))
|
||||
content = msg.get("content", "")
|
||||
if role in ("user", "human"):
|
||||
converted.append(HumanMessage(content=content))
|
||||
else:
|
||||
converted.append(HumanMessage(content=content))
|
||||
else:
|
||||
converted.append(msg)
|
||||
return {**raw_input, "messages": converted}
|
||||
|
||||
def _build_runnable_config(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
request_config: JSONMapping | None,
|
||||
metadata: JSONMapping | None,
|
||||
assistant_id: str | None,
|
||||
context: JSONMapping | None,
|
||||
) -> RunnableConfigDict:
|
||||
"""Build RunnableConfig from request payload and app-side rules."""
|
||||
config: RunnableConfigDict = {"recursion_limit": 100}
|
||||
|
||||
if request_config:
|
||||
if "context" in request_config:
|
||||
config["context"] = request_config["context"]
|
||||
else:
|
||||
configurable = {"thread_id": thread_id}
|
||||
raw_configurable = request_config.get("configurable")
|
||||
if isinstance(raw_configurable, dict):
|
||||
configurable.update(raw_configurable)
|
||||
config["configurable"] = configurable
|
||||
|
||||
for key, value in request_config.items():
|
||||
if key not in ("configurable", "context"):
|
||||
config[key] = value
|
||||
else:
|
||||
config["configurable"] = {"thread_id": thread_id}
|
||||
|
||||
configurable = config.get("configurable")
|
||||
if (
|
||||
assistant_id
|
||||
and assistant_id != self.DEFAULT_ASSISTANT_ID
|
||||
and isinstance(configurable, dict)
|
||||
and "agent_name" not in configurable
|
||||
):
|
||||
normalized = assistant_id.strip().lower().replace("_", "-")
|
||||
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
|
||||
raise ValueError(
|
||||
f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization."
|
||||
)
|
||||
configurable["agent_name"] = normalized
|
||||
|
||||
if metadata:
|
||||
existing_metadata = config.get("metadata")
|
||||
if isinstance(existing_metadata, dict):
|
||||
existing_metadata.update(metadata)
|
||||
else:
|
||||
config["metadata"] = dict(metadata)
|
||||
|
||||
if context and isinstance(configurable, dict):
|
||||
for key in self.CONTEXT_CONFIGURABLE_KEYS:
|
||||
if key in context:
|
||||
configurable.setdefault(key, context[key])
|
||||
|
||||
return config
|
||||
5
backend/app/gateway/services/runs/storage_observer.py
Normal file
5
backend/app/gateway/services/runs/storage_observer.py
Normal file
@ -0,0 +1,5 @@
|
||||
"""Compatibility wrapper for the app-owned storage observer."""
|
||||
|
||||
from app.infra.storage.runs import StorageRunObserver
|
||||
|
||||
__all__ = ["StorageRunObserver"]
|
||||
11
backend/app/gateway/services/runs/store/__init__.py
Normal file
11
backend/app/gateway/services/runs/store/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
"""App-owned runs store adapters."""
|
||||
|
||||
from .create_store import AppRunCreateStore
|
||||
from .delete_store import AppRunDeleteStore
|
||||
from .query_store import AppRunQueryStore
|
||||
|
||||
__all__ = [
|
||||
"AppRunCreateStore",
|
||||
"AppRunDeleteStore",
|
||||
"AppRunQueryStore",
|
||||
]
|
||||
38
backend/app/gateway/services/runs/store/create_store.py
Normal file
38
backend/app/gateway/services/runs/store/create_store.py
Normal file
@ -0,0 +1,38 @@
|
||||
"""App-owned durable run creation adapter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from deerflow.runtime.runs.store import RunCreateStore
|
||||
from deerflow.runtime.runs.types import RunRecord
|
||||
|
||||
from app.infra.storage import ThreadMetaStorage
|
||||
from app.infra.storage.runs import RunWriteRepository
|
||||
|
||||
|
||||
class AppRunCreateStore(RunCreateStore):
|
||||
"""Write the initial durable row for a newly created run."""
|
||||
|
||||
def __init__(self, repo: RunWriteRepository, thread_meta_storage: ThreadMetaStorage | None = None) -> None:
|
||||
self._repo = repo
|
||||
self._thread_meta_storage = thread_meta_storage
|
||||
|
||||
async def create_run(self, record: RunRecord) -> None:
|
||||
await self._repo.create(
|
||||
run_id=record.run_id,
|
||||
thread_id=record.thread_id,
|
||||
assistant_id=record.assistant_id,
|
||||
status=str(record.status),
|
||||
metadata=record.metadata,
|
||||
follow_up_to_run_id=record.follow_up_to_run_id,
|
||||
created_at=record.created_at,
|
||||
)
|
||||
if self._thread_meta_storage is not None and record.assistant_id:
|
||||
thread = await self._thread_meta_storage.ensure_thread(
|
||||
thread_id=record.thread_id,
|
||||
assistant_id=record.assistant_id,
|
||||
)
|
||||
if thread.assistant_id != record.assistant_id:
|
||||
await self._thread_meta_storage.sync_thread_assistant_id(
|
||||
thread_id=record.thread_id,
|
||||
assistant_id=record.assistant_id,
|
||||
)
|
||||
17
backend/app/gateway/services/runs/store/delete_store.py
Normal file
17
backend/app/gateway/services/runs/store/delete_store.py
Normal file
@ -0,0 +1,17 @@
|
||||
"""App-owned durable run deletion adapter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from deerflow.runtime.runs.store import RunDeleteStore
|
||||
|
||||
from app.infra.storage.runs import RunDeleteRepository
|
||||
|
||||
|
||||
class AppRunDeleteStore(RunDeleteStore):
|
||||
"""Delete durable run rows via the app storage adapter."""
|
||||
|
||||
def __init__(self, repo: RunDeleteRepository) -> None:
|
||||
self._repo = repo
|
||||
|
||||
async def delete_run(self, run_id: str) -> bool:
|
||||
return await self._repo.delete(run_id)
|
||||
47
backend/app/gateway/services/runs/store/query_store.py
Normal file
47
backend/app/gateway/services/runs/store/query_store.py
Normal file
@ -0,0 +1,47 @@
|
||||
"""App-owned durable run query adapter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from deerflow.runtime.runs.store import RunQueryStore
|
||||
from deerflow.runtime.runs.types import RunRecord, RunStatus
|
||||
|
||||
from app.infra.storage.runs import RunReadRepository, RunRow
|
||||
|
||||
|
||||
class AppRunQueryStore(RunQueryStore):
|
||||
"""Map app-side durable run rows into harness RunRecord DTOs."""
|
||||
|
||||
def __init__(self, repo: RunReadRepository) -> None:
|
||||
self._repo = repo
|
||||
|
||||
async def get_run(self, run_id: str) -> RunRecord | None:
|
||||
row = await self._repo.get(run_id)
|
||||
if row is None:
|
||||
return None
|
||||
return self._to_run_record(row)
|
||||
|
||||
async def list_runs(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 100,
|
||||
) -> list[RunRecord]:
|
||||
rows = await self._repo.list_by_thread(thread_id, limit=limit)
|
||||
return [self._to_run_record(row) for row in rows]
|
||||
|
||||
def _to_run_record(self, row: RunRow) -> RunRecord:
|
||||
return RunRecord(
|
||||
run_id=row["run_id"],
|
||||
thread_id=row["thread_id"],
|
||||
assistant_id=row.get("assistant_id"),
|
||||
status=RunStatus(row.get("status", "pending")),
|
||||
temporary=False,
|
||||
multitask_strategy=row.get("multitask_strategy", "reject"),
|
||||
metadata=row.get("metadata", {}),
|
||||
follow_up_to_run_id=row.get("follow_up_to_run_id"),
|
||||
created_at=row.get("created_at", ""),
|
||||
updated_at=row.get("updated_at", ""),
|
||||
started_at=row.get("started_at"),
|
||||
ended_at=row.get("ended_at"),
|
||||
error=row.get("error"),
|
||||
)
|
||||
@ -1,6 +0,0 @@
|
||||
"""Shared utility helpers for the Gateway layer."""
|
||||
|
||||
|
||||
def sanitize_log_param(value: str) -> str:
|
||||
"""Strip control characters to prevent log injection."""
|
||||
return value.replace("\n", "").replace("\r", "").replace("\x00", "")
|
||||
Loading…
x
Reference in New Issue
Block a user