From 16dbbf3461c3940a09ed530900418ad4fb6154c5 Mon Sep 17 00:00:00 2001 From: linyq Date: Sat, 28 Mar 2026 00:34:01 +0800 Subject: [PATCH] =?UTF-8?q?refactor(config):=20=E9=87=8D=E6=9E=84=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E7=B3=BB=E7=BB=9F=E4=BB=A5=E6=94=AF=E6=8C=81=E9=BB=98?= =?UTF-8?q?=E8=AE=A4=E5=80=BC=E5=92=8C=E6=A8=A1=E5=9E=8B=E5=90=8D=E7=A7=B0?= =?UTF-8?q?=E8=A7=84=E8=8C=83=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 defaults.py 提供共享默认配置和模型名称处理工具 - 重构 config.py 使用默认值填充缺失配置 - 修改 openai_compatible_provider.py 简化模型名称处理逻辑 - 更新 WebUI 组件使用新的默认值系统 - 添加测试用例验证配置引导和模型名称处理 --- README.md | 8 +- app/config/config.py | 44 ++++++-- app/config/defaults.py | 61 ++++++++++ app/config/test_config_bootstrap_unittest.py | 83 ++++++++++++++ .../llm/openai_compatible_provider.py | 40 +------ config.example.toml | 8 +- webui/components/basic_settings.py | 105 +++++++----------- 7 files changed, 234 insertions(+), 115 deletions(-) create mode 100644 app/config/defaults.py create mode 100644 app/config/test_config_bootstrap_unittest.py diff --git a/README.md b/README.md index a99fd9b..c668fba 100644 --- a/README.md +++ b/README.md @@ -33,10 +33,10 @@ NarratoAI 是一个自动化影视解说工具,基于LLM实现文案撰写、 本项目仅供学习和研究使用,不得商用。如需商业授权,请联系作者。 ## 最新资讯 -- 2026.03.27 出于安全考虑,已移除 LiteLLM 依赖,统一使用 OpenAI 兼容请求链路 -- 2025.11.20 发布新版本 0.7.5, 新增 [IndexTTS2](https://github.com/index-tts/index-tts) 语音克隆支持 -- 2025.10.15 发布新版本 0.7.3, 升级大模型供应商管理能力 -- 2025.09.10 发布新版本 0.7.2, 新增腾讯云tts +- 2026.03.27 发布新版本 0.7.6,出于安全考虑,已移除 LiteLLM 依赖,统一使用 OpenAI 兼容请求链路 +- 2025.11.20 发布新版本 0.7.5,新增 [IndexTTS2](https://github.com/index-tts/index-tts) 语音克隆支持 +- 2025.10.15 发布新版本 0.7.3,升级大模型供应商管理能力 +- 2025.09.10 发布新版本 0.7.2,新增腾讯云tts - 2025.08.18 发布新版本 0.7.1,支持 **语音克隆** 和 最新大模型 - 2025.05.11 发布新版本 0.6.0,支持 **短剧解说** 和 优化剪辑流程 - 2025.03.06 发布新版本 0.5.2,支持 DeepSeek R1 和 DeepSeek V3 模型进行短剧混剪 diff --git a/app/config/config.py b/app/config/config.py index c8957a4..21026e8 100644 --- a/app/config/config.py +++ b/app/config/config.py @@ -4,6 +4,8 @@ import toml import shutil from loguru import logger +from app.config.defaults import build_default_app_config, merge_missing_app_defaults + root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) config_file = f"{root_dir}/config.toml" version_file = f"{root_dir}/project_version" @@ -27,21 +29,47 @@ def load_config(): shutil.rmtree(config_file) if not os.path.isfile(config_file): - example_file = f"{root_dir}/config.example.toml" - if os.path.isfile(example_file): - shutil.copyfile(example_file, config_file) - logger.info(f"copy config.example.toml to config.toml") + _config_ = build_default_config() + write_config_file(_config_) + logger.info("create config.toml with shared defaults") + return _config_ logger.info(f"load config from file: {config_file}") + _config_ = load_toml_file(config_file) + _config_["app"] = merge_missing_app_defaults(_config_.get("app", {})) + return _config_ + + +def load_toml_file(file_path): + """Load a TOML file and fall back to utf-8-sig when needed.""" try: - _config_ = toml.load(config_file) + return toml.load(file_path) except Exception as e: logger.warning(f"load config failed: {str(e)}, try to load as utf-8-sig") - with open(config_file, mode="r", encoding="utf-8-sig") as fp: + with open(file_path, mode="r", encoding="utf-8-sig") as fp: _cfg_content = fp.read() - _config_ = toml.loads(_cfg_content) - return _config_ + return toml.loads(_cfg_content) + + +def build_default_config(): + """Build the initial config file content for a fresh installation.""" + example_file = f"{root_dir}/config.example.toml" + config_data = {} + if os.path.isfile(example_file): + config_data = load_toml_file(example_file) + + config_data["app"] = build_default_app_config(config_data.get("app", {})) + return config_data + + +def write_config_file(config_data): + parent_dir = os.path.dirname(config_file) + if parent_dir: + os.makedirs(parent_dir, exist_ok=True) + + with open(config_file, "w", encoding="utf-8") as f: + f.write(toml.dumps(config_data)) def save_config(): diff --git a/app/config/defaults.py b/app/config/defaults.py new file mode 100644 index 0000000..859e121 --- /dev/null +++ b/app/config/defaults.py @@ -0,0 +1,61 @@ +"""Shared config defaults used by both bootstrap and WebUI fallbacks.""" + +DEFAULT_OPENAI_COMPATIBLE_BASE_URL = "https://api.siliconflow.cn/v1" +DEFAULT_OPENAI_COMPATIBLE_PROVIDER = "openai" + +DEFAULT_VISION_LLM_PROVIDER = DEFAULT_OPENAI_COMPATIBLE_PROVIDER +DEFAULT_VISION_OPENAI_MODEL_NAME = "Qwen/Qwen3.5-122B-A10B" + +DEFAULT_TEXT_LLM_PROVIDER = DEFAULT_OPENAI_COMPATIBLE_PROVIDER +DEFAULT_TEXT_OPENAI_MODEL_NAME = "Pro/zai-org/GLM-5" + +DEFAULT_LLM_APP_CONFIG = { + "vision_llm_provider": DEFAULT_VISION_LLM_PROVIDER, + "vision_openai_model_name": DEFAULT_VISION_OPENAI_MODEL_NAME, + "vision_openai_api_key": "", + "vision_openai_base_url": DEFAULT_OPENAI_COMPATIBLE_BASE_URL, + "text_llm_provider": DEFAULT_TEXT_LLM_PROVIDER, + "text_openai_model_name": DEFAULT_TEXT_OPENAI_MODEL_NAME, + "text_openai_api_key": "", + "text_openai_base_url": DEFAULT_OPENAI_COMPATIBLE_BASE_URL, +} + + +def build_default_app_config(app_config: dict | None = None) -> dict: + """Force the shared LLM defaults into a fresh app config.""" + merged = dict(app_config or {}) + merged.update(DEFAULT_LLM_APP_CONFIG) + return merged + + +def merge_missing_app_defaults(app_config: dict | None = None) -> dict: + """Backfill missing keys without overriding saved user values.""" + merged = dict(app_config or {}) + for key, value in DEFAULT_LLM_APP_CONFIG.items(): + merged.setdefault(key, value) + return merged + + +def normalize_openai_compatible_model_name( + model_name: str, + provider: str = DEFAULT_OPENAI_COMPATIBLE_PROVIDER, +) -> str: + """Strip only the internal OpenAI-compatible provider prefix if present.""" + normalized = (model_name or "").strip() + provider_prefix = f"{provider}/" + if normalized.lower().startswith(provider_prefix): + return normalized[len(provider_prefix):] + return normalized + + +def get_openai_compatible_ui_values( + full_model_name: str, + default_model: str, + provider: str = DEFAULT_OPENAI_COMPATIBLE_PROVIDER, +) -> tuple[str, str]: + """Keep the UI provider fixed while preserving the full model identifier.""" + current_model = normalize_openai_compatible_model_name( + full_model_name or default_model, + provider=provider, + ) + return provider, current_model or default_model diff --git a/app/config/test_config_bootstrap_unittest.py b/app/config/test_config_bootstrap_unittest.py new file mode 100644 index 0000000..c6844ed --- /dev/null +++ b/app/config/test_config_bootstrap_unittest.py @@ -0,0 +1,83 @@ +import tempfile +import unittest +from pathlib import Path + +import tomllib + +from app.config import config as cfg +from app.config.defaults import ( + get_openai_compatible_ui_values, + normalize_openai_compatible_model_name, +) + + +class ConfigBootstrapDefaultsTests(unittest.TestCase): + def test_load_config_bootstraps_webui_llm_defaults(self): + original_root_dir = cfg.root_dir + original_config_file = cfg.config_file + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + example_file = tmp_path / "config.example.toml" + example_file.write_text( + """ +[app] +vision_llm_provider = "openai" +vision_openai_model_name = "gemini/gemini-2.0-flash-lite" +vision_openai_api_key = "" +vision_openai_base_url = "" +text_llm_provider = "openai" +text_openai_model_name = "deepseek/deepseek-chat" +text_openai_api_key = "" +text_openai_base_url = "" +hide_config = true +""".strip() + + "\n", + encoding="utf-8", + ) + + config_path = tmp_path / "config.toml" + try: + cfg.root_dir = str(tmp_path) + cfg.config_file = str(config_path) + + config_data = cfg.load_config() + saved_config = tomllib.loads(config_path.read_text(encoding="utf-8")) + finally: + cfg.root_dir = original_root_dir + cfg.config_file = original_config_file + + self.assertEqual("openai", config_data["app"]["vision_llm_provider"]) + self.assertEqual("Qwen/Qwen3.5-122B-A10B", config_data["app"]["vision_openai_model_name"]) + self.assertEqual("https://api.siliconflow.cn/v1", config_data["app"]["vision_openai_base_url"]) + self.assertEqual("openai", config_data["app"]["text_llm_provider"]) + self.assertEqual("Pro/zai-org/GLM-5", config_data["app"]["text_openai_model_name"]) + self.assertEqual("https://api.siliconflow.cn/v1", config_data["app"]["text_openai_base_url"]) + self.assertEqual("Qwen/Qwen3.5-122B-A10B", saved_config["app"]["vision_openai_model_name"]) + self.assertEqual("Pro/zai-org/GLM-5", saved_config["app"]["text_openai_model_name"]) + self.assertTrue(saved_config["app"]["hide_config"]) + + +class OpenAICompatibleModelDefaultsTests(unittest.TestCase): + def test_ui_keeps_full_model_name_and_openai_provider(self): + provider, model_name = get_openai_compatible_ui_values( + "Qwen/Qwen3.5-122B-A10B", + "fallback-model", + ) + + self.assertEqual("openai", provider) + self.assertEqual("Qwen/Qwen3.5-122B-A10B", model_name) + + def test_normalize_only_strips_openai_prefix(self): + self.assertEqual( + "Qwen/Qwen3.5-122B-A10B", + normalize_openai_compatible_model_name("openai/Qwen/Qwen3.5-122B-A10B"), + ) + self.assertEqual( + "Qwen/Qwen3.5-122B-A10B", + normalize_openai_compatible_model_name("Qwen/Qwen3.5-122B-A10B"), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/app/services/llm/openai_compatible_provider.py b/app/services/llm/openai_compatible_provider.py index 36723b6..6423ec9 100644 --- a/app/services/llm/openai_compatible_provider.py +++ b/app/services/llm/openai_compatible_provider.py @@ -21,48 +21,14 @@ from openai import ( ) from app.config import config +from app.config.defaults import normalize_openai_compatible_model_name from .base import TextModelProvider, VisionModelProvider from .exceptions import APICallError, AuthenticationError, ContentFilterError, RateLimitError -# 常见 OpenAI 兼容网关前缀。若使用 provider/model 格式,将剥离 provider 前缀。 -OPENAI_COMPATIBLE_PROVIDER_PREFIXES = { - "openai", - "gemini", - "deepseek", - "qwen", - "siliconflow", - "moonshot", - "openrouter", - "anthropic", - "azure", - "ollama", - "mistral", - "groq", - "cohere", - "together_ai", - "fireworks_ai", - "volcengine", - "vertex_ai", - "huggingface", - "xai", - "bedrock", - "cloudflare", - "vllm", - "codestral", - "replicate", - "deepgram", -} - def _normalize_model_name(model_name: str) -> str: - """兼容历史 provider/model 写法,必要时自动剥离 provider 前缀。""" - if "/" not in model_name: - return model_name - - provider_prefix, raw_model = model_name.split("/", 1) - if provider_prefix.lower() in OPENAI_COMPATIBLE_PROVIDER_PREFIXES and raw_model: - return raw_model - return model_name + """仅剥离误保存的 openai/ 前缀,保留完整模型名称。""" + return normalize_openai_compatible_model_name(model_name) def _is_response_format_error(message: str) -> bool: diff --git a/config.example.toml b/config.example.toml index cc12fb1..f226c34 100644 --- a/config.example.toml +++ b/config.example.toml @@ -22,9 +22,9 @@ # - OpenAI: gpt-4o, gpt-4o-mini # - Qwen: qwen/qwen2.5-vl-32b-instruct # - SiliconFlow: siliconflow/Qwen/Qwen2.5-VL-32B-Instruct - vision_openai_model_name = "gemini/gemini-2.0-flash-lite" + vision_openai_model_name = "Qwen/Qwen3.5-122B-A10B" vision_openai_api_key = "" # 填入对应 provider 的 API key - vision_openai_base_url = "" # 可选:自定义 API base URL(官方 OpenAI 可留空) + vision_openai_base_url = "https://api.siliconflow.cn/v1" # 可选:自定义 API base URL(官方 OpenAI 可留空) # ===== 文本模型配置 ===== text_llm_provider = "openai" @@ -37,9 +37,9 @@ # - Qwen: qwen/qwen-plus, qwen/qwen-turbo # - SiliconFlow: siliconflow/deepseek-ai/DeepSeek-R1 # - Moonshot: moonshot/moonshot-v1-8k - text_openai_model_name = "deepseek/deepseek-chat" + text_openai_model_name = "Pro/zai-org/GLM-5" text_openai_api_key = "" # 填入对应 provider 的 API key - text_openai_base_url = "" # 可选:自定义 API base URL(官方 OpenAI 可留空) + text_openai_base_url = "https://api.siliconflow.cn/v1" # 可选:自定义 API base URL(官方 OpenAI 可留空) # ===== API Keys 参考 ===== # 主流 LLM Providers API Key 获取地址: diff --git a/webui/components/basic_settings.py b/webui/components/basic_settings.py index 48ef976..7f72bbf 100644 --- a/webui/components/basic_settings.py +++ b/webui/components/basic_settings.py @@ -3,6 +3,16 @@ import traceback import streamlit as st import os from app.config import config +from app.config.defaults import ( + DEFAULT_OPENAI_COMPATIBLE_BASE_URL, + DEFAULT_OPENAI_COMPATIBLE_PROVIDER, + DEFAULT_TEXT_LLM_PROVIDER, + DEFAULT_TEXT_OPENAI_MODEL_NAME, + DEFAULT_VISION_LLM_PROVIDER, + DEFAULT_VISION_OPENAI_MODEL_NAME, + get_openai_compatible_ui_values, + normalize_openai_compatible_model_name as normalize_openai_compatible_model_id, +) from app.utils import utils from loguru import logger from app.services.llm.unified_service import UnifiedLLMService @@ -116,10 +126,11 @@ def validate_openai_compatible_model_name(model_name: str, model_type: str) -> t def normalize_openai_compatible_model_name(model_name: str) -> str: - """将 provider/model 格式转换为网关实际使用的模型名。""" - if "/" not in model_name: - return model_name - return model_name.split("/", 1)[1] + """仅剥离误保存的 openai/ 前缀,保留完整模型名称。""" + return normalize_openai_compatible_model_id( + model_name, + provider=DEFAULT_OPENAI_COMPATIBLE_PROVIDER, + ) def show_config_validation_errors(errors: list): @@ -419,37 +430,22 @@ def render_vision_llm_settings(tr): st.subheader(tr("Vision Model Settings")) # 固定使用 OpenAI 兼容 提供商 - config.app["vision_llm_provider"] = "openai" + config.app["vision_llm_provider"] = DEFAULT_VISION_LLM_PROVIDER # 获取已保存的配置 - full_vision_model_name = config.app.get("vision_openai_model_name") or "gemini/gemini-2.0-flash-lite" + full_vision_model_name = config.app.get("vision_openai_model_name") or DEFAULT_VISION_OPENAI_MODEL_NAME vision_api_key = config.app.get("vision_openai_api_key", "") - vision_base_url = config.app.get("vision_openai_base_url", "") - - # 解析 provider 和 model - default_provider = "gemini" - default_model = "gemini-2.0-flash-lite" + vision_base_url = config.app.get("vision_openai_base_url", DEFAULT_OPENAI_COMPATIBLE_BASE_URL) - if "/" in full_vision_model_name: - parts = full_vision_model_name.split("/", 1) - current_provider = parts[0] - current_model = parts[1] - else: - current_provider = default_provider - current_model = full_vision_model_name + # 固定 provider 为 openai,模型输入框保留完整模型名称 + current_provider, current_model = get_openai_compatible_ui_values( + full_vision_model_name, + DEFAULT_VISION_OPENAI_MODEL_NAME, + provider=DEFAULT_VISION_LLM_PROVIDER, + ) # 定义支持的 provider 列表 - OPENAI_COMPATIBLE_PROVIDERS = [ - "openai", "gemini", "deepseek", "qwen", "siliconflow", "moonshot", - "anthropic", "azure", "ollama", "vertex_ai", "mistral", "codestral", - "volcengine", "groq", "cohere", "together_ai", "fireworks_ai", - "openrouter", "replicate", "huggingface", "xai", "deepgram", "vllm", - "bedrock", "cloudflare" - ] - - # 如果当前 provider 不在列表中,添加到列表头部 - if current_provider not in OPENAI_COMPATIBLE_PROVIDERS: - OPENAI_COMPATIBLE_PROVIDERS.insert(0, current_provider) + OPENAI_COMPATIBLE_PROVIDERS = ["openai"] # 渲染配置输入框 col1, col2 = st.columns([1, 2]) @@ -465,18 +461,18 @@ def render_vision_llm_settings(tr): model_name_input = st.text_input( tr("Vision Model Name"), value=current_model, - help="输入模型名称(不包含 provider 前缀)\n\n" + help="输入完整模型名称\n\n" "常用示例:\n" - "• gemini-2.0-flash-lite\n" + "• Qwen/Qwen3.5-122B-A10B\n" + "• gemini/gemini-2.0-flash-lite\n" "• gpt-4o\n" - "• qwen-vl-max\n" "• Qwen/Qwen2.5-VL-32B-Instruct (SiliconFlow)\n\n" "支持常见 OpenAI 兼容网关(如 OpenAI/DeepSeek/OpenRouter/SiliconFlow)", key="vision_model_input" ) # 组合完整的模型名称 - st_vision_model_name = f"{selected_provider}/{model_name_input}" if selected_provider and model_name_input else "" + st_vision_model_name = normalize_openai_compatible_model_name(model_name_input) st_vision_api_key = st.text_input( tr("Vision API Key"), @@ -691,37 +687,22 @@ def render_text_llm_settings(tr): st.subheader(tr("Text Generation Model Settings")) # 固定使用 OpenAI 兼容 提供商 - config.app["text_llm_provider"] = "openai" + config.app["text_llm_provider"] = DEFAULT_TEXT_LLM_PROVIDER # 获取已保存的配置 - full_text_model_name = config.app.get("text_openai_model_name") or "deepseek/deepseek-chat" + full_text_model_name = config.app.get("text_openai_model_name") or DEFAULT_TEXT_OPENAI_MODEL_NAME text_api_key = config.app.get("text_openai_api_key", "") - text_base_url = config.app.get("text_openai_base_url", "") + text_base_url = config.app.get("text_openai_base_url", DEFAULT_OPENAI_COMPATIBLE_BASE_URL) - # 解析 provider 和 model - default_provider = "deepseek" - default_model = "deepseek-chat" - - if "/" in full_text_model_name: - parts = full_text_model_name.split("/", 1) - current_provider = parts[0] - current_model = parts[1] - else: - current_provider = default_provider - current_model = full_text_model_name + # 固定 provider 为 openai,模型输入框保留完整模型名称 + current_provider, current_model = get_openai_compatible_ui_values( + full_text_model_name, + DEFAULT_TEXT_OPENAI_MODEL_NAME, + provider=DEFAULT_TEXT_LLM_PROVIDER, + ) # 定义支持的 provider 列表 - OPENAI_COMPATIBLE_PROVIDERS = [ - "openai", "gemini", "deepseek", "qwen", "siliconflow", "moonshot", - "anthropic", "azure", "ollama", "vertex_ai", "mistral", "codestral", - "volcengine", "groq", "cohere", "together_ai", "fireworks_ai", - "openrouter", "replicate", "huggingface", "xai", "deepgram", "vllm", - "bedrock", "cloudflare" - ] - - # 如果当前 provider 不在列表中,添加到列表头部 - if current_provider not in OPENAI_COMPATIBLE_PROVIDERS: - OPENAI_COMPATIBLE_PROVIDERS.insert(0, current_provider) + OPENAI_COMPATIBLE_PROVIDERS = ["openai"] # 渲染配置输入框 col1, col2 = st.columns([1, 2]) @@ -737,18 +718,18 @@ def render_text_llm_settings(tr): model_name_input = st.text_input( tr("Text Model Name"), value=current_model, - help="输入模型名称(不包含 provider 前缀)\n\n" + help="输入完整模型名称\n\n" "常用示例:\n" - "• deepseek-chat\n" + "• Pro/zai-org/GLM-5\n" + "• deepseek/deepseek-chat\n" "• gpt-4o\n" - "• gemini-2.0-flash\n" "• deepseek-ai/DeepSeek-R1 (SiliconFlow)\n\n" "支持常见 OpenAI 兼容网关(如 OpenAI/DeepSeek/OpenRouter/SiliconFlow)", key="text_model_input" ) # 组合完整的模型名称 - st_text_model_name = f"{selected_provider}/{model_name_input}" if selected_provider and model_name_input else "" + st_text_model_name = normalize_openai_compatible_model_name(model_name_input) st_text_api_key = st.text_input( tr("Text API Key"),