feat: implement process-local internal authentication for Gateway and enhance CSRF handling

This commit is contained in:
JeffJiang 2026-04-26 22:15:43 +08:00
parent 897dae5475
commit da174dfd4d
15 changed files with 134 additions and 26 deletions

View File

@ -419,6 +419,7 @@ channels:
Notes:
- `assistant_id: lead_agent` calls the default LangGraph assistant directly.
- If `assistant_id` is set to a custom agent name, DeerFlow still routes through `lead_agent` and injects that value as `agent_name`, so the custom agent's SOUL/config takes effect for IM channels.
- IM channel workers call Gateway's LangGraph-compatible API internally and automatically attach process-local internal auth plus the CSRF cookie/header pair required for thread and run creation.
Set the corresponding API keys in your `.env` file:

View File

@ -314,7 +314,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
Bridges external messaging platforms (Feishu, Slack, Telegram) to the DeerFlow agent via Gateway's LangGraph-compatible API.
**Architecture**: Channels communicate with Gateway through the `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side.
**Architecture**: Channels communicate with Gateway through the `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side. The internal SDK client injects process-local internal auth plus a matching CSRF cookie/header pair so Gateway accepts state-changing thread/run requests from channel workers without relying on browser session cookies.
**Components**:
- `message_bus.py` - Async pub/sub hub (`InboundMessage` → queue → dispatcher; `OutboundMessage` → callbacks → channels)

View File

@ -17,6 +17,8 @@ from langgraph_sdk.errors import ConflictError
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
from app.channels.store import ChannelStore
from app.gateway.csrf_middleware import CSRF_COOKIE_NAME, CSRF_HEADER_NAME, generate_csrf_token
from app.gateway.internal_auth import create_internal_auth_headers
from deerflow.runtime.user_context import get_effective_user_id
logger = logging.getLogger(__name__)
@ -534,6 +536,7 @@ class ChannelManager:
self._default_session = _as_dict(default_session)
self._channel_sessions = dict(channel_sessions or {})
self._client = None # lazy init — langgraph_sdk async client
self._csrf_token = generate_csrf_token()
self._semaphore: asyncio.Semaphore | None = None
self._running = False
self._task: asyncio.Task | None = None
@ -586,7 +589,14 @@ class ChannelManager:
if self._client is None:
from langgraph_sdk import get_client
self._client = get_client(url=self._langgraph_url)
self._client = get_client(
url=self._langgraph_url,
headers={
**create_internal_auth_headers(),
CSRF_HEADER_NAME: self._csrf_token,
"Cookie": f"{CSRF_COOKIE_NAME}={self._csrf_token}",
},
)
return self._client
# -- lifecycle ---------------------------------------------------------

View File

@ -79,6 +79,11 @@ async def _ensure_admin_user(app: FastAPI) -> None:
# Skip admin migration work rather than failing gateway startup.
logger.warning("Auth persistence not ready; skipping admin bootstrap check")
return
sf = get_session_factory()
if sf is None:
return
admin_count = await provider.count_admin_users()
if admin_count == 0:
@ -90,10 +95,6 @@ async def _ensure_admin_user(app: FastAPI) -> None:
# Admin already exists — run orphan thread migration for any
# LangGraph thread metadata that pre-dates the auth module.
sf = get_session_factory()
if sf is None:
return
async with sf() as session:
stmt = select(UserRow).where(UserRow.system_role == "admin").limit(1)
row = (await session.execute(stmt)).scalar_one_or_none()

View File

@ -18,6 +18,7 @@ from starlette.types import ASGIApp
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
from app.gateway.authz import _ALL_PERMISSIONS, AuthContext
from app.gateway.internal_auth import INTERNAL_AUTH_HEADER_NAME, get_internal_user, is_valid_internal_auth_token
from deerflow.runtime.user_context import reset_current_user, set_current_user
# Paths that never require authentication.
@ -75,8 +76,12 @@ class AuthMiddleware(BaseHTTPMiddleware):
if _is_public(request.url.path):
return await call_next(request)
internal_user = None
if is_valid_internal_auth_token(request.headers.get(INTERNAL_AUTH_HEADER_NAME)):
internal_user = get_internal_user()
# Non-public path: require session cookie
if not request.cookies.get("access_token"):
if internal_user is None and not request.cookies.get("access_token"):
return JSONResponse(
status_code=401,
content={
@ -100,10 +105,13 @@ class AuthMiddleware(BaseHTTPMiddleware):
# bubble up, so we catch and render it as JSONResponse here.
from app.gateway.deps import get_current_user_from_request
try:
user = await get_current_user_from_request(request)
except HTTPException as exc:
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
if internal_user is not None:
user = internal_user
else:
try:
user = await get_current_user_from_request(request)
except HTTPException as exc:
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
# Stamp both request.state.user (for the contextvar pattern)
# and request.state.auth (so @require_permission's "auth is

View File

@ -0,0 +1,26 @@
"""Process-local authentication for Gateway internal callers."""
from __future__ import annotations
import secrets
from types import SimpleNamespace
from deerflow.runtime.user_context import DEFAULT_USER_ID
INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token"
_INTERNAL_AUTH_TOKEN = secrets.token_urlsafe(32)
def create_internal_auth_headers() -> dict[str, str]:
"""Return headers that authenticate same-process Gateway internal calls."""
return {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN}
def is_valid_internal_auth_token(token: str | None) -> bool:
"""Return True when *token* matches the process-local internal token."""
return bool(token) and secrets.compare_digest(token, _INTERNAL_AUTH_TOKEN)
def get_internal_user():
"""Return the synthetic user used for trusted internal channel calls."""
return SimpleNamespace(id=DEFAULT_USER_ID, system_role="internal")

View File

@ -70,6 +70,7 @@ class MemoryUpdateQueue:
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
user_id=user_id,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
@ -82,6 +83,7 @@ class MemoryUpdateQueue:
thread_id: str,
messages: list[Any],
agent_name: str | None = None,
user_id: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> None:
@ -95,6 +97,7 @@ class MemoryUpdateQueue:
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
user_id=user_id,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
@ -108,6 +111,7 @@ class MemoryUpdateQueue:
thread_id: str,
messages: list[Any],
agent_name: str | None,
user_id: str | None,
correction_detected: bool,
reinforcement_detected: bool,
) -> None:
@ -121,6 +125,7 @@ class MemoryUpdateQueue:
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
user_id=user_id,
correction_detected=merged_correction_detected,
reinforcement_detected=merged_reinforcement_detected,
)

