# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT import logging import os from pathlib import Path from typing import Any, Dict, get_args import httpx from langchain_core.language_models import BaseChatModel from langchain_deepseek import ChatDeepSeek from langchain_google_genai import ChatGoogleGenerativeAI from langchain_openai import AzureChatOpenAI, ChatOpenAI from src.config import load_yaml_config from src.config.agents import LLMType from src.llms.providers.dashscope import ChatDashscope logger = logging.getLogger(__name__) # Cache for LLM instances _llm_cache: dict[LLMType, BaseChatModel] = {} # Allowed LLM configuration keys to prevent unexpected parameters from being passed # to LLM constructors (Issue #411 - SEARCH_ENGINE warning fix) ALLOWED_LLM_CONFIG_KEYS = { # Common LLM configuration keys "model", "api_key", "base_url", "api_base", "max_retries", "timeout", "max_tokens", "temperature", "top_p", "frequency_penalty", "presence_penalty", "stop", "n", "stream", "logprobs", "echo", "best_of", "logit_bias", "user", "seed", # SSL and HTTP client settings "verify_ssl", "http_client", "http_async_client", # Platform-specific keys "platform", "google_api_key", # Azure-specific keys "azure_endpoint", "azure_deployment", "api_version", "azure_ad_token", "azure_ad_token_provider", # Dashscope/Doubao specific keys "extra_body", # Token limit for context compression (removed before passing to LLM) "token_limit", # Default headers "default_headers", "default_query", } def _get_config_file_path() -> str: """Get the path to the configuration file.""" return str((Path(__file__).parent.parent.parent / "conf.yaml").resolve()) def _get_llm_type_config_keys() -> dict[str, str]: """Get mapping of LLM types to their configuration keys.""" return { "reasoning": "REASONING_MODEL", "basic": "BASIC_MODEL", "vision": "VISION_MODEL", "code": "CODE_MODEL", } def _get_env_llm_conf(llm_type: str) -> Dict[str, Any]: """ Get LLM configuration from environment variables. Environment variables should follow the format: {LLM_TYPE}__{KEY} e.g., BASIC_MODEL__api_key, BASIC_MODEL__base_url """ prefix = f"{llm_type.upper()}_MODEL__" conf = {} for key, value in os.environ.items(): if key.startswith(prefix): conf_key = key[len(prefix) :].lower() conf[conf_key] = value return conf def _create_llm_use_conf(llm_type: LLMType, conf: Dict[str, Any]) -> BaseChatModel: """Create LLM instance using configuration.""" llm_type_config_keys = _get_llm_type_config_keys() config_key = llm_type_config_keys.get(llm_type) if not config_key: raise ValueError(f"Unknown LLM type: {llm_type}") llm_conf = conf.get(config_key, {}) if not isinstance(llm_conf, dict): raise ValueError(f"Invalid LLM configuration for {llm_type}: {llm_conf}") # Get configuration from environment variables env_conf = _get_env_llm_conf(llm_type) # Merge configurations, with environment variables taking precedence merged_conf = {**llm_conf, **env_conf} # Filter out unexpected parameters to prevent LangChain warnings (Issue #411) # This prevents configuration keys like SEARCH_ENGINE from being passed to LLM constructors allowed_keys_lower = {k.lower() for k in ALLOWED_LLM_CONFIG_KEYS} unexpected_keys = [key for key in merged_conf.keys() if key.lower() not in allowed_keys_lower] for key in unexpected_keys: removed_value = merged_conf.pop(key) logger.warning( f"Removed unexpected LLM configuration key '{key}'. " f"This key is not a valid LLM parameter and may have been placed in the wrong section of conf.yaml. " f"Valid LLM config keys include: model, api_key, base_url, max_retries, temperature, etc." ) # Remove unnecessary parameters when initializing the client if "token_limit" in merged_conf: merged_conf.pop("token_limit") if not merged_conf: raise ValueError(f"No configuration found for LLM type: {llm_type}") # Add max_retries to handle rate limit errors if "max_retries" not in merged_conf: merged_conf["max_retries"] = 3 # Handle SSL verification settings verify_ssl = merged_conf.pop("verify_ssl", True) # Create custom HTTP client if SSL verification is disabled if not verify_ssl: http_client = httpx.Client(verify=False) http_async_client = httpx.AsyncClient(verify=False) merged_conf["http_client"] = http_client merged_conf["http_async_client"] = http_async_client # Check if it's Google AI Studio platform based on configuration platform = merged_conf.get("platform", "").lower() is_google_aistudio = platform == "google_aistudio" or platform == "google-aistudio" if is_google_aistudio: # Handle Google AI Studio specific configuration gemini_conf = merged_conf.copy() # Map common keys to Google AI Studio specific keys if "api_key" in gemini_conf: gemini_conf["google_api_key"] = gemini_conf.pop("api_key") # Remove base_url and platform since Google AI Studio doesn't use them gemini_conf.pop("base_url", None) gemini_conf.pop("platform", None) # Remove unsupported parameters for Google AI Studio gemini_conf.pop("http_client", None) gemini_conf.pop("http_async_client", None) return ChatGoogleGenerativeAI(**gemini_conf) if "azure_endpoint" in merged_conf or os.getenv("AZURE_OPENAI_ENDPOINT"): return AzureChatOpenAI(**merged_conf) # Check if base_url is dashscope endpoint if "base_url" in merged_conf and "dashscope." in merged_conf["base_url"]: if llm_type == "reasoning": merged_conf["extra_body"] = {"enable_thinking": True} else: merged_conf["extra_body"] = {"enable_thinking": False} return ChatDashscope(**merged_conf) if llm_type == "reasoning": merged_conf["api_base"] = merged_conf.pop("base_url", None) return ChatDeepSeek(**merged_conf) else: return ChatOpenAI(**merged_conf) def get_llm_by_type(llm_type: LLMType) -> BaseChatModel: """ Get LLM instance by type. Returns cached instance if available. """ if llm_type in _llm_cache: return _llm_cache[llm_type] conf = load_yaml_config(_get_config_file_path()) llm = _create_llm_use_conf(llm_type, conf) _llm_cache[llm_type] = llm return llm def get_configured_llm_models() -> dict[str, list[str]]: """ Get all configured LLM models grouped by type. Returns: Dictionary mapping LLM type to list of configured model names. """ try: conf = load_yaml_config(_get_config_file_path()) llm_type_config_keys = _get_llm_type_config_keys() configured_models: dict[str, list[str]] = {} for llm_type in get_args(LLMType): # Get configuration from YAML file config_key = llm_type_config_keys.get(llm_type, "") yaml_conf = conf.get(config_key, {}) if config_key else {} # Get configuration from environment variables env_conf = _get_env_llm_conf(llm_type) # Merge configurations, with environment variables taking precedence merged_conf = {**yaml_conf, **env_conf} # Check if model is configured model_name = merged_conf.get("model") if model_name: configured_models.setdefault(llm_type, []).append(model_name) return configured_models except Exception as e: # Log error and return empty dict to avoid breaking the application print(f"Warning: Failed to load LLM configuration: {e}") return {} def _get_model_token_limit_defaults() -> dict[str, int]: """ Get default token limits for common LLM models. These are conservative limits to prevent token overflow errors (Issue #721). Users can override by setting token_limit in their config. """ return { # OpenAI models "gpt-4o": 120000, "gpt-4-turbo": 120000, "gpt-4": 8000, "gpt-3.5-turbo": 4000, # Anthropic Claude "claude-3": 180000, "claude-2": 100000, # Google Gemini "gemini-2": 180000, "gemini-1.5-pro": 180000, "gemini-1.5-flash": 180000, "gemini-pro": 30000, # Bytedance Doubao "doubao": 200000, # DeepSeek "deepseek": 100000, # Ollama/local "qwen": 30000, "llama": 4000, # Default fallback for unknown models "default": 100000, } def _infer_token_limit_from_model(model_name: str) -> int: """ Infer a reasonable token limit from the model name. This helps protect against token overflow errors when token_limit is not explicitly configured. Args: model_name: The model name from configuration Returns: A conservative token limit based on known model capabilities """ if not model_name: return 100000 # Safe default model_name_lower = model_name.lower() defaults = _get_model_token_limit_defaults() # Try exact or prefix matches for key, limit in defaults.items(): if key in model_name_lower: return limit # Return safe default if no match found return defaults["default"] def get_llm_token_limit_by_type(llm_type: str) -> int: """ Get the maximum token limit for a given LLM type. Priority order: 1. Explicitly configured token_limit in conf.yaml 2. Inferred from model name based on known model capabilities 3. Safe default (100,000 tokens) This helps prevent token overflow errors (Issue #721) even when token_limit is not configured. Args: llm_type (str): The type of LLM (e.g., 'basic', 'reasoning', 'vision', 'code'). Returns: int: The maximum token limit for the specified LLM type (conservative estimate). """ llm_type_config_keys = _get_llm_type_config_keys() config_key = llm_type_config_keys.get(llm_type) conf = load_yaml_config(_get_config_file_path()) model_config = conf.get(config_key, {}) # First priority: explicitly configured token_limit if "token_limit" in model_config: configured_limit = model_config["token_limit"] if configured_limit is not None: return configured_limit # Second priority: infer from model name model_name = model_config.get("model") if model_name: inferred_limit = _infer_token_limit_from_model(model_name) return inferred_limit # Fallback: safe default return _get_model_token_limit_defaults()["default"] # In the future, we will use reasoning_llm and vl_llm for different purposes # reasoning_llm = get_llm_by_type("reasoning") # vl_llm = get_llm_by_type("vision")