mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-06 00:38:23 +00:00
* fix(llm): filter unexpected config keys to prevent LangChain warnings (#411) Add allowlist validation for LLM configuration keys to prevent unexpected parameters like SEARCH_ENGINE from being passed to LLM constructors. Changes: - Add ALLOWED_LLM_CONFIG_KEYS set with valid LLM configuration parameters - Filter out unexpected keys before creating LLM instances - Log clear warning messages when unexpected keys are removed - Add unit test for configuration key filtering This fixes the confusing LangChain warning "WARNING! SEARCH_ENGINE is not default parameter. SEARCH_ENGINE was transferred to model_kwargs" that occurred when users accidentally placed configuration keys in wrong sections of conf.yaml. * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
342 lines
11 KiB
Python
342 lines
11 KiB
Python
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any, Dict, get_args
|
|
|
|
import httpx
|
|
from langchain_core.language_models import BaseChatModel
|
|
from langchain_deepseek import ChatDeepSeek
|
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
from langchain_openai import AzureChatOpenAI, ChatOpenAI
|
|
|
|
from src.config import load_yaml_config
|
|
from src.config.agents import LLMType
|
|
from src.llms.providers.dashscope import ChatDashscope
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Cache for LLM instances
|
|
_llm_cache: dict[LLMType, BaseChatModel] = {}
|
|
|
|
# Allowed LLM configuration keys to prevent unexpected parameters from being passed
|
|
# to LLM constructors (Issue #411 - SEARCH_ENGINE warning fix)
|
|
ALLOWED_LLM_CONFIG_KEYS = {
|
|
# Common LLM configuration keys
|
|
"model",
|
|
"api_key",
|
|
"base_url",
|
|
"api_base",
|
|
"max_retries",
|
|
"timeout",
|
|
"max_tokens",
|
|
"temperature",
|
|
"top_p",
|
|
"frequency_penalty",
|
|
"presence_penalty",
|
|
"stop",
|
|
"n",
|
|
"stream",
|
|
"logprobs",
|
|
"echo",
|
|
"best_of",
|
|
"logit_bias",
|
|
"user",
|
|
"seed",
|
|
# SSL and HTTP client settings
|
|
"verify_ssl",
|
|
"http_client",
|
|
"http_async_client",
|
|
# Platform-specific keys
|
|
"platform",
|
|
"google_api_key",
|
|
# Azure-specific keys
|
|
"azure_endpoint",
|
|
"azure_deployment",
|
|
"api_version",
|
|
"azure_ad_token",
|
|
"azure_ad_token_provider",
|
|
# Dashscope/Doubao specific keys
|
|
"extra_body",
|
|
# Token limit for context compression (removed before passing to LLM)
|
|
"token_limit",
|
|
# Default headers
|
|
"default_headers",
|
|
"default_query",
|
|
}
|
|
|
|
|
|
def _get_config_file_path() -> str:
|
|
"""Get the path to the configuration file."""
|
|
return str((Path(__file__).parent.parent.parent / "conf.yaml").resolve())
|
|
|
|
|
|
def _get_llm_type_config_keys() -> dict[str, str]:
|
|
"""Get mapping of LLM types to their configuration keys."""
|
|
return {
|
|
"reasoning": "REASONING_MODEL",
|
|
"basic": "BASIC_MODEL",
|
|
"vision": "VISION_MODEL",
|
|
"code": "CODE_MODEL",
|
|
}
|
|
|
|
|
|
def _get_env_llm_conf(llm_type: str) -> Dict[str, Any]:
|
|
"""
|
|
Get LLM configuration from environment variables.
|
|
Environment variables should follow the format: {LLM_TYPE}__{KEY}
|
|
e.g., BASIC_MODEL__api_key, BASIC_MODEL__base_url
|
|
"""
|
|
prefix = f"{llm_type.upper()}_MODEL__"
|
|
conf = {}
|
|
for key, value in os.environ.items():
|
|
if key.startswith(prefix):
|
|
conf_key = key[len(prefix) :].lower()
|
|
conf[conf_key] = value
|
|
return conf
|
|
|
|
|
|
def _create_llm_use_conf(llm_type: LLMType, conf: Dict[str, Any]) -> BaseChatModel:
|
|
"""Create LLM instance using configuration."""
|
|
llm_type_config_keys = _get_llm_type_config_keys()
|
|
config_key = llm_type_config_keys.get(llm_type)
|
|
|
|
if not config_key:
|
|
raise ValueError(f"Unknown LLM type: {llm_type}")
|
|
|
|
llm_conf = conf.get(config_key, {})
|
|
if not isinstance(llm_conf, dict):
|
|
raise ValueError(f"Invalid LLM configuration for {llm_type}: {llm_conf}")
|
|
|
|
# Get configuration from environment variables
|
|
env_conf = _get_env_llm_conf(llm_type)
|
|
|
|
# Merge configurations, with environment variables taking precedence
|
|
merged_conf = {**llm_conf, **env_conf}
|
|
|
|
# Filter out unexpected parameters to prevent LangChain warnings (Issue #411)
|
|
# This prevents configuration keys like SEARCH_ENGINE from being passed to LLM constructors
|
|
allowed_keys_lower = {k.lower() for k in ALLOWED_LLM_CONFIG_KEYS}
|
|
unexpected_keys = [key for key in merged_conf.keys() if key.lower() not in allowed_keys_lower]
|
|
for key in unexpected_keys:
|
|
removed_value = merged_conf.pop(key)
|
|
logger.warning(
|
|
f"Removed unexpected LLM configuration key '{key}'. "
|
|
f"This key is not a valid LLM parameter and may have been placed in the wrong section of conf.yaml. "
|
|
f"Valid LLM config keys include: model, api_key, base_url, max_retries, temperature, etc."
|
|
)
|
|
|
|
# Remove unnecessary parameters when initializing the client
|
|
if "token_limit" in merged_conf:
|
|
merged_conf.pop("token_limit")
|
|
|
|
if not merged_conf:
|
|
raise ValueError(f"No configuration found for LLM type: {llm_type}")
|
|
|
|
# Add max_retries to handle rate limit errors
|
|
if "max_retries" not in merged_conf:
|
|
merged_conf["max_retries"] = 3
|
|
|
|
# Handle SSL verification settings
|
|
verify_ssl = merged_conf.pop("verify_ssl", True)
|
|
|
|
# Create custom HTTP client if SSL verification is disabled
|
|
if not verify_ssl:
|
|
http_client = httpx.Client(verify=False)
|
|
http_async_client = httpx.AsyncClient(verify=False)
|
|
merged_conf["http_client"] = http_client
|
|
merged_conf["http_async_client"] = http_async_client
|
|
|
|
# Check if it's Google AI Studio platform based on configuration
|
|
platform = merged_conf.get("platform", "").lower()
|
|
is_google_aistudio = platform == "google_aistudio" or platform == "google-aistudio"
|
|
|
|
if is_google_aistudio:
|
|
# Handle Google AI Studio specific configuration
|
|
gemini_conf = merged_conf.copy()
|
|
|
|
# Map common keys to Google AI Studio specific keys
|
|
if "api_key" in gemini_conf:
|
|
gemini_conf["google_api_key"] = gemini_conf.pop("api_key")
|
|
|
|
# Remove base_url and platform since Google AI Studio doesn't use them
|
|
gemini_conf.pop("base_url", None)
|
|
gemini_conf.pop("platform", None)
|
|
|
|
# Remove unsupported parameters for Google AI Studio
|
|
gemini_conf.pop("http_client", None)
|
|
gemini_conf.pop("http_async_client", None)
|
|
|
|
return ChatGoogleGenerativeAI(**gemini_conf)
|
|
|
|
if "azure_endpoint" in merged_conf or os.getenv("AZURE_OPENAI_ENDPOINT"):
|
|
return AzureChatOpenAI(**merged_conf)
|
|
|
|
# Check if base_url is dashscope endpoint
|
|
if "base_url" in merged_conf and "dashscope." in merged_conf["base_url"]:
|
|
if llm_type == "reasoning":
|
|
merged_conf["extra_body"] = {"enable_thinking": True}
|
|
else:
|
|
merged_conf["extra_body"] = {"enable_thinking": False}
|
|
return ChatDashscope(**merged_conf)
|
|
|
|
if llm_type == "reasoning":
|
|
merged_conf["api_base"] = merged_conf.pop("base_url", None)
|
|
return ChatDeepSeek(**merged_conf)
|
|
else:
|
|
return ChatOpenAI(**merged_conf)
|
|
|
|
|
|
def get_llm_by_type(llm_type: LLMType) -> BaseChatModel:
|
|
"""
|
|
Get LLM instance by type. Returns cached instance if available.
|
|
"""
|
|
if llm_type in _llm_cache:
|
|
return _llm_cache[llm_type]
|
|
|
|
conf = load_yaml_config(_get_config_file_path())
|
|
llm = _create_llm_use_conf(llm_type, conf)
|
|
_llm_cache[llm_type] = llm
|
|
return llm
|
|
|
|
|
|
def get_configured_llm_models() -> dict[str, list[str]]:
|
|
"""
|
|
Get all configured LLM models grouped by type.
|
|
|
|
Returns:
|
|
Dictionary mapping LLM type to list of configured model names.
|
|
"""
|
|
try:
|
|
conf = load_yaml_config(_get_config_file_path())
|
|
llm_type_config_keys = _get_llm_type_config_keys()
|
|
|
|
configured_models: dict[str, list[str]] = {}
|
|
|
|
for llm_type in get_args(LLMType):
|
|
# Get configuration from YAML file
|
|
config_key = llm_type_config_keys.get(llm_type, "")
|
|
yaml_conf = conf.get(config_key, {}) if config_key else {}
|
|
|
|
# Get configuration from environment variables
|
|
env_conf = _get_env_llm_conf(llm_type)
|
|
|
|
# Merge configurations, with environment variables taking precedence
|
|
merged_conf = {**yaml_conf, **env_conf}
|
|
|
|
# Check if model is configured
|
|
model_name = merged_conf.get("model")
|
|
if model_name:
|
|
configured_models.setdefault(llm_type, []).append(model_name)
|
|
|
|
return configured_models
|
|
|
|
except Exception as e:
|
|
# Log error and return empty dict to avoid breaking the application
|
|
print(f"Warning: Failed to load LLM configuration: {e}")
|
|
return {}
|
|
|
|
|
|
def _get_model_token_limit_defaults() -> dict[str, int]:
|
|
"""
|
|
Get default token limits for common LLM models.
|
|
These are conservative limits to prevent token overflow errors (Issue #721).
|
|
Users can override by setting token_limit in their config.
|
|
"""
|
|
return {
|
|
# OpenAI models
|
|
"gpt-4o": 120000,
|
|
"gpt-4-turbo": 120000,
|
|
"gpt-4": 8000,
|
|
"gpt-3.5-turbo": 4000,
|
|
# Anthropic Claude
|
|
"claude-3": 180000,
|
|
"claude-2": 100000,
|
|
# Google Gemini
|
|
"gemini-2": 180000,
|
|
"gemini-1.5-pro": 180000,
|
|
"gemini-1.5-flash": 180000,
|
|
"gemini-pro": 30000,
|
|
# Bytedance Doubao
|
|
"doubao": 200000,
|
|
# DeepSeek
|
|
"deepseek": 100000,
|
|
# Ollama/local
|
|
"qwen": 30000,
|
|
"llama": 4000,
|
|
# Default fallback for unknown models
|
|
"default": 100000,
|
|
}
|
|
|
|
|
|
def _infer_token_limit_from_model(model_name: str) -> int:
|
|
"""
|
|
Infer a reasonable token limit from the model name.
|
|
This helps protect against token overflow errors when token_limit is not explicitly configured.
|
|
|
|
Args:
|
|
model_name: The model name from configuration
|
|
|
|
Returns:
|
|
A conservative token limit based on known model capabilities
|
|
"""
|
|
if not model_name:
|
|
return 100000 # Safe default
|
|
|
|
model_name_lower = model_name.lower()
|
|
defaults = _get_model_token_limit_defaults()
|
|
|
|
# Try exact or prefix matches
|
|
for key, limit in defaults.items():
|
|
if key in model_name_lower:
|
|
return limit
|
|
|
|
# Return safe default if no match found
|
|
return defaults["default"]
|
|
|
|
|
|
def get_llm_token_limit_by_type(llm_type: str) -> int:
|
|
"""
|
|
Get the maximum token limit for a given LLM type.
|
|
|
|
Priority order:
|
|
1. Explicitly configured token_limit in conf.yaml
|
|
2. Inferred from model name based on known model capabilities
|
|
3. Safe default (100,000 tokens)
|
|
|
|
This helps prevent token overflow errors (Issue #721) even when token_limit is not configured.
|
|
|
|
Args:
|
|
llm_type (str): The type of LLM (e.g., 'basic', 'reasoning', 'vision', 'code').
|
|
|
|
Returns:
|
|
int: The maximum token limit for the specified LLM type (conservative estimate).
|
|
"""
|
|
llm_type_config_keys = _get_llm_type_config_keys()
|
|
config_key = llm_type_config_keys.get(llm_type)
|
|
|
|
conf = load_yaml_config(_get_config_file_path())
|
|
model_config = conf.get(config_key, {})
|
|
|
|
# First priority: explicitly configured token_limit
|
|
if "token_limit" in model_config:
|
|
configured_limit = model_config["token_limit"]
|
|
if configured_limit is not None:
|
|
return configured_limit
|
|
|
|
# Second priority: infer from model name
|
|
model_name = model_config.get("model")
|
|
if model_name:
|
|
inferred_limit = _infer_token_limit_from_model(model_name)
|
|
return inferred_limit
|
|
|
|
# Fallback: safe default
|
|
return _get_model_token_limit_defaults()["default"]
|
|
|
|
|
|
# In the future, we will use reasoning_llm and vl_llm for different purposes
|
|
# reasoning_llm = get_llm_by_type("reasoning")
|
|
# vl_llm = get_llm_by_type("vision")
|