diff --git a/README.md b/README.md index 908f11c69..b64c75f97 100644 --- a/README.md +++ b/README.md @@ -419,6 +419,7 @@ channels: Notes: - `assistant_id: lead_agent` calls the default LangGraph assistant directly. - If `assistant_id` is set to a custom agent name, DeerFlow still routes through `lead_agent` and injects that value as `agent_name`, so the custom agent's SOUL/config takes effect for IM channels. +- IM channel workers call Gateway's LangGraph-compatible API internally and automatically attach process-local internal auth plus the CSRF cookie/header pair required for thread and run creation. Set the corresponding API keys in your `.env` file: diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md index 10b9db6c4..4414a0203 100644 --- a/backend/CLAUDE.md +++ b/backend/CLAUDE.md @@ -314,7 +314,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` → Bridges external messaging platforms (Feishu, Slack, Telegram) to the DeerFlow agent via Gateway's LangGraph-compatible API. -**Architecture**: Channels communicate with Gateway through the `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side. +**Architecture**: Channels communicate with Gateway through the `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side. The internal SDK client injects process-local internal auth plus a matching CSRF cookie/header pair so Gateway accepts state-changing thread/run requests from channel workers without relying on browser session cookies. **Components**: - `message_bus.py` - Async pub/sub hub (`InboundMessage` → queue → dispatcher; `OutboundMessage` → callbacks → channels) diff --git a/backend/app/channels/manager.py b/backend/app/channels/manager.py index 5c5848bc2..5680943b0 100644 --- a/backend/app/channels/manager.py +++ b/backend/app/channels/manager.py @@ -17,6 +17,8 @@ from langgraph_sdk.errors import ConflictError from app.channels.commands import KNOWN_CHANNEL_COMMANDS from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment from app.channels.store import ChannelStore +from app.gateway.csrf_middleware import CSRF_COOKIE_NAME, CSRF_HEADER_NAME, generate_csrf_token +from app.gateway.internal_auth import create_internal_auth_headers from deerflow.runtime.user_context import get_effective_user_id logger = logging.getLogger(__name__) @@ -534,6 +536,7 @@ class ChannelManager: self._default_session = _as_dict(default_session) self._channel_sessions = dict(channel_sessions or {}) self._client = None # lazy init — langgraph_sdk async client + self._csrf_token = generate_csrf_token() self._semaphore: asyncio.Semaphore | None = None self._running = False self._task: asyncio.Task | None = None @@ -586,7 +589,14 @@ class ChannelManager: if self._client is None: from langgraph_sdk import get_client - self._client = get_client(url=self._langgraph_url) + self._client = get_client( + url=self._langgraph_url, + headers={ + **create_internal_auth_headers(), + CSRF_HEADER_NAME: self._csrf_token, + "Cookie": f"{CSRF_COOKIE_NAME}={self._csrf_token}", + }, + ) return self._client # -- lifecycle --------------------------------------------------------- diff --git a/backend/app/gateway/app.py b/backend/app/gateway/app.py index cdf353299..852c787fd 100644 --- a/backend/app/gateway/app.py +++ b/backend/app/gateway/app.py @@ -79,6 +79,11 @@ async def _ensure_admin_user(app: FastAPI) -> None: # Skip admin migration work rather than failing gateway startup. logger.warning("Auth persistence not ready; skipping admin bootstrap check") return + + sf = get_session_factory() + if sf is None: + return + admin_count = await provider.count_admin_users() if admin_count == 0: @@ -90,10 +95,6 @@ async def _ensure_admin_user(app: FastAPI) -> None: # Admin already exists — run orphan thread migration for any # LangGraph thread metadata that pre-dates the auth module. - sf = get_session_factory() - if sf is None: - return - async with sf() as session: stmt = select(UserRow).where(UserRow.system_role == "admin").limit(1) row = (await session.execute(stmt)).scalar_one_or_none() diff --git a/backend/app/gateway/auth_middleware.py b/backend/app/gateway/auth_middleware.py index fd982cd79..6b6452264 100644 --- a/backend/app/gateway/auth_middleware.py +++ b/backend/app/gateway/auth_middleware.py @@ -18,6 +18,7 @@ from starlette.types import ASGIApp from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse from app.gateway.authz import _ALL_PERMISSIONS, AuthContext +from app.gateway.internal_auth import INTERNAL_AUTH_HEADER_NAME, get_internal_user, is_valid_internal_auth_token from deerflow.runtime.user_context import reset_current_user, set_current_user # Paths that never require authentication. @@ -75,8 +76,12 @@ class AuthMiddleware(BaseHTTPMiddleware): if _is_public(request.url.path): return await call_next(request) + internal_user = None + if is_valid_internal_auth_token(request.headers.get(INTERNAL_AUTH_HEADER_NAME)): + internal_user = get_internal_user() + # Non-public path: require session cookie - if not request.cookies.get("access_token"): + if internal_user is None and not request.cookies.get("access_token"): return JSONResponse( status_code=401, content={ @@ -100,10 +105,13 @@ class AuthMiddleware(BaseHTTPMiddleware): # bubble up, so we catch and render it as JSONResponse here. from app.gateway.deps import get_current_user_from_request - try: - user = await get_current_user_from_request(request) - except HTTPException as exc: - return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail}) + if internal_user is not None: + user = internal_user + else: + try: + user = await get_current_user_from_request(request) + except HTTPException as exc: + return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail}) # Stamp both request.state.user (for the contextvar pattern) # and request.state.auth (so @require_permission's "auth is diff --git a/backend/app/gateway/internal_auth.py b/backend/app/gateway/internal_auth.py new file mode 100644 index 000000000..b0380379b --- /dev/null +++ b/backend/app/gateway/internal_auth.py @@ -0,0 +1,26 @@ +"""Process-local authentication for Gateway internal callers.""" + +from __future__ import annotations + +import secrets +from types import SimpleNamespace + +from deerflow.runtime.user_context import DEFAULT_USER_ID + +INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token" +_INTERNAL_AUTH_TOKEN = secrets.token_urlsafe(32) + + +def create_internal_auth_headers() -> dict[str, str]: + """Return headers that authenticate same-process Gateway internal calls.""" + return {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN} + + +def is_valid_internal_auth_token(token: str | None) -> bool: + """Return True when *token* matches the process-local internal token.""" + return bool(token) and secrets.compare_digest(token, _INTERNAL_AUTH_TOKEN) + + +def get_internal_user(): + """Return the synthetic user used for trusted internal channel calls.""" + return SimpleNamespace(id=DEFAULT_USER_ID, system_role="internal") diff --git a/backend/packages/harness/deerflow/agents/memory/queue.py b/backend/packages/harness/deerflow/agents/memory/queue.py index 0cdff8e6b..b2a147bce 100644 --- a/backend/packages/harness/deerflow/agents/memory/queue.py +++ b/backend/packages/harness/deerflow/agents/memory/queue.py @@ -70,6 +70,7 @@ class MemoryUpdateQueue: thread_id=thread_id, messages=messages, agent_name=agent_name, + user_id=user_id, correction_detected=correction_detected, reinforcement_detected=reinforcement_detected, ) @@ -82,6 +83,7 @@ class MemoryUpdateQueue: thread_id: str, messages: list[Any], agent_name: str | None = None, + user_id: str | None = None, correction_detected: bool = False, reinforcement_detected: bool = False, ) -> None: @@ -95,6 +97,7 @@ class MemoryUpdateQueue: thread_id=thread_id, messages=messages, agent_name=agent_name, + user_id=user_id, correction_detected=correction_detected, reinforcement_detected=reinforcement_detected, ) @@ -108,6 +111,7 @@ class MemoryUpdateQueue: thread_id: str, messages: list[Any], agent_name: str | None, + user_id: str | None, correction_detected: bool, reinforcement_detected: bool, ) -> None: @@ -121,6 +125,7 @@ class MemoryUpdateQueue: thread_id=thread_id, messages=messages, agent_name=agent_name, + user_id=user_id, correction_detected=merged_correction_detected, reinforcement_detected=merged_reinforcement_detected, ) diff --git a/backend/packages/harness/deerflow/agents/memory/storage.py b/backend/packages/harness/deerflow/agents/memory/storage.py index f8a527b35..3d0a0e9af 100644 --- a/backend/packages/harness/deerflow/agents/memory/storage.py +++ b/backend/packages/harness/deerflow/agents/memory/storage.py @@ -66,7 +66,7 @@ class FileMemoryStorage(MemoryStorage): """Initialize the file memory storage.""" # Per-user/agent memory cache: keyed by (user_id, agent_name) tuple (None = global) # Value: (memory_data, file_mtime) - self._memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {} + self._memory_cache: dict[tuple[str | None, str | None], tuple[dict[str, Any], float | None]] = {} # Guards all reads and writes to _memory_cache across concurrent callers. self._cache_lock = threading.Lock() @@ -116,9 +116,14 @@ class FileMemoryStorage(MemoryStorage): logger.warning("Failed to load memory file: %s", e) return create_empty_memory() + @staticmethod + def _cache_key(agent_name: str | None = None, *, user_id: str | None = None) -> tuple[str | None, str | None]: + return (user_id, agent_name) + def load(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: """Load memory data (cached with file modification time check).""" file_path = self._get_memory_file_path(agent_name, user_id=user_id) + cache_key = self._cache_key(agent_name, user_id=user_id) try: current_mtime = file_path.stat().st_mtime if file_path.exists() else None @@ -126,14 +131,14 @@ class FileMemoryStorage(MemoryStorage): current_mtime = None with self._cache_lock: - cached = self._memory_cache.get(agent_name) + cached = self._memory_cache.get(cache_key) if cached is not None and cached[1] == current_mtime: return cached[0] - memory_data = self._load_memory_from_file(agent_name) + memory_data = self._load_memory_from_file(agent_name, user_id=user_id) with self._cache_lock: - self._memory_cache[agent_name] = (memory_data, current_mtime) + self._memory_cache[cache_key] = (memory_data, current_mtime) return memory_data @@ -141,6 +146,7 @@ class FileMemoryStorage(MemoryStorage): """Reload memory data from file, forcing cache invalidation.""" file_path = self._get_memory_file_path(agent_name, user_id=user_id) memory_data = self._load_memory_from_file(agent_name, user_id=user_id) + cache_key = self._cache_key(agent_name, user_id=user_id) try: mtime = file_path.stat().st_mtime if file_path.exists() else None @@ -148,12 +154,13 @@ class FileMemoryStorage(MemoryStorage): mtime = None with self._cache_lock: - self._memory_cache[agent_name] = (memory_data, mtime) + self._memory_cache[cache_key] = (memory_data, mtime) return memory_data def save(self, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool: """Save memory data to file and update cache.""" file_path = self._get_memory_file_path(agent_name, user_id=user_id) + cache_key = self._cache_key(agent_name, user_id=user_id) try: file_path.parent.mkdir(parents=True, exist_ok=True) @@ -174,7 +181,7 @@ class FileMemoryStorage(MemoryStorage): mtime = None with self._cache_lock: - self._memory_cache[agent_name] = (memory_data, mtime) + self._memory_cache[cache_key] = (memory_data, mtime) logger.info("Memory saved to %s", file_path) return True except OSError as e: diff --git a/backend/packages/harness/deerflow/runtime/journal.py b/backend/packages/harness/deerflow/runtime/journal.py index 5f1838888..e47bb96e1 100644 --- a/backend/packages/harness/deerflow/runtime/journal.py +++ b/backend/packages/harness/deerflow/runtime/journal.py @@ -141,7 +141,7 @@ class RunJournal(BaseCallbackHandler): logger.info(f"on_chat_model_start {run_id}: tags={tags} serialized={serialized} messages={messages}") # Capture the first human message sent to any LLM in this run. - if not self._first_human_msg: + if not self._first_human_msg and not messages: for batch in messages.reversed(): for m in batch.reversed(): if isinstance(m, HumanMessage) and m.name != "summary": diff --git a/backend/packages/harness/deerflow/tools/builtins/present_file_tool.py b/backend/packages/harness/deerflow/tools/builtins/present_file_tool.py index 21c2aa6c8..13a7a017e 100644 --- a/backend/packages/harness/deerflow/tools/builtins/present_file_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/present_file_tool.py @@ -66,7 +66,10 @@ def _normalize_presented_filepath( virtual_prefix = VIRTUAL_PATH_PREFIX.lstrip("/") if stripped == virtual_prefix or stripped.startswith(virtual_prefix + "/"): - actual_path = get_paths().resolve_virtual_path(thread_id, filepath, user_id=get_effective_user_id()) + try: + actual_path = get_paths().resolve_virtual_path(thread_id, filepath, user_id=get_effective_user_id()) + except TypeError: + actual_path = get_paths().resolve_virtual_path(thread_id, filepath) else: actual_path = Path(filepath).expanduser().resolve() diff --git a/backend/tests/test_auth_middleware.py b/backend/tests/test_auth_middleware.py index 398f9cec6..726786ac9 100644 --- a/backend/tests/test_auth_middleware.py +++ b/backend/tests/test_auth_middleware.py @@ -174,6 +174,20 @@ def test_protected_post_no_cookie_returns_401(client): assert res.status_code == 401 +def test_protected_post_with_internal_auth_header_passes(): + from app.gateway.internal_auth import create_internal_auth_headers + + app = _make_app() + client = TestClient(app) + + res = client.post( + "/api/threads/abc/runs/stream", + headers=create_internal_auth_headers(), + ) + + assert res.status_code == 200 + + # ── Method matrix: PUT/DELETE/PATCH also protected ──────────────────────── diff --git a/backend/tests/test_channels.py b/backend/tests/test_channels.py index b2a5573a3..779b75b08 100644 --- a/backend/tests/test_channels.py +++ b/backend/tests/test_channels.py @@ -414,6 +414,27 @@ def _make_async_iterator(items): class TestChannelManager: + def test_get_client_includes_csrf_header_and_cookie(self): + from app.channels.manager import ChannelManager + + bus = MessageBus() + store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") + manager = ChannelManager(bus=bus, store=store, langgraph_url="http://localhost:8001") + + with patch("langgraph_sdk.get_client") as get_client: + get_client.return_value = object() + + manager._get_client() + + get_client.assert_called_once() + kwargs = get_client.call_args.kwargs + assert kwargs["url"] == "http://localhost:8001" + headers = kwargs["headers"] + csrf_token = headers["X-CSRF-Token"] + assert csrf_token + assert headers["Cookie"] == f"csrf_token={csrf_token}" + assert headers["X-DeerFlow-Internal-Token"] + def test_handle_chat_calls_channel_receive_file_for_inbound_files(self, monkeypatch): from app.channels.manager import ChannelManager diff --git a/backend/tests/test_lead_agent_prompt.py b/backend/tests/test_lead_agent_prompt.py index c63e32d99..e82cc7ccb 100644 --- a/backend/tests/test_lead_agent_prompt.py +++ b/backend/tests/test_lead_agent_prompt.py @@ -7,6 +7,15 @@ from deerflow.agents.lead_agent import prompt as prompt_module from deerflow.skills.types import Skill +def _set_skills_cache_state(*, skills=None, active=False, version=0): + prompt_module._get_cached_skills_prompt_section.cache_clear() + with prompt_module._enabled_skills_lock: + prompt_module._enabled_skills_cache = skills + prompt_module._enabled_skills_refresh_active = active + prompt_module._enabled_skills_refresh_version = version + prompt_module._enabled_skills_refresh_event.clear() + + def test_build_custom_mounts_section_returns_empty_when_no_mounts(monkeypatch): config = SimpleNamespace(sandbox=SimpleNamespace(mounts=[])) monkeypatch.setattr("deerflow.config.get_app_config", lambda: config) @@ -84,7 +93,7 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc state = {"skills": [make_skill("first-skill")]} monkeypatch.setattr(prompt_module, "load_skills", lambda enabled_only=True: list(state["skills"])) - prompt_module.clear_skills_system_prompt_cache() + _set_skills_cache_state() try: prompt_module.warm_enabled_skills_cache() @@ -95,7 +104,7 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc assert [skill.name for skill in prompt_module._get_enabled_skills()] == ["second-skill"] finally: - prompt_module.clear_skills_system_prompt_cache() + _set_skills_cache_state() def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_path): @@ -137,7 +146,7 @@ def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_pa return [make_skill(f"skill-{current_call}")] monkeypatch.setattr(prompt_module, "load_skills", fake_load_skills) - prompt_module.clear_skills_system_prompt_cache() + _set_skills_cache_state() try: prompt_module.clear_skills_system_prompt_cache() @@ -151,7 +160,7 @@ def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_pa assert [skill.name for skill in prompt_module._get_enabled_skills()] == ["skill-2"] finally: release.set() - prompt_module.clear_skills_system_prompt_cache() + _set_skills_cache_state() def test_warm_enabled_skills_cache_logs_on_timeout(monkeypatch, caplog): diff --git a/backend/tests/test_title_middleware_core_logic.py b/backend/tests/test_title_middleware_core_logic.py index 684de2345..afd10f2b3 100644 --- a/backend/tests/test_title_middleware_core_logic.py +++ b/backend/tests/test_title_middleware_core_logic.py @@ -93,7 +93,10 @@ class TestTitleMiddlewareCoreLogic: assert title == "短标题" title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False) model.ainvoke.assert_awaited_once() - assert model.ainvoke.await_args.kwargs["config"] == {"run_name": "title_agent"} + assert model.ainvoke.await_args.kwargs["config"] == { + "run_name": "title_agent", + "tags": ["middleware:title"], + } def test_generate_title_normalizes_structured_message_content(self, monkeypatch): _set_test_title_config(max_chars=20) diff --git a/backend/tests/test_uploads_router.py b/backend/tests/test_uploads_router.py index e2f51625d..65f3fb811 100644 --- a/backend/tests/test_uploads_router.py +++ b/backend/tests/test_uploads_router.py @@ -49,7 +49,7 @@ def test_upload_files_skips_acquire_when_thread_data_is_mounted(tmp_path): patch.object(uploads, "get_sandbox_provider", return_value=provider), ): file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads")) - result = asyncio.run(uploads.upload_files("thread-mounted", files=[file])) + result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-mounted", request=MagicMock(), files=[file])) assert result.success is True assert (thread_uploads_dir / "notes.txt").read_bytes() == b"hello uploads" @@ -75,7 +75,7 @@ def test_upload_files_does_not_auto_convert_documents_by_default(tmp_path): patch.object(uploads, "convert_file_to_markdown", AsyncMock()) as convert_mock, ): file = UploadFile(filename="report.pdf", file=BytesIO(b"pdf-bytes")) - result = asyncio.run(uploads.upload_files("thread-local", files=[file])) + result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file])) assert result.success is True assert len(result.files) == 1