fix(middleware): offload memory injection off event loop to prevent tiktoken blocking (#3402) (#3411)

* fix(middleware): offload memory injection off event loop to prevent tiktoken blocking (#3402)

  DynamicContextMiddleware.abefore_agent() called _inject() synchronously
  on the asyncio event loop.  The first time memory is injected (second
  request), _inject() → format_memory_for_injection() → _count_tokens()
  → tiktoken.get_encoding("cl100k_base") needs to download the BPE data
  from openaipublic.blob.core.windows.net.  In network-restricted
  environments this download blocks until the OS TCP timeout (~26 min),
  starving ALL concurrent handlers including /api/v1/auth/me.

  Fix:
  - abefore_agent now uses asyncio.to_thread(self._inject, state) so
    file I/O and tiktoken never block the event loop.
  - Extract _get_tiktoken_encoding() with a module-level cache so
    tiktoken.get_encoding() is called at most once per encoding name.
  - Add warm_tiktoken_cache() startup helper; gateway lifespan pre-warms
    the cache via asyncio.to_thread so the first request never triggers a
    cold download.
  - _count_tokens falls back to len(text) // 4 on any encoding failure.

  Tests:
  - tests/test_tiktoken_cache_and_count_tokens.py (12 tests): cache
    hit/miss, fallback paths, warm-up helper.
  - tests/blocking_io/test_dynamic_context_middleware.py (2 tests):
    Blockbuster gate verifies abefore_agent does not block the event
    loop; async/sync parity check.

  Fixes #3402

* Apply suggestions from code review

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* fix the lint error

* fix(memory): use future annotations to avoid NameError when tiktoken is absent

Add `from __future__ import annotations` to prompt.py so that
tiktoken.Encoding type hints are never evaluated at runtime.  Without
this, environments where tiktoken is not installed could raise NameError
on the module-level cache and function return annotations.

Addresses Copilot review comment on PR #3411.

* fix(middleware): bound abefore_agent injection with timeout to prevent hung requests

Wrap the asyncio.to_thread(self._inject) offload in asyncio.wait_for()
with a 5-second cap.  If the startup warm-up failed silently (e.g.
network blip during deploy), a cold tiktoken BPE download on the first
request can block until the OS TCP timeout (~26 min).  The bounded
timeout ensures the request degrades gracefully (no memory/date context
for that turn) rather than hanging.

Adds test_abefore_agent_returns_none_on_timeout to the blocking-IO
regression anchors.

Addresses review feedback from xg-gh-25 on PR #3411.

---------

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Willem Jiang 2026-06-08 12:21:55 +08:00 committed by GitHub
parent 40a371b88c
commit 519200728a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 372 additions and 3 deletions

View File

@ -179,6 +179,25 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
config = get_gateway_config()
logger.info(f"Starting API Gateway on {config.host}:{config.port}")
# Pre-warm tiktoken encoding cache so the first memory-injection request
# never blocks on the BPE data download (which hits an OpenAI/Azure URL
# that may be unreachable in restricted networks — see issue #3402).
try:
from deerflow.agents.memory.prompt import warm_tiktoken_cache
warmed = await asyncio.wait_for(
asyncio.to_thread(warm_tiktoken_cache),
timeout=5,
)
if warmed:
logger.info("tiktoken encoding cache warmed successfully")
else:
logger.warning("tiktoken encoding cache warm-up failed; token counting will use character-based fallback")
except TimeoutError:
logger.warning("tiktoken encoding cache warm-up timed out; token counting will use character-based fallback")
except Exception:
logger.warning("tiktoken warm-up skipped", exc_info=True)
# Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store)
async with langgraph_runtime(app, startup_config):
logger.info("LangGraph runtime initialised")

View File

@ -1,9 +1,14 @@
"""Prompt templates for memory update and injection."""
from __future__ import annotations
import logging
import math
import re
from typing import Any
logger = logging.getLogger(__name__)
try:
import tiktoken
@ -160,6 +165,39 @@ Rules:
Return ONLY valid JSON."""
# Module-level tiktoken encoding cache. Populated lazily on first use;
# subsequent calls are a dict lookup (no network I/O). Pre-warming at
# startup via :func:`warm_tiktoken_cache` avoids blocking a request on the
# (potentially slow) first ``get_encoding`` call.
_tiktoken_encoding_cache: dict[str, tiktoken.Encoding] = {}
def _get_tiktoken_encoding(encoding_name: str = "cl100k_base") -> tiktoken.Encoding | None:
"""Return a cached tiktoken encoding, or ``None`` on failure / unavailability.
On the very first call for a given *encoding_name*, tiktoken may need to
download the BPE data from ``openaipublic.blob.core.windows.net``. In
network-restricted environments (e.g. deployments behind the GFW) this
download can block for tens of minutes before the OS TCP timeout kicks in.
The caller must therefore be prepared for this to block and should run it
off the event loop (e.g. via ``asyncio.to_thread``).
"""
if not TIKTOKEN_AVAILABLE:
return None
cached = _tiktoken_encoding_cache.get(encoding_name)
if cached is not None:
return cached
try:
encoding = tiktoken.get_encoding(encoding_name)
_tiktoken_encoding_cache[encoding_name] = encoding
return encoding
except Exception:
logger.warning("Failed to load tiktoken encoding %r; falling back to char-based estimation", encoding_name, exc_info=True)
return None
def _count_tokens(text: str, encoding_name: str = "cl100k_base") -> int:
"""Count tokens in text using tiktoken.
@ -170,18 +208,30 @@ def _count_tokens(text: str, encoding_name: str = "cl100k_base") -> int:
Returns:
The number of tokens in the text.
"""
if not TIKTOKEN_AVAILABLE:
encoding = _get_tiktoken_encoding(encoding_name)
if encoding is None:
# Fallback to character-based estimation if tiktoken is not available
# or the encoding failed to load.
return len(text) // 4
try:
encoding = tiktoken.get_encoding(encoding_name)
return len(encoding.encode(text))
except Exception:
# Fallback to character-based estimation on error
return len(text) // 4
def warm_tiktoken_cache() -> bool:
"""Pre-warm the tiktoken encoding cache.
Call at startup (off the event loop) so the first request never blocks
on the BPE download. Returns ``True`` if the encoding was loaded
successfully (or was already cached), ``False`` if tiktoken is
unavailable or the download failed.
"""
return _get_tiktoken_encoding("cl100k_base") is not None
def _coerce_confidence(value: Any, default: float = 0.0) -> float:
"""Coerce a confidence-like value to a bounded float in [0, 1].

View File

@ -28,6 +28,7 @@ Date-update format:
from __future__ import annotations
import asyncio
import logging
import re
import uuid
@ -43,6 +44,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Upper bound (seconds) for a single _inject() offload. If the warm-up at
# gateway startup failed silently, the first request may still hit a cold
# tiktoken BPE download that blocks until the OS TCP timeout (~26 min).
# This cap ensures the request degrades gracefully instead of hanging.
_INJECT_TIMEOUT_SECONDS = 5.0
_DATE_RE = re.compile(r"<current_date>([^<]+)</current_date>")
_DYNAMIC_CONTEXT_REMINDER_KEY = "dynamic_context_reminder"
_SUMMARY_MESSAGE_NAME = "summary"
@ -201,4 +208,25 @@ class DynamicContextMiddleware(AgentMiddleware):
@override
async def abefore_agent(self, state, runtime: Runtime) -> dict | None:
return self._inject(state)
# _inject() performs synchronous file I/O (memory JSON loading) and
# potentially blocking network calls (tiktoken encoding download on
# first use). Offload to a thread so the event loop is never blocked
# — a blocking call here starves all concurrent HTTP handlers (auth,
# SSE heartbeats, etc.). See issue #3402.
#
# Bounded timeout: if startup warm-up failed silently (e.g. network
# blip during deploy), the first request's cold tiktoken download can
# block for tens of minutes (OS TCP timeout). Time-box injection so
# the request degrades gracefully (no memory context) rather than
# hanging.
try:
return await asyncio.wait_for(
asyncio.to_thread(self._inject, state),
timeout=_INJECT_TIMEOUT_SECONDS,
)
except TimeoutError:
logger.warning(
"DynamicContextMiddleware: injection timed out (%.1fs); skipping memory/date injection for this turn",
_INJECT_TIMEOUT_SECONDS,
)
return None

View File

@ -0,0 +1,124 @@
"""Regression anchor: DynamicContextMiddleware must not block the event loop.
``_inject`` performs synchronous file I/O (memory JSON loading) and
potentially blocking network calls (tiktoken encoding download on first
use see issue #3402). ``abefore_agent`` offloads the call via
``asyncio.to_thread`` so the event loop stays responsive.
This anchor drives the real ``create_agent`` graph via ``ainvoke`` under
the strict Blockbuster gate. If the offload regresses and the blocking
I/O runs on the event loop, Blockbuster raises ``BlockingError`` and
this test fails.
"""
from __future__ import annotations
import asyncio
from types import SimpleNamespace
from unittest import mock
import pytest
from langchain.agents import create_agent
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
from langchain_core.messages import AIMessage, HumanMessage
from deerflow.agents.middlewares.dynamic_context_middleware import DynamicContextMiddleware
pytestmark = pytest.mark.asyncio
class _FakeModel(FakeMessagesListChatModel):
"""FakeMessagesListChatModel with a no-op ``bind_tools`` for create_agent."""
def bind_tools(self, tools, **kwargs): # type: ignore[override]
return self
async def test_abefore_agent_does_not_block_event_loop() -> None:
"""``abefore_agent`` must offload _inject() to a thread pool."""
mw = DynamicContextMiddleware()
# Mock _build_full_reminder to simulate a slow synchronous operation
# (file I/O + tiktoken download). The mock sleeps briefly to make any
# event-loop blocking visible to the Blockbuster gate.
original_build = mw._build_full_reminder
def slow_build_reminder():
import time
time.sleep(0.05) # 50ms sync sleep — blocks the thread it runs on
return original_build()
with (
mock.patch.object(mw, "_build_full_reminder", slow_build_reminder),
mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""),
):
agent = await asyncio.to_thread(
lambda: create_agent(
model=_FakeModel(responses=[AIMessage(content="ok")]),
tools=[],
middleware=[mw],
)
)
result = await agent.ainvoke(
{"messages": [HumanMessage(content="hi")]},
{"configurable": {"thread_id": "test-thread"}},
)
assert result["messages"]
async def test_abefore_agent_returns_same_result_as_before_agent() -> None:
"""``abefore_agent`` (async, offloaded) must produce the same result as
``before_agent`` (sync, for backward compatibility)."""
mw = DynamicContextMiddleware()
state = {"messages": [HumanMessage(content="Hello", id="msg-1")]}
runtime = SimpleNamespace(context={})
with (
mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""),
mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt,
):
mock_dt.now.return_value.strftime.return_value = "2026-06-05, Friday"
# Sync path
sync_result = mw.before_agent(state, runtime)
# Async path (offloaded to thread)
async_result = await mw.abefore_agent(state, runtime)
assert sync_result is not None
assert async_result is not None
assert sync_result.keys() == async_result.keys()
# Both return 2 messages: reminder + user content
assert len(sync_result["messages"]) == 2
assert len(async_result["messages"]) == 2
# IDs match
assert sync_result["messages"][0].id == async_result["messages"][0].id
assert sync_result["messages"][1].id == async_result["messages"][1].id
async def test_abefore_agent_returns_none_on_timeout() -> None:
"""If _inject() exceeds the timeout, abefore_agent returns None gracefully."""
import time
mw = DynamicContextMiddleware()
def blocking_inject(state):
time.sleep(10) # Simulate a blocking call that far exceeds the timeout
return {"messages": [HumanMessage(content="should not reach")]}
with (
mock.patch.object(mw, "_inject", blocking_inject),
mock.patch(
"deerflow.agents.middlewares.dynamic_context_middleware._INJECT_TIMEOUT_SECONDS",
0.1,
),
):
state = {"messages": [HumanMessage(content="Hello", id="msg-1")]}
runtime = SimpleNamespace(context={})
result = await mw.abefore_agent(state, runtime)
assert result is None

