diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md index 6cc0aebb1..caa36f579 100644 --- a/backend/CLAUDE.md +++ b/backend/CLAUDE.md @@ -223,17 +223,9 @@ Setup: Copy `config.example.yaml` to `config.yaml` in the **project root** direc **Config Caching**: `get_app_config()` caches the parsed config, but automatically reloads it when the resolved config path changes or the file's mtime increases. This keeps Gateway and LangGraph reads aligned with `config.yaml` edits without requiring a manual process restart. -**Config Hot-Reload Boundary**: Gateway dependencies route through `get_app_config()` on every request, so per-run fields like `models[*].max_tokens`, `summarization.*`, `title.*`, `memory.*`, `subagents.*`, `tools[*]`, and the agent system prompt pick up `config.yaml` edits on the next message. `AppConfig` is intentionally **not** cached on `app.state` — `lifespan()` keeps a local `startup_config` variable for one-shot bootstrap work (logging level, channels, `langgraph_runtime` engines) and passes it explicitly to `langgraph_runtime(app, startup_config)`. Infrastructure fields are **restart-required**: +**Config Hot-Reload Boundary**: Gateway dependencies route through `get_app_config()` on every request, so per-run fields like `models[*].max_tokens`, `summarization.*`, `title.*`, `memory.*`, `subagents.*`, `tools[*]`, and the agent system prompt pick up `config.yaml` edits on the next message. `AppConfig` is intentionally **not** cached on `app.state` — `lifespan()` keeps a local `startup_config` variable for one-shot bootstrap work and passes it to `langgraph_runtime(app, startup_config)`. -| Field | Why a restart is required | -|---|---| -| `database.*` | `init_engine_from_config()` runs once during `langgraph_runtime()` startup; the SQLAlchemy engine holds the connection pool. | -| `checkpointer.*` (including SQLite WAL/journal settings) | `make_checkpointer()` binds the persistent checkpointer once at startup. | -| `run_events.*` | `make_run_event_store()` selects memory- vs. SQL-backed implementation at startup. | -| `stream_bridge.*` | `make_stream_bridge()` constructs the bridge object once. | -| `sandbox.use` | `get_sandbox_provider()` caches the provider singleton (`_default_sandbox_provider`); a new class path takes effect only on next process start. | -| `log_level` | `apply_logging_level()` is called only in `app.py` startup; it mutates the root logger's level, and `get_app_config()` returning a fresh `AppConfig` does not retrigger it. | -| `channels.*` IM platform credentials | `start_channel_service()` is invoked once during startup; live channels are not rebuilt on config change. | +Infrastructure fields are **restart-required**. The authoritative list lives in `packages/harness/deerflow/config/reload_boundary.py::STARTUP_ONLY_FIELDS` and is mirrored by the standardised `"startup-only:"` prefix on the corresponding `Field(description=...)` in `AppConfig`, so IDE hover on those fields surfaces the reason inline (no need to context-switch into this table). Currently registered: `database`, `checkpointer`, `run_events`, `stream_bridge`, `sandbox`, `log_level`, `channels`. Adding a new restart-required field requires updating the registry; drift is pinned by `tests/test_reload_boundary.py`. Configuration priority: 1. Explicit `config_path` argument diff --git a/backend/app/gateway/deps.py b/backend/app/gateway/deps.py index 24af1deeb..5739d217d 100644 --- a/backend/app/gateway/deps.py +++ b/backend/app/gateway/deps.py @@ -119,6 +119,16 @@ def get_config() -> AppConfig: split-brain where the worker / lead-agent thread saw a stale startup snapshot. + Hot-reload boundary: fields backed by startup-time singletons + (engines, sandbox provider, IM channels, logging handler) require a + process restart to change at runtime. The authoritative list lives in + :mod:`deerflow.config.reload_boundary` and is mirrored by the + standardised ``"startup-only:"`` prefix on the matching + ``Field(description=...)`` in :class:`AppConfig` — IDE hover on those + fields will surface the boundary inline. See + ``backend/CLAUDE.md`` "Config Hot-Reload Boundary" for the operator + summary. + Any failure to materialise the config (missing file, permission denied, YAML parse error, validation error) is reported as 503 — semantically "the gateway cannot serve requests without a usable configuration" — and diff --git a/backend/packages/harness/deerflow/config/app_config.py b/backend/packages/harness/deerflow/config/app_config.py index 8fcc564a8..842b49d7a 100644 --- a/backend/packages/harness/deerflow/config/app_config.py +++ b/backend/packages/harness/deerflow/config/app_config.py @@ -18,6 +18,7 @@ from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_ from deerflow.config.loop_detection_config import LoopDetectionConfig from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict from deerflow.config.model_config import ModelConfig +from deerflow.config.reload_boundary import format_field_description from deerflow.config.run_events_config import RunEventsConfig from deerflow.config.runtime_paths import existing_project_file from deerflow.config.safety_finish_reason_config import SafetyFinishReasonConfig @@ -85,10 +86,21 @@ def apply_logging_level(name: str | None) -> None: class AppConfig(BaseModel): """Config for the DeerFlow application""" - log_level: str = Field(default="info", description="Logging level for deerflow and app modules (debug/info/warning/error); third-party libraries are not affected") + log_level: str = Field( + default="info", + description=format_field_description( + "log_level", + field_doc="Logging level for deerflow and app modules (debug/info/warning/error); third-party libraries are not affected.", + ), + ) token_usage: TokenUsageConfig = Field(default_factory=TokenUsageConfig, description="Token usage tracking configuration") models: list[ModelConfig] = Field(default_factory=list, description="Available models") - sandbox: SandboxConfig = Field(description="Sandbox configuration") + sandbox: SandboxConfig = Field( + description=format_field_description( + "sandbox", + field_doc="Sandbox provider configuration (local filesystem or Docker-based aio sandbox).", + ), + ) tools: list[ToolConfig] = Field(default_factory=list, description="Available tools") tool_groups: list[ToolGroupConfig] = Field(default_factory=list, description="Available tool groups") skills: SkillsConfig = Field(default_factory=SkillsConfig, description="Skills configuration") @@ -107,10 +119,34 @@ class AppConfig(BaseModel): loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration") safety_finish_reason: SafetyFinishReasonConfig = Field(default_factory=SafetyFinishReasonConfig, description="Provider safety-filter finish_reason interception middleware configuration") model_config = ConfigDict(extra="allow") - database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration") - run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration") - checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration") - stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration") + database: DatabaseConfig = Field( + default_factory=DatabaseConfig, + description=format_field_description( + "database", + field_doc="Unified database backend for run/feedback metadata (memory, sqlite, or postgres).", + ), + ) + run_events: RunEventsConfig = Field( + default_factory=RunEventsConfig, + description=format_field_description( + "run_events", + field_doc="Run-event store backend (memory for dev, db for production queries, jsonl for lightweight single-node persistence).", + ), + ) + checkpointer: CheckpointerConfig | None = Field( + default=None, + description=format_field_description( + "checkpointer", + field_doc="LangGraph state-persistence checkpointer configuration.", + ), + ) + stream_bridge: StreamBridgeConfig | None = Field( + default=None, + description=format_field_description( + "stream_bridge", + field_doc="Stream bridge connecting agent workers to SSE endpoints.", + ), + ) @classmethod def resolve_config_path(cls, config_path: str | None = None) -> Path: diff --git a/backend/packages/harness/deerflow/config/reload_boundary.py b/backend/packages/harness/deerflow/config/reload_boundary.py new file mode 100644 index 000000000..d39502776 --- /dev/null +++ b/backend/packages/harness/deerflow/config/reload_boundary.py @@ -0,0 +1,104 @@ +"""Single source of truth for the config hot-reload boundary. + +Bytedance/deer-flow issue #3144: gateway request dependencies resolve +``AppConfig`` through ``get_app_config()`` on every request, so per-run +fields take effect on the next message without restarting the gateway. +The fields listed in this module are the **infrastructure** subset that +the gateway captures once at startup — engines, singletons, IM clients, +the logging handler — and that therefore require a process restart to +change at runtime. + +The registry covers two kinds of entries: + +- Top-level ``AppConfig`` fields (``database``, ``checkpointer``, + ``run_events``, ``stream_bridge``, ``sandbox``, ``log_level``). For + these, :func:`format_field_description` produces the standardised + ``"startup-only: ..."`` prefix that the matching Pydantic + ``Field(description=...)`` carries, so the boundary surfaces in IDE + hover next to the field itself. +- Top-level ``config.yaml`` sections that are not part of the + ``AppConfig`` schema (``channels``). These cannot be standardised at + the schema level, so the registry is their only canonical location. + +Any future "needs restart" scanner — operator tooling, lint hooks, doc +generators — should drive off this registry rather than re-parsing +prose. +""" + +from __future__ import annotations + +from collections.abc import Iterator + +#: The standardised prefix every restart-required field description starts +#: with. ``test_reload_boundary`` enforces both directions: registered +#: fields must use this prefix in the schema, and any schema field using +#: this prefix must be in the registry. +STARTUP_ONLY_PREFIX = "startup-only:" + + +#: Restart-required field paths mapped to the human-readable reason. +#: +#: The reason text is what surfaces in ``Field(description=...)``, so it +#: must explain *what* code captures the snapshot — not just that the +#: field is restart-required — so an operator changing the value knows +#: which subsystem to restart. +STARTUP_ONLY_FIELDS: dict[str, str] = { + "database": ("init_engine_from_config() runs once during langgraph_runtime() startup; the SQLAlchemy engine holds the connection pool and is not rebuilt on config.yaml edits."), + "checkpointer": ("make_checkpointer() binds the persistent checkpointer once at startup, including SQLite WAL / busy_timeout settings."), + "run_events": ("make_run_event_store() picks the memory- vs SQL-backed implementation at startup and is frozen onto app.state.run_events_config to stay paired with the underlying event store."), + "stream_bridge": ("make_stream_bridge() constructs the stream-bridge singleton once during startup."), + "sandbox": ("get_sandbox_provider() caches the provider singleton (``_default_sandbox_provider``); a different ``sandbox.use`` class path only takes effect on next process start."), + "log_level": ( + "apply_logging_level() runs only during app.py startup; it sets the deerflow/app logger levels and may lower root handler thresholds so configured messages can propagate. A freshly reloaded AppConfig does not retrigger it." + ), + # Not part of the AppConfig Pydantic schema — channel credentials are + # consumed directly by ``start_channel_service()`` once at lifespan + # startup and the live channel clients are not rebuilt on + # config.yaml edits. + "channels": ("start_channel_service() is invoked once during startup; the live IM channel clients (Feishu, Slack, Telegram, DingTalk) are not rebuilt when channels.* changes."), +} + + +def iter_startup_only_field_paths() -> Iterator[str]: + """Yield every registered restart-required field path.""" + return iter(STARTUP_ONLY_FIELDS) + + +def is_startup_only_field(field_path: str) -> bool: + """Return ``True`` when *field_path* is registered as restart-required. + + Accepts only top-level paths (``"database"``, ``"sandbox"`` etc.); + nested keys like ``"database.url"`` are not modelled here because the + boundary is per-section, not per-leaf. + """ + return field_path in STARTUP_ONLY_FIELDS + + +def format_field_description(field_path: str, *, field_doc: str | None = None) -> str: + """Build the standardised description for a registered field. + + Used inside ``AppConfig`` ``Field(description=...)`` so the hover + text in IDEs matches the registry and the drift tests can pin one + side against the other. + + Args: + field_path: A registered top-level field path (e.g. ``"log_level"``). + field_doc: Optional human-facing description for the field itself + (allowed values, semantics, etc.). When supplied, it is + appended after the ``startup-only:`` marker block separated by + a blank line so IDE hover shows both the restart-required + reason *and* the field's normal documentation. Composition + keeps the marker as the leading token machine-readable tooling + pivots on while restoring the prose that ``Field(description=)`` + used to carry before the registry took over. + + Raises: + KeyError: when *field_path* is not registered. This is deliberate + — silently returning a placeholder would let a typo bypass + the drift coverage. + """ + reason = STARTUP_ONLY_FIELDS[field_path] + header = f"{STARTUP_ONLY_PREFIX} {reason}" + if field_doc is None: + return header + return f"{header}\n\n{field_doc.strip()}" diff --git a/backend/packages/harness/deerflow/mcp/__init__.py b/backend/packages/harness/deerflow/mcp/__init__.py index 74195c195..839cce1ee 100644 --- a/backend/packages/harness/deerflow/mcp/__init__.py +++ b/backend/packages/harness/deerflow/mcp/__init__.py @@ -1,6 +1,10 @@ """MCP (Model Context Protocol) integration using langchain-mcp-adapters.""" -from .cache import get_cached_mcp_tools, initialize_mcp_tools, reset_mcp_tools_cache +from .cache import ( + get_cached_mcp_tools, + initialize_mcp_tools, + reset_mcp_tools_cache, +) from .client import build_server_params, build_servers_config from .tools import get_mcp_tools diff --git a/backend/packages/harness/deerflow/mcp/cache.py b/backend/packages/harness/deerflow/mcp/cache.py index 20cc48b1e..333a034c8 100644 --- a/backend/packages/harness/deerflow/mcp/cache.py +++ b/backend/packages/harness/deerflow/mcp/cache.py @@ -143,11 +143,20 @@ def reset_mcp_tools_cache() -> None: # Close persistent sessions – they will be recreated by the next # get_mcp_tools() call with the (possibly updated) connection config. + # + # close_all_sync() already picks the correct strategy per owning loop: + # * sessions owned by the *current* running loop are only *signalled* + # (their owner task runs __aexit__ once the loop regains control – + # this is correct and leak-free, since the loop keeps the task alive), + # * sessions on other threads' loops are torn down deterministically, + # * idle/closed loops are handled or skipped. + # We deliberately do NOT try to synchronously wait for the current running + # loop to finish teardown here: that is a self-deadlock (the loop can only + # run the teardown after this synchronous call returns control to it). try: from deerflow.mcp.session_pool import get_session_pool - pool = get_session_pool() - pool.close_all_sync() + get_session_pool().close_all_sync() except Exception: logger.debug("Could not close MCP session pool on cache reset", exc_info=True) diff --git a/backend/packages/harness/deerflow/mcp/session_pool.py b/backend/packages/harness/deerflow/mcp/session_pool.py index 8450cac8e..63703ba25 100644 --- a/backend/packages/harness/deerflow/mcp/session_pool.py +++ b/backend/packages/harness/deerflow/mcp/session_pool.py @@ -8,6 +8,27 @@ This module provides a session pool that maintains persistent MCP sessions, scoped by ``(server_name, scope_key)`` — typically scope_key is the thread_id — so that consecutive tool calls share the same session and server-side state. Sessions are evicted in LRU order when the pool reaches capacity. + +Lifecycle model (owner task) +---------------------------- +An MCP ``ClientSession`` is implemented on top of an ``anyio`` task group, and +anyio enforces that a cancel scope must be exited from the *same task* that +entered it. Calling ``cm.__aexit__`` from any task other than the one that ran +``cm.__aenter__`` raises:: + + RuntimeError: Attempted to exit cancel scope in a different task than it + was entered in + +The sync-tool path (``make_sync_tool_wrapper``) drives each call through a fresh +``asyncio.run`` event loop, so a session entered while answering one call would +otherwise be exited while answering another — from a different task — and crash +(GitHub issue #3379). + +To make this impossible, every pooled session is owned by a dedicated +``_run_session`` task. That task enters the context manager, hands the live +session back to the caller, and then *waits* on a close event. All shutdown +paths only ever **signal** that event; the owner task performs ``__aexit__`` +itself, guaranteeing enter and exit always happen in the same task. """ from __future__ import annotations @@ -27,18 +48,81 @@ class MCPSessionPool: """Manages persistent MCP sessions scoped by ``(server_name, scope_key)``.""" MAX_SESSIONS = 256 - SESSION_CLOSE_TIMEOUT = 5.0 # seconds to wait when closing a session via run_coroutine_threadsafe + SESSION_CLOSE_TIMEOUT = 5.0 # seconds to wait when closing a session on a foreign loop def __init__(self) -> None: + # Each entry: (session, owning_loop, owner_task, close_event). self._entries: OrderedDict[ tuple[str, str], - tuple[ClientSession, asyncio.AbstractEventLoop], + tuple[ + ClientSession, + asyncio.AbstractEventLoop, + asyncio.Task[Any], + asyncio.Event, + ], ] = OrderedDict() - self._context_managers: dict[tuple[str, str], Any] = {} + # In-flight creations, keyed by (server, scope). Lets concurrent callers + # on the same loop share a single creation instead of each spawning a + # duplicate session. Value: (loop, ready_future, owner_task, close_event). + self._inflight: dict[ + tuple[str, str], + tuple[ + asyncio.AbstractEventLoop, + asyncio.Future[ClientSession], + asyncio.Task[Any], + asyncio.Event, + ], + ] = {} # threading.Lock is not bound to any event loop, so it is safe to # acquire from both async paths and sync/worker-thread paths. self._lock = threading.Lock() + # ------------------------------------------------------------------ + # Session owner task + # ------------------------------------------------------------------ + + async def _run_session( + self, + connection: dict[str, Any], + ready: asyncio.Future[ClientSession], + close_evt: asyncio.Event, + ) -> None: + """Own a single MCP session for its entire lifetime. + + Enters the session context manager, initializes it, publishes the live + session via ``ready``, then blocks until ``close_evt`` is set. The + context manager is *always* exited from this task, satisfying anyio's + cancel-scope same-task requirement. + """ + from langchain_mcp_adapters.sessions import create_session + + cm = create_session(connection) + try: + session = await cm.__aenter__() + except BaseException as e: + # Never entered the cancel scope, so there is nothing to exit. + if not ready.done(): + ready.set_exception(e) + return + + # The context manager is now entered. From here on __aexit__ MUST run in + # this task — on init failure, on cancellation, or on the close signal — + # to satisfy anyio's same-task cancel-scope requirement and to avoid + # leaking the session/subprocess. + try: + await session.initialize() + if not ready.done(): + ready.set_result(session) + await close_evt.wait() + except BaseException as e: + if not ready.done(): + ready.set_exception(e) + finally: + try: + await cm.__aexit__(None, None, None) + except Exception: + logger.warning("Error closing MCP session", exc_info=True) + async def get_session( self, server_name: str, @@ -47,9 +131,9 @@ class MCPSessionPool: ) -> ClientSession: """Get or create a persistent MCP session. - If an existing session was created in a different event loop (e.g. - the sync-wrapper path), it is closed and replaced with a fresh one - in the current loop. + If an existing session was created in a different (or closed) event + loop, it is evicted and replaced with a fresh one owned by a task on + the current loop. Args: server_name: MCP server name. @@ -63,44 +147,118 @@ class MCPSessionPool: current_loop = asyncio.get_running_loop() # Phase 1: inspect/mutate the registry under the thread lock (no awaits). - cms_to_close: list[tuple[tuple[str, str], Any]] = [] + # Decide one of three outcomes atomically: return an existing session, + # join an in-flight creation, or become the creator for this key. + # Each item: (loop, owner_task, close_event, cancel). ``cancel`` is True + # for in-flight creations, whose owner may be blocked inside + # ``initialize()`` where close_evt cannot wake it — it must be cancelled. + evicted: list[tuple[asyncio.AbstractEventLoop, asyncio.Task[Any], asyncio.Event, bool]] = [] + join: asyncio.Future[ClientSession] | None = None + ready: asyncio.Future[ClientSession] | None = None + close_evt: asyncio.Event | None = None + task: asyncio.Task[Any] | None = None with self._lock: if key in self._entries: - session, loop = self._entries[key] - if loop is current_loop: + session, loop, ent_task, ent_close = self._entries[key] + if loop is current_loop and not loop.is_closed(): self._entries.move_to_end(key) return session - # Session belongs to a different event loop – evict it. - cm = self._context_managers.pop(key, None) + # Session belongs to a different/closed event loop – evict it. self._entries.pop(key) - if cm is not None: - cms_to_close.append((key, cm)) + evicted.append((loop, ent_task, ent_close, False)) + + inflight = self._inflight.get(key) + if inflight is not None and inflight[0] is current_loop and not inflight[0].is_closed(): + # Another caller on this loop is already creating the session; + # wait for the same result instead of building a duplicate. + join = inflight[1] + else: + if inflight is not None: + # Stale in-flight creation owned by a different/closed loop. + # Drop the record and tear its owner down; because that owner + # may be blocked inside initialize() (where close_evt cannot + # wake it), it must be cancelled. We then create a fresh + # session here. + self._inflight.pop(key) + evicted.append((inflight[0], inflight[2], inflight[3], True)) + # Become the creator: publish an in-flight record before any + # await so concurrent callers join us instead of racing. + ready = current_loop.create_future() + close_evt = asyncio.Event() + task = current_loop.create_task(self._run_session(connection, ready, close_evt)) + self._inflight[key] = (current_loop, ready, task, close_evt) # Evict LRU entries when at capacity. while len(self._entries) >= self.MAX_SESSIONS: - oldest_key = next(iter(self._entries)) - cm = self._context_managers.pop(oldest_key, None) + oldest_key, (_, loop, ent_task, ent_close) = next(iter(self._entries.items())) self._entries.pop(oldest_key) - if cm is not None: - cms_to_close.append((oldest_key, cm)) + evicted.append((loop, ent_task, ent_close, False)) - # Phase 2: async cleanup outside the lock so we never await while holding it. - for close_key, cm in cms_to_close: + # Phase 2: shut down evicted sessions/creations. Same-loop owners are + # awaited so they finish deterministically; foreign-loop owners are + # routed to their own loop. In every case the owner task — never this + # one — runs __aexit__. In-flight owners are cancelled (cancel=True) so a + # blocking initialize() cannot leave them hung. + for loop, ent_task, ent_close, cancel in evicted: + if loop is current_loop and not loop.is_closed(): + await self._shutdown(ent_close, ent_task, cancel) + elif cancel: + await self._shutdown_entry(loop, ent_task, ent_close, cancel=True) + else: + self._signal_close(loop, ent_close) + + # Phase 2b: a concurrent creation for this key is already in progress on + # this loop — share its result rather than create a duplicate session. + if join is not None: + return await asyncio.shield(join) + + assert ready is not None and close_evt is not None and task is not None + + # Phase 3: wait for our owner task to publish the initialized session. + try: + session = await asyncio.shield(ready) + except BaseException: + # Two distinct cases reach here: + # + # 1. The owner task failed (e.g. connect/initialize error) and + # reported it via ready.set_exception(). It is *already* in its + # finally block running cm.__aexit__ in its own task, so we must + # NOT cancel it — doing so would interrupt that cleanup. We only + # wait for it to finish unwinding. + # 2. This call itself was cancelled (CancelledError). Because of the + # shield, `ready` is still pending and the owner task is alive and + # blocked. We signal close and cancel it so it exits the cancel + # scope in its own task, then wait for it to finish. + # + # The session is never registered yet, so nobody else can close it; + # waiting here guarantees we never leak a session or owner task. + owner_already_failed = ready.done() and not ready.cancelled() and ready.exception() is not None + if not owner_already_failed: + close_evt.set() + task.cancel() try: - await cm.__aexit__(None, None, None) - except Exception: - logger.warning("Error closing MCP session %s", close_key, exc_info=True) + await asyncio.shield(task) + except BaseException: + logger.debug("Owner task ended during get_session unwind", exc_info=True) + with self._lock: + if self._inflight.get(key) == (current_loop, ready, task, close_evt): + self._inflight.pop(key) + raise - from langchain_mcp_adapters.sessions import create_session - - cm = create_session(connection) - session = await cm.__aenter__() - await session.initialize() - - # Phase 3: register the new session under the lock. + # Phase 4: promote the in-flight creation to a registered entry — but + # only if our in-flight record is still the live one. A concurrent + # close_* / close_all may have removed it while we were initializing; in + # that case we must NOT resurrect the session into _entries. Instead we + # own the teardown: signal our owner task and wait for it to run + # __aexit__ in its own task, then surface the cancellation. with self._lock: - self._entries[key] = (session, current_loop) - self._context_managers[key] = cm + still_ours = self._inflight.get(key) == (current_loop, ready, task, close_evt) + if still_ours: + self._inflight.pop(key) + self._entries[key] = (session, current_loop, task, close_evt) + if not still_ours: + await self._shutdown(close_evt, task) + raise asyncio.CancelledError("MCP session pool was closed while the session was being created") logger.info("Created persistent MCP session for %s/%s", server_name, scope_key) return session @@ -108,70 +266,169 @@ class MCPSessionPool: # Cleanup helpers # ------------------------------------------------------------------ - async def _close_cm(self, key: tuple[str, str], cm: Any) -> None: - """Close a single context manager (must be called WITHOUT the lock).""" + @staticmethod + def _signal_close(loop: asyncio.AbstractEventLoop, close_evt: asyncio.Event) -> None: + """Ask an owner task to shut down without waiting. + + ``asyncio.Event.set`` is not thread-safe, so it is scheduled on the + owning loop. A closed loop means the owner task is already gone. + """ + if loop.is_closed(): + return try: - await cm.__aexit__(None, None, None) - except Exception: - logger.warning("Error closing MCP session %s", key, exc_info=True) + loop.call_soon_threadsafe(close_evt.set) + except RuntimeError: + # Loop was closed between the is_closed() check and now. + pass + + async def _shutdown( + self, + close_evt: asyncio.Event, + task: asyncio.Task[Any], + cancel: bool = False, + ) -> None: + """Signal an owner task and wait for it to finish (runs on its loop). + + ``cancel=True`` is used for in-flight creations: the owner task may be + blocked inside ``initialize()`` where ``close_evt`` cannot wake it, so it + must be cancelled. Its ``finally`` block still runs ``__aexit__`` in its + own task, satisfying anyio's same-task cancel-scope requirement. + """ + close_evt.set() + if cancel: + task.cancel() + try: + await task + except (Exception, asyncio.CancelledError): + logger.debug("Owner task ended during shutdown", exc_info=True) + + async def _shutdown_entry( + self, + loop: asyncio.AbstractEventLoop, + task: asyncio.Task[Any], + close_evt: asyncio.Event, + cancel: bool = False, + ) -> None: + """Shut down one entry, routing the close to its owning loop.""" + if loop.is_closed(): + return + current_loop = asyncio.get_running_loop() + if loop is current_loop: + await self._shutdown(close_evt, task, cancel) + elif loop.is_running(): + future = asyncio.run_coroutine_threadsafe(self._shutdown(close_evt, task, cancel), loop) + try: + await asyncio.wrap_future(future) + except Exception: + logger.warning("Error closing MCP session on owning loop", exc_info=True) + else: + # Owning loop exists but is neither the current loop nor running. + # We are inside an async context here, so run_until_complete() would + # raise "Cannot run the event loop while another loop is running"; + # and the loop may belong to another thread, where driving it from + # here is unsafe. This branch is not expected in practice — a + # session's owning loop is either the long-lived gateway loop (which + # is running) or a short-lived asyncio.run loop (which is closed and + # caught above). Fall back to a best-effort thread-safe signal so the + # owner task tears down if/when its loop runs again. + logger.warning("Owning loop for MCP session is idle; signalling close best-effort. Session may leak until the loop runs again.") + self._signal_close(loop, close_evt) + if cancel: + try: + loop.call_soon_threadsafe(task.cancel) + except RuntimeError: + pass async def close_scope(self, scope_key: str) -> None: """Close all sessions for a given scope (e.g. thread_id).""" with self._lock: keys = [k for k in self._entries if k[1] == scope_key] - cms = [(k, self._context_managers.pop(k, None)) for k in keys] - for k in keys: - self._entries.pop(k, None) - for key, cm in cms: - if cm is not None: - await self._close_cm(key, cm) + entries = [(self._entries.pop(k)) for k in keys] + inflight_keys = [k for k in self._inflight if k[1] == scope_key] + inflight = [self._inflight.pop(k) for k in inflight_keys] + for _session, loop, task, close_evt in entries: + await self._shutdown_entry(loop, task, close_evt) + for loop, _ready, task, close_evt in inflight: + await self._shutdown_entry(loop, task, close_evt, cancel=True) async def close_server(self, server_name: str) -> None: """Close all sessions for a given server.""" with self._lock: keys = [k for k in self._entries if k[0] == server_name] - cms = [(k, self._context_managers.pop(k, None)) for k in keys] - for k in keys: - self._entries.pop(k, None) - for key, cm in cms: - if cm is not None: - await self._close_cm(key, cm) + entries = [(self._entries.pop(k)) for k in keys] + inflight_keys = [k for k in self._inflight if k[0] == server_name] + inflight = [self._inflight.pop(k) for k in inflight_keys] + for _session, loop, task, close_evt in entries: + await self._shutdown_entry(loop, task, close_evt) + for loop, _ready, task, close_evt in inflight: + await self._shutdown_entry(loop, task, close_evt, cancel=True) async def close_all(self) -> None: """Close every managed session.""" with self._lock: - cms = list(self._context_managers.items()) - self._context_managers.clear() + entries = list(self._entries.values()) self._entries.clear() - for key, cm in cms: - await self._close_cm(key, cm) + inflight = list(self._inflight.values()) + self._inflight.clear() + for _session, loop, task, close_evt in entries: + await self._shutdown_entry(loop, task, close_evt) + for loop, _ready, task, close_evt in inflight: + await self._shutdown_entry(loop, task, close_evt, cancel=True) def close_all_sync(self) -> None: - """Close all sessions using their owning event loops (synchronous). + """Close all sessions on their owning event loops (synchronous). - Each session is closed on the loop it was created in, avoiding - cross-loop resource leaks. Safe to call from any thread without an - active event loop. + Each session is closed by its owner task on the loop it was created in, + avoiding cross-loop and cross-task errors. Safe to call from any thread + without an active event loop. + + Closing semantics differ by where the owning loop runs: + + * Owning loop is idle, or running on another thread — this call blocks + until teardown completes (or ``SESSION_CLOSE_TIMEOUT`` elapses). + * Owning loop is the one currently running on *this* thread — we cannot + block on it without deadlocking, so teardown is only *signalled* here + and completes asynchronously once control returns to that loop. The + caller must therefore keep that loop running afterwards; if it stops + the loop immediately, the owner task's ``__aexit__`` may not run. When + a deterministic close is required from inside a running loop, ``await + close_all()`` instead. """ with self._lock: - entries = list(self._entries.items()) - cms = dict(self._context_managers) + entries = list(self._entries.values()) self._entries.clear() - self._context_managers.clear() + inflight = list(self._inflight.values()) + self._inflight.clear() - for key, (_, loop) in entries: - cm = cms.get(key) - if cm is None or loop.is_closed(): + # Entries are initialized (gentle close_evt path). In-flight creations + # may be blocked mid-init, so they are cancelled to unblock teardown. + owners = [(loop, task, close_evt, False) for _s, loop, task, close_evt in entries] + owners += [(loop, task, close_evt, True) for loop, _r, task, close_evt in inflight] + try: + current_running_loop = asyncio.get_running_loop() + except RuntimeError: + current_running_loop = None + for loop, task, close_evt, cancel in owners: + if loop.is_closed(): continue try: - if loop.is_running(): - # Schedule on the owning loop from this (different) thread. - future = asyncio.run_coroutine_threadsafe(cm.__aexit__(None, None, None), loop) + if loop is current_running_loop: + # We are executing inside this loop's thread, so synchronously + # waiting on run_coroutine_threadsafe(...).result() would + # deadlock until timeout. Signal the owner task directly and + # let it finish once this synchronous call returns control to + # the running loop. + close_evt.set() + if cancel: + task.cancel() + elif loop.is_running(): + # Schedule the shutdown on the owning loop from this thread. + future = asyncio.run_coroutine_threadsafe(self._shutdown(close_evt, task, cancel), loop) future.result(timeout=self.SESSION_CLOSE_TIMEOUT) else: - loop.run_until_complete(cm.__aexit__(None, None, None)) + loop.run_until_complete(self._shutdown(close_evt, task, cancel)) except Exception: - logger.debug("Error closing MCP session %s during sync close", key, exc_info=True) + logger.debug("Error closing MCP session during sync close", exc_info=True) # ------------------------------------------------------------------ diff --git a/backend/tests/test_mcp_session_pool.py b/backend/tests/test_mcp_session_pool.py index 40d02e61c..852cfd861 100644 --- a/backend/tests/test_mcp_session_pool.py +++ b/backend/tests/test_mcp_session_pool.py @@ -1,5 +1,7 @@ """Tests for the MCP persistent-session pool.""" +import asyncio +import threading from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -614,3 +616,585 @@ async def test_http_transport_tools_not_pooled(): stdio_tools = [t for t in tools if t.name == "playwright_navigate"] assert len(stdio_tools) == 1 assert stdio_tools[0].coroutine is not stdio_tool.coroutine + + +# --------------------------------------------------------------------------- +# Regression for #3379: cancel scope must be exited in the entering task +# --------------------------------------------------------------------------- + + +class _CancelScopeCm: + """Fake session context manager that mimics anyio's cancel-scope rule. + + ``ClientSession`` is built on an anyio task group, which requires the cancel + scope to be exited from the *same asyncio task* that entered it. This fake + records the task that runs ``__aenter__`` and raises the exact RuntimeError + anyio would raise if ``__aexit__`` runs in a different task — reproducing the + crash reported in GitHub issue #3379. + """ + + def __init__(self) -> None: + self.enter_task: object | None = None + self.closed = False + + async def __aenter__(self): + self.enter_task = asyncio.current_task() + return AsyncMock() + + async def __aexit__(self, *args): + if asyncio.current_task() is not self.enter_task: + raise RuntimeError("Attempted to exit cancel scope in a different task than it was entered in") + self.closed = True + return False + + +async def _get_session_in_own_task(pool, *args): + """Create a pooled session from a *dedicated* child task. + + In production every stdio session is entered from its own short-lived task + (the sync-tool path runs each call through a fresh ``asyncio.run``). This + helper reproduces that so the close paths are exercised from a *different* + task than the one that entered the session — the exact condition that + triggered #3379. + """ + return await asyncio.create_task(pool.get_session(*args)) + + +@pytest.mark.asyncio +async def test_close_all_does_not_cross_tasks(): + """close_all must not raise the cross-task cancel-scope RuntimeError (#3379).""" + pool = MCPSessionPool() + cms: list[_CancelScopeCm] = [] + + def make_cm(*a, **kw): + cm = _CancelScopeCm() + cms.append(cm) + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + await _get_session_in_own_task(pool, "s1", "t1", {"transport": "stdio", "command": "x", "args": []}) + await _get_session_in_own_task(pool, "s2", "t2", {"transport": "stdio", "command": "x", "args": []}) + + # close_all runs in this task, which is *not* the task that entered either + # session. The owner task must perform __aexit__ so each CM closes cleanly. + await pool.close_all() + + assert all(cm.closed for cm in cms) + assert len(pool._entries) == 0 + + +@pytest.mark.asyncio +async def test_close_scope_does_not_cross_tasks(): + """close_scope must respect the same-task cancel-scope rule (#3379).""" + pool = MCPSessionPool() + cms: list[_CancelScopeCm] = [] + + def make_cm(*a, **kw): + cm = _CancelScopeCm() + cms.append(cm) + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + await _get_session_in_own_task(pool, "s", "t1", {"transport": "stdio", "command": "x", "args": []}) + await _get_session_in_own_task(pool, "s", "t2", {"transport": "stdio", "command": "x", "args": []}) + + await pool.close_scope("t1") + + assert cms[0].closed is True + assert cms[1].closed is False + assert ("s", "t2") in pool._entries + + +@pytest.mark.asyncio +async def test_lru_eviction_does_not_cross_tasks(): + """LRU eviction must close the victim without a cross-task RuntimeError (#3379).""" + pool = MCPSessionPool() + pool.MAX_SESSIONS = 2 + cms: list[_CancelScopeCm] = [] + + def make_cm(*a, **kw): + cm = _CancelScopeCm() + cms.append(cm) + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + await _get_session_in_own_task(pool, "s", "t1", {"transport": "stdio", "command": "x", "args": []}) + await _get_session_in_own_task(pool, "s", "t2", {"transport": "stdio", "command": "x", "args": []}) + # Adding t3 evicts t1 — its own owner task must run __aexit__, even + # though the eviction is driven from t3's get_session call. + await _get_session_in_own_task(pool, "s", "t3", {"transport": "stdio", "command": "x", "args": []}) + + assert cms[0].closed is True + assert cms[1].closed is False + assert cms[2].closed is False + + +def test_close_all_sync_across_loops_does_not_cross_tasks(): + """close_all_sync, the path hit by the sync tool wrapper, must close sessions + created in earlier (now-finished) asyncio.run loops without crashing (#3379). + """ + pool = MCPSessionPool() + cms: list[_CancelScopeCm] = [] + + def make_cm(*a, **kw): + cm = _CancelScopeCm() + cms.append(cm) + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + # Simulate the sync-tool path: a session created inside one short-lived + # event loop, then a second one in a different loop. + asyncio.run(pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []})) + asyncio.run(pool.get_session("s", "t2", {"transport": "stdio", "command": "x", "args": []})) + + # The owning loops are already closed; close_all_sync must not raise. + pool.close_all_sync() + + assert len(pool._entries) == 0 + + +def test_get_session_replaces_session_from_closed_loop(): + """A pooled session whose owning loop has closed is evicted and recreated.""" + pool = MCPSessionPool() + cms: list[_CancelScopeCm] = [] + + def make_cm(*a, **kw): + cm = _CancelScopeCm() + cms.append(cm) + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + # First session created in a throwaway loop that is torn down by + # asyncio.run (mirrors the sync-tool path). asyncio.run cancels the + # pending owner task and runs its __aexit__ on the same loop. + asyncio.run(pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []})) + assert ("s", "t1") in pool._entries + + # Now request the same key from a fresh loop: the stale entry (closed + # loop) must be evicted and replaced with a fresh session. + session = asyncio.run(pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []})) + + assert session is not None + assert len(cms) == 2 + assert pool._entries[("s", "t1")][0] is session + + +class _BlockingInitCm: + """Fake session CM whose ``initialize`` blocks until released. + + Lets a test cancel ``get_session`` while the owner task is still + initializing, reproducing the caller-cancellation window. + """ + + def __init__(self, gate: asyncio.Event) -> None: + self._gate = gate + self.entered = False + self.closed = False + + async def __aenter__(self): + self.entered = True + session = MagicMock() + session.initialize = self._initialize + return session + + async def _initialize(self): + await self._gate.wait() + + async def __aexit__(self, *args): + self.closed = True + return False + + +@pytest.mark.asyncio +async def test_get_session_cancelled_while_initializing_does_not_leak(): + """Cancelling get_session mid-init must not leak the owner task/session (#3379 CR). + + The session is not registered yet, so if cancellation skipped the cleanup + the owner task would block forever on close_evt.wait() and the CM's + __aexit__ would never run — an unreachable, unclosable session. + """ + pool = MCPSessionPool() + gate = asyncio.Event() + cms: list[_BlockingInitCm] = [] + + def make_cm(*a, **kw): + cm = _BlockingInitCm(gate) + cms.append(cm) + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + call = asyncio.create_task(pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []})) + # Let the owner task enter the CM and reach the blocking initialize(). + await asyncio.sleep(0.01) + call.cancel() + with pytest.raises(asyncio.CancelledError): + await call + + # Release initialize() so the owner task can finish its shutdown path. + gate.set() + # Give the owner task a chance to run __aexit__ and complete. + for _ in range(10): + if cms and cms[0].closed: + break + await asyncio.sleep(0.01) + + assert len(cms) == 1 + assert cms[0].entered is True + assert cms[0].closed is True, "owner task must run __aexit__ after cancellation" + assert len(pool._entries) == 0 + + current = asyncio.current_task() + leaked = [t for t in asyncio.all_tasks() if t is not current and not t.done() and "_run_session" in str(t.get_coro())] + assert not leaked, "owner task must not be left pending after cancellation" + + +class _InitFailCm: + """Fake session CM whose ``initialize`` fails, with a slow ``__aexit__``. + + The slow __aexit__ lets a test observe whether cleanup is allowed to run to + completion (closed=True) or is interrupted by a stray cancellation. + """ + + def __init__(self) -> None: + self.entered = False + self.exit_started = False + self.closed = False + + async def __aenter__(self): + self.entered = True + session = MagicMock() + session.initialize = self._initialize + return session + + async def _initialize(self): + raise RuntimeError("init boom") + + async def __aexit__(self, *args): + self.exit_started = True + # Yield control so a buggy double-cancel would interrupt us here. + await asyncio.sleep(0.02) + self.closed = True + return False + + +@pytest.mark.asyncio +async def test_get_session_init_failure_runs_full_cleanup(): + """On initialize() failure the owner task's __aexit__ must complete (#3379 CR P1). + + The caller must NOT cancel the owner task on a reported failure, otherwise + the in-progress __aexit__ cleanup gets interrupted and leaks resources. + """ + pool = MCPSessionPool() + cms: list[_InitFailCm] = [] + + def make_cm(*a, **kw): + cm = _InitFailCm() + cms.append(cm) + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + with pytest.raises(RuntimeError, match="init boom"): + await pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []}) + + assert len(cms) == 1 + assert cms[0].entered is True + assert cms[0].exit_started is True + assert cms[0].closed is True, "__aexit__ must run to completion, not be interrupted" + assert len(pool._entries) == 0 + assert len(pool._inflight) == 0 + + +@pytest.mark.asyncio +async def test_concurrent_get_session_same_key_creates_single_session(): + """Concurrent get_session for the same key must share one session (#3379 CR P1).""" + pool = MCPSessionPool() + gate = asyncio.Event() + cms: list[_BlockingInitCm] = [] + + def make_cm(*a, **kw): + cm = _BlockingInitCm(gate) + cms.append(cm) + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + conn = {"transport": "stdio", "command": "x", "args": []} + t1 = asyncio.create_task(pool.get_session("s", "same", conn)) + t2 = asyncio.create_task(pool.get_session("s", "same", conn)) + # Let both calls pass Phase 1 and reach the (gated) initialize(). + await asyncio.sleep(0.02) + gate.set() + s1, s2 = await asyncio.gather(t1, t2) + + # Only one CM/session created, both callers got the same object. + assert len(cms) == 1, "concurrent same-key calls must not create duplicate sessions" + assert s1 is s2 + assert len(pool._entries) == 1 + assert len(pool._inflight) == 0 + + +@pytest.mark.asyncio +async def test_close_all_during_in_flight_creation_does_not_resurrect_session(): + """close_all while a creation is in-flight must not leave a live session (#3379 CR P1). + + The in-flight record must be removed and its owner task torn down, so when + the (blocked) creator finishes initializing it does NOT register the session + back into _entries — otherwise the pool resurrects an unclosable session. + """ + pool = MCPSessionPool() + gate = asyncio.Event() + cms: list[_BlockingInitCm] = [] + + def make_cm(*a, **kw): + cm = _BlockingInitCm(gate) + cms.append(cm) + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + conn = {"transport": "stdio", "command": "x", "args": []} + call = asyncio.create_task(pool.get_session("s", "t1", conn)) + # Let the owner task enter the CM and reach the blocking initialize(). + await asyncio.sleep(0.01) + assert ("s", "t1") in pool._inflight + + # Close everything while the creation is still in-flight. + await pool.close_all() + + # The in-flight creation must be gone, not promoted to an entry. + assert len(pool._inflight) == 0 + assert len(pool._entries) == 0 + + # Even if the gate is released afterwards, nothing must come back. + gate.set() + with pytest.raises(asyncio.CancelledError): + await call + + assert len(pool._entries) == 0 + assert len(pool._inflight) == 0 + assert cms[0].closed is True, "in-flight session's __aexit__ must run on teardown" + + current = asyncio.current_task() + leaked = [t for t in asyncio.all_tasks() if t is not current and not t.done() and "_run_session" in str(t.get_coro())] + assert not leaked, "in-flight owner task must not leak after close_all" + + +def test_get_session_cross_loop_in_flight_does_not_raise_assertion(): + """A same-key request from another loop must not hit the in-flight assertion (#3379 CR P1). + + Loop A starts (and leaves running) an in-flight creation, then loop B + requests the same key. The stale in-flight record (owned by loop A) must be + dropped and loop B must become a fresh creator — never fall through to an + AssertionError. + """ + pool = MCPSessionPool() + cms: list[_CancelScopeCm] = [] + + def make_cm(*a, **kw): + cm = _CancelScopeCm() + cms.append(cm) + return cm + + conn = {"transport": "stdio", "command": "x", "args": []} + results: list[object] = [] + errors: list[BaseException] = [] + + def run_in_own_loop(): + try: + results.append(asyncio.run(pool.get_session("s", "t1", conn))) + except BaseException as e: # noqa: BLE001 - capture for assertion + errors.append(e) + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + # First loop creates and registers an entry, then its loop is torn down + # by asyncio.run, leaving a stale (closed-loop) record behind. + t1 = threading.Thread(target=run_in_own_loop) + t1.start() + t1.join() + + # Second loop requests the same key. It must evict the stale record and + # create a fresh session instead of raising AssertionError. + t2 = threading.Thread(target=run_in_own_loop) + t2.start() + t2.join() + + assert not errors, f"cross-loop same-key request must not raise: {errors}" + assert len(results) == 2 + assert all(r is not None for r in results) + + +def test_cross_loop_preempting_blocked_in_flight_does_not_hang_owner(): + """A foreign-loop request must not leave a still-initializing owner hung (#3379 CR P1). + + Loop A starts a creation that blocks inside initialize() (the in-flight + record stays live). Loop B then requests the same key. B must tear A's owner + down — cancelling it, because close_evt alone cannot wake a task blocked in + initialize() — so that A's get_session unwinds instead of hanging forever. + """ + pool = MCPSessionPool() + conn = {"transport": "stdio", "command": "x", "args": []} + first_gate = threading.Event() + entered = threading.Event() + results: list[tuple[str, object]] = [] + errors: list[tuple[str, BaseException]] = [] + closed: list[str] = [] + + class _BlockingForeverCm: + async def __aenter__(self): + session = MagicMock() + session.initialize = self._initialize + entered.set() + return session + + async def _initialize(self): + # Block until released, simulating a slow/stuck server handshake. + while not first_gate.is_set(): + await asyncio.sleep(0.005) + + async def __aexit__(self, *args): + closed.append("blocking") + return False + + class _FastCm: + async def __aenter__(self): + session = MagicMock() + + async def init(): + return None + + session.initialize = init + return session + + async def __aexit__(self, *args): + return False + + cms: list[object] = [_BlockingForeverCm(), _FastCm()] + + def make_cm(*a, **kw): + return cms.pop(0) + + def run_get(name): + try: + results.append((name, asyncio.run(pool.get_session("s", "t1", conn)))) + except BaseException as e: # noqa: BLE001 - capture for assertion + errors.append((name, e)) + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + ta = threading.Thread(target=run_get, args=("A",)) + ta.start() + assert entered.wait(2), "owner A must enter the CM and start initializing" + + tb = threading.Thread(target=run_get, args=("B",)) + tb.start() + tb.join(3) + + # B must complete without depending on A's blocked initialize(). + assert not tb.is_alive(), "foreign-loop request B must not hang" + # A must already be unwound (cancelled), not waiting on the dead gate. + ta.join(3) + assert not ta.is_alive(), "preempted owner A must not hang forever" + + assert [n for n, _ in results] == ["B"], "only B produces a usable session" + assert any(isinstance(e, asyncio.CancelledError) for _, e in errors), "preempted A must unwind via CancelledError" + assert "blocking" in closed, "preempted owner's __aexit__ must run on teardown" + + +@pytest.mark.asyncio +async def test_close_all_sync_from_running_loop_does_not_wait_on_itself(): + """close_all_sync must not block on the current running loop (#3379 CR P1). + + When called from code already executing inside the owner loop's thread, + close_all_sync cannot synchronously wait for that loop to run the shutdown + coroutine. It must signal the owner task and return promptly, then the owner + task closes itself once the loop regains control. + """ + pool = MCPSessionPool() + pool.SESSION_CLOSE_TIMEOUT = 0.2 + conn = {"transport": "stdio", "command": "x", "args": []} + + cm = _CloseTrackingCm() + + def make_cm(*a, **kw): + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + await pool.get_session("s", "t1", conn) + start = asyncio.get_running_loop().time() + pool.close_all_sync() + elapsed = asyncio.get_running_loop().time() - start + + assert elapsed < 0.1, "close_all_sync must not stall until timeout on the current loop" + assert len(pool._entries) == 0 + assert len(pool._inflight) == 0 + assert cm.closed is False, "owner task has not run yet while close_all_sync is still executing" + + for _ in range(10): + if cm.closed: + break + await asyncio.sleep(0.01) + + assert cm.closed is True, "owner task must close itself after the loop regains control" + + +# --------------------------------------------------------------------------- +# reset_mcp_tools_cache deadlock regression +# --------------------------------------------------------------------------- + + +class _CloseTrackingCm: + """A create_session() context manager that records when __aexit__ runs.""" + + def __init__(self) -> None: + self.closed = False + + async def __aenter__(self): + session = MagicMock() + + async def init(): + return None + + session.initialize = init + return session + + async def __aexit__(self, *args): + self.closed = True + return False + + +def test_reset_mcp_tools_cache_from_running_loop_is_bounded(): + """reset_mcp_tools_cache() must not deadlock when called from inside a + running loop that owns sessions (#3392 CR blocker). + + The previous implementation spun up a worker thread running + ``asyncio.run(pool.close_all())`` and blocked the loop thread on + ``.result()``. close_all() then routed teardown of the current loop's + sessions back onto that blocked loop via run_coroutine_threadsafe(...), + so neither side could make progress. This test drives the exact scenario + on a daemon thread and asserts the call returns within a bounded time. + """ + from deerflow.mcp.cache import reset_mcp_tools_cache + from deerflow.mcp.session_pool import get_session_pool + + conn = {"transport": "stdio", "command": "x", "args": []} + cm = _CloseTrackingCm() + done = threading.Event() + + async def scenario(): + pool = get_session_pool() + # Entry owned by THIS loop — the deadlock-prone case. + await pool.get_session("s", "t1", conn) + # Synchronous call: asyncio.get_running_loop() succeeds inside it, so + # it takes the "running loop" branch in reset_mcp_tools_cache(). + reset_mcp_tools_cache() + # Signal-only teardown completes once the loop regains control. + await asyncio.sleep(0.05) + + def run(): + asyncio.run(scenario()) + done.set() + + t = threading.Thread(target=run, daemon=True) + with patch("langchain_mcp_adapters.sessions.create_session", return_value=cm): + t.start() + t.join(timeout=5) + + assert done.is_set(), "reset_mcp_tools_cache() deadlocked inside a running loop" + assert cm.closed is True, "owner task must run __aexit__ once the loop regains control" diff --git a/backend/tests/test_reload_boundary.py b/backend/tests/test_reload_boundary.py new file mode 100644 index 000000000..5610ccafb --- /dev/null +++ b/backend/tests/test_reload_boundary.py @@ -0,0 +1,140 @@ +"""Regression tests for the config reload boundary registry. + +Bytedance/deer-flow issue #3144: the hot-reload boundary is the contract +between gateway dependencies that resolve ``AppConfig`` every request and the +infrastructure that captures the snapshot once at startup. The registry in +``deerflow.config.reload_boundary`` is the machine-readable source of truth; +these tests pin the registry against the actual Pydantic schema so a future +field rename / addition / boundary change cannot silently drift. +""" + +from __future__ import annotations + +import pytest + +from deerflow.config.app_config import AppConfig +from deerflow.config.reload_boundary import ( + STARTUP_ONLY_FIELDS, + STARTUP_ONLY_PREFIX, + format_field_description, + is_startup_only_field, + iter_startup_only_field_paths, +) + + +def test_registry_has_a_reason_for_every_field(): + """Every registry entry must explain *why* the field is restart-required. + + The reason text is what surfaces in IDE hover and in the AppConfig schema + description, so an empty / placeholder value would defeat the purpose. + """ + for field_path, reason in STARTUP_ONLY_FIELDS.items(): + assert reason.strip(), f"empty reason for {field_path}" + assert len(reason) > 20, f"reason for {field_path} too short to be useful: {reason!r}" + + +def test_iter_startup_only_field_paths_matches_registry(): + """Iterator stays in sync with the registry mapping.""" + assert sorted(iter_startup_only_field_paths()) == sorted(STARTUP_ONLY_FIELDS) + + +def test_is_startup_only_field_recognises_registered_fields(): + """The membership helper accepts every registered field path.""" + for field_path in STARTUP_ONLY_FIELDS: + assert is_startup_only_field(field_path) + assert not is_startup_only_field("memory") # hot-reloadable + assert not is_startup_only_field("models") + assert not is_startup_only_field("nonexistent_field") + + +def test_format_field_description_prefixes_with_marker(): + """The formatter produces a description that machine-readable tooling can + pivot on (drift tests, future "needs-restart" scanners).""" + for field_path in STARTUP_ONLY_FIELDS: + text = format_field_description(field_path) + assert text.startswith(STARTUP_ONLY_PREFIX), text + # The reason is appended after the prefix; the formatter must not + # silently drop it. + assert STARTUP_ONLY_FIELDS[field_path] in text + + +def test_format_field_description_rejects_unknown_field(): + with pytest.raises(KeyError): + format_field_description("not_in_registry") + + +def test_format_field_description_appends_optional_field_doc(): + """The formatter composes the startup-only marker with the field's own + human-facing description when supplied. + + The original ``Field(description=)`` used to document allowed values + (e.g. ``log_level`` listed ``debug/info/warning/error``); registry + adoption must not drop that. The composed output keeps the marker as + the leading token so machine-readable tooling still pivots on it, + then appends the prose after a blank line. + """ + text = format_field_description("log_level", field_doc="Logging level (debug/info/warning/error).") + assert text.startswith(STARTUP_ONLY_PREFIX) + assert STARTUP_ONLY_FIELDS["log_level"] in text + assert "debug/info/warning/error" in text + + +def test_appconfig_descriptions_retain_original_field_documentation(): + """``AppConfig.model_fields[name].description`` for restart-required + fields should still carry the original human-facing field doc so IDE + hover documents what the field is *and* why a restart is needed.""" + descriptions = { + "log_level": "debug/info/warning/error", + "database": "memory, sqlite, or postgres", + "sandbox": "Sandbox provider", + "run_events": "memory for dev", + "checkpointer": "state-persistence checkpointer", + "stream_bridge": "Stream bridge", + } + for field_name, expected_substring in descriptions.items(): + description = AppConfig.model_fields[field_name].description or "" + assert description.startswith(STARTUP_ONLY_PREFIX), f"AppConfig.{field_name} missing startup-only marker" + assert expected_substring in description, f"AppConfig.{field_name} description lost original field doc; got {description!r}" + + +def test_appconfig_schema_marks_registered_fields_with_prefix(): + """Every registry entry that corresponds to a top-level AppConfig field + must carry the standardized ``startup-only:`` prefix in its Pydantic + ``Field(description=...)``. This is the contract IDE hover relies on. + """ + schema_fields = AppConfig.model_fields + for field_path in STARTUP_ONLY_FIELDS: + if field_path not in schema_fields: + # Some entries (e.g. ``channels``) live outside the AppConfig + # schema. The registry still owns them, but the schema-prefix + # assertion does not apply. + continue + description = schema_fields[field_path].description or "" + assert description.startswith(STARTUP_ONLY_PREFIX), f"AppConfig.{field_path} should have Field(description=) starting with {STARTUP_ONLY_PREFIX!r}, got {description!r}" + + +def test_no_appconfig_field_uses_prefix_without_registration(): + """Reverse drift check: if a future schema edit adds the + ``startup-only:`` prefix to a new field, the registry must list it. + + This catches the silent-drift case where someone marks a field + restart-required in the schema but forgets to update the registry + that the operator-facing scanners and docs consume. + """ + for name, info in AppConfig.model_fields.items(): + description = info.description or "" + if not description.startswith(STARTUP_ONLY_PREFIX): + continue + assert name in STARTUP_ONLY_FIELDS, f"AppConfig.{name} schema description starts with {STARTUP_ONLY_PREFIX!r} but the field is not listed in reload_boundary.STARTUP_ONLY_FIELDS — update the registry." + + +def test_pydantic_field_descriptions_are_introspectable_at_runtime(): + """``AppConfig.model_fields[name].description`` is the IDE-hover source. + + If this read ever breaks (e.g. Pydantic deprecation, schema swap), the + IDE-hover guarantee #3144 promises silently regresses. Pin it. + """ + assert "database" in AppConfig.model_fields + description = AppConfig.model_fields["database"].description + assert description is not None + assert description.startswith(STARTUP_ONLY_PREFIX)