From befe334f10158a4c2588d569405b8b42510b07fc Mon Sep 17 00:00:00 2001 From: Xinmin Zeng <135568692+fancyboi999@users.noreply.github.com> Date: Sun, 7 Jun 2026 21:27:14 +0800 Subject: [PATCH 1/2] fix(config): make the reload boundary discoverable from code (#3144) (#3153) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(config): make the reload boundary discoverable from code, not just docs Closes #3144. The hot-reload contract — per-run fields are resolved through `get_app_config()` on every request, infrastructure fields snapshot at gateway startup — landed in `backend/CLAUDE.md` as part of #3131. A maintainer reading `get_config()` or an `AppConfig` field still had to context-switch to that document to know which fields require a process restart, and there was no enforcement that the prose list stayed in sync with the code. This commit moves the boundary to a machine-readable single source of truth and surfaces it where the code lives: - New `deerflow.config.reload_boundary` module owns the registry of restart-required fields (`STARTUP_ONLY_FIELDS`) and a tiny helper API (`is_startup_only_field`, `iter_startup_only_field_paths`, `format_field_description`). The standardised `"startup-only:"` prefix is exported as `STARTUP_ONLY_PREFIX` so future scanners / lint hooks / doc generators can pivot off it without re-parsing prose. - `AppConfig`'s `database`, `checkpointer`, `run_events`, `stream_bridge`, `sandbox`, and `log_level` fields now build their `Field(description=...)` from `format_field_description(...)`. The same text shows up in IDE hover (Pydantic v2 exposes `description` via `model_fields[...]`). - `channels` is restart-required too but lives outside the AppConfig Pydantic schema (the config section is consumed directly by `start_channel_service`). The registry owns it so the boundary is not split between two places. - `get_config()` docstring points to the registry instead of leaving the reader to find `CLAUDE.md`. The `CLAUDE.md` table collapses to a one-liner pointing back at `reload_boundary.py` so the boundary has one canonical location, not two. Drift coverage in `tests/test_reload_boundary.py`: - Every registered field has a non-trivial reason. - Iterator / membership helpers stay in sync with the dict. - Every registry entry that maps to an `AppConfig` field also carries the `"startup-only:"` prefix in the schema (catches "forgot to update the schema"). - Reverse drift: any AppConfig field whose description starts with the prefix must be registered (catches "marked restart-required in the schema but forgot the registry"). - The runtime introspection that IDE hover depends on (`AppConfig.model_fields["database"].description`) is pinned, so a future Pydantic upgrade or schema swap that breaks the hover surface shows up as a test failure rather than a silent regression. Refs: bytedance/deer-flow#3138 (split summary), #3107 (origin), #3131 (prior boundary fix in prose form). * fix(config): preserve field doc and correct log_level reload reason Two follow-ups on the PR #3153 review: 1. The `log_level` STARTUP_ONLY_FIELDS reason previously claimed `apply_logging_level()` mutates the root logger level. It does not: only the `deerflow` / `app` logger levels are set, and root handler thresholds are conditionally lowered so messages from those loggers can propagate. Reword to match the actual behavior so operators reading IDE hover get accurate restart guidance. 2. `format_field_description(field_path)` was the sole `Field(description=)` for every restart-required field, which silently overwrote the original human-facing documentation — most visibly the `log_level` field that used to list debug/info/warning/error and clarify that third-party libraries are not affected. Extend the helper with a keyword-only `field_doc` parameter that composes the startup-only marker with the original prose so IDE hover documents both *why* the field is restart-required and *what* it actually accepts. Updated all six restart-required AppConfig fields (`log_level`, `database`, `sandbox`, `run_events`, `checkpointer`, `stream_bridge`) to pass their original descriptions through the helper. Tests: two new cases in `test_reload_boundary.py` pin (a) the helper composition and (b) every AppConfig restart-required field still surfaces a recognisable substring of its original documentation. --------- Co-authored-by: Willem Jiang --- backend/CLAUDE.md | 12 +- backend/app/gateway/deps.py | 10 ++ .../harness/deerflow/config/app_config.py | 48 +++++- .../deerflow/config/reload_boundary.py | 104 +++++++++++++ backend/tests/test_reload_boundary.py | 140 ++++++++++++++++++ 5 files changed, 298 insertions(+), 16 deletions(-) create mode 100644 backend/packages/harness/deerflow/config/reload_boundary.py create mode 100644 backend/tests/test_reload_boundary.py diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md index 6cc0aebb1..caa36f579 100644 --- a/backend/CLAUDE.md +++ b/backend/CLAUDE.md @@ -223,17 +223,9 @@ Setup: Copy `config.example.yaml` to `config.yaml` in the **project root** direc **Config Caching**: `get_app_config()` caches the parsed config, but automatically reloads it when the resolved config path changes or the file's mtime increases. This keeps Gateway and LangGraph reads aligned with `config.yaml` edits without requiring a manual process restart. -**Config Hot-Reload Boundary**: Gateway dependencies route through `get_app_config()` on every request, so per-run fields like `models[*].max_tokens`, `summarization.*`, `title.*`, `memory.*`, `subagents.*`, `tools[*]`, and the agent system prompt pick up `config.yaml` edits on the next message. `AppConfig` is intentionally **not** cached on `app.state` — `lifespan()` keeps a local `startup_config` variable for one-shot bootstrap work (logging level, channels, `langgraph_runtime` engines) and passes it explicitly to `langgraph_runtime(app, startup_config)`. Infrastructure fields are **restart-required**: +**Config Hot-Reload Boundary**: Gateway dependencies route through `get_app_config()` on every request, so per-run fields like `models[*].max_tokens`, `summarization.*`, `title.*`, `memory.*`, `subagents.*`, `tools[*]`, and the agent system prompt pick up `config.yaml` edits on the next message. `AppConfig` is intentionally **not** cached on `app.state` — `lifespan()` keeps a local `startup_config` variable for one-shot bootstrap work and passes it to `langgraph_runtime(app, startup_config)`. -| Field | Why a restart is required | -|---|---| -| `database.*` | `init_engine_from_config()` runs once during `langgraph_runtime()` startup; the SQLAlchemy engine holds the connection pool. | -| `checkpointer.*` (including SQLite WAL/journal settings) | `make_checkpointer()` binds the persistent checkpointer once at startup. | -| `run_events.*` | `make_run_event_store()` selects memory- vs. SQL-backed implementation at startup. | -| `stream_bridge.*` | `make_stream_bridge()` constructs the bridge object once. | -| `sandbox.use` | `get_sandbox_provider()` caches the provider singleton (`_default_sandbox_provider`); a new class path takes effect only on next process start. | -| `log_level` | `apply_logging_level()` is called only in `app.py` startup; it mutates the root logger's level, and `get_app_config()` returning a fresh `AppConfig` does not retrigger it. | -| `channels.*` IM platform credentials | `start_channel_service()` is invoked once during startup; live channels are not rebuilt on config change. | +Infrastructure fields are **restart-required**. The authoritative list lives in `packages/harness/deerflow/config/reload_boundary.py::STARTUP_ONLY_FIELDS` and is mirrored by the standardised `"startup-only:"` prefix on the corresponding `Field(description=...)` in `AppConfig`, so IDE hover on those fields surfaces the reason inline (no need to context-switch into this table). Currently registered: `database`, `checkpointer`, `run_events`, `stream_bridge`, `sandbox`, `log_level`, `channels`. Adding a new restart-required field requires updating the registry; drift is pinned by `tests/test_reload_boundary.py`. Configuration priority: 1. Explicit `config_path` argument diff --git a/backend/app/gateway/deps.py b/backend/app/gateway/deps.py index 24af1deeb..5739d217d 100644 --- a/backend/app/gateway/deps.py +++ b/backend/app/gateway/deps.py @@ -119,6 +119,16 @@ def get_config() -> AppConfig: split-brain where the worker / lead-agent thread saw a stale startup snapshot. + Hot-reload boundary: fields backed by startup-time singletons + (engines, sandbox provider, IM channels, logging handler) require a + process restart to change at runtime. The authoritative list lives in + :mod:`deerflow.config.reload_boundary` and is mirrored by the + standardised ``"startup-only:"`` prefix on the matching + ``Field(description=...)`` in :class:`AppConfig` — IDE hover on those + fields will surface the boundary inline. See + ``backend/CLAUDE.md`` "Config Hot-Reload Boundary" for the operator + summary. + Any failure to materialise the config (missing file, permission denied, YAML parse error, validation error) is reported as 503 — semantically "the gateway cannot serve requests without a usable configuration" — and diff --git a/backend/packages/harness/deerflow/config/app_config.py b/backend/packages/harness/deerflow/config/app_config.py index 8fcc564a8..842b49d7a 100644 --- a/backend/packages/harness/deerflow/config/app_config.py +++ b/backend/packages/harness/deerflow/config/app_config.py @@ -18,6 +18,7 @@ from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_ from deerflow.config.loop_detection_config import LoopDetectionConfig from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict from deerflow.config.model_config import ModelConfig +from deerflow.config.reload_boundary import format_field_description from deerflow.config.run_events_config import RunEventsConfig from deerflow.config.runtime_paths import existing_project_file from deerflow.config.safety_finish_reason_config import SafetyFinishReasonConfig @@ -85,10 +86,21 @@ def apply_logging_level(name: str | None) -> None: class AppConfig(BaseModel): """Config for the DeerFlow application""" - log_level: str = Field(default="info", description="Logging level for deerflow and app modules (debug/info/warning/error); third-party libraries are not affected") + log_level: str = Field( + default="info", + description=format_field_description( + "log_level", + field_doc="Logging level for deerflow and app modules (debug/info/warning/error); third-party libraries are not affected.", + ), + ) token_usage: TokenUsageConfig = Field(default_factory=TokenUsageConfig, description="Token usage tracking configuration") models: list[ModelConfig] = Field(default_factory=list, description="Available models") - sandbox: SandboxConfig = Field(description="Sandbox configuration") + sandbox: SandboxConfig = Field( + description=format_field_description( + "sandbox", + field_doc="Sandbox provider configuration (local filesystem or Docker-based aio sandbox).", + ), + ) tools: list[ToolConfig] = Field(default_factory=list, description="Available tools") tool_groups: list[ToolGroupConfig] = Field(default_factory=list, description="Available tool groups") skills: SkillsConfig = Field(default_factory=SkillsConfig, description="Skills configuration") @@ -107,10 +119,34 @@ class AppConfig(BaseModel): loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration") safety_finish_reason: SafetyFinishReasonConfig = Field(default_factory=SafetyFinishReasonConfig, description="Provider safety-filter finish_reason interception middleware configuration") model_config = ConfigDict(extra="allow") - database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration") - run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration") - checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration") - stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration") + database: DatabaseConfig = Field( + default_factory=DatabaseConfig, + description=format_field_description( + "database", + field_doc="Unified database backend for run/feedback metadata (memory, sqlite, or postgres).", + ), + ) + run_events: RunEventsConfig = Field( + default_factory=RunEventsConfig, + description=format_field_description( + "run_events", + field_doc="Run-event store backend (memory for dev, db for production queries, jsonl for lightweight single-node persistence).", + ), + ) + checkpointer: CheckpointerConfig | None = Field( + default=None, + description=format_field_description( + "checkpointer", + field_doc="LangGraph state-persistence checkpointer configuration.", + ), + ) + stream_bridge: StreamBridgeConfig | None = Field( + default=None, + description=format_field_description( + "stream_bridge", + field_doc="Stream bridge connecting agent workers to SSE endpoints.", + ), + ) @classmethod def resolve_config_path(cls, config_path: str | None = None) -> Path: diff --git a/backend/packages/harness/deerflow/config/reload_boundary.py b/backend/packages/harness/deerflow/config/reload_boundary.py new file mode 100644 index 000000000..d39502776 --- /dev/null +++ b/backend/packages/harness/deerflow/config/reload_boundary.py @@ -0,0 +1,104 @@ +"""Single source of truth for the config hot-reload boundary. + +Bytedance/deer-flow issue #3144: gateway request dependencies resolve +``AppConfig`` through ``get_app_config()`` on every request, so per-run +fields take effect on the next message without restarting the gateway. +The fields listed in this module are the **infrastructure** subset that +the gateway captures once at startup — engines, singletons, IM clients, +the logging handler — and that therefore require a process restart to +change at runtime. + +The registry covers two kinds of entries: + +- Top-level ``AppConfig`` fields (``database``, ``checkpointer``, + ``run_events``, ``stream_bridge``, ``sandbox``, ``log_level``). For + these, :func:`format_field_description` produces the standardised + ``"startup-only: ..."`` prefix that the matching Pydantic + ``Field(description=...)`` carries, so the boundary surfaces in IDE + hover next to the field itself. +- Top-level ``config.yaml`` sections that are not part of the + ``AppConfig`` schema (``channels``). These cannot be standardised at + the schema level, so the registry is their only canonical location. + +Any future "needs restart" scanner — operator tooling, lint hooks, doc +generators — should drive off this registry rather than re-parsing +prose. +""" + +from __future__ import annotations + +from collections.abc import Iterator + +#: The standardised prefix every restart-required field description starts +#: with. ``test_reload_boundary`` enforces both directions: registered +#: fields must use this prefix in the schema, and any schema field using +#: this prefix must be in the registry. +STARTUP_ONLY_PREFIX = "startup-only:" + + +#: Restart-required field paths mapped to the human-readable reason. +#: +#: The reason text is what surfaces in ``Field(description=...)``, so it +#: must explain *what* code captures the snapshot — not just that the +#: field is restart-required — so an operator changing the value knows +#: which subsystem to restart. +STARTUP_ONLY_FIELDS: dict[str, str] = { + "database": ("init_engine_from_config() runs once during langgraph_runtime() startup; the SQLAlchemy engine holds the connection pool and is not rebuilt on config.yaml edits."), + "checkpointer": ("make_checkpointer() binds the persistent checkpointer once at startup, including SQLite WAL / busy_timeout settings."), + "run_events": ("make_run_event_store() picks the memory- vs SQL-backed implementation at startup and is frozen onto app.state.run_events_config to stay paired with the underlying event store."), + "stream_bridge": ("make_stream_bridge() constructs the stream-bridge singleton once during startup."), + "sandbox": ("get_sandbox_provider() caches the provider singleton (``_default_sandbox_provider``); a different ``sandbox.use`` class path only takes effect on next process start."), + "log_level": ( + "apply_logging_level() runs only during app.py startup; it sets the deerflow/app logger levels and may lower root handler thresholds so configured messages can propagate. A freshly reloaded AppConfig does not retrigger it." + ), + # Not part of the AppConfig Pydantic schema — channel credentials are + # consumed directly by ``start_channel_service()`` once at lifespan + # startup and the live channel clients are not rebuilt on + # config.yaml edits. + "channels": ("start_channel_service() is invoked once during startup; the live IM channel clients (Feishu, Slack, Telegram, DingTalk) are not rebuilt when channels.* changes."), +} + + +def iter_startup_only_field_paths() -> Iterator[str]: + """Yield every registered restart-required field path.""" + return iter(STARTUP_ONLY_FIELDS) + + +def is_startup_only_field(field_path: str) -> bool: + """Return ``True`` when *field_path* is registered as restart-required. + + Accepts only top-level paths (``"database"``, ``"sandbox"`` etc.); + nested keys like ``"database.url"`` are not modelled here because the + boundary is per-section, not per-leaf. + """ + return field_path in STARTUP_ONLY_FIELDS + + +def format_field_description(field_path: str, *, field_doc: str | None = None) -> str: + """Build the standardised description for a registered field. + + Used inside ``AppConfig`` ``Field(description=...)`` so the hover + text in IDEs matches the registry and the drift tests can pin one + side against the other. + + Args: + field_path: A registered top-level field path (e.g. ``"log_level"``). + field_doc: Optional human-facing description for the field itself + (allowed values, semantics, etc.). When supplied, it is + appended after the ``startup-only:`` marker block separated by + a blank line so IDE hover shows both the restart-required + reason *and* the field's normal documentation. Composition + keeps the marker as the leading token machine-readable tooling + pivots on while restoring the prose that ``Field(description=)`` + used to carry before the registry took over. + + Raises: + KeyError: when *field_path* is not registered. This is deliberate + — silently returning a placeholder would let a typo bypass + the drift coverage. + """ + reason = STARTUP_ONLY_FIELDS[field_path] + header = f"{STARTUP_ONLY_PREFIX} {reason}" + if field_doc is None: + return header + return f"{header}\n\n{field_doc.strip()}" diff --git a/backend/tests/test_reload_boundary.py b/backend/tests/test_reload_boundary.py new file mode 100644 index 000000000..5610ccafb --- /dev/null +++ b/backend/tests/test_reload_boundary.py @@ -0,0 +1,140 @@ +"""Regression tests for the config reload boundary registry. + +Bytedance/deer-flow issue #3144: the hot-reload boundary is the contract +between gateway dependencies that resolve ``AppConfig`` every request and the +infrastructure that captures the snapshot once at startup. The registry in +``deerflow.config.reload_boundary`` is the machine-readable source of truth; +these tests pin the registry against the actual Pydantic schema so a future +field rename / addition / boundary change cannot silently drift. +""" + +from __future__ import annotations + +import pytest + +from deerflow.config.app_config import AppConfig +from deerflow.config.reload_boundary import ( + STARTUP_ONLY_FIELDS, + STARTUP_ONLY_PREFIX, + format_field_description, + is_startup_only_field, + iter_startup_only_field_paths, +) + + +def test_registry_has_a_reason_for_every_field(): + """Every registry entry must explain *why* the field is restart-required. + + The reason text is what surfaces in IDE hover and in the AppConfig schema + description, so an empty / placeholder value would defeat the purpose. + """ + for field_path, reason in STARTUP_ONLY_FIELDS.items(): + assert reason.strip(), f"empty reason for {field_path}" + assert len(reason) > 20, f"reason for {field_path} too short to be useful: {reason!r}" + + +def test_iter_startup_only_field_paths_matches_registry(): + """Iterator stays in sync with the registry mapping.""" + assert sorted(iter_startup_only_field_paths()) == sorted(STARTUP_ONLY_FIELDS) + + +def test_is_startup_only_field_recognises_registered_fields(): + """The membership helper accepts every registered field path.""" + for field_path in STARTUP_ONLY_FIELDS: + assert is_startup_only_field(field_path) + assert not is_startup_only_field("memory") # hot-reloadable + assert not is_startup_only_field("models") + assert not is_startup_only_field("nonexistent_field") + + +def test_format_field_description_prefixes_with_marker(): + """The formatter produces a description that machine-readable tooling can + pivot on (drift tests, future "needs-restart" scanners).""" + for field_path in STARTUP_ONLY_FIELDS: + text = format_field_description(field_path) + assert text.startswith(STARTUP_ONLY_PREFIX), text + # The reason is appended after the prefix; the formatter must not + # silently drop it. + assert STARTUP_ONLY_FIELDS[field_path] in text + + +def test_format_field_description_rejects_unknown_field(): + with pytest.raises(KeyError): + format_field_description("not_in_registry") + + +def test_format_field_description_appends_optional_field_doc(): + """The formatter composes the startup-only marker with the field's own + human-facing description when supplied. + + The original ``Field(description=)`` used to document allowed values + (e.g. ``log_level`` listed ``debug/info/warning/error``); registry + adoption must not drop that. The composed output keeps the marker as + the leading token so machine-readable tooling still pivots on it, + then appends the prose after a blank line. + """ + text = format_field_description("log_level", field_doc="Logging level (debug/info/warning/error).") + assert text.startswith(STARTUP_ONLY_PREFIX) + assert STARTUP_ONLY_FIELDS["log_level"] in text + assert "debug/info/warning/error" in text + + +def test_appconfig_descriptions_retain_original_field_documentation(): + """``AppConfig.model_fields[name].description`` for restart-required + fields should still carry the original human-facing field doc so IDE + hover documents what the field is *and* why a restart is needed.""" + descriptions = { + "log_level": "debug/info/warning/error", + "database": "memory, sqlite, or postgres", + "sandbox": "Sandbox provider", + "run_events": "memory for dev", + "checkpointer": "state-persistence checkpointer", + "stream_bridge": "Stream bridge", + } + for field_name, expected_substring in descriptions.items(): + description = AppConfig.model_fields[field_name].description or "" + assert description.startswith(STARTUP_ONLY_PREFIX), f"AppConfig.{field_name} missing startup-only marker" + assert expected_substring in description, f"AppConfig.{field_name} description lost original field doc; got {description!r}" + + +def test_appconfig_schema_marks_registered_fields_with_prefix(): + """Every registry entry that corresponds to a top-level AppConfig field + must carry the standardized ``startup-only:`` prefix in its Pydantic + ``Field(description=...)``. This is the contract IDE hover relies on. + """ + schema_fields = AppConfig.model_fields + for field_path in STARTUP_ONLY_FIELDS: + if field_path not in schema_fields: + # Some entries (e.g. ``channels``) live outside the AppConfig + # schema. The registry still owns them, but the schema-prefix + # assertion does not apply. + continue + description = schema_fields[field_path].description or "" + assert description.startswith(STARTUP_ONLY_PREFIX), f"AppConfig.{field_path} should have Field(description=) starting with {STARTUP_ONLY_PREFIX!r}, got {description!r}" + + +def test_no_appconfig_field_uses_prefix_without_registration(): + """Reverse drift check: if a future schema edit adds the + ``startup-only:`` prefix to a new field, the registry must list it. + + This catches the silent-drift case where someone marks a field + restart-required in the schema but forgets to update the registry + that the operator-facing scanners and docs consume. + """ + for name, info in AppConfig.model_fields.items(): + description = info.description or "" + if not description.startswith(STARTUP_ONLY_PREFIX): + continue + assert name in STARTUP_ONLY_FIELDS, f"AppConfig.{name} schema description starts with {STARTUP_ONLY_PREFIX!r} but the field is not listed in reload_boundary.STARTUP_ONLY_FIELDS — update the registry." + + +def test_pydantic_field_descriptions_are_introspectable_at_runtime(): + """``AppConfig.model_fields[name].description`` is the IDE-hover source. + + If this read ever breaks (e.g. Pydantic deprecation, schema swap), the + IDE-hover guarantee #3144 promises silently regresses. Pin it. + """ + assert "database" in AppConfig.model_fields + description = AppConfig.model_fields["database"].description + assert description is not None + assert description.startswith(STARTUP_ONLY_PREFIX) From d8b728f7cbcf15d0d48d308abfe0df9e0ab655e1 Mon Sep 17 00:00:00 2001 From: Ryker_Feng <90562015+18062706139fcz@users.noreply.github.com> Date: Sun, 7 Jun 2026 21:37:30 +0800 Subject: [PATCH 2/2] fix(mcp): close stdio sessions on their owning loop to avoid cross-task cancel-scope error (#3379) (#3392) * fix(mcp): close stdio sessions on their owning loop to avoid cross-task cancel-scope error (#3379) Adopt an owner-task lifecycle for pooled MCP ClientSessions so each session is entered, initialized, and exited within a single asyncio task on its owning event loop. This eliminates the anyio "Attempted to exit cancel scope in a different task than it was entered in" RuntimeError that surfaced when stdio MCP tools were used via the sync tool wrapper (which spins up and tears down event loops across tasks). Also harden the pool lifecycle: - track in-flight session creation per (server, scope) to dedupe concurrent get_session() calls for the same key - make close_scope/close_server/close_all/close_all_sync cover both established entries and in-flight creations so sessions cannot be resurrected or leaked after close - handle cross-loop preemption of an in-flight creation by cancelling the stale owner task instead of only signalling it - define close_all_sync() semantics for a running loop on the current thread (signal-only, async completion) and route reset_mcp_tools_cache through a deterministic async close in that case * fix(mcp): avoid reset deadlock on running loop cache reset * fix(mcp): address session pool review feedback --- .../packages/harness/deerflow/mcp/__init__.py | 6 +- .../packages/harness/deerflow/mcp/cache.py | 13 +- .../harness/deerflow/mcp/session_pool.py | 391 ++++++++++-- backend/tests/test_mcp_session_pool.py | 584 ++++++++++++++++++ 4 files changed, 924 insertions(+), 70 deletions(-) diff --git a/backend/packages/harness/deerflow/mcp/__init__.py b/backend/packages/harness/deerflow/mcp/__init__.py index 74195c195..839cce1ee 100644 --- a/backend/packages/harness/deerflow/mcp/__init__.py +++ b/backend/packages/harness/deerflow/mcp/__init__.py @@ -1,6 +1,10 @@ """MCP (Model Context Protocol) integration using langchain-mcp-adapters.""" -from .cache import get_cached_mcp_tools, initialize_mcp_tools, reset_mcp_tools_cache +from .cache import ( + get_cached_mcp_tools, + initialize_mcp_tools, + reset_mcp_tools_cache, +) from .client import build_server_params, build_servers_config from .tools import get_mcp_tools diff --git a/backend/packages/harness/deerflow/mcp/cache.py b/backend/packages/harness/deerflow/mcp/cache.py index 20cc48b1e..333a034c8 100644 --- a/backend/packages/harness/deerflow/mcp/cache.py +++ b/backend/packages/harness/deerflow/mcp/cache.py @@ -143,11 +143,20 @@ def reset_mcp_tools_cache() -> None: # Close persistent sessions – they will be recreated by the next # get_mcp_tools() call with the (possibly updated) connection config. + # + # close_all_sync() already picks the correct strategy per owning loop: + # * sessions owned by the *current* running loop are only *signalled* + # (their owner task runs __aexit__ once the loop regains control – + # this is correct and leak-free, since the loop keeps the task alive), + # * sessions on other threads' loops are torn down deterministically, + # * idle/closed loops are handled or skipped. + # We deliberately do NOT try to synchronously wait for the current running + # loop to finish teardown here: that is a self-deadlock (the loop can only + # run the teardown after this synchronous call returns control to it). try: from deerflow.mcp.session_pool import get_session_pool - pool = get_session_pool() - pool.close_all_sync() + get_session_pool().close_all_sync() except Exception: logger.debug("Could not close MCP session pool on cache reset", exc_info=True) diff --git a/backend/packages/harness/deerflow/mcp/session_pool.py b/backend/packages/harness/deerflow/mcp/session_pool.py index 8450cac8e..63703ba25 100644 --- a/backend/packages/harness/deerflow/mcp/session_pool.py +++ b/backend/packages/harness/deerflow/mcp/session_pool.py @@ -8,6 +8,27 @@ This module provides a session pool that maintains persistent MCP sessions, scoped by ``(server_name, scope_key)`` — typically scope_key is the thread_id — so that consecutive tool calls share the same session and server-side state. Sessions are evicted in LRU order when the pool reaches capacity. + +Lifecycle model (owner task) +---------------------------- +An MCP ``ClientSession`` is implemented on top of an ``anyio`` task group, and +anyio enforces that a cancel scope must be exited from the *same task* that +entered it. Calling ``cm.__aexit__`` from any task other than the one that ran +``cm.__aenter__`` raises:: + + RuntimeError: Attempted to exit cancel scope in a different task than it + was entered in + +The sync-tool path (``make_sync_tool_wrapper``) drives each call through a fresh +``asyncio.run`` event loop, so a session entered while answering one call would +otherwise be exited while answering another — from a different task — and crash +(GitHub issue #3379). + +To make this impossible, every pooled session is owned by a dedicated +``_run_session`` task. That task enters the context manager, hands the live +session back to the caller, and then *waits* on a close event. All shutdown +paths only ever **signal** that event; the owner task performs ``__aexit__`` +itself, guaranteeing enter and exit always happen in the same task. """ from __future__ import annotations @@ -27,18 +48,81 @@ class MCPSessionPool: """Manages persistent MCP sessions scoped by ``(server_name, scope_key)``.""" MAX_SESSIONS = 256 - SESSION_CLOSE_TIMEOUT = 5.0 # seconds to wait when closing a session via run_coroutine_threadsafe + SESSION_CLOSE_TIMEOUT = 5.0 # seconds to wait when closing a session on a foreign loop def __init__(self) -> None: + # Each entry: (session, owning_loop, owner_task, close_event). self._entries: OrderedDict[ tuple[str, str], - tuple[ClientSession, asyncio.AbstractEventLoop], + tuple[ + ClientSession, + asyncio.AbstractEventLoop, + asyncio.Task[Any], + asyncio.Event, + ], ] = OrderedDict() - self._context_managers: dict[tuple[str, str], Any] = {} + # In-flight creations, keyed by (server, scope). Lets concurrent callers + # on the same loop share a single creation instead of each spawning a + # duplicate session. Value: (loop, ready_future, owner_task, close_event). + self._inflight: dict[ + tuple[str, str], + tuple[ + asyncio.AbstractEventLoop, + asyncio.Future[ClientSession], + asyncio.Task[Any], + asyncio.Event, + ], + ] = {} # threading.Lock is not bound to any event loop, so it is safe to # acquire from both async paths and sync/worker-thread paths. self._lock = threading.Lock() + # ------------------------------------------------------------------ + # Session owner task + # ------------------------------------------------------------------ + + async def _run_session( + self, + connection: dict[str, Any], + ready: asyncio.Future[ClientSession], + close_evt: asyncio.Event, + ) -> None: + """Own a single MCP session for its entire lifetime. + + Enters the session context manager, initializes it, publishes the live + session via ``ready``, then blocks until ``close_evt`` is set. The + context manager is *always* exited from this task, satisfying anyio's + cancel-scope same-task requirement. + """ + from langchain_mcp_adapters.sessions import create_session + + cm = create_session(connection) + try: + session = await cm.__aenter__() + except BaseException as e: + # Never entered the cancel scope, so there is nothing to exit. + if not ready.done(): + ready.set_exception(e) + return + + # The context manager is now entered. From here on __aexit__ MUST run in + # this task — on init failure, on cancellation, or on the close signal — + # to satisfy anyio's same-task cancel-scope requirement and to avoid + # leaking the session/subprocess. + try: + await session.initialize() + if not ready.done(): + ready.set_result(session) + await close_evt.wait() + except BaseException as e: + if not ready.done(): + ready.set_exception(e) + finally: + try: + await cm.__aexit__(None, None, None) + except Exception: + logger.warning("Error closing MCP session", exc_info=True) + async def get_session( self, server_name: str, @@ -47,9 +131,9 @@ class MCPSessionPool: ) -> ClientSession: """Get or create a persistent MCP session. - If an existing session was created in a different event loop (e.g. - the sync-wrapper path), it is closed and replaced with a fresh one - in the current loop. + If an existing session was created in a different (or closed) event + loop, it is evicted and replaced with a fresh one owned by a task on + the current loop. Args: server_name: MCP server name. @@ -63,44 +147,118 @@ class MCPSessionPool: current_loop = asyncio.get_running_loop() # Phase 1: inspect/mutate the registry under the thread lock (no awaits). - cms_to_close: list[tuple[tuple[str, str], Any]] = [] + # Decide one of three outcomes atomically: return an existing session, + # join an in-flight creation, or become the creator for this key. + # Each item: (loop, owner_task, close_event, cancel). ``cancel`` is True + # for in-flight creations, whose owner may be blocked inside + # ``initialize()`` where close_evt cannot wake it — it must be cancelled. + evicted: list[tuple[asyncio.AbstractEventLoop, asyncio.Task[Any], asyncio.Event, bool]] = [] + join: asyncio.Future[ClientSession] | None = None + ready: asyncio.Future[ClientSession] | None = None + close_evt: asyncio.Event | None = None + task: asyncio.Task[Any] | None = None with self._lock: if key in self._entries: - session, loop = self._entries[key] - if loop is current_loop: + session, loop, ent_task, ent_close = self._entries[key] + if loop is current_loop and not loop.is_closed(): self._entries.move_to_end(key) return session - # Session belongs to a different event loop – evict it. - cm = self._context_managers.pop(key, None) + # Session belongs to a different/closed event loop – evict it. self._entries.pop(key) - if cm is not None: - cms_to_close.append((key, cm)) + evicted.append((loop, ent_task, ent_close, False)) + + inflight = self._inflight.get(key) + if inflight is not None and inflight[0] is current_loop and not inflight[0].is_closed(): + # Another caller on this loop is already creating the session; + # wait for the same result instead of building a duplicate. + join = inflight[1] + else: + if inflight is not None: + # Stale in-flight creation owned by a different/closed loop. + # Drop the record and tear its owner down; because that owner + # may be blocked inside initialize() (where close_evt cannot + # wake it), it must be cancelled. We then create a fresh + # session here. + self._inflight.pop(key) + evicted.append((inflight[0], inflight[2], inflight[3], True)) + # Become the creator: publish an in-flight record before any + # await so concurrent callers join us instead of racing. + ready = current_loop.create_future() + close_evt = asyncio.Event() + task = current_loop.create_task(self._run_session(connection, ready, close_evt)) + self._inflight[key] = (current_loop, ready, task, close_evt) # Evict LRU entries when at capacity. while len(self._entries) >= self.MAX_SESSIONS: - oldest_key = next(iter(self._entries)) - cm = self._context_managers.pop(oldest_key, None) + oldest_key, (_, loop, ent_task, ent_close) = next(iter(self._entries.items())) self._entries.pop(oldest_key) - if cm is not None: - cms_to_close.append((oldest_key, cm)) + evicted.append((loop, ent_task, ent_close, False)) - # Phase 2: async cleanup outside the lock so we never await while holding it. - for close_key, cm in cms_to_close: + # Phase 2: shut down evicted sessions/creations. Same-loop owners are + # awaited so they finish deterministically; foreign-loop owners are + # routed to their own loop. In every case the owner task — never this + # one — runs __aexit__. In-flight owners are cancelled (cancel=True) so a + # blocking initialize() cannot leave them hung. + for loop, ent_task, ent_close, cancel in evicted: + if loop is current_loop and not loop.is_closed(): + await self._shutdown(ent_close, ent_task, cancel) + elif cancel: + await self._shutdown_entry(loop, ent_task, ent_close, cancel=True) + else: + self._signal_close(loop, ent_close) + + # Phase 2b: a concurrent creation for this key is already in progress on + # this loop — share its result rather than create a duplicate session. + if join is not None: + return await asyncio.shield(join) + + assert ready is not None and close_evt is not None and task is not None + + # Phase 3: wait for our owner task to publish the initialized session. + try: + session = await asyncio.shield(ready) + except BaseException: + # Two distinct cases reach here: + # + # 1. The owner task failed (e.g. connect/initialize error) and + # reported it via ready.set_exception(). It is *already* in its + # finally block running cm.__aexit__ in its own task, so we must + # NOT cancel it — doing so would interrupt that cleanup. We only + # wait for it to finish unwinding. + # 2. This call itself was cancelled (CancelledError). Because of the + # shield, `ready` is still pending and the owner task is alive and + # blocked. We signal close and cancel it so it exits the cancel + # scope in its own task, then wait for it to finish. + # + # The session is never registered yet, so nobody else can close it; + # waiting here guarantees we never leak a session or owner task. + owner_already_failed = ready.done() and not ready.cancelled() and ready.exception() is not None + if not owner_already_failed: + close_evt.set() + task.cancel() try: - await cm.__aexit__(None, None, None) - except Exception: - logger.warning("Error closing MCP session %s", close_key, exc_info=True) + await asyncio.shield(task) + except BaseException: + logger.debug("Owner task ended during get_session unwind", exc_info=True) + with self._lock: + if self._inflight.get(key) == (current_loop, ready, task, close_evt): + self._inflight.pop(key) + raise - from langchain_mcp_adapters.sessions import create_session - - cm = create_session(connection) - session = await cm.__aenter__() - await session.initialize() - - # Phase 3: register the new session under the lock. + # Phase 4: promote the in-flight creation to a registered entry — but + # only if our in-flight record is still the live one. A concurrent + # close_* / close_all may have removed it while we were initializing; in + # that case we must NOT resurrect the session into _entries. Instead we + # own the teardown: signal our owner task and wait for it to run + # __aexit__ in its own task, then surface the cancellation. with self._lock: - self._entries[key] = (session, current_loop) - self._context_managers[key] = cm + still_ours = self._inflight.get(key) == (current_loop, ready, task, close_evt) + if still_ours: + self._inflight.pop(key) + self._entries[key] = (session, current_loop, task, close_evt) + if not still_ours: + await self._shutdown(close_evt, task) + raise asyncio.CancelledError("MCP session pool was closed while the session was being created") logger.info("Created persistent MCP session for %s/%s", server_name, scope_key) return session @@ -108,70 +266,169 @@ class MCPSessionPool: # Cleanup helpers # ------------------------------------------------------------------ - async def _close_cm(self, key: tuple[str, str], cm: Any) -> None: - """Close a single context manager (must be called WITHOUT the lock).""" + @staticmethod + def _signal_close(loop: asyncio.AbstractEventLoop, close_evt: asyncio.Event) -> None: + """Ask an owner task to shut down without waiting. + + ``asyncio.Event.set`` is not thread-safe, so it is scheduled on the + owning loop. A closed loop means the owner task is already gone. + """ + if loop.is_closed(): + return try: - await cm.__aexit__(None, None, None) - except Exception: - logger.warning("Error closing MCP session %s", key, exc_info=True) + loop.call_soon_threadsafe(close_evt.set) + except RuntimeError: + # Loop was closed between the is_closed() check and now. + pass + + async def _shutdown( + self, + close_evt: asyncio.Event, + task: asyncio.Task[Any], + cancel: bool = False, + ) -> None: + """Signal an owner task and wait for it to finish (runs on its loop). + + ``cancel=True`` is used for in-flight creations: the owner task may be + blocked inside ``initialize()`` where ``close_evt`` cannot wake it, so it + must be cancelled. Its ``finally`` block still runs ``__aexit__`` in its + own task, satisfying anyio's same-task cancel-scope requirement. + """ + close_evt.set() + if cancel: + task.cancel() + try: + await task + except (Exception, asyncio.CancelledError): + logger.debug("Owner task ended during shutdown", exc_info=True) + + async def _shutdown_entry( + self, + loop: asyncio.AbstractEventLoop, + task: asyncio.Task[Any], + close_evt: asyncio.Event, + cancel: bool = False, + ) -> None: + """Shut down one entry, routing the close to its owning loop.""" + if loop.is_closed(): + return + current_loop = asyncio.get_running_loop() + if loop is current_loop: + await self._shutdown(close_evt, task, cancel) + elif loop.is_running(): + future = asyncio.run_coroutine_threadsafe(self._shutdown(close_evt, task, cancel), loop) + try: + await asyncio.wrap_future(future) + except Exception: + logger.warning("Error closing MCP session on owning loop", exc_info=True) + else: + # Owning loop exists but is neither the current loop nor running. + # We are inside an async context here, so run_until_complete() would + # raise "Cannot run the event loop while another loop is running"; + # and the loop may belong to another thread, where driving it from + # here is unsafe. This branch is not expected in practice — a + # session's owning loop is either the long-lived gateway loop (which + # is running) or a short-lived asyncio.run loop (which is closed and + # caught above). Fall back to a best-effort thread-safe signal so the + # owner task tears down if/when its loop runs again. + logger.warning("Owning loop for MCP session is idle; signalling close best-effort. Session may leak until the loop runs again.") + self._signal_close(loop, close_evt) + if cancel: + try: + loop.call_soon_threadsafe(task.cancel) + except RuntimeError: + pass async def close_scope(self, scope_key: str) -> None: """Close all sessions for a given scope (e.g. thread_id).""" with self._lock: keys = [k for k in self._entries if k[1] == scope_key] - cms = [(k, self._context_managers.pop(k, None)) for k in keys] - for k in keys: - self._entries.pop(k, None) - for key, cm in cms: - if cm is not None: - await self._close_cm(key, cm) + entries = [(self._entries.pop(k)) for k in keys] + inflight_keys = [k for k in self._inflight if k[1] == scope_key] + inflight = [self._inflight.pop(k) for k in inflight_keys] + for _session, loop, task, close_evt in entries: + await self._shutdown_entry(loop, task, close_evt) + for loop, _ready, task, close_evt in inflight: + await self._shutdown_entry(loop, task, close_evt, cancel=True) async def close_server(self, server_name: str) -> None: """Close all sessions for a given server.""" with self._lock: keys = [k for k in self._entries if k[0] == server_name] - cms = [(k, self._context_managers.pop(k, None)) for k in keys] - for k in keys: - self._entries.pop(k, None) - for key, cm in cms: - if cm is not None: - await self._close_cm(key, cm) + entries = [(self._entries.pop(k)) for k in keys] + inflight_keys = [k for k in self._inflight if k[0] == server_name] + inflight = [self._inflight.pop(k) for k in inflight_keys] + for _session, loop, task, close_evt in entries: + await self._shutdown_entry(loop, task, close_evt) + for loop, _ready, task, close_evt in inflight: + await self._shutdown_entry(loop, task, close_evt, cancel=True) async def close_all(self) -> None: """Close every managed session.""" with self._lock: - cms = list(self._context_managers.items()) - self._context_managers.clear() + entries = list(self._entries.values()) self._entries.clear() - for key, cm in cms: - await self._close_cm(key, cm) + inflight = list(self._inflight.values()) + self._inflight.clear() + for _session, loop, task, close_evt in entries: + await self._shutdown_entry(loop, task, close_evt) + for loop, _ready, task, close_evt in inflight: + await self._shutdown_entry(loop, task, close_evt, cancel=True) def close_all_sync(self) -> None: - """Close all sessions using their owning event loops (synchronous). + """Close all sessions on their owning event loops (synchronous). - Each session is closed on the loop it was created in, avoiding - cross-loop resource leaks. Safe to call from any thread without an - active event loop. + Each session is closed by its owner task on the loop it was created in, + avoiding cross-loop and cross-task errors. Safe to call from any thread + without an active event loop. + + Closing semantics differ by where the owning loop runs: + + * Owning loop is idle, or running on another thread — this call blocks + until teardown completes (or ``SESSION_CLOSE_TIMEOUT`` elapses). + * Owning loop is the one currently running on *this* thread — we cannot + block on it without deadlocking, so teardown is only *signalled* here + and completes asynchronously once control returns to that loop. The + caller must therefore keep that loop running afterwards; if it stops + the loop immediately, the owner task's ``__aexit__`` may not run. When + a deterministic close is required from inside a running loop, ``await + close_all()`` instead. """ with self._lock: - entries = list(self._entries.items()) - cms = dict(self._context_managers) + entries = list(self._entries.values()) self._entries.clear() - self._context_managers.clear() + inflight = list(self._inflight.values()) + self._inflight.clear() - for key, (_, loop) in entries: - cm = cms.get(key) - if cm is None or loop.is_closed(): + # Entries are initialized (gentle close_evt path). In-flight creations + # may be blocked mid-init, so they are cancelled to unblock teardown. + owners = [(loop, task, close_evt, False) for _s, loop, task, close_evt in entries] + owners += [(loop, task, close_evt, True) for loop, _r, task, close_evt in inflight] + try: + current_running_loop = asyncio.get_running_loop() + except RuntimeError: + current_running_loop = None + for loop, task, close_evt, cancel in owners: + if loop.is_closed(): continue try: - if loop.is_running(): - # Schedule on the owning loop from this (different) thread. - future = asyncio.run_coroutine_threadsafe(cm.__aexit__(None, None, None), loop) + if loop is current_running_loop: + # We are executing inside this loop's thread, so synchronously + # waiting on run_coroutine_threadsafe(...).result() would + # deadlock until timeout. Signal the owner task directly and + # let it finish once this synchronous call returns control to + # the running loop. + close_evt.set() + if cancel: + task.cancel() + elif loop.is_running(): + # Schedule the shutdown on the owning loop from this thread. + future = asyncio.run_coroutine_threadsafe(self._shutdown(close_evt, task, cancel), loop) future.result(timeout=self.SESSION_CLOSE_TIMEOUT) else: - loop.run_until_complete(cm.__aexit__(None, None, None)) + loop.run_until_complete(self._shutdown(close_evt, task, cancel)) except Exception: - logger.debug("Error closing MCP session %s during sync close", key, exc_info=True) + logger.debug("Error closing MCP session during sync close", exc_info=True) # ------------------------------------------------------------------ diff --git a/backend/tests/test_mcp_session_pool.py b/backend/tests/test_mcp_session_pool.py index 40d02e61c..852cfd861 100644 --- a/backend/tests/test_mcp_session_pool.py +++ b/backend/tests/test_mcp_session_pool.py @@ -1,5 +1,7 @@ """Tests for the MCP persistent-session pool.""" +import asyncio +import threading from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -614,3 +616,585 @@ async def test_http_transport_tools_not_pooled(): stdio_tools = [t for t in tools if t.name == "playwright_navigate"] assert len(stdio_tools) == 1 assert stdio_tools[0].coroutine is not stdio_tool.coroutine + + +# --------------------------------------------------------------------------- +# Regression for #3379: cancel scope must be exited in the entering task +# --------------------------------------------------------------------------- + + +class _CancelScopeCm: + """Fake session context manager that mimics anyio's cancel-scope rule. + + ``ClientSession`` is built on an anyio task group, which requires the cancel + scope to be exited from the *same asyncio task* that entered it. This fake + records the task that runs ``__aenter__`` and raises the exact RuntimeError + anyio would raise if ``__aexit__`` runs in a different task — reproducing the + crash reported in GitHub issue #3379. + """ + + def __init__(self) -> None: + self.enter_task: object | None = None + self.closed = False + + async def __aenter__(self): + self.enter_task = asyncio.current_task() + return AsyncMock() + + async def __aexit__(self, *args): + if asyncio.current_task() is not self.enter_task: + raise RuntimeError("Attempted to exit cancel scope in a different task than it was entered in") + self.closed = True + return False + + +async def _get_session_in_own_task(pool, *args): + """Create a pooled session from a *dedicated* child task. + + In production every stdio session is entered from its own short-lived task + (the sync-tool path runs each call through a fresh ``asyncio.run``). This + helper reproduces that so the close paths are exercised from a *different* + task than the one that entered the session — the exact condition that + triggered #3379. + """ + return await asyncio.create_task(pool.get_session(*args)) + + +@pytest.mark.asyncio +async def test_close_all_does_not_cross_tasks(): + """close_all must not raise the cross-task cancel-scope RuntimeError (#3379).""" + pool = MCPSessionPool() + cms: list[_CancelScopeCm] = [] + + def make_cm(*a, **kw): + cm = _CancelScopeCm() + cms.append(cm) + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + await _get_session_in_own_task(pool, "s1", "t1", {"transport": "stdio", "command": "x", "args": []}) + await _get_session_in_own_task(pool, "s2", "t2", {"transport": "stdio", "command": "x", "args": []}) + + # close_all runs in this task, which is *not* the task that entered either + # session. The owner task must perform __aexit__ so each CM closes cleanly. + await pool.close_all() + + assert all(cm.closed for cm in cms) + assert len(pool._entries) == 0 + + +@pytest.mark.asyncio +async def test_close_scope_does_not_cross_tasks(): + """close_scope must respect the same-task cancel-scope rule (#3379).""" + pool = MCPSessionPool() + cms: list[_CancelScopeCm] = [] + + def make_cm(*a, **kw): + cm = _CancelScopeCm() + cms.append(cm) + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + await _get_session_in_own_task(pool, "s", "t1", {"transport": "stdio", "command": "x", "args": []}) + await _get_session_in_own_task(pool, "s", "t2", {"transport": "stdio", "command": "x", "args": []}) + + await pool.close_scope("t1") + + assert cms[0].closed is True + assert cms[1].closed is False + assert ("s", "t2") in pool._entries + + +@pytest.mark.asyncio +async def test_lru_eviction_does_not_cross_tasks(): + """LRU eviction must close the victim without a cross-task RuntimeError (#3379).""" + pool = MCPSessionPool() + pool.MAX_SESSIONS = 2 + cms: list[_CancelScopeCm] = [] + + def make_cm(*a, **kw): + cm = _CancelScopeCm() + cms.append(cm) + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + await _get_session_in_own_task(pool, "s", "t1", {"transport": "stdio", "command": "x", "args": []}) + await _get_session_in_own_task(pool, "s", "t2", {"transport": "stdio", "command": "x", "args": []}) + # Adding t3 evicts t1 — its own owner task must run __aexit__, even + # though the eviction is driven from t3's get_session call. + await _get_session_in_own_task(pool, "s", "t3", {"transport": "stdio", "command": "x", "args": []}) + + assert cms[0].closed is True + assert cms[1].closed is False + assert cms[2].closed is False + + +def test_close_all_sync_across_loops_does_not_cross_tasks(): + """close_all_sync, the path hit by the sync tool wrapper, must close sessions + created in earlier (now-finished) asyncio.run loops without crashing (#3379). + """ + pool = MCPSessionPool() + cms: list[_CancelScopeCm] = [] + + def make_cm(*a, **kw): + cm = _CancelScopeCm() + cms.append(cm) + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + # Simulate the sync-tool path: a session created inside one short-lived + # event loop, then a second one in a different loop. + asyncio.run(pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []})) + asyncio.run(pool.get_session("s", "t2", {"transport": "stdio", "command": "x", "args": []})) + + # The owning loops are already closed; close_all_sync must not raise. + pool.close_all_sync() + + assert len(pool._entries) == 0 + + +def test_get_session_replaces_session_from_closed_loop(): + """A pooled session whose owning loop has closed is evicted and recreated.""" + pool = MCPSessionPool() + cms: list[_CancelScopeCm] = [] + + def make_cm(*a, **kw): + cm = _CancelScopeCm() + cms.append(cm) + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + # First session created in a throwaway loop that is torn down by + # asyncio.run (mirrors the sync-tool path). asyncio.run cancels the + # pending owner task and runs its __aexit__ on the same loop. + asyncio.run(pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []})) + assert ("s", "t1") in pool._entries + + # Now request the same key from a fresh loop: the stale entry (closed + # loop) must be evicted and replaced with a fresh session. + session = asyncio.run(pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []})) + + assert session is not None + assert len(cms) == 2 + assert pool._entries[("s", "t1")][0] is session + + +class _BlockingInitCm: + """Fake session CM whose ``initialize`` blocks until released. + + Lets a test cancel ``get_session`` while the owner task is still + initializing, reproducing the caller-cancellation window. + """ + + def __init__(self, gate: asyncio.Event) -> None: + self._gate = gate + self.entered = False + self.closed = False + + async def __aenter__(self): + self.entered = True + session = MagicMock() + session.initialize = self._initialize + return session + + async def _initialize(self): + await self._gate.wait() + + async def __aexit__(self, *args): + self.closed = True + return False + + +@pytest.mark.asyncio +async def test_get_session_cancelled_while_initializing_does_not_leak(): + """Cancelling get_session mid-init must not leak the owner task/session (#3379 CR). + + The session is not registered yet, so if cancellation skipped the cleanup + the owner task would block forever on close_evt.wait() and the CM's + __aexit__ would never run — an unreachable, unclosable session. + """ + pool = MCPSessionPool() + gate = asyncio.Event() + cms: list[_BlockingInitCm] = [] + + def make_cm(*a, **kw): + cm = _BlockingInitCm(gate) + cms.append(cm) + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + call = asyncio.create_task(pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []})) + # Let the owner task enter the CM and reach the blocking initialize(). + await asyncio.sleep(0.01) + call.cancel() + with pytest.raises(asyncio.CancelledError): + await call + + # Release initialize() so the owner task can finish its shutdown path. + gate.set() + # Give the owner task a chance to run __aexit__ and complete. + for _ in range(10): + if cms and cms[0].closed: + break + await asyncio.sleep(0.01) + + assert len(cms) == 1 + assert cms[0].entered is True + assert cms[0].closed is True, "owner task must run __aexit__ after cancellation" + assert len(pool._entries) == 0 + + current = asyncio.current_task() + leaked = [t for t in asyncio.all_tasks() if t is not current and not t.done() and "_run_session" in str(t.get_coro())] + assert not leaked, "owner task must not be left pending after cancellation" + + +class _InitFailCm: + """Fake session CM whose ``initialize`` fails, with a slow ``__aexit__``. + + The slow __aexit__ lets a test observe whether cleanup is allowed to run to + completion (closed=True) or is interrupted by a stray cancellation. + """ + + def __init__(self) -> None: + self.entered = False + self.exit_started = False + self.closed = False + + async def __aenter__(self): + self.entered = True + session = MagicMock() + session.initialize = self._initialize + return session + + async def _initialize(self): + raise RuntimeError("init boom") + + async def __aexit__(self, *args): + self.exit_started = True + # Yield control so a buggy double-cancel would interrupt us here. + await asyncio.sleep(0.02) + self.closed = True + return False + + +@pytest.mark.asyncio +async def test_get_session_init_failure_runs_full_cleanup(): + """On initialize() failure the owner task's __aexit__ must complete (#3379 CR P1). + + The caller must NOT cancel the owner task on a reported failure, otherwise + the in-progress __aexit__ cleanup gets interrupted and leaks resources. + """ + pool = MCPSessionPool() + cms: list[_InitFailCm] = [] + + def make_cm(*a, **kw): + cm = _InitFailCm() + cms.append(cm) + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + with pytest.raises(RuntimeError, match="init boom"): + await pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []}) + + assert len(cms) == 1 + assert cms[0].entered is True + assert cms[0].exit_started is True + assert cms[0].closed is True, "__aexit__ must run to completion, not be interrupted" + assert len(pool._entries) == 0 + assert len(pool._inflight) == 0 + + +@pytest.mark.asyncio +async def test_concurrent_get_session_same_key_creates_single_session(): + """Concurrent get_session for the same key must share one session (#3379 CR P1).""" + pool = MCPSessionPool() + gate = asyncio.Event() + cms: list[_BlockingInitCm] = [] + + def make_cm(*a, **kw): + cm = _BlockingInitCm(gate) + cms.append(cm) + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + conn = {"transport": "stdio", "command": "x", "args": []} + t1 = asyncio.create_task(pool.get_session("s", "same", conn)) + t2 = asyncio.create_task(pool.get_session("s", "same", conn)) + # Let both calls pass Phase 1 and reach the (gated) initialize(). + await asyncio.sleep(0.02) + gate.set() + s1, s2 = await asyncio.gather(t1, t2) + + # Only one CM/session created, both callers got the same object. + assert len(cms) == 1, "concurrent same-key calls must not create duplicate sessions" + assert s1 is s2 + assert len(pool._entries) == 1 + assert len(pool._inflight) == 0 + + +@pytest.mark.asyncio +async def test_close_all_during_in_flight_creation_does_not_resurrect_session(): + """close_all while a creation is in-flight must not leave a live session (#3379 CR P1). + + The in-flight record must be removed and its owner task torn down, so when + the (blocked) creator finishes initializing it does NOT register the session + back into _entries — otherwise the pool resurrects an unclosable session. + """ + pool = MCPSessionPool() + gate = asyncio.Event() + cms: list[_BlockingInitCm] = [] + + def make_cm(*a, **kw): + cm = _BlockingInitCm(gate) + cms.append(cm) + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + conn = {"transport": "stdio", "command": "x", "args": []} + call = asyncio.create_task(pool.get_session("s", "t1", conn)) + # Let the owner task enter the CM and reach the blocking initialize(). + await asyncio.sleep(0.01) + assert ("s", "t1") in pool._inflight + + # Close everything while the creation is still in-flight. + await pool.close_all() + + # The in-flight creation must be gone, not promoted to an entry. + assert len(pool._inflight) == 0 + assert len(pool._entries) == 0 + + # Even if the gate is released afterwards, nothing must come back. + gate.set() + with pytest.raises(asyncio.CancelledError): + await call + + assert len(pool._entries) == 0 + assert len(pool._inflight) == 0 + assert cms[0].closed is True, "in-flight session's __aexit__ must run on teardown" + + current = asyncio.current_task() + leaked = [t for t in asyncio.all_tasks() if t is not current and not t.done() and "_run_session" in str(t.get_coro())] + assert not leaked, "in-flight owner task must not leak after close_all" + + +def test_get_session_cross_loop_in_flight_does_not_raise_assertion(): + """A same-key request from another loop must not hit the in-flight assertion (#3379 CR P1). + + Loop A starts (and leaves running) an in-flight creation, then loop B + requests the same key. The stale in-flight record (owned by loop A) must be + dropped and loop B must become a fresh creator — never fall through to an + AssertionError. + """ + pool = MCPSessionPool() + cms: list[_CancelScopeCm] = [] + + def make_cm(*a, **kw): + cm = _CancelScopeCm() + cms.append(cm) + return cm + + conn = {"transport": "stdio", "command": "x", "args": []} + results: list[object] = [] + errors: list[BaseException] = [] + + def run_in_own_loop(): + try: + results.append(asyncio.run(pool.get_session("s", "t1", conn))) + except BaseException as e: # noqa: BLE001 - capture for assertion + errors.append(e) + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + # First loop creates and registers an entry, then its loop is torn down + # by asyncio.run, leaving a stale (closed-loop) record behind. + t1 = threading.Thread(target=run_in_own_loop) + t1.start() + t1.join() + + # Second loop requests the same key. It must evict the stale record and + # create a fresh session instead of raising AssertionError. + t2 = threading.Thread(target=run_in_own_loop) + t2.start() + t2.join() + + assert not errors, f"cross-loop same-key request must not raise: {errors}" + assert len(results) == 2 + assert all(r is not None for r in results) + + +def test_cross_loop_preempting_blocked_in_flight_does_not_hang_owner(): + """A foreign-loop request must not leave a still-initializing owner hung (#3379 CR P1). + + Loop A starts a creation that blocks inside initialize() (the in-flight + record stays live). Loop B then requests the same key. B must tear A's owner + down — cancelling it, because close_evt alone cannot wake a task blocked in + initialize() — so that A's get_session unwinds instead of hanging forever. + """ + pool = MCPSessionPool() + conn = {"transport": "stdio", "command": "x", "args": []} + first_gate = threading.Event() + entered = threading.Event() + results: list[tuple[str, object]] = [] + errors: list[tuple[str, BaseException]] = [] + closed: list[str] = [] + + class _BlockingForeverCm: + async def __aenter__(self): + session = MagicMock() + session.initialize = self._initialize + entered.set() + return session + + async def _initialize(self): + # Block until released, simulating a slow/stuck server handshake. + while not first_gate.is_set(): + await asyncio.sleep(0.005) + + async def __aexit__(self, *args): + closed.append("blocking") + return False + + class _FastCm: + async def __aenter__(self): + session = MagicMock() + + async def init(): + return None + + session.initialize = init + return session + + async def __aexit__(self, *args): + return False + + cms: list[object] = [_BlockingForeverCm(), _FastCm()] + + def make_cm(*a, **kw): + return cms.pop(0) + + def run_get(name): + try: + results.append((name, asyncio.run(pool.get_session("s", "t1", conn)))) + except BaseException as e: # noqa: BLE001 - capture for assertion + errors.append((name, e)) + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + ta = threading.Thread(target=run_get, args=("A",)) + ta.start() + assert entered.wait(2), "owner A must enter the CM and start initializing" + + tb = threading.Thread(target=run_get, args=("B",)) + tb.start() + tb.join(3) + + # B must complete without depending on A's blocked initialize(). + assert not tb.is_alive(), "foreign-loop request B must not hang" + # A must already be unwound (cancelled), not waiting on the dead gate. + ta.join(3) + assert not ta.is_alive(), "preempted owner A must not hang forever" + + assert [n for n, _ in results] == ["B"], "only B produces a usable session" + assert any(isinstance(e, asyncio.CancelledError) for _, e in errors), "preempted A must unwind via CancelledError" + assert "blocking" in closed, "preempted owner's __aexit__ must run on teardown" + + +@pytest.mark.asyncio +async def test_close_all_sync_from_running_loop_does_not_wait_on_itself(): + """close_all_sync must not block on the current running loop (#3379 CR P1). + + When called from code already executing inside the owner loop's thread, + close_all_sync cannot synchronously wait for that loop to run the shutdown + coroutine. It must signal the owner task and return promptly, then the owner + task closes itself once the loop regains control. + """ + pool = MCPSessionPool() + pool.SESSION_CLOSE_TIMEOUT = 0.2 + conn = {"transport": "stdio", "command": "x", "args": []} + + cm = _CloseTrackingCm() + + def make_cm(*a, **kw): + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + await pool.get_session("s", "t1", conn) + start = asyncio.get_running_loop().time() + pool.close_all_sync() + elapsed = asyncio.get_running_loop().time() - start + + assert elapsed < 0.1, "close_all_sync must not stall until timeout on the current loop" + assert len(pool._entries) == 0 + assert len(pool._inflight) == 0 + assert cm.closed is False, "owner task has not run yet while close_all_sync is still executing" + + for _ in range(10): + if cm.closed: + break + await asyncio.sleep(0.01) + + assert cm.closed is True, "owner task must close itself after the loop regains control" + + +# --------------------------------------------------------------------------- +# reset_mcp_tools_cache deadlock regression +# --------------------------------------------------------------------------- + + +class _CloseTrackingCm: + """A create_session() context manager that records when __aexit__ runs.""" + + def __init__(self) -> None: + self.closed = False + + async def __aenter__(self): + session = MagicMock() + + async def init(): + return None + + session.initialize = init + return session + + async def __aexit__(self, *args): + self.closed = True + return False + + +def test_reset_mcp_tools_cache_from_running_loop_is_bounded(): + """reset_mcp_tools_cache() must not deadlock when called from inside a + running loop that owns sessions (#3392 CR blocker). + + The previous implementation spun up a worker thread running + ``asyncio.run(pool.close_all())`` and blocked the loop thread on + ``.result()``. close_all() then routed teardown of the current loop's + sessions back onto that blocked loop via run_coroutine_threadsafe(...), + so neither side could make progress. This test drives the exact scenario + on a daemon thread and asserts the call returns within a bounded time. + """ + from deerflow.mcp.cache import reset_mcp_tools_cache + from deerflow.mcp.session_pool import get_session_pool + + conn = {"transport": "stdio", "command": "x", "args": []} + cm = _CloseTrackingCm() + done = threading.Event() + + async def scenario(): + pool = get_session_pool() + # Entry owned by THIS loop — the deadlock-prone case. + await pool.get_session("s", "t1", conn) + # Synchronous call: asyncio.get_running_loop() succeeds inside it, so + # it takes the "running loop" branch in reset_mcp_tools_cache(). + reset_mcp_tools_cache() + # Signal-only teardown completes once the loop regains control. + await asyncio.sleep(0.05) + + def run(): + asyncio.run(scenario()) + done.set() + + t = threading.Thread(target=run, daemon=True) + with patch("langchain_mcp_adapters.sessions.create_session", return_value=cm): + t.start() + t.join(timeout=5) + + assert done.is_set(), "reset_mcp_tools_cache() deadlocked inside a running loop" + assert cm.closed is True, "owner task must run __aexit__ once the loop regains control"