View File

@ -66,7 +66,7 @@ class FileMemoryStorage(MemoryStorage):
"""Initialize the file memory storage."""
# Per-user/agent memory cache: keyed by (user_id, agent_name) tuple (None = global)
# Value: (memory_data, file_mtime)
self._memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {}
self._memory_cache: dict[tuple[str | None, str | None], tuple[dict[str, Any], float | None]] = {}
# Guards all reads and writes to _memory_cache across concurrent callers.
self._cache_lock = threading.Lock()
@ -116,9 +116,14 @@ class FileMemoryStorage(MemoryStorage):
logger.warning("Failed to load memory file: %s", e)
return create_empty_memory()
@staticmethod
def _cache_key(agent_name: str | None = None, *, user_id: str | None = None) -> tuple[str | None, str | None]:
return (user_id, agent_name)
def load(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Load memory data (cached with file modification time check)."""
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
cache_key = self._cache_key(agent_name, user_id=user_id)
try:
current_mtime = file_path.stat().st_mtime if file_path.exists() else None
@ -126,14 +131,14 @@ class FileMemoryStorage(MemoryStorage):
current_mtime = None
with self._cache_lock:
cached = self._memory_cache.get(agent_name)
cached = self._memory_cache.get(cache_key)
if cached is not None and cached[1] == current_mtime:
return cached[0]
memory_data = self._load_memory_from_file(agent_name)
memory_data = self._load_memory_from_file(agent_name, user_id=user_id)
with self._cache_lock:
self._memory_cache[agent_name] = (memory_data, current_mtime)
self._memory_cache[cache_key] = (memory_data, current_mtime)
return memory_data
@ -141,6 +146,7 @@ class FileMemoryStorage(MemoryStorage):
"""Reload memory data from file, forcing cache invalidation."""
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
memory_data = self._load_memory_from_file(agent_name, user_id=user_id)
cache_key = self._cache_key(agent_name, user_id=user_id)
try:
mtime = file_path.stat().st_mtime if file_path.exists() else None
@ -148,12 +154,13 @@ class FileMemoryStorage(MemoryStorage):
mtime = None
with self._cache_lock:
self._memory_cache[agent_name] = (memory_data, mtime)
self._memory_cache[cache_key] = (memory_data, mtime)
return memory_data
def save(self, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
"""Save memory data to file and update cache."""
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
cache_key = self._cache_key(agent_name, user_id=user_id)
try:
file_path.parent.mkdir(parents=True, exist_ok=True)
@ -174,7 +181,7 @@ class FileMemoryStorage(MemoryStorage):
mtime = None
with self._cache_lock:
self._memory_cache[agent_name] = (memory_data, mtime)
self._memory_cache[cache_key] = (memory_data, mtime)
logger.info("Memory saved to %s", file_path)
return True
except OSError as e:

View File

@ -141,7 +141,7 @@ class RunJournal(BaseCallbackHandler):
logger.info(f"on_chat_model_start {run_id}: tags={tags} serialized={serialized} messages={messages}")
# Capture the first human message sent to any LLM in this run.
if not self._first_human_msg:
if not self._first_human_msg and not messages:
for batch in messages.reversed():
for m in batch.reversed():
if isinstance(m, HumanMessage) and m.name != "summary":

View File

@ -66,7 +66,10 @@ def _normalize_presented_filepath(
virtual_prefix = VIRTUAL_PATH_PREFIX.lstrip("/")
if stripped == virtual_prefix or stripped.startswith(virtual_prefix + "/"):
actual_path = get_paths().resolve_virtual_path(thread_id, filepath, user_id=get_effective_user_id())
try:
actual_path = get_paths().resolve_virtual_path(thread_id, filepath, user_id=get_effective_user_id())
except TypeError:
actual_path = get_paths().resolve_virtual_path(thread_id, filepath)
else:
actual_path = Path(filepath).expanduser().resolve()

View File

@ -174,6 +174,20 @@ def test_protected_post_no_cookie_returns_401(client):
assert res.status_code == 401
def test_protected_post_with_internal_auth_header_passes():
from app.gateway.internal_auth import create_internal_auth_headers
app = _make_app()
client = TestClient(app)
res = client.post(
"/api/threads/abc/runs/stream",
headers=create_internal_auth_headers(),
)
assert res.status_code == 200
# ── Method matrix: PUT/DELETE/PATCH also protected ────────────────────────

View File

@ -414,6 +414,27 @@ def _make_async_iterator(items):
class TestChannelManager:
def test_get_client_includes_csrf_header_and_cookie(self):
from app.channels.manager import ChannelManager
bus = MessageBus()
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
manager = ChannelManager(bus=bus, store=store, langgraph_url="http://localhost:8001")
with patch("langgraph_sdk.get_client") as get_client:
get_client.return_value = object()
manager._get_client()
get_client.assert_called_once()
kwargs = get_client.call_args.kwargs
assert kwargs["url"] == "http://localhost:8001"
headers = kwargs["headers"]
csrf_token = headers["X-CSRF-Token"]
assert csrf_token
assert headers["Cookie"] == f"csrf_token={csrf_token}"
assert headers["X-DeerFlow-Internal-Token"]
def test_handle_chat_calls_channel_receive_file_for_inbound_files(self, monkeypatch):
from app.channels.manager import ChannelManager

View File

@ -7,6 +7,15 @@ from deerflow.agents.lead_agent import prompt as prompt_module
from deerflow.skills.types import Skill
def _set_skills_cache_state(*, skills=None, active=False, version=0):
prompt_module._get_cached_skills_prompt_section.cache_clear()
with prompt_module._enabled_skills_lock:
prompt_module._enabled_skills_cache = skills
prompt_module._enabled_skills_refresh_active = active
prompt_module._enabled_skills_refresh_version = version
prompt_module._enabled_skills_refresh_event.clear()
def test_build_custom_mounts_section_returns_empty_when_no_mounts(monkeypatch):
config = SimpleNamespace(sandbox=SimpleNamespace(mounts=[]))
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
@ -84,7 +93,7 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc
state = {"skills": [make_skill("first-skill")]}
monkeypatch.setattr(prompt_module, "load_skills", lambda enabled_only=True: list(state["skills"]))
prompt_module.clear_skills_system_prompt_cache()
_set_skills_cache_state()
try:
prompt_module.warm_enabled_skills_cache()
@ -95,7 +104,7 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc
assert [skill.name for skill in prompt_module._get_enabled_skills()] == ["second-skill"]
finally:
prompt_module.clear_skills_system_prompt_cache()
_set_skills_cache_state()
def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_path):
@ -137,7 +146,7 @@ def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_pa
return [make_skill(f"skill-{current_call}")]
monkeypatch.setattr(prompt_module, "load_skills", fake_load_skills)
prompt_module.clear_skills_system_prompt_cache()
_set_skills_cache_state()
try:
prompt_module.clear_skills_system_prompt_cache()
@ -151,7 +160,7 @@ def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_pa
assert [skill.name for skill in prompt_module._get_enabled_skills()] == ["skill-2"]
finally:
release.set()
prompt_module.clear_skills_system_prompt_cache()
_set_skills_cache_state()
def test_warm_enabled_skills_cache_logs_on_timeout(monkeypatch, caplog):

View File

@ -93,7 +93,10 @@ class TestTitleMiddlewareCoreLogic:
assert title == "短标题"
title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False)
model.ainvoke.assert_awaited_once()
assert model.ainvoke.await_args.kwargs["config"] == {"run_name": "title_agent"}
assert model.ainvoke.await_args.kwargs["config"] == {
"run_name": "title_agent",
"tags": ["middleware:title"],
}
def test_generate_title_normalizes_structured_message_content(self, monkeypatch):
_set_test_title_config(max_chars=20)

View File

@ -49,7 +49,7 @@ def test_upload_files_skips_acquire_when_thread_data_is_mounted(tmp_path):
patch.object(uploads, "get_sandbox_provider", return_value=provider),
):
file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads"))
result = asyncio.run(uploads.upload_files("thread-mounted", files=[file]))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-mounted", request=MagicMock(), files=[file]))
assert result.success is True
assert (thread_uploads_dir / "notes.txt").read_bytes() == b"hello uploads"
@ -75,7 +75,7 @@ def test_upload_files_does_not_auto_convert_documents_by_default(tmp_path):
patch.object(uploads, "convert_file_to_markdown", AsyncMock()) as convert_mock,
):
file = UploadFile(filename="report.pdf", file=BytesIO(b"pdf-bytes"))
result = asyncio.run(uploads.upload_files("thread-local", files=[file]))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file]))
assert result.success is True
assert len(result.files) == 1