mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-09 17:12:01 +00:00
Merge branch 'main' into copilot/fix-lint-frontend-job
This commit is contained in:
commit
0fdfbae435
@ -223,17 +223,9 @@ Setup: Copy `config.example.yaml` to `config.yaml` in the **project root** direc
|
||||
|
||||
**Config Caching**: `get_app_config()` caches the parsed config, but automatically reloads it when the resolved config path changes or the file's mtime increases. This keeps Gateway and LangGraph reads aligned with `config.yaml` edits without requiring a manual process restart.
|
||||
|
||||
**Config Hot-Reload Boundary**: Gateway dependencies route through `get_app_config()` on every request, so per-run fields like `models[*].max_tokens`, `summarization.*`, `title.*`, `memory.*`, `subagents.*`, `tools[*]`, and the agent system prompt pick up `config.yaml` edits on the next message. `AppConfig` is intentionally **not** cached on `app.state` — `lifespan()` keeps a local `startup_config` variable for one-shot bootstrap work (logging level, channels, `langgraph_runtime` engines) and passes it explicitly to `langgraph_runtime(app, startup_config)`. Infrastructure fields are **restart-required**:
|
||||
**Config Hot-Reload Boundary**: Gateway dependencies route through `get_app_config()` on every request, so per-run fields like `models[*].max_tokens`, `summarization.*`, `title.*`, `memory.*`, `subagents.*`, `tools[*]`, and the agent system prompt pick up `config.yaml` edits on the next message. `AppConfig` is intentionally **not** cached on `app.state` — `lifespan()` keeps a local `startup_config` variable for one-shot bootstrap work and passes it to `langgraph_runtime(app, startup_config)`.
|
||||
|
||||
| Field | Why a restart is required |
|
||||
|---|---|
|
||||
| `database.*` | `init_engine_from_config()` runs once during `langgraph_runtime()` startup; the SQLAlchemy engine holds the connection pool. |
|
||||
| `checkpointer.*` (including SQLite WAL/journal settings) | `make_checkpointer()` binds the persistent checkpointer once at startup. |
|
||||
| `run_events.*` | `make_run_event_store()` selects memory- vs. SQL-backed implementation at startup. |
|
||||
| `stream_bridge.*` | `make_stream_bridge()` constructs the bridge object once. |
|
||||
| `sandbox.use` | `get_sandbox_provider()` caches the provider singleton (`_default_sandbox_provider`); a new class path takes effect only on next process start. |
|
||||
| `log_level` | `apply_logging_level()` is called only in `app.py` startup; it mutates the root logger's level, and `get_app_config()` returning a fresh `AppConfig` does not retrigger it. |
|
||||
| `channels.*` IM platform credentials | `start_channel_service()` is invoked once during startup; live channels are not rebuilt on config change. |
|
||||
Infrastructure fields are **restart-required**. The authoritative list lives in `packages/harness/deerflow/config/reload_boundary.py::STARTUP_ONLY_FIELDS` and is mirrored by the standardised `"startup-only:"` prefix on the corresponding `Field(description=...)` in `AppConfig`, so IDE hover on those fields surfaces the reason inline (no need to context-switch into this table). Currently registered: `database`, `checkpointer`, `run_events`, `stream_bridge`, `sandbox`, `log_level`, `channels`. Adding a new restart-required field requires updating the registry; drift is pinned by `tests/test_reload_boundary.py`.
|
||||
|
||||
Configuration priority:
|
||||
1. Explicit `config_path` argument
|
||||
|
||||
@ -119,6 +119,16 @@ def get_config() -> AppConfig:
|
||||
split-brain where the worker / lead-agent thread saw a stale startup
|
||||
snapshot.
|
||||
|
||||
Hot-reload boundary: fields backed by startup-time singletons
|
||||
(engines, sandbox provider, IM channels, logging handler) require a
|
||||
process restart to change at runtime. The authoritative list lives in
|
||||
:mod:`deerflow.config.reload_boundary` and is mirrored by the
|
||||
standardised ``"startup-only:"`` prefix on the matching
|
||||
``Field(description=...)`` in :class:`AppConfig` — IDE hover on those
|
||||
fields will surface the boundary inline. See
|
||||
``backend/CLAUDE.md`` "Config Hot-Reload Boundary" for the operator
|
||||
summary.
|
||||
|
||||
Any failure to materialise the config (missing file, permission denied,
|
||||
YAML parse error, validation error) is reported as 503 — semantically
|
||||
"the gateway cannot serve requests without a usable configuration" — and
|
||||
|
||||
@ -18,6 +18,7 @@ from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_
|
||||
from deerflow.config.loop_detection_config import LoopDetectionConfig
|
||||
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
|
||||
from deerflow.config.model_config import ModelConfig
|
||||
from deerflow.config.reload_boundary import format_field_description
|
||||
from deerflow.config.run_events_config import RunEventsConfig
|
||||
from deerflow.config.runtime_paths import existing_project_file
|
||||
from deerflow.config.safety_finish_reason_config import SafetyFinishReasonConfig
|
||||
@ -85,10 +86,21 @@ def apply_logging_level(name: str | None) -> None:
|
||||
class AppConfig(BaseModel):
|
||||
"""Config for the DeerFlow application"""
|
||||
|
||||
log_level: str = Field(default="info", description="Logging level for deerflow and app modules (debug/info/warning/error); third-party libraries are not affected")
|
||||
log_level: str = Field(
|
||||
default="info",
|
||||
description=format_field_description(
|
||||
"log_level",
|
||||
field_doc="Logging level for deerflow and app modules (debug/info/warning/error); third-party libraries are not affected.",
|
||||
),
|
||||
)
|
||||
token_usage: TokenUsageConfig = Field(default_factory=TokenUsageConfig, description="Token usage tracking configuration")
|
||||
models: list[ModelConfig] = Field(default_factory=list, description="Available models")
|
||||
sandbox: SandboxConfig = Field(description="Sandbox configuration")
|
||||
sandbox: SandboxConfig = Field(
|
||||
description=format_field_description(
|
||||
"sandbox",
|
||||
field_doc="Sandbox provider configuration (local filesystem or Docker-based aio sandbox).",
|
||||
),
|
||||
)
|
||||
tools: list[ToolConfig] = Field(default_factory=list, description="Available tools")
|
||||
tool_groups: list[ToolGroupConfig] = Field(default_factory=list, description="Available tool groups")
|
||||
skills: SkillsConfig = Field(default_factory=SkillsConfig, description="Skills configuration")
|
||||
@ -107,10 +119,34 @@ class AppConfig(BaseModel):
|
||||
loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration")
|
||||
safety_finish_reason: SafetyFinishReasonConfig = Field(default_factory=SafetyFinishReasonConfig, description="Provider safety-filter finish_reason interception middleware configuration")
|
||||
model_config = ConfigDict(extra="allow")
|
||||
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
|
||||
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
|
||||
checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration")
|
||||
stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration")
|
||||
database: DatabaseConfig = Field(
|
||||
default_factory=DatabaseConfig,
|
||||
description=format_field_description(
|
||||
"database",
|
||||
field_doc="Unified database backend for run/feedback metadata (memory, sqlite, or postgres).",
|
||||
),
|
||||
)
|
||||
run_events: RunEventsConfig = Field(
|
||||
default_factory=RunEventsConfig,
|
||||
description=format_field_description(
|
||||
"run_events",
|
||||
field_doc="Run-event store backend (memory for dev, db for production queries, jsonl for lightweight single-node persistence).",
|
||||
),
|
||||
)
|
||||
checkpointer: CheckpointerConfig | None = Field(
|
||||
default=None,
|
||||
description=format_field_description(
|
||||
"checkpointer",
|
||||
field_doc="LangGraph state-persistence checkpointer configuration.",
|
||||
),
|
||||
)
|
||||
stream_bridge: StreamBridgeConfig | None = Field(
|
||||
default=None,
|
||||
description=format_field_description(
|
||||
"stream_bridge",
|
||||
field_doc="Stream bridge connecting agent workers to SSE endpoints.",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def resolve_config_path(cls, config_path: str | None = None) -> Path:
|
||||
|
||||
104
backend/packages/harness/deerflow/config/reload_boundary.py
Normal file
104
backend/packages/harness/deerflow/config/reload_boundary.py
Normal file
@ -0,0 +1,104 @@
|
||||
"""Single source of truth for the config hot-reload boundary.
|
||||
|
||||
Bytedance/deer-flow issue #3144: gateway request dependencies resolve
|
||||
``AppConfig`` through ``get_app_config()`` on every request, so per-run
|
||||
fields take effect on the next message without restarting the gateway.
|
||||
The fields listed in this module are the **infrastructure** subset that
|
||||
the gateway captures once at startup — engines, singletons, IM clients,
|
||||
the logging handler — and that therefore require a process restart to
|
||||
change at runtime.
|
||||
|
||||
The registry covers two kinds of entries:
|
||||
|
||||
- Top-level ``AppConfig`` fields (``database``, ``checkpointer``,
|
||||
``run_events``, ``stream_bridge``, ``sandbox``, ``log_level``). For
|
||||
these, :func:`format_field_description` produces the standardised
|
||||
``"startup-only: ..."`` prefix that the matching Pydantic
|
||||
``Field(description=...)`` carries, so the boundary surfaces in IDE
|
||||
hover next to the field itself.
|
||||
- Top-level ``config.yaml`` sections that are not part of the
|
||||
``AppConfig`` schema (``channels``). These cannot be standardised at
|
||||
the schema level, so the registry is their only canonical location.
|
||||
|
||||
Any future "needs restart" scanner — operator tooling, lint hooks, doc
|
||||
generators — should drive off this registry rather than re-parsing
|
||||
prose.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
|
||||
#: The standardised prefix every restart-required field description starts
|
||||
#: with. ``test_reload_boundary`` enforces both directions: registered
|
||||
#: fields must use this prefix in the schema, and any schema field using
|
||||
#: this prefix must be in the registry.
|
||||
STARTUP_ONLY_PREFIX = "startup-only:"
|
||||
|
||||
|
||||
#: Restart-required field paths mapped to the human-readable reason.
|
||||
#:
|
||||
#: The reason text is what surfaces in ``Field(description=...)``, so it
|
||||
#: must explain *what* code captures the snapshot — not just that the
|
||||
#: field is restart-required — so an operator changing the value knows
|
||||
#: which subsystem to restart.
|
||||
STARTUP_ONLY_FIELDS: dict[str, str] = {
|
||||
"database": ("init_engine_from_config() runs once during langgraph_runtime() startup; the SQLAlchemy engine holds the connection pool and is not rebuilt on config.yaml edits."),
|
||||
"checkpointer": ("make_checkpointer() binds the persistent checkpointer once at startup, including SQLite WAL / busy_timeout settings."),
|
||||
"run_events": ("make_run_event_store() picks the memory- vs SQL-backed implementation at startup and is frozen onto app.state.run_events_config to stay paired with the underlying event store."),
|
||||
"stream_bridge": ("make_stream_bridge() constructs the stream-bridge singleton once during startup."),
|
||||
"sandbox": ("get_sandbox_provider() caches the provider singleton (``_default_sandbox_provider``); a different ``sandbox.use`` class path only takes effect on next process start."),
|
||||
"log_level": (
|
||||
"apply_logging_level() runs only during app.py startup; it sets the deerflow/app logger levels and may lower root handler thresholds so configured messages can propagate. A freshly reloaded AppConfig does not retrigger it."
|
||||
),
|
||||
# Not part of the AppConfig Pydantic schema — channel credentials are
|
||||
# consumed directly by ``start_channel_service()`` once at lifespan
|
||||
# startup and the live channel clients are not rebuilt on
|
||||
# config.yaml edits.
|
||||
"channels": ("start_channel_service() is invoked once during startup; the live IM channel clients (Feishu, Slack, Telegram, DingTalk) are not rebuilt when channels.* changes."),
|
||||
}
|
||||
|
||||
|
||||
def iter_startup_only_field_paths() -> Iterator[str]:
|
||||
"""Yield every registered restart-required field path."""
|
||||
return iter(STARTUP_ONLY_FIELDS)
|
||||
|
||||
|
||||
def is_startup_only_field(field_path: str) -> bool:
|
||||
"""Return ``True`` when *field_path* is registered as restart-required.
|
||||
|
||||
Accepts only top-level paths (``"database"``, ``"sandbox"`` etc.);
|
||||
nested keys like ``"database.url"`` are not modelled here because the
|
||||
boundary is per-section, not per-leaf.
|
||||
"""
|
||||
return field_path in STARTUP_ONLY_FIELDS
|
||||
|
||||
|
||||
def format_field_description(field_path: str, *, field_doc: str | None = None) -> str:
|
||||
"""Build the standardised description for a registered field.
|
||||
|
||||
Used inside ``AppConfig`` ``Field(description=...)`` so the hover
|
||||
text in IDEs matches the registry and the drift tests can pin one
|
||||
side against the other.
|
||||
|
||||
Args:
|
||||
field_path: A registered top-level field path (e.g. ``"log_level"``).
|
||||
field_doc: Optional human-facing description for the field itself
|
||||
(allowed values, semantics, etc.). When supplied, it is
|
||||
appended after the ``startup-only:`` marker block separated by
|
||||
a blank line so IDE hover shows both the restart-required
|
||||
reason *and* the field's normal documentation. Composition
|
||||
keeps the marker as the leading token machine-readable tooling
|
||||
pivots on while restoring the prose that ``Field(description=)``
|
||||
used to carry before the registry took over.
|
||||
|
||||
Raises:
|
||||
KeyError: when *field_path* is not registered. This is deliberate
|
||||
— silently returning a placeholder would let a typo bypass
|
||||
the drift coverage.
|
||||
"""
|
||||
reason = STARTUP_ONLY_FIELDS[field_path]
|
||||
header = f"{STARTUP_ONLY_PREFIX} {reason}"
|
||||
if field_doc is None:
|
||||
return header
|
||||
return f"{header}\n\n{field_doc.strip()}"
|
||||
@ -1,6 +1,10 @@
|
||||
"""MCP (Model Context Protocol) integration using langchain-mcp-adapters."""
|
||||
|
||||
from .cache import get_cached_mcp_tools, initialize_mcp_tools, reset_mcp_tools_cache
|
||||
from .cache import (
|
||||
get_cached_mcp_tools,
|
||||
initialize_mcp_tools,
|
||||
reset_mcp_tools_cache,
|
||||
)
|
||||
from .client import build_server_params, build_servers_config
|
||||
from .tools import get_mcp_tools
|
||||
|
||||
|
||||
@ -143,11 +143,20 @@ def reset_mcp_tools_cache() -> None:
|
||||
|
||||
# Close persistent sessions – they will be recreated by the next
|
||||
# get_mcp_tools() call with the (possibly updated) connection config.
|
||||
#
|
||||
# close_all_sync() already picks the correct strategy per owning loop:
|
||||
# * sessions owned by the *current* running loop are only *signalled*
|
||||
# (their owner task runs __aexit__ once the loop regains control –
|
||||
# this is correct and leak-free, since the loop keeps the task alive),
|
||||
# * sessions on other threads' loops are torn down deterministically,
|
||||
# * idle/closed loops are handled or skipped.
|
||||
# We deliberately do NOT try to synchronously wait for the current running
|
||||
# loop to finish teardown here: that is a self-deadlock (the loop can only
|
||||
# run the teardown after this synchronous call returns control to it).
|
||||
try:
|
||||
from deerflow.mcp.session_pool import get_session_pool
|
||||
|
||||
pool = get_session_pool()
|
||||
pool.close_all_sync()
|
||||
get_session_pool().close_all_sync()
|
||||
except Exception:
|
||||
logger.debug("Could not close MCP session pool on cache reset", exc_info=True)
|
||||
|
||||
|
||||
@ -8,6 +8,27 @@ This module provides a session pool that maintains persistent MCP sessions,
|
||||
scoped by ``(server_name, scope_key)`` — typically scope_key is the thread_id —
|
||||
so that consecutive tool calls share the same session and server-side state.
|
||||
Sessions are evicted in LRU order when the pool reaches capacity.
|
||||
|
||||
Lifecycle model (owner task)
|
||||
----------------------------
|
||||
An MCP ``ClientSession`` is implemented on top of an ``anyio`` task group, and
|
||||
anyio enforces that a cancel scope must be exited from the *same task* that
|
||||
entered it. Calling ``cm.__aexit__`` from any task other than the one that ran
|
||||
``cm.__aenter__`` raises::
|
||||
|
||||
RuntimeError: Attempted to exit cancel scope in a different task than it
|
||||
was entered in
|
||||
|
||||
The sync-tool path (``make_sync_tool_wrapper``) drives each call through a fresh
|
||||
``asyncio.run`` event loop, so a session entered while answering one call would
|
||||
otherwise be exited while answering another — from a different task — and crash
|
||||
(GitHub issue #3379).
|
||||
|
||||
To make this impossible, every pooled session is owned by a dedicated
|
||||
``_run_session`` task. That task enters the context manager, hands the live
|
||||
session back to the caller, and then *waits* on a close event. All shutdown
|
||||
paths only ever **signal** that event; the owner task performs ``__aexit__``
|
||||
itself, guaranteeing enter and exit always happen in the same task.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -27,18 +48,81 @@ class MCPSessionPool:
|
||||
"""Manages persistent MCP sessions scoped by ``(server_name, scope_key)``."""
|
||||
|
||||
MAX_SESSIONS = 256
|
||||
SESSION_CLOSE_TIMEOUT = 5.0 # seconds to wait when closing a session via run_coroutine_threadsafe
|
||||
SESSION_CLOSE_TIMEOUT = 5.0 # seconds to wait when closing a session on a foreign loop
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Each entry: (session, owning_loop, owner_task, close_event).
|
||||
self._entries: OrderedDict[
|
||||
tuple[str, str],
|
||||
tuple[ClientSession, asyncio.AbstractEventLoop],
|
||||
tuple[
|
||||
ClientSession,
|
||||
asyncio.AbstractEventLoop,
|
||||
asyncio.Task[Any],
|
||||
asyncio.Event,
|
||||
],
|
||||
] = OrderedDict()
|
||||
self._context_managers: dict[tuple[str, str], Any] = {}
|
||||
# In-flight creations, keyed by (server, scope). Lets concurrent callers
|
||||
# on the same loop share a single creation instead of each spawning a
|
||||
# duplicate session. Value: (loop, ready_future, owner_task, close_event).
|
||||
self._inflight: dict[
|
||||
tuple[str, str],
|
||||
tuple[
|
||||
asyncio.AbstractEventLoop,
|
||||
asyncio.Future[ClientSession],
|
||||
asyncio.Task[Any],
|
||||
asyncio.Event,
|
||||
],
|
||||
] = {}
|
||||
# threading.Lock is not bound to any event loop, so it is safe to
|
||||
# acquire from both async paths and sync/worker-thread paths.
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Session owner task
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _run_session(
|
||||
self,
|
||||
connection: dict[str, Any],
|
||||
ready: asyncio.Future[ClientSession],
|
||||
close_evt: asyncio.Event,
|
||||
) -> None:
|
||||
"""Own a single MCP session for its entire lifetime.
|
||||
|
||||
Enters the session context manager, initializes it, publishes the live
|
||||
session via ``ready``, then blocks until ``close_evt`` is set. The
|
||||
context manager is *always* exited from this task, satisfying anyio's
|
||||
cancel-scope same-task requirement.
|
||||
"""
|
||||
from langchain_mcp_adapters.sessions import create_session
|
||||
|
||||
cm = create_session(connection)
|
||||
try:
|
||||
session = await cm.__aenter__()
|
||||
except BaseException as e:
|
||||
# Never entered the cancel scope, so there is nothing to exit.
|
||||
if not ready.done():
|
||||
ready.set_exception(e)
|
||||
return
|
||||
|
||||
# The context manager is now entered. From here on __aexit__ MUST run in
|
||||
# this task — on init failure, on cancellation, or on the close signal —
|
||||
# to satisfy anyio's same-task cancel-scope requirement and to avoid
|
||||
# leaking the session/subprocess.
|
||||
try:
|
||||
await session.initialize()
|
||||
if not ready.done():
|
||||
ready.set_result(session)
|
||||
await close_evt.wait()
|
||||
except BaseException as e:
|
||||
if not ready.done():
|
||||
ready.set_exception(e)
|
||||
finally:
|
||||
try:
|
||||
await cm.__aexit__(None, None, None)
|
||||
except Exception:
|
||||
logger.warning("Error closing MCP session", exc_info=True)
|
||||
|
||||
async def get_session(
|
||||
self,
|
||||
server_name: str,
|
||||
@ -47,9 +131,9 @@ class MCPSessionPool:
|
||||
) -> ClientSession:
|
||||
"""Get or create a persistent MCP session.
|
||||
|
||||
If an existing session was created in a different event loop (e.g.
|
||||
the sync-wrapper path), it is closed and replaced with a fresh one
|
||||
in the current loop.
|
||||
If an existing session was created in a different (or closed) event
|
||||
loop, it is evicted and replaced with a fresh one owned by a task on
|
||||
the current loop.
|
||||
|
||||
Args:
|
||||
server_name: MCP server name.
|
||||
@ -63,44 +147,118 @@ class MCPSessionPool:
|
||||
current_loop = asyncio.get_running_loop()
|
||||
|
||||
# Phase 1: inspect/mutate the registry under the thread lock (no awaits).
|
||||
cms_to_close: list[tuple[tuple[str, str], Any]] = []
|
||||
# Decide one of three outcomes atomically: return an existing session,
|
||||
# join an in-flight creation, or become the creator for this key.
|
||||
# Each item: (loop, owner_task, close_event, cancel). ``cancel`` is True
|
||||
# for in-flight creations, whose owner may be blocked inside
|
||||
# ``initialize()`` where close_evt cannot wake it — it must be cancelled.
|
||||
evicted: list[tuple[asyncio.AbstractEventLoop, asyncio.Task[Any], asyncio.Event, bool]] = []
|
||||
join: asyncio.Future[ClientSession] | None = None
|
||||
ready: asyncio.Future[ClientSession] | None = None
|
||||
close_evt: asyncio.Event | None = None
|
||||
task: asyncio.Task[Any] | None = None
|
||||
with self._lock:
|
||||
if key in self._entries:
|
||||
session, loop = self._entries[key]
|
||||
if loop is current_loop:
|
||||
session, loop, ent_task, ent_close = self._entries[key]
|
||||
if loop is current_loop and not loop.is_closed():
|
||||
self._entries.move_to_end(key)
|
||||
return session
|
||||
# Session belongs to a different event loop – evict it.
|
||||
cm = self._context_managers.pop(key, None)
|
||||
# Session belongs to a different/closed event loop – evict it.
|
||||
self._entries.pop(key)
|
||||
if cm is not None:
|
||||
cms_to_close.append((key, cm))
|
||||
evicted.append((loop, ent_task, ent_close, False))
|
||||
|
||||
inflight = self._inflight.get(key)
|
||||
if inflight is not None and inflight[0] is current_loop and not inflight[0].is_closed():
|
||||
# Another caller on this loop is already creating the session;
|
||||
# wait for the same result instead of building a duplicate.
|
||||
join = inflight[1]
|
||||
else:
|
||||
if inflight is not None:
|
||||
# Stale in-flight creation owned by a different/closed loop.
|
||||
# Drop the record and tear its owner down; because that owner
|
||||
# may be blocked inside initialize() (where close_evt cannot
|
||||
# wake it), it must be cancelled. We then create a fresh
|
||||
# session here.
|
||||
self._inflight.pop(key)
|
||||
evicted.append((inflight[0], inflight[2], inflight[3], True))
|
||||
# Become the creator: publish an in-flight record before any
|
||||
# await so concurrent callers join us instead of racing.
|
||||
ready = current_loop.create_future()
|
||||
close_evt = asyncio.Event()
|
||||
task = current_loop.create_task(self._run_session(connection, ready, close_evt))
|
||||
self._inflight[key] = (current_loop, ready, task, close_evt)
|
||||
|
||||
# Evict LRU entries when at capacity.
|
||||
while len(self._entries) >= self.MAX_SESSIONS:
|
||||
oldest_key = next(iter(self._entries))
|
||||
cm = self._context_managers.pop(oldest_key, None)
|
||||
oldest_key, (_, loop, ent_task, ent_close) = next(iter(self._entries.items()))
|
||||
self._entries.pop(oldest_key)
|
||||
if cm is not None:
|
||||
cms_to_close.append((oldest_key, cm))
|
||||
evicted.append((loop, ent_task, ent_close, False))
|
||||
|
||||
# Phase 2: async cleanup outside the lock so we never await while holding it.
|
||||
for close_key, cm in cms_to_close:
|
||||
# Phase 2: shut down evicted sessions/creations. Same-loop owners are
|
||||
# awaited so they finish deterministically; foreign-loop owners are
|
||||
# routed to their own loop. In every case the owner task — never this
|
||||
# one — runs __aexit__. In-flight owners are cancelled (cancel=True) so a
|
||||
# blocking initialize() cannot leave them hung.
|
||||
for loop, ent_task, ent_close, cancel in evicted:
|
||||
if loop is current_loop and not loop.is_closed():
|
||||
await self._shutdown(ent_close, ent_task, cancel)
|
||||
elif cancel:
|
||||
await self._shutdown_entry(loop, ent_task, ent_close, cancel=True)
|
||||
else:
|
||||
self._signal_close(loop, ent_close)
|
||||
|
||||
# Phase 2b: a concurrent creation for this key is already in progress on
|
||||
# this loop — share its result rather than create a duplicate session.
|
||||
if join is not None:
|
||||
return await asyncio.shield(join)
|
||||
|
||||
assert ready is not None and close_evt is not None and task is not None
|
||||
|
||||
# Phase 3: wait for our owner task to publish the initialized session.
|
||||
try:
|
||||
session = await asyncio.shield(ready)
|
||||
except BaseException:
|
||||
# Two distinct cases reach here:
|
||||
#
|
||||
# 1. The owner task failed (e.g. connect/initialize error) and
|
||||
# reported it via ready.set_exception(). It is *already* in its
|
||||
# finally block running cm.__aexit__ in its own task, so we must
|
||||
# NOT cancel it — doing so would interrupt that cleanup. We only
|
||||
# wait for it to finish unwinding.
|
||||
# 2. This call itself was cancelled (CancelledError). Because of the
|
||||
# shield, `ready` is still pending and the owner task is alive and
|
||||
# blocked. We signal close and cancel it so it exits the cancel
|
||||
# scope in its own task, then wait for it to finish.
|
||||
#
|
||||
# The session is never registered yet, so nobody else can close it;
|
||||
# waiting here guarantees we never leak a session or owner task.
|
||||
owner_already_failed = ready.done() and not ready.cancelled() and ready.exception() is not None
|
||||
if not owner_already_failed:
|
||||
close_evt.set()
|
||||
task.cancel()
|
||||
try:
|
||||
await cm.__aexit__(None, None, None)
|
||||
except Exception:
|
||||
logger.warning("Error closing MCP session %s", close_key, exc_info=True)
|
||||
await asyncio.shield(task)
|
||||
except BaseException:
|
||||
logger.debug("Owner task ended during get_session unwind", exc_info=True)
|
||||
with self._lock:
|
||||
if self._inflight.get(key) == (current_loop, ready, task, close_evt):
|
||||
self._inflight.pop(key)
|
||||
raise
|
||||
|
||||
from langchain_mcp_adapters.sessions import create_session
|
||||
|
||||
cm = create_session(connection)
|
||||
session = await cm.__aenter__()
|
||||
await session.initialize()
|
||||
|
||||
# Phase 3: register the new session under the lock.
|
||||
# Phase 4: promote the in-flight creation to a registered entry — but
|
||||
# only if our in-flight record is still the live one. A concurrent
|
||||
# close_* / close_all may have removed it while we were initializing; in
|
||||
# that case we must NOT resurrect the session into _entries. Instead we
|
||||
# own the teardown: signal our owner task and wait for it to run
|
||||
# __aexit__ in its own task, then surface the cancellation.
|
||||
with self._lock:
|
||||
self._entries[key] = (session, current_loop)
|
||||
self._context_managers[key] = cm
|
||||
still_ours = self._inflight.get(key) == (current_loop, ready, task, close_evt)
|
||||
if still_ours:
|
||||
self._inflight.pop(key)
|
||||
self._entries[key] = (session, current_loop, task, close_evt)
|
||||
if not still_ours:
|
||||
await self._shutdown(close_evt, task)
|
||||
raise asyncio.CancelledError("MCP session pool was closed while the session was being created")
|
||||
logger.info("Created persistent MCP session for %s/%s", server_name, scope_key)
|
||||
return session
|
||||
|
||||
@ -108,70 +266,169 @@ class MCPSessionPool:
|
||||
# Cleanup helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _close_cm(self, key: tuple[str, str], cm: Any) -> None:
|
||||
"""Close a single context manager (must be called WITHOUT the lock)."""
|
||||
@staticmethod
|
||||
def _signal_close(loop: asyncio.AbstractEventLoop, close_evt: asyncio.Event) -> None:
|
||||
"""Ask an owner task to shut down without waiting.
|
||||
|
||||
``asyncio.Event.set`` is not thread-safe, so it is scheduled on the
|
||||
owning loop. A closed loop means the owner task is already gone.
|
||||
"""
|
||||
if loop.is_closed():
|
||||
return
|
||||
try:
|
||||
await cm.__aexit__(None, None, None)
|
||||
except Exception:
|
||||
logger.warning("Error closing MCP session %s", key, exc_info=True)
|
||||
loop.call_soon_threadsafe(close_evt.set)
|
||||
except RuntimeError:
|
||||
# Loop was closed between the is_closed() check and now.
|
||||
pass
|
||||
|
||||
async def _shutdown(
|
||||
self,
|
||||
close_evt: asyncio.Event,
|
||||
task: asyncio.Task[Any],
|
||||
cancel: bool = False,
|
||||
) -> None:
|
||||
"""Signal an owner task and wait for it to finish (runs on its loop).
|
||||
|
||||
``cancel=True`` is used for in-flight creations: the owner task may be
|
||||
blocked inside ``initialize()`` where ``close_evt`` cannot wake it, so it
|
||||
must be cancelled. Its ``finally`` block still runs ``__aexit__`` in its
|
||||
own task, satisfying anyio's same-task cancel-scope requirement.
|
||||
"""
|
||||
close_evt.set()
|
||||
if cancel:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except (Exception, asyncio.CancelledError):
|
||||
logger.debug("Owner task ended during shutdown", exc_info=True)
|
||||
|
||||
async def _shutdown_entry(
|
||||
self,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
task: asyncio.Task[Any],
|
||||
close_evt: asyncio.Event,
|
||||
cancel: bool = False,
|
||||
) -> None:
|
||||
"""Shut down one entry, routing the close to its owning loop."""
|
||||
if loop.is_closed():
|
||||
return
|
||||
current_loop = asyncio.get_running_loop()
|
||||
if loop is current_loop:
|
||||
await self._shutdown(close_evt, task, cancel)
|
||||
elif loop.is_running():
|
||||
future = asyncio.run_coroutine_threadsafe(self._shutdown(close_evt, task, cancel), loop)
|
||||
try:
|
||||
await asyncio.wrap_future(future)
|
||||
except Exception:
|
||||
logger.warning("Error closing MCP session on owning loop", exc_info=True)
|
||||
else:
|
||||
# Owning loop exists but is neither the current loop nor running.
|
||||
# We are inside an async context here, so run_until_complete() would
|
||||
# raise "Cannot run the event loop while another loop is running";
|
||||
# and the loop may belong to another thread, where driving it from
|
||||
# here is unsafe. This branch is not expected in practice — a
|
||||
# session's owning loop is either the long-lived gateway loop (which
|
||||
# is running) or a short-lived asyncio.run loop (which is closed and
|
||||
# caught above). Fall back to a best-effort thread-safe signal so the
|
||||
# owner task tears down if/when its loop runs again.
|
||||
logger.warning("Owning loop for MCP session is idle; signalling close best-effort. Session may leak until the loop runs again.")
|
||||
self._signal_close(loop, close_evt)
|
||||
if cancel:
|
||||
try:
|
||||
loop.call_soon_threadsafe(task.cancel)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
async def close_scope(self, scope_key: str) -> None:
|
||||
"""Close all sessions for a given scope (e.g. thread_id)."""
|
||||
with self._lock:
|
||||
keys = [k for k in self._entries if k[1] == scope_key]
|
||||
cms = [(k, self._context_managers.pop(k, None)) for k in keys]
|
||||
for k in keys:
|
||||
self._entries.pop(k, None)
|
||||
for key, cm in cms:
|
||||
if cm is not None:
|
||||
await self._close_cm(key, cm)
|
||||
entries = [(self._entries.pop(k)) for k in keys]
|
||||
inflight_keys = [k for k in self._inflight if k[1] == scope_key]
|
||||
inflight = [self._inflight.pop(k) for k in inflight_keys]
|
||||
for _session, loop, task, close_evt in entries:
|
||||
await self._shutdown_entry(loop, task, close_evt)
|
||||
for loop, _ready, task, close_evt in inflight:
|
||||
await self._shutdown_entry(loop, task, close_evt, cancel=True)
|
||||
|
||||
async def close_server(self, server_name: str) -> None:
|
||||
"""Close all sessions for a given server."""
|
||||
with self._lock:
|
||||
keys = [k for k in self._entries if k[0] == server_name]
|
||||
cms = [(k, self._context_managers.pop(k, None)) for k in keys]
|
||||
for k in keys:
|
||||
self._entries.pop(k, None)
|
||||
for key, cm in cms:
|
||||
if cm is not None:
|
||||
await self._close_cm(key, cm)
|
||||
entries = [(self._entries.pop(k)) for k in keys]
|
||||
inflight_keys = [k for k in self._inflight if k[0] == server_name]
|
||||
inflight = [self._inflight.pop(k) for k in inflight_keys]
|
||||
for _session, loop, task, close_evt in entries:
|
||||
await self._shutdown_entry(loop, task, close_evt)
|
||||
for loop, _ready, task, close_evt in inflight:
|
||||
await self._shutdown_entry(loop, task, close_evt, cancel=True)
|
||||
|
||||
async def close_all(self) -> None:
|
||||
"""Close every managed session."""
|
||||
with self._lock:
|
||||
cms = list(self._context_managers.items())
|
||||
self._context_managers.clear()
|
||||
entries = list(self._entries.values())
|
||||
self._entries.clear()
|
||||
for key, cm in cms:
|
||||
await self._close_cm(key, cm)
|
||||
inflight = list(self._inflight.values())
|
||||
self._inflight.clear()
|
||||
for _session, loop, task, close_evt in entries:
|
||||
await self._shutdown_entry(loop, task, close_evt)
|
||||
for loop, _ready, task, close_evt in inflight:
|
||||
await self._shutdown_entry(loop, task, close_evt, cancel=True)
|
||||
|
||||
def close_all_sync(self) -> None:
|
||||
"""Close all sessions using their owning event loops (synchronous).
|
||||
"""Close all sessions on their owning event loops (synchronous).
|
||||
|
||||
Each session is closed on the loop it was created in, avoiding
|
||||
cross-loop resource leaks. Safe to call from any thread without an
|
||||
active event loop.
|
||||
Each session is closed by its owner task on the loop it was created in,
|
||||
avoiding cross-loop and cross-task errors. Safe to call from any thread
|
||||
without an active event loop.
|
||||
|
||||
Closing semantics differ by where the owning loop runs:
|
||||
|
||||
* Owning loop is idle, or running on another thread — this call blocks
|
||||
until teardown completes (or ``SESSION_CLOSE_TIMEOUT`` elapses).
|
||||
* Owning loop is the one currently running on *this* thread — we cannot
|
||||
block on it without deadlocking, so teardown is only *signalled* here
|
||||
and completes asynchronously once control returns to that loop. The
|
||||
caller must therefore keep that loop running afterwards; if it stops
|
||||
the loop immediately, the owner task's ``__aexit__`` may not run. When
|
||||
a deterministic close is required from inside a running loop, ``await
|
||||
close_all()`` instead.
|
||||
"""
|
||||
with self._lock:
|
||||
entries = list(self._entries.items())
|
||||
cms = dict(self._context_managers)
|
||||
entries = list(self._entries.values())
|
||||
self._entries.clear()
|
||||
self._context_managers.clear()
|
||||
inflight = list(self._inflight.values())
|
||||
self._inflight.clear()
|
||||
|
||||
for key, (_, loop) in entries:
|
||||
cm = cms.get(key)
|
||||
if cm is None or loop.is_closed():
|
||||
# Entries are initialized (gentle close_evt path). In-flight creations
|
||||
# may be blocked mid-init, so they are cancelled to unblock teardown.
|
||||
owners = [(loop, task, close_evt, False) for _s, loop, task, close_evt in entries]
|
||||
owners += [(loop, task, close_evt, True) for loop, _r, task, close_evt in inflight]
|
||||
try:
|
||||
current_running_loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
current_running_loop = None
|
||||
for loop, task, close_evt, cancel in owners:
|
||||
if loop.is_closed():
|
||||
continue
|
||||
try:
|
||||
if loop.is_running():
|
||||
# Schedule on the owning loop from this (different) thread.
|
||||
future = asyncio.run_coroutine_threadsafe(cm.__aexit__(None, None, None), loop)
|
||||
if loop is current_running_loop:
|
||||
# We are executing inside this loop's thread, so synchronously
|
||||
# waiting on run_coroutine_threadsafe(...).result() would
|
||||
# deadlock until timeout. Signal the owner task directly and
|
||||
# let it finish once this synchronous call returns control to
|
||||
# the running loop.
|
||||
close_evt.set()
|
||||
if cancel:
|
||||
task.cancel()
|
||||
elif loop.is_running():
|
||||
# Schedule the shutdown on the owning loop from this thread.
|
||||
future = asyncio.run_coroutine_threadsafe(self._shutdown(close_evt, task, cancel), loop)
|
||||
future.result(timeout=self.SESSION_CLOSE_TIMEOUT)
|
||||
else:
|
||||
loop.run_until_complete(cm.__aexit__(None, None, None))
|
||||
loop.run_until_complete(self._shutdown(close_evt, task, cancel))
|
||||
except Exception:
|
||||
logger.debug("Error closing MCP session %s during sync close", key, exc_info=True)
|
||||
logger.debug("Error closing MCP session during sync close", exc_info=True)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
"""Tests for the MCP persistent-session pool."""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -614,3 +616,585 @@ async def test_http_transport_tools_not_pooled():
|
||||
stdio_tools = [t for t in tools if t.name == "playwright_navigate"]
|
||||
assert len(stdio_tools) == 1
|
||||
assert stdio_tools[0].coroutine is not stdio_tool.coroutine
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression for #3379: cancel scope must be exited in the entering task
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _CancelScopeCm:
|
||||
"""Fake session context manager that mimics anyio's cancel-scope rule.
|
||||
|
||||
``ClientSession`` is built on an anyio task group, which requires the cancel
|
||||
scope to be exited from the *same asyncio task* that entered it. This fake
|
||||
records the task that runs ``__aenter__`` and raises the exact RuntimeError
|
||||
anyio would raise if ``__aexit__`` runs in a different task — reproducing the
|
||||
crash reported in GitHub issue #3379.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.enter_task: object | None = None
|
||||
self.closed = False
|
||||
|
||||
async def __aenter__(self):
|
||||
self.enter_task = asyncio.current_task()
|
||||
return AsyncMock()
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
if asyncio.current_task() is not self.enter_task:
|
||||
raise RuntimeError("Attempted to exit cancel scope in a different task than it was entered in")
|
||||
self.closed = True
|
||||
return False
|
||||
|
||||
|
||||
async def _get_session_in_own_task(pool, *args):
|
||||
"""Create a pooled session from a *dedicated* child task.
|
||||
|
||||
In production every stdio session is entered from its own short-lived task
|
||||
(the sync-tool path runs each call through a fresh ``asyncio.run``). This
|
||||
helper reproduces that so the close paths are exercised from a *different*
|
||||
task than the one that entered the session — the exact condition that
|
||||
triggered #3379.
|
||||
"""
|
||||
return await asyncio.create_task(pool.get_session(*args))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_all_does_not_cross_tasks():
|
||||
"""close_all must not raise the cross-task cancel-scope RuntimeError (#3379)."""
|
||||
pool = MCPSessionPool()
|
||||
cms: list[_CancelScopeCm] = []
|
||||
|
||||
def make_cm(*a, **kw):
|
||||
cm = _CancelScopeCm()
|
||||
cms.append(cm)
|
||||
return cm
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
|
||||
await _get_session_in_own_task(pool, "s1", "t1", {"transport": "stdio", "command": "x", "args": []})
|
||||
await _get_session_in_own_task(pool, "s2", "t2", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
# close_all runs in this task, which is *not* the task that entered either
|
||||
# session. The owner task must perform __aexit__ so each CM closes cleanly.
|
||||
await pool.close_all()
|
||||
|
||||
assert all(cm.closed for cm in cms)
|
||||
assert len(pool._entries) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_scope_does_not_cross_tasks():
|
||||
"""close_scope must respect the same-task cancel-scope rule (#3379)."""
|
||||
pool = MCPSessionPool()
|
||||
cms: list[_CancelScopeCm] = []
|
||||
|
||||
def make_cm(*a, **kw):
|
||||
cm = _CancelScopeCm()
|
||||
cms.append(cm)
|
||||
return cm
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
|
||||
await _get_session_in_own_task(pool, "s", "t1", {"transport": "stdio", "command": "x", "args": []})
|
||||
await _get_session_in_own_task(pool, "s", "t2", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
await pool.close_scope("t1")
|
||||
|
||||
assert cms[0].closed is True
|
||||
assert cms[1].closed is False
|
||||
assert ("s", "t2") in pool._entries
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lru_eviction_does_not_cross_tasks():
|
||||
"""LRU eviction must close the victim without a cross-task RuntimeError (#3379)."""
|
||||
pool = MCPSessionPool()
|
||||
pool.MAX_SESSIONS = 2
|
||||
cms: list[_CancelScopeCm] = []
|
||||
|
||||
def make_cm(*a, **kw):
|
||||
cm = _CancelScopeCm()
|
||||
cms.append(cm)
|
||||
return cm
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
|
||||
await _get_session_in_own_task(pool, "s", "t1", {"transport": "stdio", "command": "x", "args": []})
|
||||
await _get_session_in_own_task(pool, "s", "t2", {"transport": "stdio", "command": "x", "args": []})
|
||||
# Adding t3 evicts t1 — its own owner task must run __aexit__, even
|
||||
# though the eviction is driven from t3's get_session call.
|
||||
await _get_session_in_own_task(pool, "s", "t3", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
assert cms[0].closed is True
|
||||
assert cms[1].closed is False
|
||||
assert cms[2].closed is False
|
||||
|
||||
|
||||
def test_close_all_sync_across_loops_does_not_cross_tasks():
|
||||
"""close_all_sync, the path hit by the sync tool wrapper, must close sessions
|
||||
created in earlier (now-finished) asyncio.run loops without crashing (#3379).
|
||||
"""
|
||||
pool = MCPSessionPool()
|
||||
cms: list[_CancelScopeCm] = []
|
||||
|
||||
def make_cm(*a, **kw):
|
||||
cm = _CancelScopeCm()
|
||||
cms.append(cm)
|
||||
return cm
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
|
||||
# Simulate the sync-tool path: a session created inside one short-lived
|
||||
# event loop, then a second one in a different loop.
|
||||
asyncio.run(pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []}))
|
||||
asyncio.run(pool.get_session("s", "t2", {"transport": "stdio", "command": "x", "args": []}))
|
||||
|
||||
# The owning loops are already closed; close_all_sync must not raise.
|
||||
pool.close_all_sync()
|
||||
|
||||
assert len(pool._entries) == 0
|
||||
|
||||
|
||||
def test_get_session_replaces_session_from_closed_loop():
|
||||
"""A pooled session whose owning loop has closed is evicted and recreated."""
|
||||
pool = MCPSessionPool()
|
||||
cms: list[_CancelScopeCm] = []
|
||||
|
||||
def make_cm(*a, **kw):
|
||||
cm = _CancelScopeCm()
|
||||
cms.append(cm)
|
||||
return cm
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
|
||||
# First session created in a throwaway loop that is torn down by
|
||||
# asyncio.run (mirrors the sync-tool path). asyncio.run cancels the
|
||||
# pending owner task and runs its __aexit__ on the same loop.
|
||||
asyncio.run(pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []}))
|
||||
assert ("s", "t1") in pool._entries
|
||||
|
||||
# Now request the same key from a fresh loop: the stale entry (closed
|
||||
# loop) must be evicted and replaced with a fresh session.
|
||||
session = asyncio.run(pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []}))
|
||||
|
||||
assert session is not None
|
||||
assert len(cms) == 2
|
||||
assert pool._entries[("s", "t1")][0] is session
|
||||
|
||||
|
||||
class _BlockingInitCm:
|
||||
"""Fake session CM whose ``initialize`` blocks until released.
|
||||
|
||||
Lets a test cancel ``get_session`` while the owner task is still
|
||||
initializing, reproducing the caller-cancellation window.
|
||||
"""
|
||||
|
||||
def __init__(self, gate: asyncio.Event) -> None:
|
||||
self._gate = gate
|
||||
self.entered = False
|
||||
self.closed = False
|
||||
|
||||
async def __aenter__(self):
|
||||
self.entered = True
|
||||
session = MagicMock()
|
||||
session.initialize = self._initialize
|
||||
return session
|
||||
|
||||
async def _initialize(self):
|
||||
await self._gate.wait()
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
self.closed = True
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_cancelled_while_initializing_does_not_leak():
|
||||
"""Cancelling get_session mid-init must not leak the owner task/session (#3379 CR).
|
||||
|
||||
The session is not registered yet, so if cancellation skipped the cleanup
|
||||
the owner task would block forever on close_evt.wait() and the CM's
|
||||
__aexit__ would never run — an unreachable, unclosable session.
|
||||
"""
|
||||
pool = MCPSessionPool()
|
||||
gate = asyncio.Event()
|
||||
cms: list[_BlockingInitCm] = []
|
||||
|
||||
def make_cm(*a, **kw):
|
||||
cm = _BlockingInitCm(gate)
|
||||
cms.append(cm)
|
||||
return cm
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
|
||||
call = asyncio.create_task(pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []}))
|
||||
# Let the owner task enter the CM and reach the blocking initialize().
|
||||
await asyncio.sleep(0.01)
|
||||
call.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await call
|
||||
|
||||
# Release initialize() so the owner task can finish its shutdown path.
|
||||
gate.set()
|
||||
# Give the owner task a chance to run __aexit__ and complete.
|
||||
for _ in range(10):
|
||||
if cms and cms[0].closed:
|
||||
break
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
assert len(cms) == 1
|
||||
assert cms[0].entered is True
|
||||
assert cms[0].closed is True, "owner task must run __aexit__ after cancellation"
|
||||
assert len(pool._entries) == 0
|
||||
|
||||
current = asyncio.current_task()
|
||||
leaked = [t for t in asyncio.all_tasks() if t is not current and not t.done() and "_run_session" in str(t.get_coro())]
|
||||
assert not leaked, "owner task must not be left pending after cancellation"
|
||||
|
||||
|
||||
class _InitFailCm:
|
||||
"""Fake session CM whose ``initialize`` fails, with a slow ``__aexit__``.
|
||||
|
||||
The slow __aexit__ lets a test observe whether cleanup is allowed to run to
|
||||
completion (closed=True) or is interrupted by a stray cancellation.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.entered = False
|
||||
self.exit_started = False
|
||||
self.closed = False
|
||||
|
||||
async def __aenter__(self):
|
||||
self.entered = True
|
||||
session = MagicMock()
|
||||
session.initialize = self._initialize
|
||||
return session
|
||||
|
||||
async def _initialize(self):
|
||||
raise RuntimeError("init boom")
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
self.exit_started = True
|
||||
# Yield control so a buggy double-cancel would interrupt us here.
|
||||
await asyncio.sleep(0.02)
|
||||
self.closed = True
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_init_failure_runs_full_cleanup():
|
||||
"""On initialize() failure the owner task's __aexit__ must complete (#3379 CR P1).
|
||||
|
||||
The caller must NOT cancel the owner task on a reported failure, otherwise
|
||||
the in-progress __aexit__ cleanup gets interrupted and leaks resources.
|
||||
"""
|
||||
pool = MCPSessionPool()
|
||||
cms: list[_InitFailCm] = []
|
||||
|
||||
def make_cm(*a, **kw):
|
||||
cm = _InitFailCm()
|
||||
cms.append(cm)
|
||||
return cm
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
|
||||
with pytest.raises(RuntimeError, match="init boom"):
|
||||
await pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
assert len(cms) == 1
|
||||
assert cms[0].entered is True
|
||||
assert cms[0].exit_started is True
|
||||
assert cms[0].closed is True, "__aexit__ must run to completion, not be interrupted"
|
||||
assert len(pool._entries) == 0
|
||||
assert len(pool._inflight) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_get_session_same_key_creates_single_session():
|
||||
"""Concurrent get_session for the same key must share one session (#3379 CR P1)."""
|
||||
pool = MCPSessionPool()
|
||||
gate = asyncio.Event()
|
||||
cms: list[_BlockingInitCm] = []
|
||||
|
||||
def make_cm(*a, **kw):
|
||||
cm = _BlockingInitCm(gate)
|
||||
cms.append(cm)
|
||||
return cm
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
|
||||
conn = {"transport": "stdio", "command": "x", "args": []}
|
||||
t1 = asyncio.create_task(pool.get_session("s", "same", conn))
|
||||
t2 = asyncio.create_task(pool.get_session("s", "same", conn))
|
||||
# Let both calls pass Phase 1 and reach the (gated) initialize().
|
||||
await asyncio.sleep(0.02)
|
||||
gate.set()
|
||||
s1, s2 = await asyncio.gather(t1, t2)
|
||||
|
||||
# Only one CM/session created, both callers got the same object.
|
||||
assert len(cms) == 1, "concurrent same-key calls must not create duplicate sessions"
|
||||
assert s1 is s2
|
||||
assert len(pool._entries) == 1
|
||||
assert len(pool._inflight) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_all_during_in_flight_creation_does_not_resurrect_session():
|
||||
"""close_all while a creation is in-flight must not leave a live session (#3379 CR P1).
|
||||
|
||||
The in-flight record must be removed and its owner task torn down, so when
|
||||
the (blocked) creator finishes initializing it does NOT register the session
|
||||
back into _entries — otherwise the pool resurrects an unclosable session.
|
||||
"""
|
||||
pool = MCPSessionPool()
|
||||
gate = asyncio.Event()
|
||||
cms: list[_BlockingInitCm] = []
|
||||
|
||||
def make_cm(*a, **kw):
|
||||
cm = _BlockingInitCm(gate)
|
||||
cms.append(cm)
|
||||
return cm
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
|
||||
conn = {"transport": "stdio", "command": "x", "args": []}
|
||||
call = asyncio.create_task(pool.get_session("s", "t1", conn))
|
||||
# Let the owner task enter the CM and reach the blocking initialize().
|
||||
await asyncio.sleep(0.01)
|
||||
assert ("s", "t1") in pool._inflight
|
||||
|
||||
# Close everything while the creation is still in-flight.
|
||||
await pool.close_all()
|
||||
|
||||
# The in-flight creation must be gone, not promoted to an entry.
|
||||
assert len(pool._inflight) == 0
|
||||
assert len(pool._entries) == 0
|
||||
|
||||
# Even if the gate is released afterwards, nothing must come back.
|
||||
gate.set()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await call
|
||||
|
||||
assert len(pool._entries) == 0
|
||||
assert len(pool._inflight) == 0
|
||||
assert cms[0].closed is True, "in-flight session's __aexit__ must run on teardown"
|
||||
|
||||
current = asyncio.current_task()
|
||||
leaked = [t for t in asyncio.all_tasks() if t is not current and not t.done() and "_run_session" in str(t.get_coro())]
|
||||
assert not leaked, "in-flight owner task must not leak after close_all"
|
||||
|
||||
|
||||
def test_get_session_cross_loop_in_flight_does_not_raise_assertion():
|
||||
"""A same-key request from another loop must not hit the in-flight assertion (#3379 CR P1).
|
||||
|
||||
Loop A starts (and leaves running) an in-flight creation, then loop B
|
||||
requests the same key. The stale in-flight record (owned by loop A) must be
|
||||
dropped and loop B must become a fresh creator — never fall through to an
|
||||
AssertionError.
|
||||
"""
|
||||
pool = MCPSessionPool()
|
||||
cms: list[_CancelScopeCm] = []
|
||||
|
||||
def make_cm(*a, **kw):
|
||||
cm = _CancelScopeCm()
|
||||
cms.append(cm)
|
||||
return cm
|
||||
|
||||
conn = {"transport": "stdio", "command": "x", "args": []}
|
||||
results: list[object] = []
|
||||
errors: list[BaseException] = []
|
||||
|
||||
def run_in_own_loop():
|
||||
try:
|
||||
results.append(asyncio.run(pool.get_session("s", "t1", conn)))
|
||||
except BaseException as e: # noqa: BLE001 - capture for assertion
|
||||
errors.append(e)
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
|
||||
# First loop creates and registers an entry, then its loop is torn down
|
||||
# by asyncio.run, leaving a stale (closed-loop) record behind.
|
||||
t1 = threading.Thread(target=run_in_own_loop)
|
||||
t1.start()
|
||||
t1.join()
|
||||
|
||||
# Second loop requests the same key. It must evict the stale record and
|
||||
# create a fresh session instead of raising AssertionError.
|
||||
t2 = threading.Thread(target=run_in_own_loop)
|
||||
t2.start()
|
||||
t2.join()
|
||||
|
||||
assert not errors, f"cross-loop same-key request must not raise: {errors}"
|
||||
assert len(results) == 2
|
||||
assert all(r is not None for r in results)
|
||||
|
||||
|
||||
def test_cross_loop_preempting_blocked_in_flight_does_not_hang_owner():
|
||||
"""A foreign-loop request must not leave a still-initializing owner hung (#3379 CR P1).
|
||||
|
||||
Loop A starts a creation that blocks inside initialize() (the in-flight
|
||||
record stays live). Loop B then requests the same key. B must tear A's owner
|
||||
down — cancelling it, because close_evt alone cannot wake a task blocked in
|
||||
initialize() — so that A's get_session unwinds instead of hanging forever.
|
||||
"""
|
||||
pool = MCPSessionPool()
|
||||
conn = {"transport": "stdio", "command": "x", "args": []}
|
||||
first_gate = threading.Event()
|
||||
entered = threading.Event()
|
||||
results: list[tuple[str, object]] = []
|
||||
errors: list[tuple[str, BaseException]] = []
|
||||
closed: list[str] = []
|
||||
|
||||
class _BlockingForeverCm:
|
||||
async def __aenter__(self):
|
||||
session = MagicMock()
|
||||
session.initialize = self._initialize
|
||||
entered.set()
|
||||
return session
|
||||
|
||||
async def _initialize(self):
|
||||
# Block until released, simulating a slow/stuck server handshake.
|
||||
while not first_gate.is_set():
|
||||
await asyncio.sleep(0.005)
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
closed.append("blocking")
|
||||
return False
|
||||
|
||||
class _FastCm:
|
||||
async def __aenter__(self):
|
||||
session = MagicMock()
|
||||
|
||||
async def init():
|
||||
return None
|
||||
|
||||
session.initialize = init
|
||||
return session
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
return False
|
||||
|
||||
cms: list[object] = [_BlockingForeverCm(), _FastCm()]
|
||||
|
||||
def make_cm(*a, **kw):
|
||||
return cms.pop(0)
|
||||
|
||||
def run_get(name):
|
||||
try:
|
||||
results.append((name, asyncio.run(pool.get_session("s", "t1", conn))))
|
||||
except BaseException as e: # noqa: BLE001 - capture for assertion
|
||||
errors.append((name, e))
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
|
||||
ta = threading.Thread(target=run_get, args=("A",))
|
||||
ta.start()
|
||||
assert entered.wait(2), "owner A must enter the CM and start initializing"
|
||||
|
||||
tb = threading.Thread(target=run_get, args=("B",))
|
||||
tb.start()
|
||||
tb.join(3)
|
||||
|
||||
# B must complete without depending on A's blocked initialize().
|
||||
assert not tb.is_alive(), "foreign-loop request B must not hang"
|
||||
# A must already be unwound (cancelled), not waiting on the dead gate.
|
||||
ta.join(3)
|
||||
assert not ta.is_alive(), "preempted owner A must not hang forever"
|
||||
|
||||
assert [n for n, _ in results] == ["B"], "only B produces a usable session"
|
||||
assert any(isinstance(e, asyncio.CancelledError) for _, e in errors), "preempted A must unwind via CancelledError"
|
||||
assert "blocking" in closed, "preempted owner's __aexit__ must run on teardown"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_all_sync_from_running_loop_does_not_wait_on_itself():
|
||||
"""close_all_sync must not block on the current running loop (#3379 CR P1).
|
||||
|
||||
When called from code already executing inside the owner loop's thread,
|
||||
close_all_sync cannot synchronously wait for that loop to run the shutdown
|
||||
coroutine. It must signal the owner task and return promptly, then the owner
|
||||
task closes itself once the loop regains control.
|
||||
"""
|
||||
pool = MCPSessionPool()
|
||||
pool.SESSION_CLOSE_TIMEOUT = 0.2
|
||||
conn = {"transport": "stdio", "command": "x", "args": []}
|
||||
|
||||
cm = _CloseTrackingCm()
|
||||
|
||||
def make_cm(*a, **kw):
|
||||
return cm
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
|
||||
await pool.get_session("s", "t1", conn)
|
||||
start = asyncio.get_running_loop().time()
|
||||
pool.close_all_sync()
|
||||
elapsed = asyncio.get_running_loop().time() - start
|
||||
|
||||
assert elapsed < 0.1, "close_all_sync must not stall until timeout on the current loop"
|
||||
assert len(pool._entries) == 0
|
||||
assert len(pool._inflight) == 0
|
||||
assert cm.closed is False, "owner task has not run yet while close_all_sync is still executing"
|
||||
|
||||
for _ in range(10):
|
||||
if cm.closed:
|
||||
break
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
assert cm.closed is True, "owner task must close itself after the loop regains control"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reset_mcp_tools_cache deadlock regression
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _CloseTrackingCm:
|
||||
"""A create_session() context manager that records when __aexit__ runs."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.closed = False
|
||||
|
||||
async def __aenter__(self):
|
||||
session = MagicMock()
|
||||
|
||||
async def init():
|
||||
return None
|
||||
|
||||
session.initialize = init
|
||||
return session
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
self.closed = True
|
||||
return False
|
||||
|
||||
|
||||
def test_reset_mcp_tools_cache_from_running_loop_is_bounded():
|
||||
"""reset_mcp_tools_cache() must not deadlock when called from inside a
|
||||
running loop that owns sessions (#3392 CR blocker).
|
||||
|
||||
The previous implementation spun up a worker thread running
|
||||
``asyncio.run(pool.close_all())`` and blocked the loop thread on
|
||||
``.result()``. close_all() then routed teardown of the current loop's
|
||||
sessions back onto that blocked loop via run_coroutine_threadsafe(...),
|
||||
so neither side could make progress. This test drives the exact scenario
|
||||
on a daemon thread and asserts the call returns within a bounded time.
|
||||
"""
|
||||
from deerflow.mcp.cache import reset_mcp_tools_cache
|
||||
from deerflow.mcp.session_pool import get_session_pool
|
||||
|
||||
conn = {"transport": "stdio", "command": "x", "args": []}
|
||||
cm = _CloseTrackingCm()
|
||||
done = threading.Event()
|
||||
|
||||
async def scenario():
|
||||
pool = get_session_pool()
|
||||
# Entry owned by THIS loop — the deadlock-prone case.
|
||||
await pool.get_session("s", "t1", conn)
|
||||
# Synchronous call: asyncio.get_running_loop() succeeds inside it, so
|
||||
# it takes the "running loop" branch in reset_mcp_tools_cache().
|
||||
reset_mcp_tools_cache()
|
||||
# Signal-only teardown completes once the loop regains control.
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
def run():
|
||||
asyncio.run(scenario())
|
||||
done.set()
|
||||
|
||||
t = threading.Thread(target=run, daemon=True)
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", return_value=cm):
|
||||
t.start()
|
||||
t.join(timeout=5)
|
||||
|
||||
assert done.is_set(), "reset_mcp_tools_cache() deadlocked inside a running loop"
|
||||
assert cm.closed is True, "owner task must run __aexit__ once the loop regains control"
|
||||
|
||||
140
backend/tests/test_reload_boundary.py
Normal file
140
backend/tests/test_reload_boundary.py
Normal file
@ -0,0 +1,140 @@
|
||||
"""Regression tests for the config reload boundary registry.
|
||||
|
||||
Bytedance/deer-flow issue #3144: the hot-reload boundary is the contract
|
||||
between gateway dependencies that resolve ``AppConfig`` every request and the
|
||||
infrastructure that captures the snapshot once at startup. The registry in
|
||||
``deerflow.config.reload_boundary`` is the machine-readable source of truth;
|
||||
these tests pin the registry against the actual Pydantic schema so a future
|
||||
field rename / addition / boundary change cannot silently drift.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.reload_boundary import (
|
||||
STARTUP_ONLY_FIELDS,
|
||||
STARTUP_ONLY_PREFIX,
|
||||
format_field_description,
|
||||
is_startup_only_field,
|
||||
iter_startup_only_field_paths,
|
||||
)
|
||||
|
||||
|
||||
def test_registry_has_a_reason_for_every_field():
|
||||
"""Every registry entry must explain *why* the field is restart-required.
|
||||
|
||||
The reason text is what surfaces in IDE hover and in the AppConfig schema
|
||||
description, so an empty / placeholder value would defeat the purpose.
|
||||
"""
|
||||
for field_path, reason in STARTUP_ONLY_FIELDS.items():
|
||||
assert reason.strip(), f"empty reason for {field_path}"
|
||||
assert len(reason) > 20, f"reason for {field_path} too short to be useful: {reason!r}"
|
||||
|
||||
|
||||
def test_iter_startup_only_field_paths_matches_registry():
|
||||
"""Iterator stays in sync with the registry mapping."""
|
||||
assert sorted(iter_startup_only_field_paths()) == sorted(STARTUP_ONLY_FIELDS)
|
||||
|
||||
|
||||
def test_is_startup_only_field_recognises_registered_fields():
|
||||
"""The membership helper accepts every registered field path."""
|
||||
for field_path in STARTUP_ONLY_FIELDS:
|
||||
assert is_startup_only_field(field_path)
|
||||
assert not is_startup_only_field("memory") # hot-reloadable
|
||||
assert not is_startup_only_field("models")
|
||||
assert not is_startup_only_field("nonexistent_field")
|
||||
|
||||
|
||||
def test_format_field_description_prefixes_with_marker():
|
||||
"""The formatter produces a description that machine-readable tooling can
|
||||
pivot on (drift tests, future "needs-restart" scanners)."""
|
||||
for field_path in STARTUP_ONLY_FIELDS:
|
||||
text = format_field_description(field_path)
|
||||
assert text.startswith(STARTUP_ONLY_PREFIX), text
|
||||
# The reason is appended after the prefix; the formatter must not
|
||||
# silently drop it.
|
||||
assert STARTUP_ONLY_FIELDS[field_path] in text
|
||||
|
||||
|
||||
def test_format_field_description_rejects_unknown_field():
|
||||
with pytest.raises(KeyError):
|
||||
format_field_description("not_in_registry")
|
||||
|
||||
|
||||
def test_format_field_description_appends_optional_field_doc():
|
||||
"""The formatter composes the startup-only marker with the field's own
|
||||
human-facing description when supplied.
|
||||
|
||||
The original ``Field(description=)`` used to document allowed values
|
||||
(e.g. ``log_level`` listed ``debug/info/warning/error``); registry
|
||||
adoption must not drop that. The composed output keeps the marker as
|
||||
the leading token so machine-readable tooling still pivots on it,
|
||||
then appends the prose after a blank line.
|
||||
"""
|
||||
text = format_field_description("log_level", field_doc="Logging level (debug/info/warning/error).")
|
||||
assert text.startswith(STARTUP_ONLY_PREFIX)
|
||||
assert STARTUP_ONLY_FIELDS["log_level"] in text
|
||||
assert "debug/info/warning/error" in text
|
||||
|
||||
|
||||
def test_appconfig_descriptions_retain_original_field_documentation():
|
||||
"""``AppConfig.model_fields[name].description`` for restart-required
|
||||
fields should still carry the original human-facing field doc so IDE
|
||||
hover documents what the field is *and* why a restart is needed."""
|
||||
descriptions = {
|
||||
"log_level": "debug/info/warning/error",
|
||||
"database": "memory, sqlite, or postgres",
|
||||
"sandbox": "Sandbox provider",
|
||||
"run_events": "memory for dev",
|
||||
"checkpointer": "state-persistence checkpointer",
|
||||
"stream_bridge": "Stream bridge",
|
||||
}
|
||||
for field_name, expected_substring in descriptions.items():
|
||||
description = AppConfig.model_fields[field_name].description or ""
|
||||
assert description.startswith(STARTUP_ONLY_PREFIX), f"AppConfig.{field_name} missing startup-only marker"
|
||||
assert expected_substring in description, f"AppConfig.{field_name} description lost original field doc; got {description!r}"
|
||||
|
||||
|
||||
def test_appconfig_schema_marks_registered_fields_with_prefix():
|
||||
"""Every registry entry that corresponds to a top-level AppConfig field
|
||||
must carry the standardized ``startup-only:`` prefix in its Pydantic
|
||||
``Field(description=...)``. This is the contract IDE hover relies on.
|
||||
"""
|
||||
schema_fields = AppConfig.model_fields
|
||||
for field_path in STARTUP_ONLY_FIELDS:
|
||||
if field_path not in schema_fields:
|
||||
# Some entries (e.g. ``channels``) live outside the AppConfig
|
||||
# schema. The registry still owns them, but the schema-prefix
|
||||
# assertion does not apply.
|
||||
continue
|
||||
description = schema_fields[field_path].description or ""
|
||||
assert description.startswith(STARTUP_ONLY_PREFIX), f"AppConfig.{field_path} should have Field(description=) starting with {STARTUP_ONLY_PREFIX!r}, got {description!r}"
|
||||
|
||||
|
||||
def test_no_appconfig_field_uses_prefix_without_registration():
|
||||
"""Reverse drift check: if a future schema edit adds the
|
||||
``startup-only:`` prefix to a new field, the registry must list it.
|
||||
|
||||
This catches the silent-drift case where someone marks a field
|
||||
restart-required in the schema but forgets to update the registry
|
||||
that the operator-facing scanners and docs consume.
|
||||
"""
|
||||
for name, info in AppConfig.model_fields.items():
|
||||
description = info.description or ""
|
||||
if not description.startswith(STARTUP_ONLY_PREFIX):
|
||||
continue
|
||||
assert name in STARTUP_ONLY_FIELDS, f"AppConfig.{name} schema description starts with {STARTUP_ONLY_PREFIX!r} but the field is not listed in reload_boundary.STARTUP_ONLY_FIELDS — update the registry."
|
||||
|
||||
|
||||
def test_pydantic_field_descriptions_are_introspectable_at_runtime():
|
||||
"""``AppConfig.model_fields[name].description`` is the IDE-hover source.
|
||||
|
||||
If this read ever breaks (e.g. Pydantic deprecation, schema swap), the
|
||||
IDE-hover guarantee #3144 promises silently regresses. Pin it.
|
||||
"""
|
||||
assert "database" in AppConfig.model_fields
|
||||
description = AppConfig.model_fields["database"].description
|
||||
assert description is not None
|
||||
assert description.startswith(STARTUP_ONLY_PREFIX)
|
||||
Loading…
x
Reference in New Issue
Block a user