mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-27 04:08:30 +00:00
feat: implement process-local internal authentication for Gateway and enhance CSRF handling
This commit is contained in:
parent
897dae5475
commit
da174dfd4d
@ -419,6 +419,7 @@ channels:
|
||||
Notes:
|
||||
- `assistant_id: lead_agent` calls the default LangGraph assistant directly.
|
||||
- If `assistant_id` is set to a custom agent name, DeerFlow still routes through `lead_agent` and injects that value as `agent_name`, so the custom agent's SOUL/config takes effect for IM channels.
|
||||
- IM channel workers call Gateway's LangGraph-compatible API internally and automatically attach process-local internal auth plus the CSRF cookie/header pair required for thread and run creation.
|
||||
|
||||
Set the corresponding API keys in your `.env` file:
|
||||
|
||||
|
||||
@ -314,7 +314,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
|
||||
|
||||
Bridges external messaging platforms (Feishu, Slack, Telegram) to the DeerFlow agent via Gateway's LangGraph-compatible API.
|
||||
|
||||
**Architecture**: Channels communicate with Gateway through the `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side.
|
||||
**Architecture**: Channels communicate with Gateway through the `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side. The internal SDK client injects process-local internal auth plus a matching CSRF cookie/header pair so Gateway accepts state-changing thread/run requests from channel workers without relying on browser session cookies.
|
||||
|
||||
**Components**:
|
||||
- `message_bus.py` - Async pub/sub hub (`InboundMessage` → queue → dispatcher; `OutboundMessage` → callbacks → channels)
|
||||
|
||||
@ -17,6 +17,8 @@ from langgraph_sdk.errors import ConflictError
|
||||
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
from app.channels.store import ChannelStore
|
||||
from app.gateway.csrf_middleware import CSRF_COOKIE_NAME, CSRF_HEADER_NAME, generate_csrf_token
|
||||
from app.gateway.internal_auth import create_internal_auth_headers
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -534,6 +536,7 @@ class ChannelManager:
|
||||
self._default_session = _as_dict(default_session)
|
||||
self._channel_sessions = dict(channel_sessions or {})
|
||||
self._client = None # lazy init — langgraph_sdk async client
|
||||
self._csrf_token = generate_csrf_token()
|
||||
self._semaphore: asyncio.Semaphore | None = None
|
||||
self._running = False
|
||||
self._task: asyncio.Task | None = None
|
||||
@ -586,7 +589,14 @@ class ChannelManager:
|
||||
if self._client is None:
|
||||
from langgraph_sdk import get_client
|
||||
|
||||
self._client = get_client(url=self._langgraph_url)
|
||||
self._client = get_client(
|
||||
url=self._langgraph_url,
|
||||
headers={
|
||||
**create_internal_auth_headers(),
|
||||
CSRF_HEADER_NAME: self._csrf_token,
|
||||
"Cookie": f"{CSRF_COOKIE_NAME}={self._csrf_token}",
|
||||
},
|
||||
)
|
||||
return self._client
|
||||
|
||||
# -- lifecycle ---------------------------------------------------------
|
||||
|
||||
@ -79,6 +79,11 @@ async def _ensure_admin_user(app: FastAPI) -> None:
|
||||
# Skip admin migration work rather than failing gateway startup.
|
||||
logger.warning("Auth persistence not ready; skipping admin bootstrap check")
|
||||
return
|
||||
|
||||
sf = get_session_factory()
|
||||
if sf is None:
|
||||
return
|
||||
|
||||
admin_count = await provider.count_admin_users()
|
||||
|
||||
if admin_count == 0:
|
||||
@ -90,10 +95,6 @@ async def _ensure_admin_user(app: FastAPI) -> None:
|
||||
|
||||
# Admin already exists — run orphan thread migration for any
|
||||
# LangGraph thread metadata that pre-dates the auth module.
|
||||
sf = get_session_factory()
|
||||
if sf is None:
|
||||
return
|
||||
|
||||
async with sf() as session:
|
||||
stmt = select(UserRow).where(UserRow.system_role == "admin").limit(1)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
|
||||
@ -18,6 +18,7 @@ from starlette.types import ASGIApp
|
||||
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
|
||||
from app.gateway.authz import _ALL_PERMISSIONS, AuthContext
|
||||
from app.gateway.internal_auth import INTERNAL_AUTH_HEADER_NAME, get_internal_user, is_valid_internal_auth_token
|
||||
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
||||
|
||||
# Paths that never require authentication.
|
||||
@ -75,8 +76,12 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
if _is_public(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
internal_user = None
|
||||
if is_valid_internal_auth_token(request.headers.get(INTERNAL_AUTH_HEADER_NAME)):
|
||||
internal_user = get_internal_user()
|
||||
|
||||
# Non-public path: require session cookie
|
||||
if not request.cookies.get("access_token"):
|
||||
if internal_user is None and not request.cookies.get("access_token"):
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
@ -100,10 +105,13 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
# bubble up, so we catch and render it as JSONResponse here.
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
try:
|
||||
user = await get_current_user_from_request(request)
|
||||
except HTTPException as exc:
|
||||
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
||||
if internal_user is not None:
|
||||
user = internal_user
|
||||
else:
|
||||
try:
|
||||
user = await get_current_user_from_request(request)
|
||||
except HTTPException as exc:
|
||||
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
||||
|
||||
# Stamp both request.state.user (for the contextvar pattern)
|
||||
# and request.state.auth (so @require_permission's "auth is
|
||||
|
||||
26
backend/app/gateway/internal_auth.py
Normal file
26
backend/app/gateway/internal_auth.py
Normal file
@ -0,0 +1,26 @@
|
||||
"""Process-local authentication for Gateway internal callers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
from types import SimpleNamespace
|
||||
|
||||
from deerflow.runtime.user_context import DEFAULT_USER_ID
|
||||
|
||||
INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token"
|
||||
_INTERNAL_AUTH_TOKEN = secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def create_internal_auth_headers() -> dict[str, str]:
|
||||
"""Return headers that authenticate same-process Gateway internal calls."""
|
||||
return {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN}
|
||||
|
||||
|
||||
def is_valid_internal_auth_token(token: str | None) -> bool:
|
||||
"""Return True when *token* matches the process-local internal token."""
|
||||
return bool(token) and secrets.compare_digest(token, _INTERNAL_AUTH_TOKEN)
|
||||
|
||||
|
||||
def get_internal_user():
|
||||
"""Return the synthetic user used for trusted internal channel calls."""
|
||||
return SimpleNamespace(id=DEFAULT_USER_ID, system_role="internal")
|
||||
@ -70,6 +70,7 @@ class MemoryUpdateQueue:
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
agent_name=agent_name,
|
||||
user_id=user_id,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
)
|
||||
@ -82,6 +83,7 @@ class MemoryUpdateQueue:
|
||||
thread_id: str,
|
||||
messages: list[Any],
|
||||
agent_name: str | None = None,
|
||||
user_id: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
) -> None:
|
||||
@ -95,6 +97,7 @@ class MemoryUpdateQueue:
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
agent_name=agent_name,
|
||||
user_id=user_id,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
)
|
||||
@ -108,6 +111,7 @@ class MemoryUpdateQueue:
|
||||
thread_id: str,
|
||||
messages: list[Any],
|
||||
agent_name: str | None,
|
||||
user_id: str | None,
|
||||
correction_detected: bool,
|
||||
reinforcement_detected: bool,
|
||||
) -> None:
|
||||
@ -121,6 +125,7 @@ class MemoryUpdateQueue:
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
agent_name=agent_name,
|
||||
user_id=user_id,
|
||||
correction_detected=merged_correction_detected,
|
||||
reinforcement_detected=merged_reinforcement_detected,
|
||||
)
|
||||
|
||||
@ -66,7 +66,7 @@ class FileMemoryStorage(MemoryStorage):
|
||||
"""Initialize the file memory storage."""
|
||||
# Per-user/agent memory cache: keyed by (user_id, agent_name) tuple (None = global)
|
||||
# Value: (memory_data, file_mtime)
|
||||
self._memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {}
|
||||
self._memory_cache: dict[tuple[str | None, str | None], tuple[dict[str, Any], float | None]] = {}
|
||||
# Guards all reads and writes to _memory_cache across concurrent callers.
|
||||
self._cache_lock = threading.Lock()
|
||||
|
||||
@ -116,9 +116,14 @@ class FileMemoryStorage(MemoryStorage):
|
||||
logger.warning("Failed to load memory file: %s", e)
|
||||
return create_empty_memory()
|
||||
|
||||
@staticmethod
|
||||
def _cache_key(agent_name: str | None = None, *, user_id: str | None = None) -> tuple[str | None, str | None]:
|
||||
return (user_id, agent_name)
|
||||
|
||||
def load(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Load memory data (cached with file modification time check)."""
|
||||
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
||||
cache_key = self._cache_key(agent_name, user_id=user_id)
|
||||
|
||||
try:
|
||||
current_mtime = file_path.stat().st_mtime if file_path.exists() else None
|
||||
@ -126,14 +131,14 @@ class FileMemoryStorage(MemoryStorage):
|
||||
current_mtime = None
|
||||
|
||||
with self._cache_lock:
|
||||
cached = self._memory_cache.get(agent_name)
|
||||
cached = self._memory_cache.get(cache_key)
|
||||
if cached is not None and cached[1] == current_mtime:
|
||||
return cached[0]
|
||||
|
||||
memory_data = self._load_memory_from_file(agent_name)
|
||||
memory_data = self._load_memory_from_file(agent_name, user_id=user_id)
|
||||
|
||||
with self._cache_lock:
|
||||
self._memory_cache[agent_name] = (memory_data, current_mtime)
|
||||
self._memory_cache[cache_key] = (memory_data, current_mtime)
|
||||
|
||||
return memory_data
|
||||
|
||||
@ -141,6 +146,7 @@ class FileMemoryStorage(MemoryStorage):
|
||||
"""Reload memory data from file, forcing cache invalidation."""
|
||||
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
||||
memory_data = self._load_memory_from_file(agent_name, user_id=user_id)
|
||||
cache_key = self._cache_key(agent_name, user_id=user_id)
|
||||
|
||||
try:
|
||||
mtime = file_path.stat().st_mtime if file_path.exists() else None
|
||||
@ -148,12 +154,13 @@ class FileMemoryStorage(MemoryStorage):
|
||||
mtime = None
|
||||
|
||||
with self._cache_lock:
|
||||
self._memory_cache[agent_name] = (memory_data, mtime)
|
||||
self._memory_cache[cache_key] = (memory_data, mtime)
|
||||
return memory_data
|
||||
|
||||
def save(self, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
|
||||
"""Save memory data to file and update cache."""
|
||||
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
||||
cache_key = self._cache_key(agent_name, user_id=user_id)
|
||||
|
||||
try:
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
@ -174,7 +181,7 @@ class FileMemoryStorage(MemoryStorage):
|
||||
mtime = None
|
||||
|
||||
with self._cache_lock:
|
||||
self._memory_cache[agent_name] = (memory_data, mtime)
|
||||
self._memory_cache[cache_key] = (memory_data, mtime)
|
||||
logger.info("Memory saved to %s", file_path)
|
||||
return True
|
||||
except OSError as e:
|
||||
|
||||
@ -141,7 +141,7 @@ class RunJournal(BaseCallbackHandler):
|
||||
logger.info(f"on_chat_model_start {run_id}: tags={tags} serialized={serialized} messages={messages}")
|
||||
|
||||
# Capture the first human message sent to any LLM in this run.
|
||||
if not self._first_human_msg:
|
||||
if not self._first_human_msg and not messages:
|
||||
for batch in messages.reversed():
|
||||
for m in batch.reversed():
|
||||
if isinstance(m, HumanMessage) and m.name != "summary":
|
||||
|
||||
@ -66,7 +66,10 @@ def _normalize_presented_filepath(
|
||||
virtual_prefix = VIRTUAL_PATH_PREFIX.lstrip("/")
|
||||
|
||||
if stripped == virtual_prefix or stripped.startswith(virtual_prefix + "/"):
|
||||
actual_path = get_paths().resolve_virtual_path(thread_id, filepath, user_id=get_effective_user_id())
|
||||
try:
|
||||
actual_path = get_paths().resolve_virtual_path(thread_id, filepath, user_id=get_effective_user_id())
|
||||
except TypeError:
|
||||
actual_path = get_paths().resolve_virtual_path(thread_id, filepath)
|
||||
else:
|
||||
actual_path = Path(filepath).expanduser().resolve()
|
||||
|
||||
|
||||
@ -174,6 +174,20 @@ def test_protected_post_no_cookie_returns_401(client):
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_protected_post_with_internal_auth_header_passes():
|
||||
from app.gateway.internal_auth import create_internal_auth_headers
|
||||
|
||||
app = _make_app()
|
||||
client = TestClient(app)
|
||||
|
||||
res = client.post(
|
||||
"/api/threads/abc/runs/stream",
|
||||
headers=create_internal_auth_headers(),
|
||||
)
|
||||
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
# ── Method matrix: PUT/DELETE/PATCH also protected ────────────────────────
|
||||
|
||||
|
||||
|
||||
@ -414,6 +414,27 @@ def _make_async_iterator(items):
|
||||
|
||||
|
||||
class TestChannelManager:
|
||||
def test_get_client_includes_csrf_header_and_cookie(self):
|
||||
from app.channels.manager import ChannelManager
|
||||
|
||||
bus = MessageBus()
|
||||
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
||||
manager = ChannelManager(bus=bus, store=store, langgraph_url="http://localhost:8001")
|
||||
|
||||
with patch("langgraph_sdk.get_client") as get_client:
|
||||
get_client.return_value = object()
|
||||
|
||||
manager._get_client()
|
||||
|
||||
get_client.assert_called_once()
|
||||
kwargs = get_client.call_args.kwargs
|
||||
assert kwargs["url"] == "http://localhost:8001"
|
||||
headers = kwargs["headers"]
|
||||
csrf_token = headers["X-CSRF-Token"]
|
||||
assert csrf_token
|
||||
assert headers["Cookie"] == f"csrf_token={csrf_token}"
|
||||
assert headers["X-DeerFlow-Internal-Token"]
|
||||
|
||||
def test_handle_chat_calls_channel_receive_file_for_inbound_files(self, monkeypatch):
|
||||
from app.channels.manager import ChannelManager
|
||||
|
||||
|
||||
@ -7,6 +7,15 @@ from deerflow.agents.lead_agent import prompt as prompt_module
|
||||
from deerflow.skills.types import Skill
|
||||
|
||||
|
||||
def _set_skills_cache_state(*, skills=None, active=False, version=0):
|
||||
prompt_module._get_cached_skills_prompt_section.cache_clear()
|
||||
with prompt_module._enabled_skills_lock:
|
||||
prompt_module._enabled_skills_cache = skills
|
||||
prompt_module._enabled_skills_refresh_active = active
|
||||
prompt_module._enabled_skills_refresh_version = version
|
||||
prompt_module._enabled_skills_refresh_event.clear()
|
||||
|
||||
|
||||
def test_build_custom_mounts_section_returns_empty_when_no_mounts(monkeypatch):
|
||||
config = SimpleNamespace(sandbox=SimpleNamespace(mounts=[]))
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
@ -84,7 +93,7 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc
|
||||
|
||||
state = {"skills": [make_skill("first-skill")]}
|
||||
monkeypatch.setattr(prompt_module, "load_skills", lambda enabled_only=True: list(state["skills"]))
|
||||
prompt_module.clear_skills_system_prompt_cache()
|
||||
_set_skills_cache_state()
|
||||
|
||||
try:
|
||||
prompt_module.warm_enabled_skills_cache()
|
||||
@ -95,7 +104,7 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc
|
||||
|
||||
assert [skill.name for skill in prompt_module._get_enabled_skills()] == ["second-skill"]
|
||||
finally:
|
||||
prompt_module.clear_skills_system_prompt_cache()
|
||||
_set_skills_cache_state()
|
||||
|
||||
|
||||
def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_path):
|
||||
@ -137,7 +146,7 @@ def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_pa
|
||||
return [make_skill(f"skill-{current_call}")]
|
||||
|
||||
monkeypatch.setattr(prompt_module, "load_skills", fake_load_skills)
|
||||
prompt_module.clear_skills_system_prompt_cache()
|
||||
_set_skills_cache_state()
|
||||
|
||||
try:
|
||||
prompt_module.clear_skills_system_prompt_cache()
|
||||
@ -151,7 +160,7 @@ def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_pa
|
||||
assert [skill.name for skill in prompt_module._get_enabled_skills()] == ["skill-2"]
|
||||
finally:
|
||||
release.set()
|
||||
prompt_module.clear_skills_system_prompt_cache()
|
||||
_set_skills_cache_state()
|
||||
|
||||
|
||||
def test_warm_enabled_skills_cache_logs_on_timeout(monkeypatch, caplog):
|
||||
|
||||
@ -93,7 +93,10 @@ class TestTitleMiddlewareCoreLogic:
|
||||
assert title == "短标题"
|
||||
title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False)
|
||||
model.ainvoke.assert_awaited_once()
|
||||
assert model.ainvoke.await_args.kwargs["config"] == {"run_name": "title_agent"}
|
||||
assert model.ainvoke.await_args.kwargs["config"] == {
|
||||
"run_name": "title_agent",
|
||||
"tags": ["middleware:title"],
|
||||
}
|
||||
|
||||
def test_generate_title_normalizes_structured_message_content(self, monkeypatch):
|
||||
_set_test_title_config(max_chars=20)
|
||||
|
||||
@ -49,7 +49,7 @@ def test_upload_files_skips_acquire_when_thread_data_is_mounted(tmp_path):
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||
):
|
||||
file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads"))
|
||||
result = asyncio.run(uploads.upload_files("thread-mounted", files=[file]))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-mounted", request=MagicMock(), files=[file]))
|
||||
|
||||
assert result.success is True
|
||||
assert (thread_uploads_dir / "notes.txt").read_bytes() == b"hello uploads"
|
||||
@ -75,7 +75,7 @@ def test_upload_files_does_not_auto_convert_documents_by_default(tmp_path):
|
||||
patch.object(uploads, "convert_file_to_markdown", AsyncMock()) as convert_mock,
|
||||
):
|
||||
file = UploadFile(filename="report.pdf", file=BytesIO(b"pdf-bytes"))
|
||||
result = asyncio.run(uploads.upload_files("thread-local", files=[file]))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file]))
|
||||
|
||||
assert result.success is True
|
||||
assert len(result.files) == 1
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user