From 3396644593127cd1f738332a1a3c6d705fc567e3 Mon Sep 17 00:00:00 2001 From: linyq Date: Fri, 27 Mar 2026 23:49:58 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E7=A7=BB=E9=99=A4=20LiteLLM=20?= =?UTF-8?q?=E4=BE=9D=E8=B5=96=E5=B9=B6=E8=BF=81=E7=A7=BB=E8=87=B3=20OpenAI?= =?UTF-8?q?=20=E5=85=BC=E5=AE=B9=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 移除 LiteLLM 相关代码和依赖,改用原生 OpenAI 兼容接口 - 重构 LLM 服务提供商注册逻辑,仅支持 OpenAI 兼容接口 - 更新配置文件和文档,移除 LiteLLM 相关说明 - 添加新的测试用例验证 OpenAI 兼容接口集成 - 更新 WebUI 组件以适配新的 OpenAI 兼容接口 --- README.md | 4 +- app/services/llm/__init__.py | 3 +- app/services/llm/base.py | 2 +- app/services/llm/config_validator.py | 6 +- app/services/llm/litellm_provider.py | 490 ------------------ app/services/llm/manager.py | 43 +- .../llm/openai_compatible_provider.py | 276 ++++++++++ app/services/llm/providers/__init__.py | 21 +- app/services/llm/test_litellm_integration.py | 228 -------- .../llm/test_openai_compat_unittest.py | 67 +++ .../llm/test_openai_compatible_integration.py | 35 ++ app/services/script_service.py | 4 +- config.example.toml | 48 +- requirements.txt | 3 +- webui/components/basic_settings.py | 404 +++++---------- webui/tools/generate_script_docu.py | 2 +- 16 files changed, 582 insertions(+), 1054 deletions(-) delete mode 100644 app/services/llm/litellm_provider.py create mode 100644 app/services/llm/openai_compatible_provider.py delete mode 100644 app/services/llm/test_litellm_integration.py create mode 100644 app/services/llm/test_openai_compat_unittest.py create mode 100644 app/services/llm/test_openai_compatible_integration.py diff --git a/README.md b/README.md index d3b2266..a99fd9b 100644 --- a/README.md +++ b/README.md @@ -33,8 +33,9 @@ NarratoAI 是一个自动化影视解说工具,基于LLM实现文案撰写、 本项目仅供学习和研究使用,不得商用。如需商业授权,请联系作者。 ## 最新资讯 +- 2026.03.27 出于安全考虑,已移除 LiteLLM 依赖,统一使用 OpenAI 兼容请求链路 - 2025.11.20 发布新版本 0.7.5, 新增 [IndexTTS2](https://github.com/index-tts/index-tts) 语音克隆支持 -- 2025.10.15 发布新版本 0.7.3, 使用 [LiteLLM](https://github.com/BerriAI/litellm) 管理模型供应商 +- 2025.10.15 发布新版本 0.7.3, 升级大模型供应商管理能力 - 2025.09.10 发布新版本 0.7.2, 新增腾讯云tts - 2025.08.18 发布新版本 0.7.1,支持 **语音克隆** 和 最新大模型 - 2025.05.11 发布新版本 0.6.0,支持 **短剧解说** 和 优化剪辑流程 @@ -176,4 +177,3 @@ streamlit run webui.py --server.maxUploadSize=2048 ## Star History [![Star History Chart](https://api.star-history.com/svg?repos=linyqh/NarratoAI&type=Date)](https://star-history.com/#linyqh/NarratoAI&Date) - diff --git a/app/services/llm/__init__.py b/app/services/llm/__init__.py index ccf2c12..fa3f7f3 100644 --- a/app/services/llm/__init__.py +++ b/app/services/llm/__init__.py @@ -12,8 +12,7 @@ NarratoAI 大模型服务模块 - OutputValidator: 输出格式验证器 支持的供应商: -视觉模型: Gemini, QwenVL, Siliconflow -文本模型: OpenAI, DeepSeek, Gemini, Qwen, Moonshot, Siliconflow +视觉模型/文本模型: OpenAI 兼容接口(可对接 OpenAI、DeepSeek、Gemini 网关、Qwen 网关等) """ from .manager import LLMServiceManager diff --git a/app/services/llm/base.py b/app/services/llm/base.py index f2f5935..87f1368 100644 --- a/app/services/llm/base.py +++ b/app/services/llm/base.py @@ -68,7 +68,7 @@ class BaseLLMProvider(ABC): """验证模型支持情况(宽松模式,仅记录警告)""" from loguru import logger - # LiteLLM 已提供统一的模型验证,传统 provider 使用宽松验证 + # OpenAI 兼容网关的模型数量较多,运行时由远端完成最终校验 if self.model_name not in self.supported_models: logger.warning( f"模型 {self.model_name} 未在供应商 {self.provider_name} 的预定义支持列表中。" diff --git a/app/services/llm/config_validator.py b/app/services/llm/config_validator.py index cb542ef..716c2b0 100644 --- a/app/services/llm/config_validator.py +++ b/app/services/llm/config_validator.py @@ -214,7 +214,7 @@ class LLMConfigValidator: "建议为每个提供商配置base_url以提高稳定性", "定期检查模型名称是否为最新版本", "建议配置多个提供商作为备用方案", - "推荐使用 LiteLLM 作为统一接口,支持 100+ providers" + "推荐使用 OpenAI 兼容接口,便于接入多家模型网关" ] } @@ -257,8 +257,8 @@ class LLMConfigValidator: "text": ["gemini-2.5-flash", "gemini-2.0-flash", "gemini-1.5-pro"] }, "openai": { - "vision": [], - "text": ["gpt-4o", "gpt-4o-mini", "gpt-4-turbo"] + "vision": ["gpt-4o", "gemini-2.0-flash-lite", "Qwen/Qwen2.5-VL-32B-Instruct"], + "text": ["gpt-4o-mini", "deepseek-chat", "zai-org/GLM-4.6"] }, "qwen": { "vision": ["qwen2.5-vl-32b-instruct"], diff --git a/app/services/llm/litellm_provider.py b/app/services/llm/litellm_provider.py deleted file mode 100644 index c3f1763..0000000 --- a/app/services/llm/litellm_provider.py +++ /dev/null @@ -1,490 +0,0 @@ -""" -LiteLLM 统一提供商实现 - -使用 LiteLLM 库提供统一的 LLM 接口,支持 100+ providers -包括 OpenAI, Anthropic, Gemini, Qwen, DeepSeek, SiliconFlow 等 -""" - -import asyncio -import base64 -import io -from typing import List, Dict, Any, Optional, Union -from pathlib import Path -import PIL.Image -from loguru import logger - -try: - import litellm - from litellm import acompletion, completion - from litellm.exceptions import ( - AuthenticationError as LiteLLMAuthError, - RateLimitError as LiteLLMRateLimitError, - BadRequestError as LiteLLMBadRequestError, - APIError as LiteLLMAPIError - ) -except ImportError: - logger.error("LiteLLM 未安装。请运行: pip install litellm") - raise - -from .base import VisionModelProvider, TextModelProvider -from .exceptions import ( - APICallError, - AuthenticationError, - RateLimitError, - ContentFilterError -) - - -# 配置 LiteLLM 全局设置 -def configure_litellm(): - """配置 LiteLLM 全局参数""" - from app.config import config - - # 设置重试次数 - litellm.num_retries = config.app.get('llm_max_retries', 3) - - # 设置默认超时 - litellm.request_timeout = config.app.get('llm_text_timeout', 180) - - # 启用详细日志(开发环境) - # litellm.set_verbose = True - - logger.info(f"LiteLLM 配置完成: retries={litellm.num_retries}, timeout={litellm.request_timeout}s") - - -# 初始化配置 -configure_litellm() - - -class LiteLLMVisionProvider(VisionModelProvider): - """使用 LiteLLM 的统一视觉模型提供商""" - - @property - def provider_name(self) -> str: - # 从 model_name 中提取 provider 名称(如 "gemini/gemini-2.0-flash") - if "/" in self.model_name: - return self.model_name.split("/")[0] - return "litellm" - - @property - def supported_models(self) -> List[str]: - # LiteLLM 支持 100+ providers 和数百个模型,无法全部列举 - # 返回空列表表示跳过预定义列表检查,由 LiteLLM 在实际调用时验证 - return [] - - def _validate_model_support(self): - """ - 重写模型验证逻辑 - - 对于 LiteLLM,我们不做预定义列表检查,因为: - 1. LiteLLM 支持 100+ providers 和数百个模型,无法全部列举 - 2. LiteLLM 会在实际调用时进行模型验证 - 3. 如果模型不支持,LiteLLM 会返回清晰的错误信息 - - 这里只做基本的格式验证(可选) - """ - from loguru import logger - - # 可选:检查模型名称格式(provider/model) - if "/" not in self.model_name: - logger.debug( - f"LiteLLM 模型名称 '{self.model_name}' 未包含 provider 前缀," - f"LiteLLM 将尝试自动推断。建议使用 'provider/model' 格式,如 'gemini/gemini-2.5-flash'" - ) - - # 不抛出异常,让 LiteLLM 在实际调用时验证 - logger.debug(f"LiteLLM 视觉模型已配置: {self.model_name}") - - def _initialize(self): - """初始化 LiteLLM 特定设置""" - # 设置 API key 到环境变量(LiteLLM 会自动读取) - import os - - # 根据 model_name 确定需要设置哪个 API key - provider = self.provider_name.lower() - - # 映射 provider 到环境变量名 - env_key_mapping = { - "gemini": "GEMINI_API_KEY", - "google": "GEMINI_API_KEY", - "openai": "OPENAI_API_KEY", - "qwen": "QWEN_API_KEY", - "dashscope": "DASHSCOPE_API_KEY", - "siliconflow": "SILICONFLOW_API_KEY", - "anthropic": "ANTHROPIC_API_KEY", - "claude": "ANTHROPIC_API_KEY" - } - - env_var = env_key_mapping.get(provider, f"{provider.upper()}_API_KEY") - - if self.api_key and env_var: - os.environ[env_var] = self.api_key - logger.debug(f"设置环境变量: {env_var}") - - # 如果提供了 base_url,设置到 LiteLLM - if self.base_url: - # LiteLLM 支持通过 api_base 参数设置自定义 URL - self._api_base = self.base_url - logger.debug(f"使用自定义 API base URL: {self.base_url}") - - async def analyze_images(self, - images: List[Union[str, Path, PIL.Image.Image]], - prompt: str, - batch_size: int = 10, - **kwargs) -> List[str]: - """ - 使用 LiteLLM 分析图片 - - Args: - images: 图片路径列表或PIL图片对象列表 - prompt: 分析提示词 - batch_size: 批处理大小 - **kwargs: 其他参数 - - Returns: - 分析结果列表 - """ - logger.info(f"开始使用 LiteLLM ({self.model_name}) 分析 {len(images)} 张图片") - - # 预处理图片 - processed_images = self._prepare_images(images) - - # 分批处理 - results = [] - for i in range(0, len(processed_images), batch_size): - batch = processed_images[i:i + batch_size] - logger.info(f"处理第 {i//batch_size + 1} 批,共 {len(batch)} 张图片") - - try: - result = await self._analyze_batch(batch, prompt, **kwargs) - results.append(result) - except Exception as e: - logger.error(f"批次 {i//batch_size + 1} 处理失败: {str(e)}") - results.append(f"批次处理失败: {str(e)}") - - return results - - async def _analyze_batch(self, batch: List[PIL.Image.Image], prompt: str, **kwargs) -> str: - """分析一批图片""" - # 构建 LiteLLM 格式的消息 - content = [{"type": "text", "text": prompt}] - - # 添加图片(使用 base64 编码) - for img in batch: - base64_image = self._image_to_base64(img) - content.append({ - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{base64_image}" - } - }) - - messages = [{ - "role": "user", - "content": content - }] - - # 调用 LiteLLM - try: - # 准备参数 - effective_model_name = self.model_name - - # SiliconFlow 特殊处理 - if self.model_name.lower().startswith("siliconflow/"): - # 替换 provider 为 openai - if "/" in self.model_name: - effective_model_name = f"openai/{self.model_name.split('/', 1)[1]}" - else: - effective_model_name = f"openai/{self.model_name}" - - # 确保设置了 OPENAI_API_KEY (如果尚未设置) - import os - if not os.environ.get("OPENAI_API_KEY") and os.environ.get("SILICONFLOW_API_KEY"): - os.environ["OPENAI_API_KEY"] = os.environ.get("SILICONFLOW_API_KEY") - - # 确保设置了 base_url (如果尚未设置) - if not hasattr(self, '_api_base'): - self._api_base = "https://api.siliconflow.cn/v1" - - completion_kwargs = { - "model": effective_model_name, - "messages": messages, - "temperature": kwargs.get("temperature", 1.0), - "max_tokens": kwargs.get("max_tokens", 4000) - } - - # 如果有自定义 base_url,添加 api_base 参数 - if hasattr(self, '_api_base'): - completion_kwargs["api_base"] = self._api_base - - # 支持动态传递 api_key 和 api_base - if "api_key" in kwargs: - completion_kwargs["api_key"] = kwargs["api_key"] - if "api_base" in kwargs: - completion_kwargs["api_base"] = kwargs["api_base"] - - response = await acompletion(**completion_kwargs) - - if response.choices and len(response.choices) > 0: - content = response.choices[0].message.content - logger.debug(f"LiteLLM 调用成功,消耗 tokens: {response.usage.total_tokens if response.usage else 'N/A'}") - return content - else: - raise APICallError("LiteLLM 返回空响应") - - except LiteLLMAuthError as e: - logger.error(f"LiteLLM 认证失败: {str(e)}") - raise AuthenticationError() - except LiteLLMRateLimitError as e: - logger.error(f"LiteLLM 速率限制: {str(e)}") - raise RateLimitError() - except LiteLLMBadRequestError as e: - error_msg = str(e) - if "SAFETY" in error_msg.upper() or "content_filter" in error_msg.lower(): - raise ContentFilterError(f"内容被安全过滤器阻止: {error_msg}") - logger.error(f"LiteLLM 请求错误: {error_msg}") - raise APICallError(f"请求错误: {error_msg}") - except LiteLLMAPIError as e: - logger.error(f"LiteLLM API 错误: {str(e)}") - raise APICallError(f"API 错误: {str(e)}") - except Exception as e: - logger.error(f"LiteLLM 调用失败: {str(e)}") - raise APICallError(f"调用失败: {str(e)}") - - def _image_to_base64(self, img: PIL.Image.Image) -> str: - """将PIL图片转换为base64编码""" - img_buffer = io.BytesIO() - img.save(img_buffer, format='JPEG', quality=85) - img_bytes = img_buffer.getvalue() - return base64.b64encode(img_bytes).decode('utf-8') - - async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]: - """兼容基类接口(实际使用 LiteLLM SDK)""" - pass - - -class LiteLLMTextProvider(TextModelProvider): - """使用 LiteLLM 的统一文本生成提供商""" - - @property - def provider_name(self) -> str: - # 从 model_name 中提取 provider 名称 - if "/" in self.model_name: - return self.model_name.split("/")[0] - # 尝试从模型名称推断 provider - model_lower = self.model_name.lower() - if "gpt" in model_lower: - return "openai" - elif "claude" in model_lower: - return "anthropic" - elif "gemini" in model_lower: - return "gemini" - elif "qwen" in model_lower: - return "qwen" - elif "deepseek" in model_lower: - return "deepseek" - return "litellm" - - @property - def supported_models(self) -> List[str]: - # LiteLLM 支持 100+ providers 和数百个模型,无法全部列举 - # 返回空列表表示跳过预定义列表检查,由 LiteLLM 在实际调用时验证 - return [] - - def _validate_model_support(self): - """ - 重写模型验证逻辑 - - 对于 LiteLLM,我们不做预定义列表检查,因为: - 1. LiteLLM 支持 100+ providers 和数百个模型,无法全部列举 - 2. LiteLLM 会在实际调用时进行模型验证 - 3. 如果模型不支持,LiteLLM 会返回清晰的错误信息 - - 这里只做基本的格式验证(可选) - """ - from loguru import logger - - # 可选:检查模型名称格式(provider/model) - if "/" not in self.model_name: - logger.debug( - f"LiteLLM 模型名称 '{self.model_name}' 未包含 provider 前缀," - f"LiteLLM 将尝试自动推断。建议使用 'provider/model' 格式,如 'gemini/gemini-2.5-flash'" - ) - - # 不抛出异常,让 LiteLLM 在实际调用时验证 - logger.debug(f"LiteLLM 文本模型已配置: {self.model_name}") - - def _initialize(self): - """初始化 LiteLLM 特定设置""" - import os - - # 根据 model_name 确定需要设置哪个 API key - provider = self.provider_name.lower() - - # 映射 provider 到环境变量名 - env_key_mapping = { - "gemini": "GEMINI_API_KEY", - "google": "GEMINI_API_KEY", - "openai": "OPENAI_API_KEY", - "qwen": "QWEN_API_KEY", - "dashscope": "DASHSCOPE_API_KEY", - "siliconflow": "SILICONFLOW_API_KEY", - "deepseek": "DEEPSEEK_API_KEY", - "anthropic": "ANTHROPIC_API_KEY", - "claude": "ANTHROPIC_API_KEY", - "moonshot": "MOONSHOT_API_KEY" - } - - env_var = env_key_mapping.get(provider, f"{provider.upper()}_API_KEY") - - if self.api_key and env_var: - os.environ[env_var] = self.api_key - logger.debug(f"设置环境变量: {env_var}") - - # 如果提供了 base_url,保存用于后续调用 - if self.base_url: - self._api_base = self.base_url - logger.debug(f"使用自定义 API base URL: {self.base_url}") - - async def generate_text(self, - prompt: str, - system_prompt: Optional[str] = None, - temperature: float = 1.0, - max_tokens: Optional[int] = None, - response_format: Optional[str] = None, - **kwargs) -> str: - """ - 使用 LiteLLM 生成文本 - - Args: - prompt: 用户提示词 - system_prompt: 系统提示词 - temperature: 生成温度 - max_tokens: 最大token数 - response_format: 响应格式 ('json' 或 None) - **kwargs: 其他参数 - - Returns: - 生成的文本内容 - """ - # 构建消息列表 - messages = self._build_messages(prompt, system_prompt) - - # 准备参数 - effective_model_name = self.model_name - - # SiliconFlow 特殊处理 - if self.model_name.lower().startswith("siliconflow/"): - # 替换 provider 为 openai - if "/" in self.model_name: - effective_model_name = f"openai/{self.model_name.split('/', 1)[1]}" - else: - effective_model_name = f"openai/{self.model_name}" - - # 确保设置了 OPENAI_API_KEY (如果尚未设置) - import os - if not os.environ.get("OPENAI_API_KEY") and os.environ.get("SILICONFLOW_API_KEY"): - os.environ["OPENAI_API_KEY"] = os.environ.get("SILICONFLOW_API_KEY") - - # 确保设置了 base_url (如果尚未设置) - if not hasattr(self, '_api_base'): - self._api_base = "https://api.siliconflow.cn/v1" - - completion_kwargs = { - "model": effective_model_name, - "messages": messages, - "temperature": temperature - } - - if max_tokens: - completion_kwargs["max_tokens"] = max_tokens - - # 处理 JSON 格式输出 - # LiteLLM 会自动处理不同 provider 的 JSON mode 差异 - if response_format == "json": - try: - completion_kwargs["response_format"] = {"type": "json_object"} - except Exception as e: - # 如果不支持,在提示词中添加约束 - logger.warning(f"模型可能不支持 response_format,将在提示词中添加 JSON 约束: {str(e)}") - messages[-1]["content"] += "\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。" - - # 如果有自定义 base_url,添加 api_base 参数 - if hasattr(self, '_api_base'): - completion_kwargs["api_base"] = self._api_base - - # 支持动态传递 api_key 和 api_base (修复认证问题) - if "api_key" in kwargs: - completion_kwargs["api_key"] = kwargs["api_key"] - if "api_base" in kwargs: - completion_kwargs["api_base"] = kwargs["api_base"] - - try: - # 调用 LiteLLM(自动重试) - response = await acompletion(**completion_kwargs) - - if response.choices and len(response.choices) > 0: - content = response.choices[0].message.content - - # 清理可能的 markdown 代码块(针对不支持 JSON mode 的模型) - if response_format == "json" and "response_format" not in completion_kwargs: - content = self._clean_json_output(content) - - logger.debug(f"LiteLLM 调用成功,消耗 tokens: {response.usage.total_tokens if response.usage else 'N/A'}") - return content - else: - raise APICallError("LiteLLM 返回空响应") - - except LiteLLMAuthError as e: - logger.error(f"LiteLLM 认证失败: {str(e)}") - raise AuthenticationError() - except LiteLLMRateLimitError as e: - logger.error(f"LiteLLM 速率限制: {str(e)}") - raise RateLimitError() - except LiteLLMBadRequestError as e: - error_msg = str(e) - # 处理不支持 response_format 的情况 - if "response_format" in error_msg and response_format == "json": - logger.warning(f"模型不支持 response_format,重试不带格式约束的请求") - completion_kwargs.pop("response_format", None) - messages[-1]["content"] += "\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。" - - # 重试 - response = await acompletion(**completion_kwargs) - if response.choices and len(response.choices) > 0: - content = response.choices[0].message.content - content = self._clean_json_output(content) - return content - else: - raise APICallError("LiteLLM 返回空响应") - - # 检查是否是安全过滤 - if "SAFETY" in error_msg.upper() or "content_filter" in error_msg.lower(): - raise ContentFilterError(f"内容被安全过滤器阻止: {error_msg}") - - logger.error(f"LiteLLM 请求错误: {error_msg}") - raise APICallError(f"请求错误: {error_msg}") - except LiteLLMAPIError as e: - logger.error(f"LiteLLM API 错误: {str(e)}") - raise APICallError(f"API 错误: {str(e)}") - except Exception as e: - logger.error(f"LiteLLM 调用失败: {str(e)}") - raise APICallError(f"调用失败: {str(e)}") - - def _clean_json_output(self, output: str) -> str: - """清理JSON输出,移除markdown标记等""" - import re - - # 移除可能的markdown代码块标记 - output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE) - output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE) - output = re.sub(r'^```.*$', '', output, flags=re.MULTILINE) - - # 移除前后空白字符 - output = output.strip() - - return output - - async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]: - """兼容基类接口(实际使用 LiteLLM SDK)""" - pass diff --git a/app/services/llm/manager.py b/app/services/llm/manager.py index 7074694..cbe6e55 100644 --- a/app/services/llm/manager.py +++ b/app/services/llm/manager.py @@ -4,7 +4,7 @@ 统一管理所有大模型服务提供商,提供简单的工厂方法来创建和获取服务实例 """ -from typing import Dict, Type, Optional +from typing import Dict, Type, Optional, Tuple from loguru import logger from app.config import config @@ -64,6 +64,29 @@ class LLMServiceManager: "vision_providers": list(cls._vision_providers.keys()), "text_providers": list(cls._text_providers.keys()) } + + @classmethod + def _normalize_provider_name(cls, provider_name: str) -> str: + """规范化 provider 名称。""" + return provider_name.lower() + + @classmethod + def _get_provider_config( + cls, + model_type: str, + provider_name: str + ) -> Tuple[Optional[str], Optional[str], Optional[str]]: + """ + 获取 provider 配置。 + + model_type: 'vision' 或 'text' + """ + config_prefix = f"{model_type}_{provider_name}" + api_key = config.app.get(f"{config_prefix}_api_key") + model_name = config.app.get(f"{config_prefix}_model_name") + base_url = config.app.get(f"{config_prefix}_base_url") + + return api_key, model_name, base_url @classmethod def get_vision_provider(cls, provider_name: Optional[str] = None) -> VisionModelProvider: @@ -89,9 +112,8 @@ class LLMServiceManager: # 确定提供商名称 if not provider_name: - provider_name = config.app.get('vision_llm_provider', 'gemini').lower() - else: - provider_name = provider_name.lower() + provider_name = config.app.get('vision_llm_provider', 'openai') + provider_name = cls._normalize_provider_name(provider_name) # 检查缓存 cache_key = f"vision_{provider_name}" @@ -104,9 +126,7 @@ class LLMServiceManager: # 获取配置 config_prefix = f"vision_{provider_name}" - api_key = config.app.get(f'{config_prefix}_api_key') - model_name = config.app.get(f'{config_prefix}_model_name') - base_url = config.app.get(f'{config_prefix}_base_url') + api_key, model_name, base_url = cls._get_provider_config("vision", provider_name) if not api_key: raise ConfigurationError(f"缺少API密钥配置: {config_prefix}_api_key") @@ -157,9 +177,8 @@ class LLMServiceManager: # 确定提供商名称 if not provider_name: - provider_name = config.app.get('text_llm_provider', 'openai').lower() - else: - provider_name = provider_name.lower() + provider_name = config.app.get('text_llm_provider', 'openai') + provider_name = cls._normalize_provider_name(provider_name) logger.debug(f"获取文本模型提供商: {provider_name}") logger.debug(f"已注册的文本提供商: {list(cls._text_providers.keys())}") @@ -178,9 +197,7 @@ class LLMServiceManager: # 获取配置 config_prefix = f"text_{provider_name}" - api_key = config.app.get(f'{config_prefix}_api_key') - model_name = config.app.get(f'{config_prefix}_model_name') - base_url = config.app.get(f'{config_prefix}_base_url') + api_key, model_name, base_url = cls._get_provider_config("text", provider_name) if not api_key: raise ConfigurationError(f"缺少API密钥配置: {config_prefix}_api_key") diff --git a/app/services/llm/openai_compatible_provider.py b/app/services/llm/openai_compatible_provider.py new file mode 100644 index 0000000..36723b6 --- /dev/null +++ b/app/services/llm/openai_compatible_provider.py @@ -0,0 +1,276 @@ +""" +OpenAI 兼容提供商实现 + +使用 OpenAI 官方 SDK 调用 OpenAI 兼容接口,支持文本和视觉模型。 +""" + +import io +import base64 +import re +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import PIL.Image +from loguru import logger +from openai import ( + APIError as OpenAIAPIError, + AsyncOpenAI, + AuthenticationError as OpenAIAuthError, + BadRequestError as OpenAIBadRequestError, + RateLimitError as OpenAIRateLimitError, +) + +from app.config import config +from .base import TextModelProvider, VisionModelProvider +from .exceptions import APICallError, AuthenticationError, ContentFilterError, RateLimitError + +# 常见 OpenAI 兼容网关前缀。若使用 provider/model 格式,将剥离 provider 前缀。 +OPENAI_COMPATIBLE_PROVIDER_PREFIXES = { + "openai", + "gemini", + "deepseek", + "qwen", + "siliconflow", + "moonshot", + "openrouter", + "anthropic", + "azure", + "ollama", + "mistral", + "groq", + "cohere", + "together_ai", + "fireworks_ai", + "volcengine", + "vertex_ai", + "huggingface", + "xai", + "bedrock", + "cloudflare", + "vllm", + "codestral", + "replicate", + "deepgram", +} + + +def _normalize_model_name(model_name: str) -> str: + """兼容历史 provider/model 写法,必要时自动剥离 provider 前缀。""" + if "/" not in model_name: + return model_name + + provider_prefix, raw_model = model_name.split("/", 1) + if provider_prefix.lower() in OPENAI_COMPATIBLE_PROVIDER_PREFIXES and raw_model: + return raw_model + return model_name + + +def _is_response_format_error(message: str) -> bool: + return "response_format" in (message or "").lower() + + +def _is_content_filter_error(message: str) -> bool: + lowered = (message or "").lower() + return "content_filter" in lowered or "safety" in lowered + + +def _clean_json_output(output: str) -> str: + """清理 JSON 输出中的 markdown 包裹。""" + output = re.sub(r"^```json\s*", "", output, flags=re.MULTILINE) + output = re.sub(r"^```\s*$", "", output, flags=re.MULTILINE) + output = re.sub(r"^```.*$", "", output, flags=re.MULTILINE) + return output.strip() + + +class _OpenAICompatibleBase: + """OpenAI 兼容 provider 共享逻辑。""" + + @property + def provider_name(self) -> str: + return "openai" + + @property + def supported_models(self) -> List[str]: + # 兼容网关模型数量很多,运行时校验由远端完成。 + return [] + + def _validate_model_support(self): + logger.debug(f"OpenAI 兼容模型已配置: {self.model_name}") + + def _initialize(self): + # SDK client 按请求参数动态构建,这里无需初始化全局状态。 + pass + + def _build_client( + self, + api_key_override: Optional[str] = None, + base_url_override: Optional[str] = None, + timeout_override: Optional[float] = None, + ) -> AsyncOpenAI: + """按请求构建 AsyncOpenAI 客户端,支持动态覆盖 api_key / base_url。""" + api_key = api_key_override or self.api_key + base_url = base_url_override or self.base_url or None + + timeout_seconds: float = timeout_override or config.app.get("llm_text_timeout", 180) + max_retries: int = config.app.get("llm_max_retries", 3) + + return AsyncOpenAI( + api_key=api_key, + base_url=base_url, + timeout=timeout_seconds, + max_retries=max_retries, + ) + + +class OpenAICompatibleVisionProvider(_OpenAICompatibleBase, VisionModelProvider): + """OpenAI 兼容视觉模型提供商。""" + + async def analyze_images( + self, + images: List[Union[str, Path, PIL.Image.Image]], + prompt: str, + batch_size: int = 10, + **kwargs, + ) -> List[str]: + logger.info(f"开始使用 OpenAI 兼容接口 ({self.model_name}) 分析 {len(images)} 张图片") + + processed_images = self._prepare_images(images) + results: List[str] = [] + + for i in range(0, len(processed_images), batch_size): + batch = processed_images[i : i + batch_size] + logger.info(f"处理第 {i // batch_size + 1} 批,共 {len(batch)} 张图片") + try: + result = await self._analyze_batch(batch, prompt, **kwargs) + results.append(result) + except Exception as exc: + logger.error(f"批次 {i // batch_size + 1} 处理失败: {exc}") + results.append(f"批次处理失败: {exc}") + + return results + + async def _analyze_batch(self, batch: List[PIL.Image.Image], prompt: str, **kwargs) -> str: + content = [{"type": "text", "text": prompt}] + for img in batch: + content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{self._image_to_base64(img)}"}, + } + ) + + messages = [{"role": "user", "content": content}] + model_name = _normalize_model_name(self.model_name) + + client = self._build_client( + api_key_override=kwargs.get("api_key"), + base_url_override=kwargs.get("api_base"), + timeout_override=config.app.get("llm_vision_timeout", 120), + ) + + try: + response = await client.chat.completions.create( + model=model_name, + messages=messages, + temperature=kwargs.get("temperature", 1.0), + max_tokens=kwargs.get("max_tokens", 4000), + ) + if response.choices and response.choices[0].message and response.choices[0].message.content: + return response.choices[0].message.content + raise APICallError("OpenAI 兼容接口返回空响应") + except OpenAIAuthError as exc: + logger.error(f"OpenAI 兼容接口认证失败: {exc}") + raise AuthenticationError(str(exc)) + except OpenAIRateLimitError as exc: + logger.error(f"OpenAI 兼容接口速率限制: {exc}") + raise RateLimitError(str(exc)) + except OpenAIBadRequestError as exc: + error_msg = str(exc) + if _is_content_filter_error(error_msg): + raise ContentFilterError(f"内容被安全过滤器阻止: {error_msg}") + raise APICallError(f"请求错误: {error_msg}") + except OpenAIAPIError as exc: + logger.error(f"OpenAI 兼容接口 API 错误: {exc}") + raise APICallError(f"API 错误: {exc}") + except Exception as exc: + logger.error(f"OpenAI 兼容接口调用失败: {exc}") + raise APICallError(f"调用失败: {exc}") + + def _image_to_base64(self, img: PIL.Image.Image) -> str: + img_buffer = io.BytesIO() + img.save(img_buffer, format="JPEG", quality=85) + return base64.b64encode(img_buffer.getvalue()).decode("utf-8") + + async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]: + return payload + + +class OpenAICompatibleTextProvider(_OpenAICompatibleBase, TextModelProvider): + """OpenAI 兼容文本模型提供商。""" + + async def generate_text( + self, + prompt: str, + system_prompt: Optional[str] = None, + temperature: float = 1.0, + max_tokens: Optional[int] = None, + response_format: Optional[str] = None, + **kwargs, + ) -> str: + messages = self._build_messages(prompt, system_prompt) + model_name = _normalize_model_name(self.model_name) + + client = self._build_client( + api_key_override=kwargs.get("api_key"), + base_url_override=kwargs.get("api_base"), + timeout_override=config.app.get("llm_text_timeout", 180), + ) + + completion_kwargs: Dict[str, Any] = { + "model": model_name, + "messages": messages, + "temperature": temperature, + } + if max_tokens: + completion_kwargs["max_tokens"] = max_tokens + if response_format == "json": + completion_kwargs["response_format"] = {"type": "json_object"} + + try: + response = await client.chat.completions.create(**completion_kwargs) + if response.choices and response.choices[0].message and response.choices[0].message.content: + return response.choices[0].message.content + raise APICallError("OpenAI 兼容接口返回空响应") + + except OpenAIBadRequestError as exc: + error_msg = str(exc) + # 某些网关不支持 response_format,回退到提示词约束模式 + if response_format == "json" and _is_response_format_error(error_msg): + logger.warning("目标网关不支持 response_format,回退为提示词约束 JSON 输出") + completion_kwargs.pop("response_format", None) + messages[-1]["content"] += "\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。" + + retry_response = await client.chat.completions.create(**completion_kwargs) + if retry_response.choices and retry_response.choices[0].message and retry_response.choices[0].message.content: + return _clean_json_output(retry_response.choices[0].message.content) + raise APICallError("OpenAI 兼容接口返回空响应") + + if _is_content_filter_error(error_msg): + raise ContentFilterError(f"内容被安全过滤器阻止: {error_msg}") + raise APICallError(f"请求错误: {error_msg}") + + except OpenAIAuthError as exc: + logger.error(f"OpenAI 兼容接口认证失败: {exc}") + raise AuthenticationError(str(exc)) + except OpenAIRateLimitError as exc: + logger.error(f"OpenAI 兼容接口速率限制: {exc}") + raise RateLimitError(str(exc)) + except OpenAIAPIError as exc: + logger.error(f"OpenAI 兼容接口 API 错误: {exc}") + raise APICallError(f"API 错误: {exc}") + except Exception as exc: + logger.error(f"OpenAI 兼容接口调用失败: {exc}") + raise APICallError(f"调用失败: {exc}") + + async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]: + return payload diff --git a/app/services/llm/providers/__init__.py b/app/services/llm/providers/__init__.py index f9bcbb0..d8ecc65 100644 --- a/app/services/llm/providers/__init__.py +++ b/app/services/llm/providers/__init__.py @@ -2,7 +2,6 @@ 大模型服务提供商实现 包含各种大模型服务提供商的具体实现 -推荐使用 LiteLLM 统一接口(支持 100+ providers) """ # 不在模块顶部导入 provider 类,避免循环依赖 @@ -13,25 +12,25 @@ def register_all_providers(): """ 注册所有提供商 - v0.8.0 变更:只注册 LiteLLM 统一接口 - - 移除了旧的单独 provider 实现 (gemini, openai, qwen, deepseek, siliconflow) - - LiteLLM 支持 100+ providers,无需单独实现 + 当前实现:只注册 OpenAI 兼容统一接口 """ # 在函数内部导入,避免循环依赖 from ..manager import LLMServiceManager from loguru import logger - # 只导入 LiteLLM provider - from ..litellm_provider import LiteLLMVisionProvider, LiteLLMTextProvider + # 只导入 OpenAI 兼容 provider + from ..openai_compatible_provider import ( + OpenAICompatibleVisionProvider, + OpenAICompatibleTextProvider, + ) logger.info("🔧 开始注册 LLM 提供商...") - # ===== 注册 LiteLLM 统一接口 ===== - # LiteLLM 支持 100+ providers(OpenAI, Gemini, Qwen, DeepSeek, SiliconFlow, 等) - LLMServiceManager.register_vision_provider('litellm', LiteLLMVisionProvider) - LLMServiceManager.register_text_provider('litellm', LiteLLMTextProvider) + # ===== 注册 OpenAI 兼容统一接口 ===== + LLMServiceManager.register_vision_provider('openai', OpenAICompatibleVisionProvider) + LLMServiceManager.register_text_provider('openai', OpenAICompatibleTextProvider) - logger.info("✅ LiteLLM 提供商注册完成(支持 100+ providers)") + logger.info("✅ OpenAI 兼容提供商注册完成") # 导出注册函数 diff --git a/app/services/llm/test_litellm_integration.py b/app/services/llm/test_litellm_integration.py deleted file mode 100644 index b354771..0000000 --- a/app/services/llm/test_litellm_integration.py +++ /dev/null @@ -1,228 +0,0 @@ -""" -LiteLLM 集成测试脚本 - -测试 LiteLLM provider 是否正确集成到系统中 -""" - -import asyncio -import sys -from pathlib import Path - -# 添加项目根目录到 Python 路径 -project_root = Path(__file__).parent.parent.parent.parent -sys.path.insert(0, str(project_root)) - -from loguru import logger -from app.services.llm.manager import LLMServiceManager -from app.services.llm.unified_service import UnifiedLLMService - - -def test_provider_registration(): - """测试 provider 是否正确注册""" - logger.info("=" * 60) - logger.info("测试 1: Provider 注册检查") - logger.info("=" * 60) - - # 检查 LiteLLM provider 是否已注册 - vision_providers = LLMServiceManager.list_vision_providers() - text_providers = LLMServiceManager.list_text_providers() - - logger.info(f"已注册的视觉模型 providers: {vision_providers}") - logger.info(f"已注册的文本模型 providers: {text_providers}") - - assert 'litellm' in vision_providers, "❌ LiteLLM Vision Provider 未注册" - assert 'litellm' in text_providers, "❌ LiteLLM Text Provider 未注册" - - logger.success("✅ LiteLLM providers 已成功注册") - - # 显示所有 provider 信息 - provider_info = LLMServiceManager.get_provider_info() - logger.info("\n所有 Provider 信息:") - logger.info(f" 视觉模型 providers: {list(provider_info['vision_providers'].keys())}") - logger.info(f" 文本模型 providers: {list(provider_info['text_providers'].keys())}") - - -def test_litellm_import(): - """测试 LiteLLM 库是否正确安装""" - logger.info("\n" + "=" * 60) - logger.info("测试 2: LiteLLM 库导入检查") - logger.info("=" * 60) - - try: - import litellm - logger.success(f"✅ LiteLLM 已安装,版本: {litellm.__version__}") - return True - except ImportError as e: - logger.error(f"❌ LiteLLM 未安装: {str(e)}") - logger.info("请运行: pip install litellm>=1.70.0") - return False - - -async def test_text_generation_mock(): - """测试文本生成接口(模拟模式,不实际调用 API)""" - logger.info("\n" + "=" * 60) - logger.info("测试 3: 文本生成接口(模拟)") - logger.info("=" * 60) - - try: - # 这里只测试接口是否可调用,不实际发送 API 请求 - logger.info("接口测试通过:UnifiedLLMService.generate_text 可调用") - logger.success("✅ 文本生成接口测试通过") - return True - except Exception as e: - logger.error(f"❌ 文本生成接口测试失败: {str(e)}") - return False - - -async def test_vision_analysis_mock(): - """测试视觉分析接口(模拟模式)""" - logger.info("\n" + "=" * 60) - logger.info("测试 4: 视觉分析接口(模拟)") - logger.info("=" * 60) - - try: - # 这里只测试接口是否可调用 - logger.info("接口测试通过:UnifiedLLMService.analyze_images 可调用") - logger.success("✅ 视觉分析接口测试通过") - return True - except Exception as e: - logger.error(f"❌ 视觉分析接口测试失败: {str(e)}") - return False - - -def test_backward_compatibility(): - """测试向后兼容性""" - logger.info("\n" + "=" * 60) - logger.info("测试 5: 向后兼容性检查") - logger.info("=" * 60) - - # 检查旧的 provider 是否仍然可用 - old_providers = ['gemini', 'openai', 'qwen', 'deepseek', 'siliconflow'] - vision_providers = LLMServiceManager.list_vision_providers() - text_providers = LLMServiceManager.list_text_providers() - - logger.info("检查旧 provider 是否仍然可用:") - for provider in old_providers: - if provider in ['openai', 'deepseek']: - # 这些只有 text provider - if provider in text_providers: - logger.info(f" ✅ {provider} (text)") - else: - logger.warning(f" ⚠️ {provider} (text) 未注册") - else: - # 这些有 vision 和 text provider - vision_ok = provider in vision_providers or f"{provider}vl" in vision_providers - text_ok = provider in text_providers - - if vision_ok: - logger.info(f" ✅ {provider} (vision)") - if text_ok: - logger.info(f" ✅ {provider} (text)") - - logger.success("✅ 向后兼容性测试通过") - - -def print_usage_guide(): - """打印使用指南""" - logger.info("\n" + "=" * 60) - logger.info("LiteLLM 使用指南") - logger.info("=" * 60) - - guide = """ -📚 如何使用 LiteLLM: - -1. 在 config.toml 中配置: - ```toml - [app] - # 方式 1:直接使用 LiteLLM(推荐) - vision_llm_provider = "litellm" - vision_litellm_model_name = "gemini/gemini-2.0-flash-lite" - vision_litellm_api_key = "your-api-key" - - text_llm_provider = "litellm" - text_litellm_model_name = "deepseek/deepseek-chat" - text_litellm_api_key = "your-api-key" - ``` - -2. 支持的模型格式: - - Gemini: gemini/gemini-2.0-flash - - DeepSeek: deepseek/deepseek-chat - - Qwen: qwen/qwen-plus - - OpenAI: gpt-4o, gpt-4o-mini - - SiliconFlow: siliconflow/deepseek-ai/DeepSeek-R1 - - 更多: 参考 https://docs.litellm.ai/docs/providers - -3. 代码调用示例: - ```python - from app.services.llm.unified_service import UnifiedLLMService - - # 文本生成 - result = await UnifiedLLMService.generate_text( - prompt="你好", - provider="litellm" - ) - - # 视觉分析 - results = await UnifiedLLMService.analyze_images( - images=["path/to/image.jpg"], - prompt="描述这张图片", - provider="litellm" - ) - ``` - -4. 优势: - ✅ 减少 80% 代码量 - ✅ 统一的错误处理 - ✅ 自动重试机制 - ✅ 支持 100+ providers - ✅ 自动成本追踪 - -5. 迁移建议: - - 新项目:直接使用 LiteLLM - - 旧项目:逐步迁移,旧的 provider 仍然可用 - - 测试充分后再切换生产环境 -""" - print(guide) - - -def main(): - """运行所有测试""" - logger.info("开始 LiteLLM 集成测试...\n") - - try: - # 测试 1: Provider 注册 - test_provider_registration() - - # 测试 2: LiteLLM 库导入 - litellm_available = test_litellm_import() - - if not litellm_available: - logger.warning("\n⚠️ LiteLLM 未安装,跳过 API 测试") - logger.info("请运行: pip install litellm>=1.70.0") - else: - # 测试 3-4: 接口测试(模拟) - asyncio.run(test_text_generation_mock()) - asyncio.run(test_vision_analysis_mock()) - - # 测试 5: 向后兼容性 - test_backward_compatibility() - - # 打印使用指南 - print_usage_guide() - - logger.info("\n" + "=" * 60) - logger.success("🎉 所有测试通过!") - logger.info("=" * 60) - - return True - - except Exception as e: - logger.error(f"\n❌ 测试失败: {str(e)}") - import traceback - traceback.print_exc() - return False - - -if __name__ == "__main__": - success = main() - sys.exit(0 if success else 1) diff --git a/app/services/llm/test_openai_compat_unittest.py b/app/services/llm/test_openai_compat_unittest.py new file mode 100644 index 0000000..faa4e80 --- /dev/null +++ b/app/services/llm/test_openai_compat_unittest.py @@ -0,0 +1,67 @@ +"""OpenAI 兼容 provider 的最小回归测试。""" + +import unittest + +from app.config import config +from app.services.llm.base import TextModelProvider +from app.services.llm.manager import LLMServiceManager +from app.services.llm.providers import register_all_providers + + +class DummyOpenAITextProvider(TextModelProvider): + @property + def provider_name(self) -> str: + return "openai" + + @property + def supported_models(self) -> list[str]: + return [] + + async def generate_text(self, prompt: str, **kwargs) -> str: + return prompt + + async def _make_api_call(self, payload: dict) -> dict: + return payload + + +def _reset_manager_state(): + LLMServiceManager._vision_providers.clear() + LLMServiceManager._text_providers.clear() + LLMServiceManager._vision_instance_cache.clear() + LLMServiceManager._text_instance_cache.clear() + + +class OpenAICompatManagerTests(unittest.TestCase): + def setUp(self): + _reset_manager_state() + self._original_app = dict(config.app) + + def tearDown(self): + _reset_manager_state() + config.app.clear() + config.app.update(self._original_app) + + def test_register_all_providers_only_registers_openai_provider(self): + register_all_providers() + + self.assertEqual({"openai"}, set(LLMServiceManager.list_text_providers())) + self.assertEqual({"openai"}, set(LLMServiceManager.list_vision_providers())) + + def test_get_text_provider_uses_openai_keys(self): + LLMServiceManager.register_text_provider("openai", DummyOpenAITextProvider) + + config.app["text_llm_provider"] = "openai" + config.app["text_openai_api_key"] = "new-key" + config.app["text_openai_model_name"] = "new-model" + config.app["text_openai_base_url"] = "https://new.example/v1" + + provider = LLMServiceManager.get_text_provider() + + self.assertIsInstance(provider, DummyOpenAITextProvider) + self.assertEqual("new-key", provider.api_key) + self.assertEqual("new-model", provider.model_name) + self.assertEqual("https://new.example/v1", provider.base_url) + + +if __name__ == "__main__": + unittest.main() diff --git a/app/services/llm/test_openai_compatible_integration.py b/app/services/llm/test_openai_compatible_integration.py new file mode 100644 index 0000000..fbc726d --- /dev/null +++ b/app/services/llm/test_openai_compatible_integration.py @@ -0,0 +1,35 @@ +""" +OpenAI 兼容接口集成测试脚本 + +用于快速检查统一 LLM Provider 是否注册成功。 +""" + +from loguru import logger + +from app.services.llm.manager import LLMServiceManager +from app.services.llm.providers import register_all_providers + + +def test_provider_registration() -> bool: + """检查 OpenAI 兼容 provider 是否注册成功。""" + logger.info("测试:Provider 注册检查") + register_all_providers() + + vision_providers = LLMServiceManager.list_vision_providers() + text_providers = LLMServiceManager.list_text_providers() + + assert "openai" in vision_providers, "❌ OpenAI 兼容 Vision Provider 未注册" + assert "openai" in text_providers, "❌ OpenAI 兼容 Text Provider 未注册" + + logger.success("✅ OpenAI 兼容 providers 已成功注册") + return True + + +if __name__ == "__main__": + try: + ok = test_provider_registration() + if ok: + logger.success("\n🎉 集成检查通过") + except Exception as exc: + logger.error(f"\n❌ 集成检查失败: {exc}") + raise diff --git a/app/services/script_service.py b/app/services/script_service.py index 1cc27ab..34a17a6 100644 --- a/app/services/script_service.py +++ b/app/services/script_service.py @@ -133,7 +133,7 @@ class ScriptGenerator: from app.services.llm.migration_adapter import create_vision_analyzer # 获取配置 - text_provider = config.app.get('text_llm_provider', 'litellm').lower() + text_provider = config.app.get('text_llm_provider', 'openai').lower() vision_api_key = config.app.get(f'vision_{vision_llm_provider}_api_key') vision_model = config.app.get(f'vision_{vision_llm_provider}_model_name') vision_base_url = config.app.get(f'vision_{vision_llm_provider}_base_url') @@ -321,4 +321,4 @@ class ScriptGenerator: last_timestamp = format_timestamp(last_time) timestamp_range = f"{first_timestamp}-{last_timestamp}" - return first_timestamp, last_timestamp, timestamp_range \ No newline at end of file + return first_timestamp, last_timestamp, timestamp_range diff --git a/config.example.toml b/config.example.toml index d347629..cc12fb1 100644 --- a/config.example.toml +++ b/config.example.toml @@ -4,24 +4,16 @@ # LLM API 超时配置(秒) llm_vision_timeout = 120 # 视觉模型基础超时时间 llm_text_timeout = 180 # 文本模型基础超时时间(解说文案生成等复杂任务需要更长时间) - llm_max_retries = 3 # API 重试次数(LiteLLM 会自动处理重试) + llm_max_retries = 3 # API 重试次数 ########################################## - # 🚀 LLM 配置 - 使用 LiteLLM 统一接口 + # 🚀 LLM 配置 - 使用 OpenAI 兼容统一接口 ########################################## - # LiteLLM 是统一的 LLM 接口库,支持 100+ providers - # 优势: - # ✅ 代码量减少 80%,统一的 API 接口 - # ✅ 自动重试和智能错误处理 - # ✅ 内置成本追踪和 token 统计 - # ✅ 支持更多 providers:OpenAI, Anthropic, Gemini, Qwen, DeepSeek, - # Cohere, Together AI, Replicate, Groq, Mistral 等 - # - # 文档:https://docs.litellm.ai/ - # 支持的模型:https://docs.litellm.ai/docs/providers + # 统一使用 OpenAI 兼容协议(/v1/chat/completions) + # 支持接入 OpenAI、DeepSeek、Gemini 兼容网关、Qwen 网关、SiliconFlow、OpenRouter 等。 # ===== 视觉模型配置 ===== - vision_llm_provider = "litellm" + vision_llm_provider = "openai" # 模型格式:provider/model_name # 常用视觉模型示例: @@ -30,12 +22,12 @@ # - OpenAI: gpt-4o, gpt-4o-mini # - Qwen: qwen/qwen2.5-vl-32b-instruct # - SiliconFlow: siliconflow/Qwen/Qwen2.5-VL-32B-Instruct - vision_litellm_model_name = "gemini/gemini-2.0-flash-lite" - vision_litellm_api_key = "" # 填入对应 provider 的 API key - vision_litellm_base_url = "" # 可选:自定义 API base URL + vision_openai_model_name = "gemini/gemini-2.0-flash-lite" + vision_openai_api_key = "" # 填入对应 provider 的 API key + vision_openai_base_url = "" # 可选:自定义 API base URL(官方 OpenAI 可留空) # ===== 文本模型配置 ===== - text_llm_provider = "litellm" + text_llm_provider = "openai" # 常用文本模型示例: # - DeepSeek: deepseek/deepseek-chat (推荐,性价比高) @@ -45,9 +37,9 @@ # - Qwen: qwen/qwen-plus, qwen/qwen-turbo # - SiliconFlow: siliconflow/deepseek-ai/DeepSeek-R1 # - Moonshot: moonshot/moonshot-v1-8k - text_litellm_model_name = "deepseek/deepseek-chat" - text_litellm_api_key = "" # 填入对应 provider 的 API key - text_litellm_base_url = "" # 可选:自定义 API base URL + text_openai_model_name = "deepseek/deepseek-chat" + text_openai_api_key = "" # 填入对应 provider 的 API key + text_openai_base_url = "" # 可选:自定义 API base URL(官方 OpenAI 可留空) # ===== API Keys 参考 ===== # 主流 LLM Providers API Key 获取地址: @@ -69,21 +61,7 @@ # WebUI 界面是否显示配置项 hide_config = true - ########################################## - # 📚 传统配置示例(仅供参考,不推荐使用) - ########################################## - # 如果需要使用传统的单独 provider 实现,可以参考以下配置 - # 但强烈推荐使用上面的 LiteLLM 配置 - # - # 传统视觉模型配置示例: - # vision_llm_provider = "gemini" # 可选:gemini, qwenvl, siliconflow - # vision_gemini_api_key = "" - # vision_gemini_model_name = "gemini-2.0-flash-lite" - # - # 传统文本模型配置示例: - # text_llm_provider = "openai" # 可选:openai, gemini, qwen, deepseek, siliconflow, moonshot - # text_openai_api_key = "" - # text_openai_model_name = "gpt-4o-mini" + # 官方 OpenAI 默认端点(可选): # text_openai_base_url = "https://api.openai.com/v1" ########################################## diff --git a/requirements.txt b/requirements.txt index 2d09a05..12def4d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,8 +12,7 @@ pysrt==1.1.2 # AI 服务依赖 openai>=1.77.0 -litellm>=1.70.0 # 统一的 LLM 接口,支持 100+ providers -google-generativeai>=0.8.5 # LiteLLM 会使用此库调用 Gemini +google-generativeai>=0.8.5 # 原生 Gemini 场景依赖 azure-cognitiveservices-speech>=1.37.0 tencentcloud-sdk-python>=3.0.1200 dashscope>=1.24.6 diff --git a/webui/components/basic_settings.py b/webui/components/basic_settings.py index 60ac96d..48ef976 100644 --- a/webui/components/basic_settings.py +++ b/webui/components/basic_settings.py @@ -72,8 +72,8 @@ def validate_model_name(model_name: str, provider: str) -> tuple[bool, str]: return True, "" -def validate_litellm_model_name(model_name: str, model_type: str) -> tuple[bool, str]: - """验证 LiteLLM 模型名称格式 +def validate_openai_compatible_model_name(model_name: str, model_type: str) -> tuple[bool, str]: + """验证 OpenAI 兼容 模型名称格式 Args: model_name: 模型名称,应为 provider/model 格式 @@ -87,8 +87,8 @@ def validate_litellm_model_name(model_name: str, model_type: str) -> tuple[bool, model_name = model_name.strip() - # LiteLLM 推荐格式:provider/model(如 gemini/gemini-2.0-flash-lite) - # 但也支持直接的模型名称(如 gpt-4o,LiteLLM 会自动推断 provider) + # OpenAI 兼容 推荐格式:provider/model(如 gemini/gemini-2.0-flash-lite) + # 但也支持直接的模型名称(如 gpt-4o,OpenAI 兼容 会自动推断 provider) # 检查是否包含 provider 前缀(推荐格式) if "/" in model_name: @@ -101,9 +101,9 @@ def validate_litellm_model_name(model_name: str, model_type: str) -> tuple[bool, if not provider.replace("-", "").replace("_", "").isalnum(): return False, f"{model_type} Provider 名称只能包含字母、数字、下划线和连字符" else: - # 直接模型名称也是有效的(LiteLLM 会自动推断) + # 直接模型名称也是有效的(OpenAI 兼容 会自动推断) # 但给出警告建议使用完整格式 - logger.debug(f"{model_type} 模型名称未包含 provider 前缀,LiteLLM 将自动推断") + logger.debug(f"{model_type} 模型名称未包含 provider 前缀,OpenAI 兼容 将自动推断") # 基本长度检查 if len(model_name) < 3: @@ -115,6 +115,13 @@ def validate_litellm_model_name(model_name: str, model_type: str) -> tuple[bool, return True, "" +def normalize_openai_compatible_model_name(model_name: str) -> str: + """将 provider/model 格式转换为网关实际使用的模型名。""" + if "/" not in model_name: + return model_name + return model_name.split("/", 1)[1] + + def show_config_validation_errors(errors: list): """显示配置验证错误""" if errors: @@ -312,243 +319,112 @@ def test_vision_model_connection(api_key, base_url, model_name, provider, tr): -def test_litellm_vision_model(api_key: str, base_url: str, model_name: str, tr) -> tuple[bool, str]: - """测试 LiteLLM 视觉模型连接 - - Args: - api_key: API 密钥 - base_url: 基础 URL(可选) - model_name: 模型名称(LiteLLM 格式:provider/model) - tr: 翻译函数 - - Returns: - (连接是否成功, 测试结果消息) - """ +def test_openai_compatible_vision_model(api_key: str, base_url: str, model_name: str, tr) -> tuple[bool, str]: + """测试 OpenAI 兼容视觉模型连接。""" try: - import litellm - import os import base64 import io + from openai import OpenAI from PIL import Image - - logger.debug(f"LiteLLM 视觉模型连通性测试: model={model_name}, api_key={api_key[:10]}..., base_url={base_url}") - - # 提取 provider 名称 - provider = model_name.split("/")[0] if "/" in model_name else "unknown" - - # 设置 API key 到环境变量 - env_key_mapping = { - "gemini": "GEMINI_API_KEY", - "google": "GEMINI_API_KEY", - "openai": "OPENAI_API_KEY", - "qwen": "QWEN_API_KEY", - "dashscope": "DASHSCOPE_API_KEY", - "siliconflow": "SILICONFLOW_API_KEY", - } - env_var = env_key_mapping.get(provider.lower(), f"{provider.upper()}_API_KEY") - old_key = os.environ.get(env_var) - os.environ[env_var] = api_key - - # SiliconFlow 特殊处理:使用 OpenAI 兼容模式 - test_model_name = model_name - if provider.lower() == "siliconflow": - # 替换 provider 为 openai - if "/" in model_name: - test_model_name = f"openai/{model_name.split('/', 1)[1]}" - else: - test_model_name = f"openai/{model_name}" - - # 确保设置了 base_url - if not base_url: - base_url = "https://api.siliconflow.cn/v1" - - # 设置 OPENAI_API_KEY (SiliconFlow 使用 OpenAI 协议) - os.environ["OPENAI_API_KEY"] = api_key - os.environ["OPENAI_API_BASE"] = base_url - - try: - # 创建测试图片(64x64 白色像素,避免某些模型对极小图片的限制) - test_image = Image.new('RGB', (64, 64), color='white') - img_buffer = io.BytesIO() - test_image.save(img_buffer, format='JPEG') - img_bytes = img_buffer.getvalue() - base64_image = base64.b64encode(img_bytes).decode('utf-8') - - # 构建测试请求 - messages = [{ - "role": "user", - "content": [ - {"type": "text", "text": "请直接回复'连接成功'"}, - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{base64_image}" - } - } - ] - }] - - # 准备参数 - completion_kwargs = { - "model": test_model_name, - "messages": messages, - "temperature": 0.1, - "max_tokens": 50 - } - - if base_url: - completion_kwargs["api_base"] = base_url - - # 调用 LiteLLM(同步调用用于测试) - response = litellm.completion(**completion_kwargs) - - if response and response.choices and len(response.choices) > 0: - return True, f"LiteLLM 视觉模型连接成功 ({model_name})" - else: - return False, f"LiteLLM 视觉模型返回空响应" - - finally: - # 恢复原始环境变量 - if old_key: - os.environ[env_var] = old_key - else: - os.environ.pop(env_var, None) - - # 清理临时设置的 OpenAI 环境变量 - if provider.lower() == "siliconflow": - os.environ.pop("OPENAI_API_KEY", None) - os.environ.pop("OPENAI_API_BASE", None) - + + logger.debug( + f"OpenAI 兼容视觉模型连通性测试: model={model_name}, api_key={api_key[:10]}..., base_url={base_url}" + ) + + client = OpenAI( + api_key=api_key, + base_url=base_url or None, + timeout=10.0, + max_retries=1, + ) + + # 创建测试图片(64x64 白色像素) + test_image = Image.new("RGB", (64, 64), color="white") + img_buffer = io.BytesIO() + test_image.save(img_buffer, format="JPEG") + base64_image = base64.b64encode(img_buffer.getvalue()).decode("utf-8") + + response = client.chat.completions.create( + model=normalize_openai_compatible_model_name(model_name), + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "请直接回复'连接成功'"}, + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}, + }, + ], + } + ], + temperature=0.1, + max_tokens=50, + ) + + if response and response.choices and len(response.choices) > 0: + return True, f"OpenAI 兼容视觉模型连接成功 ({model_name})" + return False, "OpenAI 兼容视觉模型返回空响应" except Exception as e: error_msg = str(e) - logger.error(f"LiteLLM 视觉模型测试失败: {error_msg}") - - # 提供更友好的错误信息 + logger.error(f"OpenAI 兼容视觉模型测试失败: {error_msg}") if "authentication" in error_msg.lower() or "api_key" in error_msg.lower(): - return False, f"认证失败,请检查 API Key 是否正确" - elif "not found" in error_msg.lower() or "404" in error_msg: - return False, f"模型不存在,请检查模型名称是否正确" - elif "rate limit" in error_msg.lower(): - return False, f"超出速率限制,请稍后重试" - else: - return False, f"连接失败: {error_msg}" + return False, "认证失败,请检查 API Key 是否正确" + if "not found" in error_msg.lower() or "404" in error_msg: + return False, "模型不存在,请检查模型名称是否正确" + if "rate limit" in error_msg.lower(): + return False, "超出速率限制,请稍后重试" + return False, f"连接失败: {error_msg}" -def test_litellm_text_model(api_key: str, base_url: str, model_name: str, tr) -> tuple[bool, str]: - """测试 LiteLLM 文本模型连接 - - Args: - api_key: API 密钥 - base_url: 基础 URL(可选) - model_name: 模型名称(LiteLLM 格式:provider/model) - tr: 翻译函数 - - Returns: - (连接是否成功, 测试结果消息) - """ +def test_openai_compatible_text_model(api_key: str, base_url: str, model_name: str, tr) -> tuple[bool, str]: + """测试 OpenAI 兼容文本模型连接。""" try: - import litellm - import os - - logger.debug(f"LiteLLM 文本模型连通性测试: model={model_name}, api_key={api_key[:10]}..., base_url={base_url}") - - # 提取 provider 名称 - provider = model_name.split("/")[0] if "/" in model_name else "unknown" - - # 设置 API key 到环境变量 - env_key_mapping = { - "gemini": "GEMINI_API_KEY", - "google": "GEMINI_API_KEY", - "openai": "OPENAI_API_KEY", - "qwen": "QWEN_API_KEY", - "dashscope": "DASHSCOPE_API_KEY", - "siliconflow": "SILICONFLOW_API_KEY", - "deepseek": "DEEPSEEK_API_KEY", - "moonshot": "MOONSHOT_API_KEY", - } - env_var = env_key_mapping.get(provider.lower(), f"{provider.upper()}_API_KEY") - old_key = os.environ.get(env_var) - os.environ[env_var] = api_key - - # SiliconFlow 特殊处理:使用 OpenAI 兼容模式 - test_model_name = model_name - if provider.lower() == "siliconflow": - # 替换 provider 为 openai - if "/" in model_name: - test_model_name = f"openai/{model_name.split('/', 1)[1]}" - else: - test_model_name = f"openai/{model_name}" - - # 确保设置了 base_url - if not base_url: - base_url = "https://api.siliconflow.cn/v1" - - # 设置 OPENAI_API_KEY (SiliconFlow 使用 OpenAI 协议) - os.environ["OPENAI_API_KEY"] = api_key - os.environ["OPENAI_API_BASE"] = base_url - - try: - # 构建测试请求 - messages = [ - {"role": "user", "content": "请直接回复'连接成功'"} - ] - - # 准备参数 - completion_kwargs = { - "model": test_model_name, - "messages": messages, - "temperature": 0.1, - "max_tokens": 20 - } - - if base_url: - completion_kwargs["api_base"] = base_url - - # 调用 LiteLLM(同步调用用于测试) - response = litellm.completion(**completion_kwargs) - - if response and response.choices and len(response.choices) > 0: - return True, f"LiteLLM 文本模型连接成功 ({model_name})" - else: - return False, f"LiteLLM 文本模型返回空响应" - - finally: - # 恢复原始环境变量 - if old_key: - os.environ[env_var] = old_key - else: - os.environ.pop(env_var, None) - - # 清理临时设置的 OpenAI 环境变量 - if provider.lower() == "siliconflow": - os.environ.pop("OPENAI_API_KEY", None) - os.environ.pop("OPENAI_API_BASE", None) - + from openai import OpenAI + + logger.debug( + f"OpenAI 兼容文本模型连通性测试: model={model_name}, api_key={api_key[:10]}..., base_url={base_url}" + ) + + client = OpenAI( + api_key=api_key, + base_url=base_url or None, + timeout=10.0, + max_retries=1, + ) + + response = client.chat.completions.create( + model=normalize_openai_compatible_model_name(model_name), + messages=[{"role": "user", "content": "请直接回复'连接成功'"}], + temperature=0.1, + max_tokens=20, + ) + + if response and response.choices and len(response.choices) > 0: + return True, f"OpenAI 兼容文本模型连接成功 ({model_name})" + return False, "OpenAI 兼容文本模型返回空响应" except Exception as e: error_msg = str(e) - logger.error(f"LiteLLM 文本模型测试失败: {error_msg}") - - # 提供更友好的错误信息 + logger.error(f"OpenAI 兼容文本模型测试失败: {error_msg}") if "authentication" in error_msg.lower() or "api_key" in error_msg.lower(): - return False, f"认证失败,请检查 API Key 是否正确" - elif "not found" in error_msg.lower() or "404" in error_msg: - return False, f"模型不存在,请检查模型名称是否正确" - elif "rate limit" in error_msg.lower(): - return False, f"超出速率限制,请稍后重试" - else: - return False, f"连接失败: {error_msg}" + return False, "认证失败,请检查 API Key 是否正确" + if "not found" in error_msg.lower() or "404" in error_msg: + return False, "模型不存在,请检查模型名称是否正确" + if "rate limit" in error_msg.lower(): + return False, "超出速率限制,请稍后重试" + return False, f"连接失败: {error_msg}" def render_vision_llm_settings(tr): - """渲染视频分析模型设置(LiteLLM 统一配置)""" + """渲染视频分析模型设置(OpenAI 兼容 统一配置)""" st.subheader(tr("Vision Model Settings")) - # 固定使用 LiteLLM 提供商 - config.app["vision_llm_provider"] = "litellm" + # 固定使用 OpenAI 兼容 提供商 + config.app["vision_llm_provider"] = "openai" - # 获取已保存的 LiteLLM 配置 - full_vision_model_name = config.app.get("vision_litellm_model_name", "gemini/gemini-2.0-flash-lite") - vision_api_key = config.app.get("vision_litellm_api_key", "") - vision_base_url = config.app.get("vision_litellm_base_url", "") + # 获取已保存的配置 + full_vision_model_name = config.app.get("vision_openai_model_name") or "gemini/gemini-2.0-flash-lite" + vision_api_key = config.app.get("vision_openai_api_key", "") + vision_base_url = config.app.get("vision_openai_base_url", "") # 解析 provider 和 model default_provider = "gemini" @@ -563,7 +439,7 @@ def render_vision_llm_settings(tr): current_model = full_vision_model_name # 定义支持的 provider 列表 - LITELLM_PROVIDERS = [ + OPENAI_COMPATIBLE_PROVIDERS = [ "openai", "gemini", "deepseek", "qwen", "siliconflow", "moonshot", "anthropic", "azure", "ollama", "vertex_ai", "mistral", "codestral", "volcengine", "groq", "cohere", "together_ai", "fireworks_ai", @@ -572,16 +448,16 @@ def render_vision_llm_settings(tr): ] # 如果当前 provider 不在列表中,添加到列表头部 - if current_provider not in LITELLM_PROVIDERS: - LITELLM_PROVIDERS.insert(0, current_provider) + if current_provider not in OPENAI_COMPATIBLE_PROVIDERS: + OPENAI_COMPATIBLE_PROVIDERS.insert(0, current_provider) # 渲染配置输入框 col1, col2 = st.columns([1, 2]) with col1: selected_provider = st.selectbox( tr("Vision Model Provider"), - options=LITELLM_PROVIDERS, - index=LITELLM_PROVIDERS.index(current_provider) if current_provider in LITELLM_PROVIDERS else 0, + options=OPENAI_COMPATIBLE_PROVIDERS, + index=OPENAI_COMPATIBLE_PROVIDERS.index(current_provider) if current_provider in OPENAI_COMPATIBLE_PROVIDERS else 0, key="vision_provider_select" ) @@ -595,7 +471,7 @@ def render_vision_llm_settings(tr): "• gpt-4o\n" "• qwen-vl-max\n" "• Qwen/Qwen2.5-VL-32B-Instruct (SiliconFlow)\n\n" - "支持 100+ providers,详见: https://docs.litellm.ai/docs/providers", + "支持常见 OpenAI 兼容网关(如 OpenAI/DeepSeek/OpenRouter/SiliconFlow)", key="vision_model_input" ) @@ -641,7 +517,7 @@ def render_vision_llm_settings(tr): else: with st.spinner(tr("Testing connection...")): try: - success, message = test_litellm_vision_model( + success, message = test_openai_compatible_vision_model( api_key=st_vision_api_key, base_url=st_vision_base_url, model_name=st_vision_model_name, @@ -654,7 +530,7 @@ def render_vision_llm_settings(tr): st.error(message) except Exception as e: st.error(f"测试连接时发生错误: {str(e)}") - logger.error(f"LiteLLM 视频分析模型连接测试失败: {str(e)}") + logger.error(f"OpenAI 兼容 视频分析模型连接测试失败: {str(e)}") # 验证和保存配置 validation_errors = [] @@ -663,10 +539,10 @@ def render_vision_llm_settings(tr): # 验证模型名称 if st_vision_model_name: # 这里的验证逻辑可能需要微调,因为我们现在是自动组合的 - is_valid, error_msg = validate_litellm_model_name(st_vision_model_name, "视频分析") + is_valid, error_msg = validate_openai_compatible_model_name(st_vision_model_name, "视频分析") if is_valid: - config.app["vision_litellm_model_name"] = st_vision_model_name - st.session_state["vision_litellm_model_name"] = st_vision_model_name + config.app["vision_openai_model_name"] = st_vision_model_name + st.session_state["vision_openai_model_name"] = st_vision_model_name config_changed = True else: validation_errors.append(error_msg) @@ -675,8 +551,8 @@ def render_vision_llm_settings(tr): if st_vision_api_key: is_valid, error_msg = validate_api_key(st_vision_api_key, "视频分析") if is_valid: - config.app["vision_litellm_api_key"] = st_vision_api_key - st.session_state["vision_litellm_api_key"] = st_vision_api_key + config.app["vision_openai_api_key"] = st_vision_api_key + st.session_state["vision_openai_api_key"] = st_vision_api_key config_changed = True else: validation_errors.append(error_msg) @@ -685,8 +561,8 @@ def render_vision_llm_settings(tr): if st_vision_base_url: is_valid, error_msg = validate_base_url(st_vision_base_url, "视频分析") if is_valid: - config.app["vision_litellm_base_url"] = st_vision_base_url - st.session_state["vision_litellm_base_url"] = st_vision_base_url + config.app["vision_openai_base_url"] = st_vision_base_url + st.session_state["vision_openai_base_url"] = st_vision_base_url config_changed = True else: validation_errors.append(error_msg) @@ -701,7 +577,7 @@ def render_vision_llm_settings(tr): # 清除缓存,确保下次使用新配置 UnifiedLLMService.clear_cache() if st_vision_api_key or st_vision_base_url or st_vision_model_name: - st.success(f"视频分析模型配置已保存(LiteLLM)") + st.success(f"视频分析模型配置已保存(OpenAI 兼容)") except Exception as e: st.error(f"保存配置失败: {str(e)}") logger.error(f"保存视频分析配置失败: {str(e)}") @@ -811,16 +687,16 @@ def test_text_model_connection(api_key, base_url, model_name, provider, tr): def render_text_llm_settings(tr): - """渲染文案生成模型设置(LiteLLM 统一配置)""" + """渲染文案生成模型设置(OpenAI 兼容 统一配置)""" st.subheader(tr("Text Generation Model Settings")) - # 固定使用 LiteLLM 提供商 - config.app["text_llm_provider"] = "litellm" + # 固定使用 OpenAI 兼容 提供商 + config.app["text_llm_provider"] = "openai" - # 获取已保存的 LiteLLM 配置 - full_text_model_name = config.app.get("text_litellm_model_name", "deepseek/deepseek-chat") - text_api_key = config.app.get("text_litellm_api_key", "") - text_base_url = config.app.get("text_litellm_base_url", "") + # 获取已保存的配置 + full_text_model_name = config.app.get("text_openai_model_name") or "deepseek/deepseek-chat" + text_api_key = config.app.get("text_openai_api_key", "") + text_base_url = config.app.get("text_openai_base_url", "") # 解析 provider 和 model default_provider = "deepseek" @@ -835,7 +711,7 @@ def render_text_llm_settings(tr): current_model = full_text_model_name # 定义支持的 provider 列表 - LITELLM_PROVIDERS = [ + OPENAI_COMPATIBLE_PROVIDERS = [ "openai", "gemini", "deepseek", "qwen", "siliconflow", "moonshot", "anthropic", "azure", "ollama", "vertex_ai", "mistral", "codestral", "volcengine", "groq", "cohere", "together_ai", "fireworks_ai", @@ -844,16 +720,16 @@ def render_text_llm_settings(tr): ] # 如果当前 provider 不在列表中,添加到列表头部 - if current_provider not in LITELLM_PROVIDERS: - LITELLM_PROVIDERS.insert(0, current_provider) + if current_provider not in OPENAI_COMPATIBLE_PROVIDERS: + OPENAI_COMPATIBLE_PROVIDERS.insert(0, current_provider) # 渲染配置输入框 col1, col2 = st.columns([1, 2]) with col1: selected_provider = st.selectbox( tr("Text Model Provider"), - options=LITELLM_PROVIDERS, - index=LITELLM_PROVIDERS.index(current_provider) if current_provider in LITELLM_PROVIDERS else 0, + options=OPENAI_COMPATIBLE_PROVIDERS, + index=OPENAI_COMPATIBLE_PROVIDERS.index(current_provider) if current_provider in OPENAI_COMPATIBLE_PROVIDERS else 0, key="text_provider_select" ) @@ -867,7 +743,7 @@ def render_text_llm_settings(tr): "• gpt-4o\n" "• gemini-2.0-flash\n" "• deepseek-ai/DeepSeek-R1 (SiliconFlow)\n\n" - "支持 100+ providers,详见: https://docs.litellm.ai/docs/providers", + "支持常见 OpenAI 兼容网关(如 OpenAI/DeepSeek/OpenRouter/SiliconFlow)", key="text_model_input" ) @@ -915,7 +791,7 @@ def render_text_llm_settings(tr): else: with st.spinner(tr("Testing connection...")): try: - success, message = test_litellm_text_model( + success, message = test_openai_compatible_text_model( api_key=st_text_api_key, base_url=st_text_base_url, model_name=st_text_model_name, @@ -928,7 +804,7 @@ def render_text_llm_settings(tr): st.error(message) except Exception as e: st.error(f"测试连接时发生错误: {str(e)}") - logger.error(f"LiteLLM 文案生成模型连接测试失败: {str(e)}") + logger.error(f"OpenAI 兼容 文案生成模型连接测试失败: {str(e)}") # 验证和保存配置 text_validation_errors = [] @@ -936,10 +812,10 @@ def render_text_llm_settings(tr): # 验证模型名称 if st_text_model_name: - is_valid, error_msg = validate_litellm_model_name(st_text_model_name, "文案生成") + is_valid, error_msg = validate_openai_compatible_model_name(st_text_model_name, "文案生成") if is_valid: - config.app["text_litellm_model_name"] = st_text_model_name - st.session_state["text_litellm_model_name"] = st_text_model_name + config.app["text_openai_model_name"] = st_text_model_name + st.session_state["text_openai_model_name"] = st_text_model_name text_config_changed = True else: text_validation_errors.append(error_msg) @@ -948,8 +824,8 @@ def render_text_llm_settings(tr): if st_text_api_key: is_valid, error_msg = validate_api_key(st_text_api_key, "文案生成") if is_valid: - config.app["text_litellm_api_key"] = st_text_api_key - st.session_state["text_litellm_api_key"] = st_text_api_key + config.app["text_openai_api_key"] = st_text_api_key + st.session_state["text_openai_api_key"] = st_text_api_key text_config_changed = True else: text_validation_errors.append(error_msg) @@ -958,8 +834,8 @@ def render_text_llm_settings(tr): if st_text_base_url: is_valid, error_msg = validate_base_url(st_text_base_url, "文案生成") if is_valid: - config.app["text_litellm_base_url"] = st_text_base_url - st.session_state["text_litellm_base_url"] = st_text_base_url + config.app["text_openai_base_url"] = st_text_base_url + st.session_state["text_openai_base_url"] = st_text_base_url text_config_changed = True else: text_validation_errors.append(error_msg) @@ -974,7 +850,7 @@ def render_text_llm_settings(tr): # 清除缓存,确保下次使用新配置 UnifiedLLMService.clear_cache() if st_text_api_key or st_text_base_url or st_text_model_name: - st.success(f"文案生成模型配置已保存(LiteLLM)") + st.success(f"文案生成模型配置已保存(OpenAI 兼容)") except Exception as e: st.error(f"保存配置失败: {str(e)}") logger.error(f"保存文案生成配置失败: {str(e)}") diff --git a/webui/tools/generate_script_docu.py b/webui/tools/generate_script_docu.py index d14c330..9f51c01 100644 --- a/webui/tools/generate_script_docu.py +++ b/webui/tools/generate_script_docu.py @@ -123,7 +123,7 @@ def generate_script_docu(params): # 最佳实践:使用 get() 的默认值参数 + 从 config 获取备用值 vision_llm_provider = ( st.session_state.get('vision_llm_provider') or - config.app.get('vision_llm_provider', 'litellm') + config.app.get('vision_llm_provider', 'openai') ).lower() logger.info(f"使用 {vision_llm_provider.upper()} 进行视觉分析")