View File

@ -0,0 +1,148 @@
"""Tests for tiktoken encoding cache and _count_tokens fallback.
Verifies:
- Module-level cache avoids repeated ``get_encoding`` calls.
- ``_count_tokens`` falls back to character estimation when tiktoken is
unavailable or the encoding fails to load.
- ``warm_tiktoken_cache`` populates the cache on success.
"""
from __future__ import annotations
from unittest import mock
from deerflow.agents.memory.prompt import (
_count_tokens,
_get_tiktoken_encoding,
_tiktoken_encoding_cache,
warm_tiktoken_cache,
)
# ---------------------------------------------------------------------------
# _get_tiktoken_encoding
# ---------------------------------------------------------------------------
class TestGetTiktokenEncoding:
"""Tests for _get_tiktoken_encoding caching and fallback."""
def test_returns_none_when_tiktoken_unavailable(self, monkeypatch):
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
assert _get_tiktoken_encoding("cl100k_base") is None
def test_returns_encoding_on_success(self, monkeypatch):
# Clear cache to ensure a fresh call
_tiktoken_encoding_cache.pop("cl100k_base", None)
fake_enc = mock.Mock()
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", mock.Mock(return_value=fake_enc))
enc = _get_tiktoken_encoding("cl100k_base")
assert enc is fake_enc
def test_populates_cache_on_success(self, monkeypatch):
_tiktoken_encoding_cache.pop("cl100k_base", None)
fake_enc = mock.Mock()
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", mock.Mock(return_value=fake_enc))
_get_tiktoken_encoding("cl100k_base")
assert _tiktoken_encoding_cache["cl100k_base"] is fake_enc
def test_returns_cached_encoding_without_calling_get_encoding(self, monkeypatch):
fake_enc = mock.Mock()
monkeypatch.setitem(_tiktoken_encoding_cache, "cl100k_base", fake_enc)
# Now patch tiktoken.get_encoding to raise if called
import tiktoken
monkeypatch.setattr(tiktoken, "get_encoding", mock.Mock(side_effect=RuntimeError("should not be called")))
# Cached path — should NOT call get_encoding
enc = _get_tiktoken_encoding("cl100k_base")
assert enc is fake_enc
tiktoken.get_encoding.assert_not_called()
def test_returns_none_and_warns_on_get_encoding_failure(self, monkeypatch):
_tiktoken_encoding_cache.pop("bogus_encoding", None)
import tiktoken
monkeypatch.setattr(tiktoken, "get_encoding", mock.Mock(side_effect=OSError("download failed")))
result = _get_tiktoken_encoding("bogus_encoding")
assert result is None
assert "bogus_encoding" not in _tiktoken_encoding_cache
# ---------------------------------------------------------------------------
# _count_tokens
# ---------------------------------------------------------------------------
class TestCountTokens:
"""Tests for _count_tokens fallback behaviour."""
def test_returns_character_estimate_when_tiktoken_unavailable(self, monkeypatch):
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
text = "Hello, world! This is a test."
result = _count_tokens(text)
assert result == len(text) // 4
def test_returns_character_estimate_when_encoding_fails(self, monkeypatch):
monkeypatch.setattr(
"deerflow.agents.memory.prompt._get_tiktoken_encoding",
lambda _name=None: None,
)
text = "Some text to count"
result = _count_tokens(text)
assert result == len(text) // 4
def test_returns_token_count_on_success(self, monkeypatch):
fake_enc = mock.Mock()
fake_enc.encode.return_value = [0, 1, 2, 3]
monkeypatch.setattr("deerflow.agents.memory.prompt._get_tiktoken_encoding", mock.Mock(return_value=fake_enc))
text = "Hello, world!"
result = _count_tokens(text)
assert result == 4
assert result <= len(text)
def test_falls_back_on_encode_exception(self, monkeypatch):
# Cache an encoding whose .encode raises
fake_enc = mock.Mock()
fake_enc.encode.side_effect = RuntimeError("encode failed")
monkeypatch.setitem(_tiktoken_encoding_cache, "test_enc", fake_enc)
text = "Fallback test"
result = _count_tokens(text, encoding_name="test_enc")
assert result == len(text) // 4
# ---------------------------------------------------------------------------
# warm_tiktoken_cache
# ---------------------------------------------------------------------------
class TestWarmTiktokenCache:
"""Tests for warm_tiktoken_cache startup helper."""
def test_returns_true_on_success(self, monkeypatch):
_tiktoken_encoding_cache.pop("cl100k_base", None)
fake_enc = mock.Mock()
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", mock.Mock(return_value=fake_enc))
assert warm_tiktoken_cache() is True
assert _tiktoken_encoding_cache["cl100k_base"] is fake_enc
def test_returns_true_if_already_cached(self, monkeypatch):
fake_enc = mock.Mock()
monkeypatch.setitem(_tiktoken_encoding_cache, "cl100k_base", fake_enc)
import tiktoken
monkeypatch.setattr(tiktoken, "get_encoding", mock.Mock(side_effect=RuntimeError("should not be called")))
assert warm_tiktoken_cache() is True
tiktoken.get_encoding.assert_not_called()
def test_returns_false_when_tiktoken_unavailable(self, monkeypatch):
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
assert warm_tiktoken_cache() is False