mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
fix(token-usage): enable stream usage for openai-compatible models (#2217)
* fix(token-usage): enable stream usage for openai-compatible models * fix(token-usage): narrow stream_usage default to ChatOpenAI
This commit is contained in:
parent
05f1da03e5
commit
c99865f53d
@ -30,6 +30,22 @@ def _vllm_disable_chat_template_kwargs(chat_template_kwargs: dict) -> dict:
|
||||
return disable_kwargs
|
||||
|
||||
|
||||
def _enable_stream_usage_by_default(model_use_path: str, model_settings_from_config: dict) -> None:
|
||||
"""Enable stream usage for OpenAI-compatible models unless explicitly configured.
|
||||
|
||||
LangChain only auto-enables ``stream_usage`` for OpenAI models when no custom
|
||||
base URL or client is configured. DeerFlow frequently uses OpenAI-compatible
|
||||
gateways, so token usage tracking would otherwise stay empty and the
|
||||
TokenUsageMiddleware would have nothing to log.
|
||||
"""
|
||||
if model_use_path != "langchain_openai:ChatOpenAI":
|
||||
return
|
||||
if "stream_usage" in model_settings_from_config:
|
||||
return
|
||||
if "base_url" in model_settings_from_config or "openai_api_base" in model_settings_from_config:
|
||||
model_settings_from_config["stream_usage"] = True
|
||||
|
||||
|
||||
def create_chat_model(name: str | None = None, thinking_enabled: bool = False, **kwargs) -> BaseChatModel:
|
||||
"""Create a chat model instance from the config.
|
||||
|
||||
@ -97,6 +113,8 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
|
||||
kwargs.pop("reasoning_effort", None)
|
||||
model_settings_from_config.pop("reasoning_effort", None)
|
||||
|
||||
_enable_stream_usage_by_default(model_config.use, model_settings_from_config)
|
||||
|
||||
# For Codex Responses API models: map thinking mode to reasoning_effort
|
||||
from deerflow.models.openai_codex_provider import CodexChatModel
|
||||
|
||||
|
||||
@ -597,6 +597,99 @@ def test_openai_compatible_provider_passes_base_url(monkeypatch):
|
||||
assert captured.get("api_key") == "test-key"
|
||||
assert captured.get("temperature") == 1.0
|
||||
assert captured.get("max_tokens") == 4096
|
||||
assert captured.get("stream_usage") is True
|
||||
|
||||
|
||||
def test_openai_compatible_provider_respects_explicit_stream_usage(monkeypatch):
|
||||
"""Explicit stream_usage should not be overwritten by the factory default."""
|
||||
model = ModelConfig(
|
||||
name="minimax-m2.5",
|
||||
display_name="MiniMax M2.5",
|
||||
description=None,
|
||||
use="langchain_openai:ChatOpenAI",
|
||||
model="MiniMax-M2.5",
|
||||
base_url="https://api.minimax.io/v1",
|
||||
api_key="test-key",
|
||||
stream_usage=False,
|
||||
supports_vision=True,
|
||||
supports_thinking=False,
|
||||
)
|
||||
cfg = _make_app_config([model])
|
||||
_patch_factory(monkeypatch, cfg)
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
class CapturingModel(FakeChatModel):
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
BaseChatModel.__init__(self, **kwargs)
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="minimax-m2.5")
|
||||
|
||||
assert captured.get("stream_usage") is False
|
||||
|
||||
|
||||
def test_openai_compatible_provider_enables_stream_usage_for_openai_api_base(monkeypatch):
|
||||
"""openai_api_base should trigger stream_usage default for ChatOpenAI."""
|
||||
model = ModelConfig(
|
||||
name="openai-compatible",
|
||||
display_name="OpenAI-Compatible",
|
||||
description=None,
|
||||
use="langchain_openai:ChatOpenAI",
|
||||
model="example-model",
|
||||
openai_api_base="https://example.com/v1",
|
||||
api_key="test-key",
|
||||
supports_vision=False,
|
||||
supports_thinking=False,
|
||||
)
|
||||
cfg = _make_app_config([model])
|
||||
_patch_factory(monkeypatch, cfg)
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
class CapturingModel(FakeChatModel):
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
BaseChatModel.__init__(self, **kwargs)
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="openai-compatible")
|
||||
|
||||
assert captured.get("openai_api_base") == "https://example.com/v1"
|
||||
assert captured.get("stream_usage") is True
|
||||
|
||||
|
||||
def test_non_openai_provider_does_not_receive_stream_usage_default(monkeypatch):
|
||||
"""Non-OpenAI providers with base_url should not receive stream_usage by default."""
|
||||
model = ModelConfig(
|
||||
name="ollama-local",
|
||||
display_name="Ollama Local",
|
||||
description=None,
|
||||
use="langchain_ollama:ChatOllama",
|
||||
model="qwen2.5",
|
||||
base_url="http://127.0.0.1:11434",
|
||||
supports_vision=False,
|
||||
supports_thinking=False,
|
||||
)
|
||||
cfg = _make_app_config([model])
|
||||
_patch_factory(monkeypatch, cfg)
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
class CapturingModel(FakeChatModel):
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
BaseChatModel.__init__(self, **kwargs)
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="ollama-local")
|
||||
|
||||
assert captured.get("base_url") == "http://127.0.0.1:11434"
|
||||
assert "stream_usage" not in captured
|
||||
|
||||
|
||||
def test_openai_compatible_provider_multiple_models(monkeypatch):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user