diff --git a/backend/app/channels/manager.py b/backend/app/channels/manager.py index 982c87035..9beceeb3a 100644 --- a/backend/app/channels/manager.py +++ b/backend/app/channels/manager.py @@ -26,6 +26,7 @@ from app.channels.message_bus import ( from app.channels.store import ChannelStore from app.gateway.csrf_middleware import CSRF_COOKIE_NAME, CSRF_HEADER_NAME, generate_csrf_token from app.gateway.internal_auth import create_internal_auth_headers +from deerflow.config.paths import make_safe_user_id from deerflow.runtime.user_context import get_effective_user_id logger = logging.getLogger(__name__) @@ -670,12 +671,20 @@ class ChannelManager: configurable["checkpoint_ns"] = "" configurable["thread_id"] = thread_id + # ``user_id`` drives user-scoped filesystem buckets that only accept + # ``[A-Za-z0-9_-]``, so normalize the channel id and keep the raw value + # under ``channel_user_id`` for platform-facing lookups. + run_context_identity: dict[str, Any] = {"thread_id": thread_id} + if msg.user_id: + run_context_identity["user_id"] = make_safe_user_id(msg.user_id) + run_context_identity["channel_user_id"] = msg.user_id + run_context = _merge_dicts( DEFAULT_RUN_CONTEXT, self._default_session.get("context"), channel_layer.get("context"), user_layer.get("context"), - {"thread_id": thread_id}, + run_context_identity, ) # Custom agents are implemented as lead_agent + agent_name context. diff --git a/backend/app/gateway/internal_auth.py b/backend/app/gateway/internal_auth.py index 51ed89a99..3a00a9662 100644 --- a/backend/app/gateway/internal_auth.py +++ b/backend/app/gateway/internal_auth.py @@ -10,6 +10,7 @@ from deerflow.runtime.user_context import DEFAULT_USER_ID INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token" INTERNAL_AUTH_ENV_VAR = "DEER_FLOW_INTERNAL_AUTH_TOKEN" +INTERNAL_SYSTEM_ROLE = "internal" def _load_internal_auth_token() -> str: @@ -34,4 +35,4 @@ def is_valid_internal_auth_token(token: str | None) -> bool: def get_internal_user(): """Return the synthetic user used for trusted internal channel calls.""" - return SimpleNamespace(id=DEFAULT_USER_ID, system_role="internal") + return SimpleNamespace(id=DEFAULT_USER_ID, system_role=INTERNAL_SYSTEM_ROLE) diff --git a/backend/app/gateway/services.py b/backend/app/gateway/services.py index 63ac0c1bf..2c5c01e61 100644 --- a/backend/app/gateway/services.py +++ b/backend/app/gateway/services.py @@ -19,6 +19,7 @@ from langchain_core.messages import BaseMessage from langchain_core.messages.utils import convert_to_messages from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge +from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE from app.gateway.utils import sanitize_log_param from deerflow.config.app_config import get_app_config from deerflow.runtime import ( @@ -140,7 +141,14 @@ def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, An """Merge whitelisted keys from ``body.context`` into both ``config['configurable']`` and ``config['context']`` so they are visible to legacy configurable readers and to LangGraph ``ToolRuntime.context`` consumers (e.g. the ``setup_agent`` tool — - see issue #2677).""" + see issue #2677). + + ``user_id`` is intentionally propagated into ``config['context']`` in addition to + the whitelisted keys, so non-web callers (e.g. IM channels) that supply identity in + ``body.context`` keep it on ``ToolRuntime.context``. It is merged with + ``setdefault`` so a server-authenticated id stamped by + :func:`inject_authenticated_user_context` always wins over the client-supplied one. + """ if not context: return configurable = config.setdefault("configurable", {}) @@ -151,6 +159,8 @@ def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, An configurable.setdefault(key, context[key]) if isinstance(runtime_context, dict): runtime_context.setdefault(key, context[key]) + if "user_id" in context and isinstance(runtime_context, dict): + runtime_context.setdefault("user_id", context["user_id"]) def inject_authenticated_user_context(config: dict[str, Any], request: Request) -> None: @@ -166,6 +176,9 @@ def inject_authenticated_user_context(config: dict[str, Any], request: Request) if user_id is None: return + if getattr(user, "system_role", None) == INTERNAL_SYSTEM_ROLE: + return + runtime_context = config.setdefault("context", {}) if isinstance(runtime_context, dict): runtime_context["user_id"] = str(user_id) diff --git a/backend/packages/harness/deerflow/config/paths.py b/backend/packages/harness/deerflow/config/paths.py index c06839040..f01959657 100644 --- a/backend/packages/harness/deerflow/config/paths.py +++ b/backend/packages/harness/deerflow/config/paths.py @@ -1,3 +1,4 @@ +import hashlib import os import re import shutil @@ -10,6 +11,8 @@ VIRTUAL_PATH_PREFIX = "/mnt/user-data" _SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$") _SAFE_USER_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$") +_UNSAFE_USER_ID_CHAR_RE = re.compile(r"[^A-Za-z0-9_\-]") +_SAFE_USER_ID_DIGEST_HEX_LEN = 16 def _default_local_base_dir() -> Path: @@ -31,6 +34,23 @@ def _validate_user_id(user_id: str) -> str: return user_id +def make_safe_user_id(raw: str) -> str: + """Normalize an external identity into the user-id charset (``[A-Za-z0-9_-]``). + + IM channel ids (Feishu/Slack/Telegram) may contain characters that + :func:`_validate_user_id` rejects. Already-safe ids pass through unchanged; + lossy ones get a short digest suffix so two distinct inputs never share a + storage bucket. + """ + if not raw: + raise ValueError("user_id must be a non-empty string.") + sanitized = _UNSAFE_USER_ID_CHAR_RE.sub("-", raw) + if sanitized == raw: + return raw + digest = hashlib.sha1(raw.encode("utf-8")).hexdigest()[:_SAFE_USER_ID_DIGEST_HEX_LEN] + return f"{sanitized}-{digest}" + + def _join_host_path(base: str, *parts: str) -> str: """Join host filesystem path segments while preserving native style. diff --git a/backend/packages/harness/deerflow/mcp/tools.py b/backend/packages/harness/deerflow/mcp/tools.py index f38b70375..e425efe0c 100644 --- a/backend/packages/harness/deerflow/mcp/tools.py +++ b/backend/packages/harness/deerflow/mcp/tools.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +from collections.abc import Mapping from typing import Any from langchain_core.tools import BaseTool, StructuredTool @@ -137,7 +138,15 @@ def _make_session_pool_tool( from langchain_mcp_adapters.interceptors import MCPToolCallRequest async def base_handler(request: MCPToolCallRequest) -> Any: - return await session.call_tool(request.name, request.args) + # Preserve interceptor-injected headers for stdio MCP calls by + # forwarding them through MCP call meta. + call_kwargs: dict[str, Any] = {} + if request.headers: + if isinstance(request.headers, Mapping): + call_kwargs["meta"] = {"headers": dict(request.headers)} + else: + logger.warning("Ignoring MCP interceptor headers with unsupported type: %s", type(request.headers).__name__) + return await session.call_tool(request.name, request.args, **call_kwargs) handler = base_handler for interceptor in reversed(tool_interceptors): diff --git a/backend/tests/test_channels.py b/backend/tests/test_channels.py index 960e32a4b..060d414b2 100644 --- a/backend/tests/test_channels.py +++ b/backend/tests/test_channels.py @@ -1787,6 +1787,51 @@ class TestChannelManager: _run(go()) +class TestResolveRunParamsUserId: + """Regression for PR #3294: channel identity must reach ``run_context`` + while staying safe for user-scoped filesystem buckets. + """ + + def _manager(self): + from app.channels.manager import ChannelManager + + bus = MessageBus() + store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") + return ChannelManager(bus=bus, store=store) + + def test_safe_user_id_is_passed_through(self): + manager = self._manager() + msg = InboundMessage(channel_name="telegram", chat_id="c", user_id="123456", text="hi") + + _, _, run_context = manager._resolve_run_params(msg, "thread-1") + + assert run_context["user_id"] == "123456" + assert run_context["channel_user_id"] == "123456" + + def test_unsafe_user_id_is_normalized_but_raw_preserved(self): + from deerflow.config.paths import make_safe_user_id + + manager = self._manager() + raw = "user@example.com" + msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw, text="hi") + + _, _, run_context = manager._resolve_run_params(msg, "thread-1") + + assert run_context["user_id"] == make_safe_user_id(raw) + assert run_context["user_id"] != raw + assert run_context["channel_user_id"] == raw + + @pytest.mark.parametrize("raw_user_id", ["", None]) + def test_empty_or_none_user_id_is_not_injected(self, raw_user_id): + manager = self._manager() + msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw_user_id, text="hi") + + _, _, run_context = manager._resolve_run_params(msg, "thread-1") + + assert "user_id" not in run_context + assert "channel_user_id" not in run_context + + # --------------------------------------------------------------------------- # ChannelService tests # --------------------------------------------------------------------------- diff --git a/backend/tests/test_gateway_services.py b/backend/tests/test_gateway_services.py index 2ccd372bf..d62ed9371 100644 --- a/backend/tests/test_gateway_services.py +++ b/backend/tests/test_gateway_services.py @@ -431,6 +431,49 @@ def test_inject_authenticated_user_context_overrides_client_user_id(): assert config["context"]["user_id"] == "auth-user-42" +def test_merge_run_context_overrides_propagates_user_id(): + """Regression for PR #3294: ``user_id`` from ``body.context`` must land in + ``config['context']`` so non-web callers (e.g. IM channels) keep their identity + on ``ToolRuntime.context``. + """ + from app.gateway.services import build_run_config, merge_run_context_overrides + + config = build_run_config("thread-1", None, None) + merge_run_context_overrides(config, {"user_id": "channel-user-7"}) + + assert config["context"]["user_id"] == "channel-user-7" + + +def test_merge_run_context_overrides_does_not_clobber_existing_user_id(): + """``merge_run_context_overrides`` must not override an already-stamped + authenticated ``context.user_id`` with the client-supplied value. + """ + from app.gateway.services import build_run_config, merge_run_context_overrides + + config = build_run_config("thread-1", {"context": {"user_id": "auth-user-42"}}, None) + merge_run_context_overrides(config, {"user_id": "spoofed-client"}) + + assert config["context"]["user_id"] == "auth-user-42" + + +def test_inject_authenticated_user_context_skips_internal_role(): + """Regression for PR #3294: internal system-role callers must not overwrite an + already-present ``context.user_id`` (e.g. a channel-supplied identity), so the + real end user keeps owning the per-user storage bucket. + """ + from types import SimpleNamespace + + from app.gateway.services import build_run_config, inject_authenticated_user_context + + config = build_run_config("thread-1", None, None) + config["context"] = {"user_id": "channel-user-7"} + request = SimpleNamespace(state=SimpleNamespace(user=SimpleNamespace(id="internal-bot", system_role="internal"))) + + inject_authenticated_user_context(config, request) + + assert config["context"]["user_id"] == "channel-user-7" + + # --------------------------------------------------------------------------- # build_run_config — context / configurable precedence (LangGraph >= 0.6.0) # --------------------------------------------------------------------------- diff --git a/backend/tests/test_mcp_session_pool.py b/backend/tests/test_mcp_session_pool.py index 1e2ce7adc..40d02e61c 100644 --- a/backend/tests/test_mcp_session_pool.py +++ b/backend/tests/test_mcp_session_pool.py @@ -256,6 +256,136 @@ async def test_session_pool_tool_wrapping(): mock_session.call_tool.assert_awaited_once_with("navigate", {"url": "https://example.com"}) +@pytest.mark.asyncio +async def test_session_pool_tool_forwards_interceptor_headers(): + """Regression for PR #3294: when an interceptor sets ``request.headers``, the + pooled stdio call must forward them via ``meta={"headers": ...}`` so downstream + MCP servers can read auth/context headers. + """ + from langchain_core.tools import StructuredTool + from pydantic import BaseModel, Field + + from deerflow.mcp.tools import _make_session_pool_tool + + class Args(BaseModel): + x: int = Field(..., description="x") + + original_tool = StructuredTool( + name="srv_act", + description="test", + args_schema=Args, + coroutine=AsyncMock(), + response_format="content_and_artifact", + ) + + mock_session = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None)) + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_cm.__aexit__ = AsyncMock(return_value=False) + + async def header_interceptor(request, handler): + return await handler(request.override(headers={"X-User-Id": "u-42"})) + + with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm): + wrapped = _make_session_pool_tool( + original_tool, + "srv", + {"transport": "stdio", "command": "x", "args": []}, + tool_interceptors=[header_interceptor], + ) + await wrapped.coroutine(runtime=None, x=1) + + mock_session.call_tool.assert_awaited_once_with("act", {"x": 1}, meta={"headers": {"X-User-Id": "u-42"}}) + + +@pytest.mark.asyncio +async def test_session_pool_tool_no_headers_omits_meta(): + """When no interceptor sets headers, the pooled call must not pass a ``meta`` + kwarg (falls back to the plain two-argument ``call_tool``). + """ + from langchain_core.tools import StructuredTool + from pydantic import BaseModel, Field + + from deerflow.mcp.tools import _make_session_pool_tool + + class Args(BaseModel): + x: int = Field(..., description="x") + + original_tool = StructuredTool( + name="srv_act", + description="test", + args_schema=Args, + coroutine=AsyncMock(), + response_format="content_and_artifact", + ) + + mock_session = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None)) + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_cm.__aexit__ = AsyncMock(return_value=False) + + async def passthrough_interceptor(request, handler): + return await handler(request) + + with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm): + wrapped = _make_session_pool_tool( + original_tool, + "srv", + {"transport": "stdio", "command": "x", "args": []}, + tool_interceptors=[passthrough_interceptor], + ) + await wrapped.coroutine(runtime=None, x=1) + + mock_session.call_tool.assert_awaited_once_with("act", {"x": 1}) + + +@pytest.mark.asyncio +async def test_session_pool_tool_ignores_unsupported_header_type(caplog): + """Defensive path: non-mapping truthy headers should be ignored safely.""" + from langchain_core.tools import StructuredTool + from pydantic import BaseModel, Field + + from deerflow.mcp.tools import _make_session_pool_tool + + class Args(BaseModel): + x: int = Field(..., description="x") + + class TruthyHeaders: + def __bool__(self) -> bool: + return True + + original_tool = StructuredTool( + name="srv_act", + description="test", + args_schema=Args, + coroutine=AsyncMock(), + response_format="content_and_artifact", + ) + + mock_session = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None)) + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_cm.__aexit__ = AsyncMock(return_value=False) + + async def invalid_header_interceptor(request, handler): + return await handler(request.override(headers=TruthyHeaders())) + + with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm): + wrapped = _make_session_pool_tool( + original_tool, + "srv", + {"transport": "stdio", "command": "x", "args": []}, + tool_interceptors=[invalid_header_interceptor], + ) + await wrapped.coroutine(runtime=None, x=1) + + mock_session.call_tool.assert_awaited_once_with("act", {"x": 1}) + assert "unsupported type" in caplog.text + + @pytest.mark.asyncio async def test_session_pool_tool_extracts_thread_id(): """Thread ID is extracted from runtime.config when not in context.""" diff --git a/backend/tests/test_paths_user_isolation.py b/backend/tests/test_paths_user_isolation.py index d5c0f540f..692c526ed 100644 --- a/backend/tests/test_paths_user_isolation.py +++ b/backend/tests/test_paths_user_isolation.py @@ -30,6 +30,41 @@ class TestValidateUserId: paths.user_dir("") +class TestMakeSafeUserId: + def test_already_safe_id_is_unchanged(self): + from deerflow.config.paths import make_safe_user_id + + assert make_safe_user_id("ou_abc-123") == "ou_abc-123" + assert make_safe_user_id("123456") == "123456" + + def test_unsafe_chars_are_sanitized_with_stable_suffix(self): + from deerflow.config.paths import make_safe_user_id + + result = make_safe_user_id("user@example.com") + # Sanitized prefix plus a stable digest of the original. + assert result.startswith("user-example-com-") + assert len(result.rsplit("-", 1)[1]) == 16 + assert make_safe_user_id("user@example.com") == result + + def test_sanitized_id_passes_validation(self, paths: Paths): + from deerflow.config.paths import make_safe_user_id + + safe = make_safe_user_id("用户/../etc") + # Must be usable as a filesystem-scoped bucket without raising. + assert paths.user_dir(safe) == paths.base_dir / "users" / safe + + def test_distinct_unsafe_ids_do_not_collide(self): + from deerflow.config.paths import make_safe_user_id + + assert make_safe_user_id("a.b") != make_safe_user_id("a/b") + + def test_empty_id_rejected(self): + from deerflow.config.paths import make_safe_user_id + + with pytest.raises(ValueError, match="non-empty"): + make_safe_user_id("") + + class TestUserDir: def test_user_dir(self, paths: Paths): assert paths.user_dir("alice") == paths.base_dir / "users" / "alice"