Willem Jiang 519200728a
fix(middleware): offload memory injection off event loop to prevent tiktoken blocking (#3402) (#3411)
* fix(middleware): offload memory injection off event loop to prevent tiktoken blocking (#3402)

  DynamicContextMiddleware.abefore_agent() called _inject() synchronously
  on the asyncio event loop.  The first time memory is injected (second
  request), _inject() → format_memory_for_injection() → _count_tokens()
  → tiktoken.get_encoding("cl100k_base") needs to download the BPE data
  from openaipublic.blob.core.windows.net.  In network-restricted
  environments this download blocks until the OS TCP timeout (~26 min),
  starving ALL concurrent handlers including /api/v1/auth/me.

  Fix:
  - abefore_agent now uses asyncio.to_thread(self._inject, state) so
    file I/O and tiktoken never block the event loop.
  - Extract _get_tiktoken_encoding() with a module-level cache so
    tiktoken.get_encoding() is called at most once per encoding name.
  - Add warm_tiktoken_cache() startup helper; gateway lifespan pre-warms
    the cache via asyncio.to_thread so the first request never triggers a
    cold download.
  - _count_tokens falls back to len(text) // 4 on any encoding failure.

  Tests:
  - tests/test_tiktoken_cache_and_count_tokens.py (12 tests): cache
    hit/miss, fallback paths, warm-up helper.
  - tests/blocking_io/test_dynamic_context_middleware.py (2 tests):
    Blockbuster gate verifies abefore_agent does not block the event
    loop; async/sync parity check.

  Fixes #3402

* Apply suggestions from code review

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* fix the lint error

* fix(memory): use future annotations to avoid NameError when tiktoken is absent

Add `from __future__ import annotations` to prompt.py so that
tiktoken.Encoding type hints are never evaluated at runtime.  Without
this, environments where tiktoken is not installed could raise NameError
on the module-level cache and function return annotations.

Addresses Copilot review comment on PR #3411.

* fix(middleware): bound abefore_agent injection with timeout to prevent hung requests

Wrap the asyncio.to_thread(self._inject) offload in asyncio.wait_for()
with a 5-second cap.  If the startup warm-up failed silently (e.g.
network blip during deploy), a cold tiktoken BPE download on the first
request can block until the OS TCP timeout (~26 min).  The bounded
timeout ensures the request degrades gracefully (no memory/date context
for that turn) rather than hanging.

Adds test_abefore_agent_returns_none_on_timeout to the blocking-IO
regression anchors.

Addresses review feedback from xg-gh-25 on PR #3411.

---------

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-06-08 12:21:55 +08:00

411 lines
15 KiB
Python

import asyncio
import logging
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.gateway.auth_middleware import AuthMiddleware
from app.gateway.config import get_gateway_config
from app.gateway.csrf_middleware import CSRFMiddleware, get_configured_cors_origins
from app.gateway.deps import langgraph_runtime
from app.gateway.routers import (
agents,
artifacts,
assistants_compat,
auth,
channels,
feedback,
mcp,
memory,
models,
runs,
skills,
suggestions,
thread_runs,
threads,
uploads,
)
from deerflow.config import app_config as deerflow_app_config
from deerflow.config.app_config import apply_logging_level
AppConfig = deerflow_app_config.AppConfig
get_app_config = deerflow_app_config.get_app_config
# Default logging; lifespan overrides from config.yaml log_level.
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
# Upper bound (seconds) each lifespan shutdown hook is allowed to run.
# Bounds worker exit time so uvicorn's reload supervisor does not keep
# firing signals into a worker that is stuck waiting for shutdown cleanup.
_SHUTDOWN_HOOK_TIMEOUT_SECONDS = 5.0
async def _ensure_admin_user(app: FastAPI) -> None:
"""Startup hook: handle first boot and migrate orphan threads otherwise.
After admin creation, migrate orphan threads from the LangGraph
store (metadata.user_id unset) to the admin account. This is the
"no-auth → with-auth" upgrade path: users who ran DeerFlow without
authentication have existing LangGraph thread data that needs an
owner assigned.
First boot (no admin exists):
- Does NOT create any user accounts automatically.
- The operator must visit ``/setup`` to create the first admin.
Subsequent boots (admin already exists):
- Runs the one-time "no-auth → with-auth" orphan thread migration for
existing LangGraph thread metadata that has no user_id.
No SQL persistence migration is needed: the four user_id columns
(threads_meta, runs, run_events, feedback) only come into existence
alongside the auth module via create_all, so freshly created tables
never contain NULL-owner rows.
"""
from sqlalchemy import select
from app.gateway.deps import get_local_provider
from deerflow.persistence.engine import get_session_factory
from deerflow.persistence.user.model import UserRow
try:
provider = get_local_provider()
except RuntimeError:
# Auth persistence may not be initialized in some test/boot paths.
# Skip admin migration work rather than failing gateway startup.
logger.warning("Auth persistence not ready; skipping admin bootstrap check")
return
sf = get_session_factory()
if sf is None:
return
admin_count = await provider.count_admin_users()
if admin_count == 0:
logger.info("=" * 60)
logger.info(" First boot detected — no admin account exists.")
logger.info(" Visit /setup to complete admin account creation.")
logger.info("=" * 60)
return
# Admin already exists — run orphan thread migration for any
# LangGraph thread metadata that pre-dates the auth module.
async with sf() as session:
stmt = select(UserRow).where(UserRow.system_role == "admin").limit(1)
row = (await session.execute(stmt)).scalar_one_or_none()
if row is None:
return # Should not happen (admin_count > 0 above), but be safe.
admin_id = str(row.id)
# LangGraph store orphan migration — non-fatal.
# This covers the "no-auth → with-auth" upgrade path for users
# whose existing LangGraph thread metadata has no user_id set.
store = getattr(app.state, "store", None)
if store is not None:
try:
migrated = await _migrate_orphaned_threads(store, admin_id)
if migrated:
logger.info("Migrated %d orphan LangGraph thread(s) to admin", migrated)
except Exception:
logger.exception("LangGraph thread migration failed (non-fatal)")
async def _iter_store_items(store, namespace, *, page_size: int = 500):
"""Paginated async iterator over a LangGraph store namespace.
Replaces the old hardcoded ``limit=1000`` call with a cursor-style
loop so that environments with more than one page of orphans do
not silently lose data. Terminates when a page is empty OR when a
short page arrives (indicating the last page).
"""
offset = 0
while True:
batch = await store.asearch(namespace, limit=page_size, offset=offset)
if not batch:
return
for item in batch:
yield item
if len(batch) < page_size:
return
offset += page_size
async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
"""Migrate LangGraph store threads with no user_id to the given admin.
Uses cursor pagination so all orphans are migrated regardless of
count. Returns the number of rows migrated.
"""
migrated = 0
async for item in _iter_store_items(store, ("threads",)):
metadata = item.value.get("metadata", {})
if not metadata.get("user_id"):
metadata["user_id"] = admin_user_id
item.value["metadata"] = metadata
await store.aput(("threads",), item.key, item.value)
migrated += 1
return migrated
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Application lifespan handler."""
# Load config and check necessary environment variables at startup.
# `startup_config` is a local snapshot used only for one-shot bootstrap
# work (logging level, langgraph_runtime engines, channels). Request-time
# config resolution always routes through `get_app_config()` in
# `app/gateway/deps.py::get_config()` so `config.yaml` edits become
# visible without a process restart. We deliberately do NOT cache this
# snapshot on `app.state` to keep that contract enforceable.
try:
startup_config = get_app_config()
apply_logging_level(startup_config.log_level)
logger.info("Configuration loaded successfully")
except Exception as e:
error_msg = f"Failed to load configuration during gateway startup: {e}"
logger.exception(error_msg)
raise RuntimeError(error_msg) from e
config = get_gateway_config()
logger.info(f"Starting API Gateway on {config.host}:{config.port}")
# Pre-warm tiktoken encoding cache so the first memory-injection request
# never blocks on the BPE data download (which hits an OpenAI/Azure URL
# that may be unreachable in restricted networks — see issue #3402).
try:
from deerflow.agents.memory.prompt import warm_tiktoken_cache
warmed = await asyncio.wait_for(
asyncio.to_thread(warm_tiktoken_cache),
timeout=5,
)
if warmed:
logger.info("tiktoken encoding cache warmed successfully")
else:
logger.warning("tiktoken encoding cache warm-up failed; token counting will use character-based fallback")
except TimeoutError:
logger.warning("tiktoken encoding cache warm-up timed out; token counting will use character-based fallback")
except Exception:
logger.warning("tiktoken warm-up skipped", exc_info=True)
# Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store)
async with langgraph_runtime(app, startup_config):
logger.info("LangGraph runtime initialised")
# Check admin bootstrap state and migrate orphan threads after admin exists.
# Must run AFTER langgraph_runtime so app.state.store is available for thread migration
await _ensure_admin_user(app)
# Start IM channel service if any channels are configured
try:
from app.channels.service import start_channel_service
channel_service = await start_channel_service(startup_config)
logger.info("Channel service started: %s", channel_service.get_status())
except Exception:
logger.exception("No IM channels configured or channel service failed to start")
yield
# Stop channel service on shutdown (bounded to prevent worker hang)
try:
from app.channels.service import stop_channel_service
await asyncio.wait_for(
stop_channel_service(),
timeout=_SHUTDOWN_HOOK_TIMEOUT_SECONDS,
)
except TimeoutError:
logger.warning(
"Channel service shutdown exceeded %.1fs; proceeding with worker exit.",
_SHUTDOWN_HOOK_TIMEOUT_SECONDS,
)
except Exception:
logger.exception("Failed to stop channel service")
logger.info("Shutting down API Gateway")
def create_app() -> FastAPI:
"""Create and configure the FastAPI application.
Returns:
Configured FastAPI application instance.
"""
config = get_gateway_config()
docs_url = "/docs" if config.enable_docs else None
redoc_url = "/redoc" if config.enable_docs else None
openapi_url = "/openapi.json" if config.enable_docs else None
app = FastAPI(
title="DeerFlow API Gateway",
description="""
## DeerFlow API Gateway
API Gateway for DeerFlow - A LangGraph-based AI agent backend with sandbox execution capabilities.
### Features
- **Models Management**: Query and retrieve available AI models
- **MCP Configuration**: Manage Model Context Protocol (MCP) server configurations
- **Memory Management**: Access and manage global memory data for personalized conversations
- **Skills Management**: Query and manage skills and their enabled status
- **Artifacts**: Access thread artifacts and generated files
- **Health Monitoring**: System health check endpoints
### Architecture
LangGraph-compatible requests are routed through nginx to this gateway.
This gateway provides runtime endpoints for agent runs plus custom endpoints for models, MCP configuration, skills, and artifacts.
""",
version="0.1.0",
lifespan=lifespan,
docs_url=docs_url,
redoc_url=redoc_url,
openapi_url=openapi_url,
openapi_tags=[
{
"name": "models",
"description": "Operations for querying available AI models and their configurations",
},
{
"name": "mcp",
"description": "Manage Model Context Protocol (MCP) server configurations",
},
{
"name": "memory",
"description": "Access and manage global memory data for personalized conversations",
},
{
"name": "skills",
"description": "Manage skills and their configurations",
},
{
"name": "artifacts",
"description": "Access and download thread artifacts and generated files",
},
{
"name": "uploads",
"description": "Upload and manage user files for threads",
},
{
"name": "threads",
"description": "Manage DeerFlow thread-local filesystem data",
},
{
"name": "agents",
"description": "Create and manage custom agents with per-agent config and prompts",
},
{
"name": "suggestions",
"description": "Generate follow-up question suggestions for conversations",
},
{
"name": "channels",
"description": "Manage IM channel integrations (Feishu, Slack, Telegram)",
},
{
"name": "assistants-compat",
"description": "LangGraph Platform-compatible assistants API (stub)",
},
{
"name": "runs",
"description": "LangGraph Platform-compatible runs lifecycle (create, stream, cancel)",
},
{
"name": "health",
"description": "Health check and system status endpoints",
},
],
)
# Auth: reject unauthenticated requests to non-public paths (fail-closed safety net)
app.add_middleware(AuthMiddleware)
# CSRF: Double Submit Cookie pattern for state-changing requests
app.add_middleware(CSRFMiddleware)
# CORS: the unified nginx endpoint is same-origin by default. Split-origin
# browser clients must opt in with this explicit Gateway allowlist so CORS
# and CSRF origin checks share the same source of truth.
cors_origins = sorted(get_configured_cors_origins())
if cors_origins:
app.add_middleware(
CORSMiddleware,
allow_origins=cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Include routers
# Models API is mounted at /api/models
app.include_router(models.router)
# MCP API is mounted at /api/mcp
app.include_router(mcp.router)
# Memory API is mounted at /api/memory
app.include_router(memory.router)
# Skills API is mounted at /api/skills
app.include_router(skills.router)
# Artifacts API is mounted at /api/threads/{thread_id}/artifacts
app.include_router(artifacts.router)
# Uploads API is mounted at /api/threads/{thread_id}/uploads
app.include_router(uploads.router)
# Thread cleanup API is mounted at /api/threads/{thread_id}
app.include_router(threads.router)
# Agents API is mounted at /api/agents
app.include_router(agents.router)
# Suggestions API is mounted at /api/threads/{thread_id}/suggestions
app.include_router(suggestions.router)
# Channels API is mounted at /api/channels
app.include_router(channels.router)
# Assistants compatibility API (LangGraph Platform stub)
app.include_router(assistants_compat.router)
# Auth API is mounted at /api/v1/auth
app.include_router(auth.router)
# Feedback API is mounted at /api/threads/{thread_id}/runs/{run_id}/feedback
app.include_router(feedback.router)
# Thread Runs API (LangGraph Platform-compatible runs lifecycle)
app.include_router(thread_runs.router)
# Stateless Runs API (stream/wait without a pre-existing thread)
app.include_router(runs.router)
@app.get("/health", tags=["health"])
async def health_check() -> dict[str, str]:
"""Health check endpoint.
Returns:
Service health status information.
"""
return {"status": "healthy", "service": "deer-flow-gateway"}
return app
# Create app instance for uvicorn
app = create_app()