feat: 移除 LiteLLM 依赖并迁移至 OpenAI 兼容接口

- 移除 LiteLLM 相关代码和依赖,改用原生 OpenAI 兼容接口
- 重构 LLM 服务提供商注册逻辑,仅支持 OpenAI 兼容接口
- 更新配置文件和文档,移除 LiteLLM 相关说明
- 添加新的测试用例验证 OpenAI 兼容接口集成
- 更新 WebUI 组件以适配新的 OpenAI 兼容接口
This commit is contained in:
linyq 2026-03-27 23:49:58 +08:00
parent a6f2e0d815
commit 3396644593
16 changed files with 582 additions and 1054 deletions

View File

@ -33,8 +33,9 @@ NarratoAI 是一个自动化影视解说工具基于LLM实现文案撰写、
本项目仅供学习和研究使用,不得商用。如需商业授权,请联系作者。
## 最新资讯
- 2026.03.27 出于安全考虑,已移除 LiteLLM 依赖,统一使用 OpenAI 兼容请求链路
- 2025.11.20 发布新版本 0.7.5 新增 [IndexTTS2](https://github.com/index-tts/index-tts) 语音克隆支持
- 2025.10.15 发布新版本 0.7.3 使用 [LiteLLM](https://github.com/BerriAI/litellm) 管理模型供应商
- 2025.10.15 发布新版本 0.7.3 升级大模型供应商管理能力
- 2025.09.10 发布新版本 0.7.2 新增腾讯云tts
- 2025.08.18 发布新版本 0.7.1,支持 **语音克隆** 和 最新大模型
- 2025.05.11 发布新版本 0.6.0,支持 **短剧解说** 和 优化剪辑流程
@ -176,4 +177,3 @@ streamlit run webui.py --server.maxUploadSize=2048
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=linyqh/NarratoAI&type=Date)](https://star-history.com/#linyqh/NarratoAI&Date)

View File

@ -12,8 +12,7 @@ NarratoAI 大模型服务模块
- OutputValidator: 输出格式验证器
支持的供应商:
视觉模型: Gemini, QwenVL, Siliconflow
文本模型: OpenAI, DeepSeek, Gemini, Qwen, Moonshot, Siliconflow
视觉模型/文本模型: OpenAI 兼容接口可对接 OpenAIDeepSeekGemini 网关Qwen 网关等
"""
from .manager import LLMServiceManager

View File

@ -68,7 +68,7 @@ class BaseLLMProvider(ABC):
"""验证模型支持情况(宽松模式,仅记录警告)"""
from loguru import logger
# LiteLLM 已提供统一的模型验证,传统 provider 使用宽松验证
# OpenAI 兼容网关的模型数量较多,运行时由远端完成最终校验
if self.model_name not in self.supported_models:
logger.warning(
f"模型 {self.model_name} 未在供应商 {self.provider_name} 的预定义支持列表中。"

View File

@ -214,7 +214,7 @@ class LLMConfigValidator:
"建议为每个提供商配置base_url以提高稳定性",
"定期检查模型名称是否为最新版本",
"建议配置多个提供商作为备用方案",
"推荐使用 LiteLLM 作为统一接口,支持 100+ providers"
"推荐使用 OpenAI 兼容接口,便于接入多家模型网关"
]
}
@ -257,8 +257,8 @@ class LLMConfigValidator:
"text": ["gemini-2.5-flash", "gemini-2.0-flash", "gemini-1.5-pro"]
},
"openai": {
"vision": [],
"text": ["gpt-4o", "gpt-4o-mini", "gpt-4-turbo"]
"vision": ["gpt-4o", "gemini-2.0-flash-lite", "Qwen/Qwen2.5-VL-32B-Instruct"],
"text": ["gpt-4o-mini", "deepseek-chat", "zai-org/GLM-4.6"]
},
"qwen": {
"vision": ["qwen2.5-vl-32b-instruct"],

View File

@ -1,490 +0,0 @@
"""
LiteLLM 统一提供商实现
使用 LiteLLM 库提供统一的 LLM 接口支持 100+ providers
包括 OpenAI, Anthropic, Gemini, Qwen, DeepSeek, SiliconFlow
"""
import asyncio
import base64
import io
from typing import List, Dict, Any, Optional, Union
from pathlib import Path
import PIL.Image
from loguru import logger
try:
import litellm
from litellm import acompletion, completion
from litellm.exceptions import (
AuthenticationError as LiteLLMAuthError,
RateLimitError as LiteLLMRateLimitError,
BadRequestError as LiteLLMBadRequestError,
APIError as LiteLLMAPIError
)
except ImportError:
logger.error("LiteLLM 未安装。请运行: pip install litellm")
raise
from .base import VisionModelProvider, TextModelProvider
from .exceptions import (
APICallError,
AuthenticationError,
RateLimitError,
ContentFilterError
)
# 配置 LiteLLM 全局设置
def configure_litellm():
"""配置 LiteLLM 全局参数"""
from app.config import config
# 设置重试次数
litellm.num_retries = config.app.get('llm_max_retries', 3)
# 设置默认超时
litellm.request_timeout = config.app.get('llm_text_timeout', 180)
# 启用详细日志(开发环境)
# litellm.set_verbose = True
logger.info(f"LiteLLM 配置完成: retries={litellm.num_retries}, timeout={litellm.request_timeout}s")
# 初始化配置
configure_litellm()
class LiteLLMVisionProvider(VisionModelProvider):
"""使用 LiteLLM 的统一视觉模型提供商"""
@property
def provider_name(self) -> str:
# 从 model_name 中提取 provider 名称(如 "gemini/gemini-2.0-flash"
if "/" in self.model_name:
return self.model_name.split("/")[0]
return "litellm"
@property
def supported_models(self) -> List[str]:
# LiteLLM 支持 100+ providers 和数百个模型,无法全部列举
# 返回空列表表示跳过预定义列表检查,由 LiteLLM 在实际调用时验证
return []
def _validate_model_support(self):
"""
重写模型验证逻辑
对于 LiteLLM我们不做预定义列表检查因为
1. LiteLLM 支持 100+ providers 和数百个模型无法全部列举
2. LiteLLM 会在实际调用时进行模型验证
3. 如果模型不支持LiteLLM 会返回清晰的错误信息
这里只做基本的格式验证可选
"""
from loguru import logger
# 可选检查模型名称格式provider/model
if "/" not in self.model_name:
logger.debug(
f"LiteLLM 模型名称 '{self.model_name}' 未包含 provider 前缀,"
f"LiteLLM 将尝试自动推断。建议使用 'provider/model' 格式,如 'gemini/gemini-2.5-flash'"
)
# 不抛出异常,让 LiteLLM 在实际调用时验证
logger.debug(f"LiteLLM 视觉模型已配置: {self.model_name}")
def _initialize(self):
"""初始化 LiteLLM 特定设置"""
# 设置 API key 到环境变量LiteLLM 会自动读取)
import os
# 根据 model_name 确定需要设置哪个 API key
provider = self.provider_name.lower()
# 映射 provider 到环境变量名
env_key_mapping = {
"gemini": "GEMINI_API_KEY",
"google": "GEMINI_API_KEY",
"openai": "OPENAI_API_KEY",
"qwen": "QWEN_API_KEY",
"dashscope": "DASHSCOPE_API_KEY",
"siliconflow": "SILICONFLOW_API_KEY",
"anthropic": "ANTHROPIC_API_KEY",
"claude": "ANTHROPIC_API_KEY"
}
env_var = env_key_mapping.get(provider, f"{provider.upper()}_API_KEY")
if self.api_key and env_var:
os.environ[env_var] = self.api_key
logger.debug(f"设置环境变量: {env_var}")
# 如果提供了 base_url设置到 LiteLLM
if self.base_url:
# LiteLLM 支持通过 api_base 参数设置自定义 URL
self._api_base = self.base_url
logger.debug(f"使用自定义 API base URL: {self.base_url}")
async def analyze_images(self,
images: List[Union[str, Path, PIL.Image.Image]],
prompt: str,
batch_size: int = 10,
**kwargs) -> List[str]:
"""
使用 LiteLLM 分析图片
Args:
images: 图片路径列表或PIL图片对象列表
prompt: 分析提示词
batch_size: 批处理大小
**kwargs: 其他参数
Returns:
分析结果列表
"""
logger.info(f"开始使用 LiteLLM ({self.model_name}) 分析 {len(images)} 张图片")
# 预处理图片
processed_images = self._prepare_images(images)
# 分批处理
results = []
for i in range(0, len(processed_images), batch_size):
batch = processed_images[i:i + batch_size]
logger.info(f"处理第 {i//batch_size + 1} 批,共 {len(batch)} 张图片")
try:
result = await self._analyze_batch(batch, prompt, **kwargs)
results.append(result)
except Exception as e:
logger.error(f"批次 {i//batch_size + 1} 处理失败: {str(e)}")
results.append(f"批次处理失败: {str(e)}")
return results
async def _analyze_batch(self, batch: List[PIL.Image.Image], prompt: str, **kwargs) -> str:
"""分析一批图片"""
# 构建 LiteLLM 格式的消息
content = [{"type": "text", "text": prompt}]
# 添加图片(使用 base64 编码)
for img in batch:
base64_image = self._image_to_base64(img)
content.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
})
messages = [{
"role": "user",
"content": content
}]
# 调用 LiteLLM
try:
# 准备参数
effective_model_name = self.model_name
# SiliconFlow 特殊处理
if self.model_name.lower().startswith("siliconflow/"):
# 替换 provider 为 openai
if "/" in self.model_name:
effective_model_name = f"openai/{self.model_name.split('/', 1)[1]}"
else:
effective_model_name = f"openai/{self.model_name}"
# 确保设置了 OPENAI_API_KEY (如果尚未设置)
import os
if not os.environ.get("OPENAI_API_KEY") and os.environ.get("SILICONFLOW_API_KEY"):
os.environ["OPENAI_API_KEY"] = os.environ.get("SILICONFLOW_API_KEY")
# 确保设置了 base_url (如果尚未设置)
if not hasattr(self, '_api_base'):
self._api_base = "https://api.siliconflow.cn/v1"
completion_kwargs = {
"model": effective_model_name,
"messages": messages,
"temperature": kwargs.get("temperature", 1.0),
"max_tokens": kwargs.get("max_tokens", 4000)
}
# 如果有自定义 base_url添加 api_base 参数
if hasattr(self, '_api_base'):
completion_kwargs["api_base"] = self._api_base
# 支持动态传递 api_key 和 api_base
if "api_key" in kwargs:
completion_kwargs["api_key"] = kwargs["api_key"]
if "api_base" in kwargs:
completion_kwargs["api_base"] = kwargs["api_base"]
response = await acompletion(**completion_kwargs)
if response.choices and len(response.choices) > 0:
content = response.choices[0].message.content
logger.debug(f"LiteLLM 调用成功,消耗 tokens: {response.usage.total_tokens if response.usage else 'N/A'}")
return content
else:
raise APICallError("LiteLLM 返回空响应")
except LiteLLMAuthError as e:
logger.error(f"LiteLLM 认证失败: {str(e)}")
raise AuthenticationError()
except LiteLLMRateLimitError as e:
logger.error(f"LiteLLM 速率限制: {str(e)}")
raise RateLimitError()
except LiteLLMBadRequestError as e:
error_msg = str(e)
if "SAFETY" in error_msg.upper() or "content_filter" in error_msg.lower():
raise ContentFilterError(f"内容被安全过滤器阻止: {error_msg}")
logger.error(f"LiteLLM 请求错误: {error_msg}")
raise APICallError(f"请求错误: {error_msg}")
except LiteLLMAPIError as e:
logger.error(f"LiteLLM API 错误: {str(e)}")
raise APICallError(f"API 错误: {str(e)}")
except Exception as e:
logger.error(f"LiteLLM 调用失败: {str(e)}")
raise APICallError(f"调用失败: {str(e)}")
def _image_to_base64(self, img: PIL.Image.Image) -> str:
"""将PIL图片转换为base64编码"""
img_buffer = io.BytesIO()
img.save(img_buffer, format='JPEG', quality=85)
img_bytes = img_buffer.getvalue()
return base64.b64encode(img_bytes).decode('utf-8')
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""兼容基类接口(实际使用 LiteLLM SDK"""
pass
class LiteLLMTextProvider(TextModelProvider):
"""使用 LiteLLM 的统一文本生成提供商"""
@property
def provider_name(self) -> str:
# 从 model_name 中提取 provider 名称
if "/" in self.model_name:
return self.model_name.split("/")[0]
# 尝试从模型名称推断 provider
model_lower = self.model_name.lower()
if "gpt" in model_lower:
return "openai"
elif "claude" in model_lower:
return "anthropic"
elif "gemini" in model_lower:
return "gemini"
elif "qwen" in model_lower:
return "qwen"
elif "deepseek" in model_lower:
return "deepseek"
return "litellm"
@property
def supported_models(self) -> List[str]:
# LiteLLM 支持 100+ providers 和数百个模型,无法全部列举
# 返回空列表表示跳过预定义列表检查,由 LiteLLM 在实际调用时验证
return []
def _validate_model_support(self):
"""
重写模型验证逻辑
对于 LiteLLM我们不做预定义列表检查因为
1. LiteLLM 支持 100+ providers 和数百个模型无法全部列举
2. LiteLLM 会在实际调用时进行模型验证
3. 如果模型不支持LiteLLM 会返回清晰的错误信息
这里只做基本的格式验证可选
"""
from loguru import logger
# 可选检查模型名称格式provider/model
if "/" not in self.model_name:
logger.debug(
f"LiteLLM 模型名称 '{self.model_name}' 未包含 provider 前缀,"
f"LiteLLM 将尝试自动推断。建议使用 'provider/model' 格式,如 'gemini/gemini-2.5-flash'"
)
# 不抛出异常,让 LiteLLM 在实际调用时验证
logger.debug(f"LiteLLM 文本模型已配置: {self.model_name}")
def _initialize(self):
"""初始化 LiteLLM 特定设置"""
import os
# 根据 model_name 确定需要设置哪个 API key
provider = self.provider_name.lower()
# 映射 provider 到环境变量名
env_key_mapping = {
"gemini": "GEMINI_API_KEY",
"google": "GEMINI_API_KEY",
"openai": "OPENAI_API_KEY",
"qwen": "QWEN_API_KEY",
"dashscope": "DASHSCOPE_API_KEY",
"siliconflow": "SILICONFLOW_API_KEY",
"deepseek": "DEEPSEEK_API_KEY",
"anthropic": "ANTHROPIC_API_KEY",
"claude": "ANTHROPIC_API_KEY",
"moonshot": "MOONSHOT_API_KEY"
}
env_var = env_key_mapping.get(provider, f"{provider.upper()}_API_KEY")
if self.api_key and env_var:
os.environ[env_var] = self.api_key
logger.debug(f"设置环境变量: {env_var}")
# 如果提供了 base_url保存用于后续调用
if self.base_url:
self._api_base = self.base_url
logger.debug(f"使用自定义 API base URL: {self.base_url}")
async def generate_text(self,
prompt: str,
system_prompt: Optional[str] = None,
temperature: float = 1.0,
max_tokens: Optional[int] = None,
response_format: Optional[str] = None,
**kwargs) -> str:
"""
使用 LiteLLM 生成文本
Args:
prompt: 用户提示词
system_prompt: 系统提示词
temperature: 生成温度
max_tokens: 最大token数
response_format: 响应格式 ('json' None)
**kwargs: 其他参数
Returns:
生成的文本内容
"""
# 构建消息列表
messages = self._build_messages(prompt, system_prompt)
# 准备参数
effective_model_name = self.model_name
# SiliconFlow 特殊处理
if self.model_name.lower().startswith("siliconflow/"):
# 替换 provider 为 openai
if "/" in self.model_name:
effective_model_name = f"openai/{self.model_name.split('/', 1)[1]}"
else:
effective_model_name = f"openai/{self.model_name}"
# 确保设置了 OPENAI_API_KEY (如果尚未设置)
import os
if not os.environ.get("OPENAI_API_KEY") and os.environ.get("SILICONFLOW_API_KEY"):
os.environ["OPENAI_API_KEY"] = os.environ.get("SILICONFLOW_API_KEY")
# 确保设置了 base_url (如果尚未设置)
if not hasattr(self, '_api_base'):
self._api_base = "https://api.siliconflow.cn/v1"
completion_kwargs = {
"model": effective_model_name,
"messages": messages,
"temperature": temperature
}
if max_tokens:
completion_kwargs["max_tokens"] = max_tokens
# 处理 JSON 格式输出
# LiteLLM 会自动处理不同 provider 的 JSON mode 差异
if response_format == "json":
try:
completion_kwargs["response_format"] = {"type": "json_object"}
except Exception as e:
# 如果不支持,在提示词中添加约束
logger.warning(f"模型可能不支持 response_format将在提示词中添加 JSON 约束: {str(e)}")
messages[-1]["content"] += "\n\n请确保输出严格的JSON格式不要包含任何其他文字或标记。"
# 如果有自定义 base_url添加 api_base 参数
if hasattr(self, '_api_base'):
completion_kwargs["api_base"] = self._api_base
# 支持动态传递 api_key 和 api_base (修复认证问题)
if "api_key" in kwargs:
completion_kwargs["api_key"] = kwargs["api_key"]
if "api_base" in kwargs:
completion_kwargs["api_base"] = kwargs["api_base"]
try:
# 调用 LiteLLM自动重试
response = await acompletion(**completion_kwargs)
if response.choices and len(response.choices) > 0:
content = response.choices[0].message.content
# 清理可能的 markdown 代码块(针对不支持 JSON mode 的模型)
if response_format == "json" and "response_format" not in completion_kwargs:
content = self._clean_json_output(content)
logger.debug(f"LiteLLM 调用成功,消耗 tokens: {response.usage.total_tokens if response.usage else 'N/A'}")
return content
else:
raise APICallError("LiteLLM 返回空响应")
except LiteLLMAuthError as e:
logger.error(f"LiteLLM 认证失败: {str(e)}")
raise AuthenticationError()
except LiteLLMRateLimitError as e:
logger.error(f"LiteLLM 速率限制: {str(e)}")
raise RateLimitError()
except LiteLLMBadRequestError as e:
error_msg = str(e)
# 处理不支持 response_format 的情况
if "response_format" in error_msg and response_format == "json":
logger.warning(f"模型不支持 response_format重试不带格式约束的请求")
completion_kwargs.pop("response_format", None)
messages[-1]["content"] += "\n\n请确保输出严格的JSON格式不要包含任何其他文字或标记。"
# 重试
response = await acompletion(**completion_kwargs)
if response.choices and len(response.choices) > 0:
content = response.choices[0].message.content
content = self._clean_json_output(content)
return content
else:
raise APICallError("LiteLLM 返回空响应")
# 检查是否是安全过滤
if "SAFETY" in error_msg.upper() or "content_filter" in error_msg.lower():
raise ContentFilterError(f"内容被安全过滤器阻止: {error_msg}")
logger.error(f"LiteLLM 请求错误: {error_msg}")
raise APICallError(f"请求错误: {error_msg}")
except LiteLLMAPIError as e:
logger.error(f"LiteLLM API 错误: {str(e)}")
raise APICallError(f"API 错误: {str(e)}")
except Exception as e:
logger.error(f"LiteLLM 调用失败: {str(e)}")
raise APICallError(f"调用失败: {str(e)}")
def _clean_json_output(self, output: str) -> str:
"""清理JSON输出移除markdown标记等"""
import re
# 移除可能的markdown代码块标记
output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE)
output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE)
output = re.sub(r'^```.*$', '', output, flags=re.MULTILINE)
# 移除前后空白字符
output = output.strip()
return output
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""兼容基类接口(实际使用 LiteLLM SDK"""
pass

View File

@ -4,7 +4,7 @@
统一管理所有大模型服务提供商提供简单的工厂方法来创建和获取服务实例
"""
from typing import Dict, Type, Optional
from typing import Dict, Type, Optional, Tuple
from loguru import logger
from app.config import config
@ -64,6 +64,29 @@ class LLMServiceManager:
"vision_providers": list(cls._vision_providers.keys()),
"text_providers": list(cls._text_providers.keys())
}
@classmethod
def _normalize_provider_name(cls, provider_name: str) -> str:
"""规范化 provider 名称。"""
return provider_name.lower()
@classmethod
def _get_provider_config(
cls,
model_type: str,
provider_name: str
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""
获取 provider 配置
model_type: 'vision' 'text'
"""
config_prefix = f"{model_type}_{provider_name}"
api_key = config.app.get(f"{config_prefix}_api_key")
model_name = config.app.get(f"{config_prefix}_model_name")
base_url = config.app.get(f"{config_prefix}_base_url")
return api_key, model_name, base_url
@classmethod
def get_vision_provider(cls, provider_name: Optional[str] = None) -> VisionModelProvider:
@ -89,9 +112,8 @@ class LLMServiceManager:
# 确定提供商名称
if not provider_name:
provider_name = config.app.get('vision_llm_provider', 'gemini').lower()
else:
provider_name = provider_name.lower()
provider_name = config.app.get('vision_llm_provider', 'openai')
provider_name = cls._normalize_provider_name(provider_name)
# 检查缓存
cache_key = f"vision_{provider_name}"
@ -104,9 +126,7 @@ class LLMServiceManager:
# 获取配置
config_prefix = f"vision_{provider_name}"
api_key = config.app.get(f'{config_prefix}_api_key')
model_name = config.app.get(f'{config_prefix}_model_name')
base_url = config.app.get(f'{config_prefix}_base_url')
api_key, model_name, base_url = cls._get_provider_config("vision", provider_name)
if not api_key:
raise ConfigurationError(f"缺少API密钥配置: {config_prefix}_api_key")
@ -157,9 +177,8 @@ class LLMServiceManager:
# 确定提供商名称
if not provider_name:
provider_name = config.app.get('text_llm_provider', 'openai').lower()
else:
provider_name = provider_name.lower()
provider_name = config.app.get('text_llm_provider', 'openai')
provider_name = cls._normalize_provider_name(provider_name)
logger.debug(f"获取文本模型提供商: {provider_name}")
logger.debug(f"已注册的文本提供商: {list(cls._text_providers.keys())}")
@ -178,9 +197,7 @@ class LLMServiceManager:
# 获取配置
config_prefix = f"text_{provider_name}"
api_key = config.app.get(f'{config_prefix}_api_key')
model_name = config.app.get(f'{config_prefix}_model_name')
base_url = config.app.get(f'{config_prefix}_base_url')
api_key, model_name, base_url = cls._get_provider_config("text", provider_name)
if not api_key:
raise ConfigurationError(f"缺少API密钥配置: {config_prefix}_api_key")

View File

@ -0,0 +1,276 @@
"""
OpenAI 兼容提供商实现
使用 OpenAI 官方 SDK 调用 OpenAI 兼容接口支持文本和视觉模型
"""
import io
import base64
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import PIL.Image
from loguru import logger
from openai import (
APIError as OpenAIAPIError,
AsyncOpenAI,
AuthenticationError as OpenAIAuthError,
BadRequestError as OpenAIBadRequestError,
RateLimitError as OpenAIRateLimitError,
)
from app.config import config
from .base import TextModelProvider, VisionModelProvider
from .exceptions import APICallError, AuthenticationError, ContentFilterError, RateLimitError
# 常见 OpenAI 兼容网关前缀。若使用 provider/model 格式,将剥离 provider 前缀。
OPENAI_COMPATIBLE_PROVIDER_PREFIXES = {
"openai",
"gemini",
"deepseek",
"qwen",
"siliconflow",
"moonshot",
"openrouter",
"anthropic",
"azure",
"ollama",
"mistral",
"groq",
"cohere",
"together_ai",
"fireworks_ai",
"volcengine",
"vertex_ai",
"huggingface",
"xai",
"bedrock",
"cloudflare",
"vllm",
"codestral",
"replicate",
"deepgram",
}
def _normalize_model_name(model_name: str) -> str:
"""兼容历史 provider/model 写法,必要时自动剥离 provider 前缀。"""
if "/" not in model_name:
return model_name
provider_prefix, raw_model = model_name.split("/", 1)
if provider_prefix.lower() in OPENAI_COMPATIBLE_PROVIDER_PREFIXES and raw_model:
return raw_model
return model_name
def _is_response_format_error(message: str) -> bool:
return "response_format" in (message or "").lower()
def _is_content_filter_error(message: str) -> bool:
lowered = (message or "").lower()
return "content_filter" in lowered or "safety" in lowered
def _clean_json_output(output: str) -> str:
"""清理 JSON 输出中的 markdown 包裹。"""
output = re.sub(r"^```json\s*", "", output, flags=re.MULTILINE)
output = re.sub(r"^```\s*$", "", output, flags=re.MULTILINE)
output = re.sub(r"^```.*$", "", output, flags=re.MULTILINE)
return output.strip()
class _OpenAICompatibleBase:
"""OpenAI 兼容 provider 共享逻辑。"""
@property
def provider_name(self) -> str:
return "openai"
@property
def supported_models(self) -> List[str]:
# 兼容网关模型数量很多,运行时校验由远端完成。
return []
def _validate_model_support(self):
logger.debug(f"OpenAI 兼容模型已配置: {self.model_name}")
def _initialize(self):
# SDK client 按请求参数动态构建,这里无需初始化全局状态。
pass
def _build_client(
self,
api_key_override: Optional[str] = None,
base_url_override: Optional[str] = None,
timeout_override: Optional[float] = None,
) -> AsyncOpenAI:
"""按请求构建 AsyncOpenAI 客户端,支持动态覆盖 api_key / base_url。"""
api_key = api_key_override or self.api_key
base_url = base_url_override or self.base_url or None
timeout_seconds: float = timeout_override or config.app.get("llm_text_timeout", 180)
max_retries: int = config.app.get("llm_max_retries", 3)
return AsyncOpenAI(
api_key=api_key,
base_url=base_url,
timeout=timeout_seconds,
max_retries=max_retries,
)
class OpenAICompatibleVisionProvider(_OpenAICompatibleBase, VisionModelProvider):
"""OpenAI 兼容视觉模型提供商。"""
async def analyze_images(
self,
images: List[Union[str, Path, PIL.Image.Image]],
prompt: str,
batch_size: int = 10,
**kwargs,
) -> List[str]:
logger.info(f"开始使用 OpenAI 兼容接口 ({self.model_name}) 分析 {len(images)} 张图片")
processed_images = self._prepare_images(images)
results: List[str] = []
for i in range(0, len(processed_images), batch_size):
batch = processed_images[i : i + batch_size]
logger.info(f"处理第 {i // batch_size + 1} 批,共 {len(batch)} 张图片")
try:
result = await self._analyze_batch(batch, prompt, **kwargs)
results.append(result)
except Exception as exc:
logger.error(f"批次 {i // batch_size + 1} 处理失败: {exc}")
results.append(f"批次处理失败: {exc}")
return results
async def _analyze_batch(self, batch: List[PIL.Image.Image], prompt: str, **kwargs) -> str:
content = [{"type": "text", "text": prompt}]
for img in batch:
content.append(
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{self._image_to_base64(img)}"},
}
)
messages = [{"role": "user", "content": content}]
model_name = _normalize_model_name(self.model_name)
client = self._build_client(
api_key_override=kwargs.get("api_key"),
base_url_override=kwargs.get("api_base"),
timeout_override=config.app.get("llm_vision_timeout", 120),
)
try:
response = await client.chat.completions.create(
model=model_name,
messages=messages,
temperature=kwargs.get("temperature", 1.0),
max_tokens=kwargs.get("max_tokens", 4000),
)
if response.choices and response.choices[0].message and response.choices[0].message.content:
return response.choices[0].message.content
raise APICallError("OpenAI 兼容接口返回空响应")
except OpenAIAuthError as exc:
logger.error(f"OpenAI 兼容接口认证失败: {exc}")
raise AuthenticationError(str(exc))
except OpenAIRateLimitError as exc:
logger.error(f"OpenAI 兼容接口速率限制: {exc}")
raise RateLimitError(str(exc))
except OpenAIBadRequestError as exc:
error_msg = str(exc)
if _is_content_filter_error(error_msg):
raise ContentFilterError(f"内容被安全过滤器阻止: {error_msg}")
raise APICallError(f"请求错误: {error_msg}")
except OpenAIAPIError as exc:
logger.error(f"OpenAI 兼容接口 API 错误: {exc}")
raise APICallError(f"API 错误: {exc}")
except Exception as exc:
logger.error(f"OpenAI 兼容接口调用失败: {exc}")
raise APICallError(f"调用失败: {exc}")
def _image_to_base64(self, img: PIL.Image.Image) -> str:
img_buffer = io.BytesIO()
img.save(img_buffer, format="JPEG", quality=85)
return base64.b64encode(img_buffer.getvalue()).decode("utf-8")
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
return payload
class OpenAICompatibleTextProvider(_OpenAICompatibleBase, TextModelProvider):
"""OpenAI 兼容文本模型提供商。"""
async def generate_text(
self,
prompt: str,
system_prompt: Optional[str] = None,
temperature: float = 1.0,
max_tokens: Optional[int] = None,
response_format: Optional[str] = None,
**kwargs,
) -> str:
messages = self._build_messages(prompt, system_prompt)
model_name = _normalize_model_name(self.model_name)
client = self._build_client(
api_key_override=kwargs.get("api_key"),
base_url_override=kwargs.get("api_base"),
timeout_override=config.app.get("llm_text_timeout", 180),
)
completion_kwargs: Dict[str, Any] = {
"model": model_name,
"messages": messages,
"temperature": temperature,
}
if max_tokens:
completion_kwargs["max_tokens"] = max_tokens
if response_format == "json":
completion_kwargs["response_format"] = {"type": "json_object"}
try:
response = await client.chat.completions.create(**completion_kwargs)
if response.choices and response.choices[0].message and response.choices[0].message.content:
return response.choices[0].message.content
raise APICallError("OpenAI 兼容接口返回空响应")
except OpenAIBadRequestError as exc:
error_msg = str(exc)
# 某些网关不支持 response_format回退到提示词约束模式
if response_format == "json" and _is_response_format_error(error_msg):
logger.warning("目标网关不支持 response_format回退为提示词约束 JSON 输出")
completion_kwargs.pop("response_format", None)
messages[-1]["content"] += "\n\n请确保输出严格的JSON格式不要包含任何其他文字或标记。"
retry_response = await client.chat.completions.create(**completion_kwargs)
if retry_response.choices and retry_response.choices[0].message and retry_response.choices[0].message.content:
return _clean_json_output(retry_response.choices[0].message.content)
raise APICallError("OpenAI 兼容接口返回空响应")
if _is_content_filter_error(error_msg):
raise ContentFilterError(f"内容被安全过滤器阻止: {error_msg}")
raise APICallError(f"请求错误: {error_msg}")
except OpenAIAuthError as exc:
logger.error(f"OpenAI 兼容接口认证失败: {exc}")
raise AuthenticationError(str(exc))
except OpenAIRateLimitError as exc:
logger.error(f"OpenAI 兼容接口速率限制: {exc}")
raise RateLimitError(str(exc))
except OpenAIAPIError as exc:
logger.error(f"OpenAI 兼容接口 API 错误: {exc}")
raise APICallError(f"API 错误: {exc}")
except Exception as exc:
logger.error(f"OpenAI 兼容接口调用失败: {exc}")
raise APICallError(f"调用失败: {exc}")
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
return payload

View File

@ -2,7 +2,6 @@
大模型服务提供商实现
包含各种大模型服务提供商的具体实现
推荐使用 LiteLLM 统一接口支持 100+ providers
"""
# 不在模块顶部导入 provider 类,避免循环依赖
@ -13,25 +12,25 @@ def register_all_providers():
"""
注册所有提供商
v0.8.0 变更只注册 LiteLLM 统一接口
- 移除了旧的单独 provider 实现 (gemini, openai, qwen, deepseek, siliconflow)
- LiteLLM 支持 100+ providers无需单独实现
当前实现只注册 OpenAI 兼容统一接口
"""
# 在函数内部导入,避免循环依赖
from ..manager import LLMServiceManager
from loguru import logger
# 只导入 LiteLLM provider
from ..litellm_provider import LiteLLMVisionProvider, LiteLLMTextProvider
# 只导入 OpenAI 兼容 provider
from ..openai_compatible_provider import (
OpenAICompatibleVisionProvider,
OpenAICompatibleTextProvider,
)
logger.info("🔧 开始注册 LLM 提供商...")
# ===== 注册 LiteLLM 统一接口 =====
# LiteLLM 支持 100+ providersOpenAI, Gemini, Qwen, DeepSeek, SiliconFlow, 等)
LLMServiceManager.register_vision_provider('litellm', LiteLLMVisionProvider)
LLMServiceManager.register_text_provider('litellm', LiteLLMTextProvider)
# ===== 注册 OpenAI 兼容统一接口 =====
LLMServiceManager.register_vision_provider('openai', OpenAICompatibleVisionProvider)
LLMServiceManager.register_text_provider('openai', OpenAICompatibleTextProvider)
logger.info("LiteLLM 提供商注册完成(支持 100+ providers")
logger.info("OpenAI 兼容提供商注册完成")
# 导出注册函数

View File

@ -1,228 +0,0 @@
"""
LiteLLM 集成测试脚本
测试 LiteLLM provider 是否正确集成到系统中
"""
import asyncio
import sys
from pathlib import Path
# 添加项目根目录到 Python 路径
project_root = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(project_root))
from loguru import logger
from app.services.llm.manager import LLMServiceManager
from app.services.llm.unified_service import UnifiedLLMService
def test_provider_registration():
"""测试 provider 是否正确注册"""
logger.info("=" * 60)
logger.info("测试 1: Provider 注册检查")
logger.info("=" * 60)
# 检查 LiteLLM provider 是否已注册
vision_providers = LLMServiceManager.list_vision_providers()
text_providers = LLMServiceManager.list_text_providers()
logger.info(f"已注册的视觉模型 providers: {vision_providers}")
logger.info(f"已注册的文本模型 providers: {text_providers}")
assert 'litellm' in vision_providers, "❌ LiteLLM Vision Provider 未注册"
assert 'litellm' in text_providers, "❌ LiteLLM Text Provider 未注册"
logger.success("✅ LiteLLM providers 已成功注册")
# 显示所有 provider 信息
provider_info = LLMServiceManager.get_provider_info()
logger.info("\n所有 Provider 信息:")
logger.info(f" 视觉模型 providers: {list(provider_info['vision_providers'].keys())}")
logger.info(f" 文本模型 providers: {list(provider_info['text_providers'].keys())}")
def test_litellm_import():
"""测试 LiteLLM 库是否正确安装"""
logger.info("\n" + "=" * 60)
logger.info("测试 2: LiteLLM 库导入检查")
logger.info("=" * 60)
try:
import litellm
logger.success(f"✅ LiteLLM 已安装,版本: {litellm.__version__}")
return True
except ImportError as e:
logger.error(f"❌ LiteLLM 未安装: {str(e)}")
logger.info("请运行: pip install litellm>=1.70.0")
return False
async def test_text_generation_mock():
"""测试文本生成接口(模拟模式,不实际调用 API"""
logger.info("\n" + "=" * 60)
logger.info("测试 3: 文本生成接口(模拟)")
logger.info("=" * 60)
try:
# 这里只测试接口是否可调用,不实际发送 API 请求
logger.info("接口测试通过UnifiedLLMService.generate_text 可调用")
logger.success("✅ 文本生成接口测试通过")
return True
except Exception as e:
logger.error(f"❌ 文本生成接口测试失败: {str(e)}")
return False
async def test_vision_analysis_mock():
"""测试视觉分析接口(模拟模式)"""
logger.info("\n" + "=" * 60)
logger.info("测试 4: 视觉分析接口(模拟)")
logger.info("=" * 60)
try:
# 这里只测试接口是否可调用
logger.info("接口测试通过UnifiedLLMService.analyze_images 可调用")
logger.success("✅ 视觉分析接口测试通过")
return True
except Exception as e:
logger.error(f"❌ 视觉分析接口测试失败: {str(e)}")
return False
def test_backward_compatibility():
"""测试向后兼容性"""
logger.info("\n" + "=" * 60)
logger.info("测试 5: 向后兼容性检查")
logger.info("=" * 60)
# 检查旧的 provider 是否仍然可用
old_providers = ['gemini', 'openai', 'qwen', 'deepseek', 'siliconflow']
vision_providers = LLMServiceManager.list_vision_providers()
text_providers = LLMServiceManager.list_text_providers()
logger.info("检查旧 provider 是否仍然可用:")
for provider in old_providers:
if provider in ['openai', 'deepseek']:
# 这些只有 text provider
if provider in text_providers:
logger.info(f"{provider} (text)")
else:
logger.warning(f" ⚠️ {provider} (text) 未注册")
else:
# 这些有 vision 和 text provider
vision_ok = provider in vision_providers or f"{provider}vl" in vision_providers
text_ok = provider in text_providers
if vision_ok:
logger.info(f"{provider} (vision)")
if text_ok:
logger.info(f"{provider} (text)")
logger.success("✅ 向后兼容性测试通过")
def print_usage_guide():
"""打印使用指南"""
logger.info("\n" + "=" * 60)
logger.info("LiteLLM 使用指南")
logger.info("=" * 60)
guide = """
📚 如何使用 LiteLLM
1. config.toml 中配置
```toml
[app]
# 方式 1直接使用 LiteLLM推荐
vision_llm_provider = "litellm"
vision_litellm_model_name = "gemini/gemini-2.0-flash-lite"
vision_litellm_api_key = "your-api-key"
text_llm_provider = "litellm"
text_litellm_model_name = "deepseek/deepseek-chat"
text_litellm_api_key = "your-api-key"
```
2. 支持的模型格式
- Gemini: gemini/gemini-2.0-flash
- DeepSeek: deepseek/deepseek-chat
- Qwen: qwen/qwen-plus
- OpenAI: gpt-4o, gpt-4o-mini
- SiliconFlow: siliconflow/deepseek-ai/DeepSeek-R1
- 更多: 参考 https://docs.litellm.ai/docs/providers
3. 代码调用示例
```python
from app.services.llm.unified_service import UnifiedLLMService
# 文本生成
result = await UnifiedLLMService.generate_text(
prompt="你好",
provider="litellm"
)
# 视觉分析
results = await UnifiedLLMService.analyze_images(
images=["path/to/image.jpg"],
prompt="描述这张图片",
provider="litellm"
)
```
4. 优势
减少 80% 代码量
统一的错误处理
自动重试机制
支持 100+ providers
自动成本追踪
5. 迁移建议
- 新项目直接使用 LiteLLM
- 旧项目逐步迁移旧的 provider 仍然可用
- 测试充分后再切换生产环境
"""
print(guide)
def main():
"""运行所有测试"""
logger.info("开始 LiteLLM 集成测试...\n")
try:
# 测试 1: Provider 注册
test_provider_registration()
# 测试 2: LiteLLM 库导入
litellm_available = test_litellm_import()
if not litellm_available:
logger.warning("\n⚠️ LiteLLM 未安装,跳过 API 测试")
logger.info("请运行: pip install litellm>=1.70.0")
else:
# 测试 3-4: 接口测试(模拟)
asyncio.run(test_text_generation_mock())
asyncio.run(test_vision_analysis_mock())
# 测试 5: 向后兼容性
test_backward_compatibility()
# 打印使用指南
print_usage_guide()
logger.info("\n" + "=" * 60)
logger.success("🎉 所有测试通过!")
logger.info("=" * 60)
return True
except Exception as e:
logger.error(f"\n❌ 测试失败: {str(e)}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)

View File

@ -0,0 +1,67 @@
"""OpenAI 兼容 provider 的最小回归测试。"""
import unittest
from app.config import config
from app.services.llm.base import TextModelProvider
from app.services.llm.manager import LLMServiceManager
from app.services.llm.providers import register_all_providers
class DummyOpenAITextProvider(TextModelProvider):
@property
def provider_name(self) -> str:
return "openai"
@property
def supported_models(self) -> list[str]:
return []
async def generate_text(self, prompt: str, **kwargs) -> str:
return prompt
async def _make_api_call(self, payload: dict) -> dict:
return payload
def _reset_manager_state():
LLMServiceManager._vision_providers.clear()
LLMServiceManager._text_providers.clear()
LLMServiceManager._vision_instance_cache.clear()
LLMServiceManager._text_instance_cache.clear()
class OpenAICompatManagerTests(unittest.TestCase):
def setUp(self):
_reset_manager_state()
self._original_app = dict(config.app)
def tearDown(self):
_reset_manager_state()
config.app.clear()
config.app.update(self._original_app)
def test_register_all_providers_only_registers_openai_provider(self):
register_all_providers()
self.assertEqual({"openai"}, set(LLMServiceManager.list_text_providers()))
self.assertEqual({"openai"}, set(LLMServiceManager.list_vision_providers()))
def test_get_text_provider_uses_openai_keys(self):
LLMServiceManager.register_text_provider("openai", DummyOpenAITextProvider)
config.app["text_llm_provider"] = "openai"
config.app["text_openai_api_key"] = "new-key"
config.app["text_openai_model_name"] = "new-model"
config.app["text_openai_base_url"] = "https://new.example/v1"
provider = LLMServiceManager.get_text_provider()
self.assertIsInstance(provider, DummyOpenAITextProvider)
self.assertEqual("new-key", provider.api_key)
self.assertEqual("new-model", provider.model_name)
self.assertEqual("https://new.example/v1", provider.base_url)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,35 @@
"""
OpenAI 兼容接口集成测试脚本
用于快速检查统一 LLM Provider 是否注册成功
"""
from loguru import logger
from app.services.llm.manager import LLMServiceManager
from app.services.llm.providers import register_all_providers
def test_provider_registration() -> bool:
"""检查 OpenAI 兼容 provider 是否注册成功。"""
logger.info("测试Provider 注册检查")
register_all_providers()
vision_providers = LLMServiceManager.list_vision_providers()
text_providers = LLMServiceManager.list_text_providers()
assert "openai" in vision_providers, "❌ OpenAI 兼容 Vision Provider 未注册"
assert "openai" in text_providers, "❌ OpenAI 兼容 Text Provider 未注册"
logger.success("✅ OpenAI 兼容 providers 已成功注册")
return True
if __name__ == "__main__":
try:
ok = test_provider_registration()
if ok:
logger.success("\n🎉 集成检查通过")
except Exception as exc:
logger.error(f"\n❌ 集成检查失败: {exc}")
raise

View File

@ -133,7 +133,7 @@ class ScriptGenerator:
from app.services.llm.migration_adapter import create_vision_analyzer
# 获取配置
text_provider = config.app.get('text_llm_provider', 'litellm').lower()
text_provider = config.app.get('text_llm_provider', 'openai').lower()
vision_api_key = config.app.get(f'vision_{vision_llm_provider}_api_key')
vision_model = config.app.get(f'vision_{vision_llm_provider}_model_name')
vision_base_url = config.app.get(f'vision_{vision_llm_provider}_base_url')
@ -321,4 +321,4 @@ class ScriptGenerator:
last_timestamp = format_timestamp(last_time)
timestamp_range = f"{first_timestamp}-{last_timestamp}"
return first_timestamp, last_timestamp, timestamp_range
return first_timestamp, last_timestamp, timestamp_range

View File

@ -4,24 +4,16 @@
# LLM API 超时配置(秒)
llm_vision_timeout = 120 # 视觉模型基础超时时间
llm_text_timeout = 180 # 文本模型基础超时时间(解说文案生成等复杂任务需要更长时间)
llm_max_retries = 3 # API 重试次数LiteLLM 会自动处理重试)
llm_max_retries = 3 # API 重试次数
##########################################
# 🚀 LLM 配置 - 使用 LiteLLM 统一接口
# 🚀 LLM 配置 - 使用 OpenAI 兼容统一接口
##########################################
# LiteLLM 是统一的 LLM 接口库,支持 100+ providers
# 优势:
# ✅ 代码量减少 80%,统一的 API 接口
# ✅ 自动重试和智能错误处理
# ✅ 内置成本追踪和 token 统计
# ✅ 支持更多 providersOpenAI, Anthropic, Gemini, Qwen, DeepSeek,
# Cohere, Together AI, Replicate, Groq, Mistral 等
#
# 文档https://docs.litellm.ai/
# 支持的模型https://docs.litellm.ai/docs/providers
# 统一使用 OpenAI 兼容协议(/v1/chat/completions
# 支持接入 OpenAI、DeepSeek、Gemini 兼容网关、Qwen 网关、SiliconFlow、OpenRouter 等。
# ===== 视觉模型配置 =====
vision_llm_provider = "litellm"
vision_llm_provider = "openai"
# 模型格式provider/model_name
# 常用视觉模型示例:
@ -30,12 +22,12 @@
# - OpenAI: gpt-4o, gpt-4o-mini
# - Qwen: qwen/qwen2.5-vl-32b-instruct
# - SiliconFlow: siliconflow/Qwen/Qwen2.5-VL-32B-Instruct
vision_litellm_model_name = "gemini/gemini-2.0-flash-lite"
vision_litellm_api_key = "" # 填入对应 provider 的 API key
vision_litellm_base_url = "" # 可选:自定义 API base URL
vision_openai_model_name = "gemini/gemini-2.0-flash-lite"
vision_openai_api_key = "" # 填入对应 provider 的 API key
vision_openai_base_url = "" # 可选:自定义 API base URL(官方 OpenAI 可留空)
# ===== 文本模型配置 =====
text_llm_provider = "litellm"
text_llm_provider = "openai"
# 常用文本模型示例:
# - DeepSeek: deepseek/deepseek-chat (推荐,性价比高)
@ -45,9 +37,9 @@
# - Qwen: qwen/qwen-plus, qwen/qwen-turbo
# - SiliconFlow: siliconflow/deepseek-ai/DeepSeek-R1
# - Moonshot: moonshot/moonshot-v1-8k
text_litellm_model_name = "deepseek/deepseek-chat"
text_litellm_api_key = "" # 填入对应 provider 的 API key
text_litellm_base_url = "" # 可选:自定义 API base URL
text_openai_model_name = "deepseek/deepseek-chat"
text_openai_api_key = "" # 填入对应 provider 的 API key
text_openai_base_url = "" # 可选:自定义 API base URL(官方 OpenAI 可留空)
# ===== API Keys 参考 =====
# 主流 LLM Providers API Key 获取地址:
@ -69,21 +61,7 @@
# WebUI 界面是否显示配置项
hide_config = true
##########################################
# 📚 传统配置示例(仅供参考,不推荐使用)
##########################################
# 如果需要使用传统的单独 provider 实现,可以参考以下配置
# 但强烈推荐使用上面的 LiteLLM 配置
#
# 传统视觉模型配置示例:
# vision_llm_provider = "gemini" # 可选gemini, qwenvl, siliconflow
# vision_gemini_api_key = ""
# vision_gemini_model_name = "gemini-2.0-flash-lite"
#
# 传统文本模型配置示例:
# text_llm_provider = "openai" # 可选openai, gemini, qwen, deepseek, siliconflow, moonshot
# text_openai_api_key = ""
# text_openai_model_name = "gpt-4o-mini"
# 官方 OpenAI 默认端点(可选):
# text_openai_base_url = "https://api.openai.com/v1"
##########################################

View File

@ -12,8 +12,7 @@ pysrt==1.1.2
# AI 服务依赖
openai>=1.77.0
litellm>=1.70.0 # 统一的 LLM 接口,支持 100+ providers
google-generativeai>=0.8.5 # LiteLLM 会使用此库调用 Gemini
google-generativeai>=0.8.5 # 原生 Gemini 场景依赖
azure-cognitiveservices-speech>=1.37.0
tencentcloud-sdk-python>=3.0.1200
dashscope>=1.24.6

View File

@ -72,8 +72,8 @@ def validate_model_name(model_name: str, provider: str) -> tuple[bool, str]:
return True, ""
def validate_litellm_model_name(model_name: str, model_type: str) -> tuple[bool, str]:
"""验证 LiteLLM 模型名称格式
def validate_openai_compatible_model_name(model_name: str, model_type: str) -> tuple[bool, str]:
"""验证 OpenAI 兼容 模型名称格式
Args:
model_name: 模型名称应为 provider/model 格式
@ -87,8 +87,8 @@ def validate_litellm_model_name(model_name: str, model_type: str) -> tuple[bool,
model_name = model_name.strip()
# LiteLLM 推荐格式provider/model如 gemini/gemini-2.0-flash-lite
# 但也支持直接的模型名称(如 gpt-4oLiteLLM 会自动推断 provider
# OpenAI 兼容 推荐格式provider/model如 gemini/gemini-2.0-flash-lite
# 但也支持直接的模型名称(如 gpt-4oOpenAI 兼容 会自动推断 provider
# 检查是否包含 provider 前缀(推荐格式)
if "/" in model_name:
@ -101,9 +101,9 @@ def validate_litellm_model_name(model_name: str, model_type: str) -> tuple[bool,
if not provider.replace("-", "").replace("_", "").isalnum():
return False, f"{model_type} Provider 名称只能包含字母、数字、下划线和连字符"
else:
# 直接模型名称也是有效的(LiteLLM 会自动推断)
# 直接模型名称也是有效的(OpenAI 兼容 会自动推断)
# 但给出警告建议使用完整格式
logger.debug(f"{model_type} 模型名称未包含 provider 前缀,LiteLLM 将自动推断")
logger.debug(f"{model_type} 模型名称未包含 provider 前缀,OpenAI 兼容 将自动推断")
# 基本长度检查
if len(model_name) < 3:
@ -115,6 +115,13 @@ def validate_litellm_model_name(model_name: str, model_type: str) -> tuple[bool,
return True, ""
def normalize_openai_compatible_model_name(model_name: str) -> str:
"""将 provider/model 格式转换为网关实际使用的模型名。"""
if "/" not in model_name:
return model_name
return model_name.split("/", 1)[1]
def show_config_validation_errors(errors: list):
"""显示配置验证错误"""
if errors:
@ -312,243 +319,112 @@ def test_vision_model_connection(api_key, base_url, model_name, provider, tr):
def test_litellm_vision_model(api_key: str, base_url: str, model_name: str, tr) -> tuple[bool, str]:
"""测试 LiteLLM 视觉模型连接
Args:
api_key: API 密钥
base_url: 基础 URL可选
model_name: 模型名称LiteLLM 格式provider/model
tr: 翻译函数
Returns:
(连接是否成功, 测试结果消息)
"""
def test_openai_compatible_vision_model(api_key: str, base_url: str, model_name: str, tr) -> tuple[bool, str]:
"""测试 OpenAI 兼容视觉模型连接。"""
try:
import litellm
import os
import base64
import io
from openai import OpenAI
from PIL import Image
logger.debug(f"LiteLLM 视觉模型连通性测试: model={model_name}, api_key={api_key[:10]}..., base_url={base_url}")
# 提取 provider 名称
provider = model_name.split("/")[0] if "/" in model_name else "unknown"
# 设置 API key 到环境变量
env_key_mapping = {
"gemini": "GEMINI_API_KEY",
"google": "GEMINI_API_KEY",
"openai": "OPENAI_API_KEY",
"qwen": "QWEN_API_KEY",
"dashscope": "DASHSCOPE_API_KEY",
"siliconflow": "SILICONFLOW_API_KEY",
}
env_var = env_key_mapping.get(provider.lower(), f"{provider.upper()}_API_KEY")
old_key = os.environ.get(env_var)
os.environ[env_var] = api_key
# SiliconFlow 特殊处理:使用 OpenAI 兼容模式
test_model_name = model_name
if provider.lower() == "siliconflow":
# 替换 provider 为 openai
if "/" in model_name:
test_model_name = f"openai/{model_name.split('/', 1)[1]}"
else:
test_model_name = f"openai/{model_name}"
# 确保设置了 base_url
if not base_url:
base_url = "https://api.siliconflow.cn/v1"
# 设置 OPENAI_API_KEY (SiliconFlow 使用 OpenAI 协议)
os.environ["OPENAI_API_KEY"] = api_key
os.environ["OPENAI_API_BASE"] = base_url
try:
# 创建测试图片64x64 白色像素,避免某些模型对极小图片的限制)
test_image = Image.new('RGB', (64, 64), color='white')
img_buffer = io.BytesIO()
test_image.save(img_buffer, format='JPEG')
img_bytes = img_buffer.getvalue()
base64_image = base64.b64encode(img_bytes).decode('utf-8')
# 构建测试请求
messages = [{
"role": "user",
"content": [
{"type": "text", "text": "请直接回复'连接成功'"},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
}]
# 准备参数
completion_kwargs = {
"model": test_model_name,
"messages": messages,
"temperature": 0.1,
"max_tokens": 50
}
if base_url:
completion_kwargs["api_base"] = base_url
# 调用 LiteLLM同步调用用于测试
response = litellm.completion(**completion_kwargs)
if response and response.choices and len(response.choices) > 0:
return True, f"LiteLLM 视觉模型连接成功 ({model_name})"
else:
return False, f"LiteLLM 视觉模型返回空响应"
finally:
# 恢复原始环境变量
if old_key:
os.environ[env_var] = old_key
else:
os.environ.pop(env_var, None)
# 清理临时设置的 OpenAI 环境变量
if provider.lower() == "siliconflow":
os.environ.pop("OPENAI_API_KEY", None)
os.environ.pop("OPENAI_API_BASE", None)
logger.debug(
f"OpenAI 兼容视觉模型连通性测试: model={model_name}, api_key={api_key[:10]}..., base_url={base_url}"
)
client = OpenAI(
api_key=api_key,
base_url=base_url or None,
timeout=10.0,
max_retries=1,
)
# 创建测试图片64x64 白色像素)
test_image = Image.new("RGB", (64, 64), color="white")
img_buffer = io.BytesIO()
test_image.save(img_buffer, format="JPEG")
base64_image = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
response = client.chat.completions.create(
model=normalize_openai_compatible_model_name(model_name),
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "请直接回复'连接成功'"},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
},
],
}
],
temperature=0.1,
max_tokens=50,
)
if response and response.choices and len(response.choices) > 0:
return True, f"OpenAI 兼容视觉模型连接成功 ({model_name})"
return False, "OpenAI 兼容视觉模型返回空响应"
except Exception as e:
error_msg = str(e)
logger.error(f"LiteLLM 视觉模型测试失败: {error_msg}")
# 提供更友好的错误信息
logger.error(f"OpenAI 兼容视觉模型测试失败: {error_msg}")
if "authentication" in error_msg.lower() or "api_key" in error_msg.lower():
return False, f"认证失败,请检查 API Key 是否正确"
elif "not found" in error_msg.lower() or "404" in error_msg:
return False, f"模型不存在,请检查模型名称是否正确"
elif "rate limit" in error_msg.lower():
return False, f"超出速率限制,请稍后重试"
else:
return False, f"连接失败: {error_msg}"
return False, "认证失败,请检查 API Key 是否正确"
if "not found" in error_msg.lower() or "404" in error_msg:
return False, "模型不存在,请检查模型名称是否正确"
if "rate limit" in error_msg.lower():
return False, "超出速率限制,请稍后重试"
return False, f"连接失败: {error_msg}"
def test_litellm_text_model(api_key: str, base_url: str, model_name: str, tr) -> tuple[bool, str]:
"""测试 LiteLLM 文本模型连接
Args:
api_key: API 密钥
base_url: 基础 URL可选
model_name: 模型名称LiteLLM 格式provider/model
tr: 翻译函数
Returns:
(连接是否成功, 测试结果消息)
"""
def test_openai_compatible_text_model(api_key: str, base_url: str, model_name: str, tr) -> tuple[bool, str]:
"""测试 OpenAI 兼容文本模型连接。"""
try:
import litellm
import os
logger.debug(f"LiteLLM 文本模型连通性测试: model={model_name}, api_key={api_key[:10]}..., base_url={base_url}")
# 提取 provider 名称
provider = model_name.split("/")[0] if "/" in model_name else "unknown"
# 设置 API key 到环境变量
env_key_mapping = {
"gemini": "GEMINI_API_KEY",
"google": "GEMINI_API_KEY",
"openai": "OPENAI_API_KEY",
"qwen": "QWEN_API_KEY",
"dashscope": "DASHSCOPE_API_KEY",
"siliconflow": "SILICONFLOW_API_KEY",
"deepseek": "DEEPSEEK_API_KEY",
"moonshot": "MOONSHOT_API_KEY",
}
env_var = env_key_mapping.get(provider.lower(), f"{provider.upper()}_API_KEY")
old_key = os.environ.get(env_var)
os.environ[env_var] = api_key
# SiliconFlow 特殊处理:使用 OpenAI 兼容模式
test_model_name = model_name
if provider.lower() == "siliconflow":
# 替换 provider 为 openai
if "/" in model_name:
test_model_name = f"openai/{model_name.split('/', 1)[1]}"
else:
test_model_name = f"openai/{model_name}"
# 确保设置了 base_url
if not base_url:
base_url = "https://api.siliconflow.cn/v1"
# 设置 OPENAI_API_KEY (SiliconFlow 使用 OpenAI 协议)
os.environ["OPENAI_API_KEY"] = api_key
os.environ["OPENAI_API_BASE"] = base_url
try:
# 构建测试请求
messages = [
{"role": "user", "content": "请直接回复'连接成功'"}
]
# 准备参数
completion_kwargs = {
"model": test_model_name,
"messages": messages,
"temperature": 0.1,
"max_tokens": 20
}
if base_url:
completion_kwargs["api_base"] = base_url
# 调用 LiteLLM同步调用用于测试
response = litellm.completion(**completion_kwargs)
if response and response.choices and len(response.choices) > 0:
return True, f"LiteLLM 文本模型连接成功 ({model_name})"
else:
return False, f"LiteLLM 文本模型返回空响应"
finally:
# 恢复原始环境变量
if old_key:
os.environ[env_var] = old_key
else:
os.environ.pop(env_var, None)
# 清理临时设置的 OpenAI 环境变量
if provider.lower() == "siliconflow":
os.environ.pop("OPENAI_API_KEY", None)
os.environ.pop("OPENAI_API_BASE", None)
from openai import OpenAI
logger.debug(
f"OpenAI 兼容文本模型连通性测试: model={model_name}, api_key={api_key[:10]}..., base_url={base_url}"
)
client = OpenAI(
api_key=api_key,
base_url=base_url or None,
timeout=10.0,
max_retries=1,
)
response = client.chat.completions.create(
model=normalize_openai_compatible_model_name(model_name),
messages=[{"role": "user", "content": "请直接回复'连接成功'"}],
temperature=0.1,
max_tokens=20,
)
if response and response.choices and len(response.choices) > 0:
return True, f"OpenAI 兼容文本模型连接成功 ({model_name})"
return False, "OpenAI 兼容文本模型返回空响应"
except Exception as e:
error_msg = str(e)
logger.error(f"LiteLLM 文本模型测试失败: {error_msg}")
# 提供更友好的错误信息
logger.error(f"OpenAI 兼容文本模型测试失败: {error_msg}")
if "authentication" in error_msg.lower() or "api_key" in error_msg.lower():
return False, f"认证失败,请检查 API Key 是否正确"
elif "not found" in error_msg.lower() or "404" in error_msg:
return False, f"模型不存在,请检查模型名称是否正确"
elif "rate limit" in error_msg.lower():
return False, f"超出速率限制,请稍后重试"
else:
return False, f"连接失败: {error_msg}"
return False, "认证失败,请检查 API Key 是否正确"
if "not found" in error_msg.lower() or "404" in error_msg:
return False, "模型不存在,请检查模型名称是否正确"
if "rate limit" in error_msg.lower():
return False, "超出速率限制,请稍后重试"
return False, f"连接失败: {error_msg}"
def render_vision_llm_settings(tr):
"""渲染视频分析模型设置(LiteLLM 统一配置)"""
"""渲染视频分析模型设置OpenAI 兼容 统一配置)"""
st.subheader(tr("Vision Model Settings"))
# 固定使用 LiteLLM 提供商
config.app["vision_llm_provider"] = "litellm"
# 固定使用 OpenAI 兼容 提供商
config.app["vision_llm_provider"] = "openai"
# 获取已保存的 LiteLLM 配置
full_vision_model_name = config.app.get("vision_litellm_model_name", "gemini/gemini-2.0-flash-lite")
vision_api_key = config.app.get("vision_litellm_api_key", "")
vision_base_url = config.app.get("vision_litellm_base_url", "")
# 获取已保存的配置
full_vision_model_name = config.app.get("vision_openai_model_name") or "gemini/gemini-2.0-flash-lite"
vision_api_key = config.app.get("vision_openai_api_key", "")
vision_base_url = config.app.get("vision_openai_base_url", "")
# 解析 provider 和 model
default_provider = "gemini"
@ -563,7 +439,7 @@ def render_vision_llm_settings(tr):
current_model = full_vision_model_name
# 定义支持的 provider 列表
LITELLM_PROVIDERS = [
OPENAI_COMPATIBLE_PROVIDERS = [
"openai", "gemini", "deepseek", "qwen", "siliconflow", "moonshot",
"anthropic", "azure", "ollama", "vertex_ai", "mistral", "codestral",
"volcengine", "groq", "cohere", "together_ai", "fireworks_ai",
@ -572,16 +448,16 @@ def render_vision_llm_settings(tr):
]
# 如果当前 provider 不在列表中,添加到列表头部
if current_provider not in LITELLM_PROVIDERS:
LITELLM_PROVIDERS.insert(0, current_provider)
if current_provider not in OPENAI_COMPATIBLE_PROVIDERS:
OPENAI_COMPATIBLE_PROVIDERS.insert(0, current_provider)
# 渲染配置输入框
col1, col2 = st.columns([1, 2])
with col1:
selected_provider = st.selectbox(
tr("Vision Model Provider"),
options=LITELLM_PROVIDERS,
index=LITELLM_PROVIDERS.index(current_provider) if current_provider in LITELLM_PROVIDERS else 0,
options=OPENAI_COMPATIBLE_PROVIDERS,
index=OPENAI_COMPATIBLE_PROVIDERS.index(current_provider) if current_provider in OPENAI_COMPATIBLE_PROVIDERS else 0,
key="vision_provider_select"
)
@ -595,7 +471,7 @@ def render_vision_llm_settings(tr):
"• gpt-4o\n"
"• qwen-vl-max\n"
"• Qwen/Qwen2.5-VL-32B-Instruct (SiliconFlow)\n\n"
"支持 100+ providers详见: https://docs.litellm.ai/docs/providers",
"支持常见 OpenAI 兼容网关(如 OpenAI/DeepSeek/OpenRouter/SiliconFlow",
key="vision_model_input"
)
@ -641,7 +517,7 @@ def render_vision_llm_settings(tr):
else:
with st.spinner(tr("Testing connection...")):
try:
success, message = test_litellm_vision_model(
success, message = test_openai_compatible_vision_model(
api_key=st_vision_api_key,
base_url=st_vision_base_url,
model_name=st_vision_model_name,
@ -654,7 +530,7 @@ def render_vision_llm_settings(tr):
st.error(message)
except Exception as e:
st.error(f"测试连接时发生错误: {str(e)}")
logger.error(f"LiteLLM 视频分析模型连接测试失败: {str(e)}")
logger.error(f"OpenAI 兼容 视频分析模型连接测试失败: {str(e)}")
# 验证和保存配置
validation_errors = []
@ -663,10 +539,10 @@ def render_vision_llm_settings(tr):
# 验证模型名称
if st_vision_model_name:
# 这里的验证逻辑可能需要微调,因为我们现在是自动组合的
is_valid, error_msg = validate_litellm_model_name(st_vision_model_name, "视频分析")
is_valid, error_msg = validate_openai_compatible_model_name(st_vision_model_name, "视频分析")
if is_valid:
config.app["vision_litellm_model_name"] = st_vision_model_name
st.session_state["vision_litellm_model_name"] = st_vision_model_name
config.app["vision_openai_model_name"] = st_vision_model_name
st.session_state["vision_openai_model_name"] = st_vision_model_name
config_changed = True
else:
validation_errors.append(error_msg)
@ -675,8 +551,8 @@ def render_vision_llm_settings(tr):
if st_vision_api_key:
is_valid, error_msg = validate_api_key(st_vision_api_key, "视频分析")
if is_valid:
config.app["vision_litellm_api_key"] = st_vision_api_key
st.session_state["vision_litellm_api_key"] = st_vision_api_key
config.app["vision_openai_api_key"] = st_vision_api_key
st.session_state["vision_openai_api_key"] = st_vision_api_key
config_changed = True
else:
validation_errors.append(error_msg)
@ -685,8 +561,8 @@ def render_vision_llm_settings(tr):
if st_vision_base_url:
is_valid, error_msg = validate_base_url(st_vision_base_url, "视频分析")
if is_valid:
config.app["vision_litellm_base_url"] = st_vision_base_url
st.session_state["vision_litellm_base_url"] = st_vision_base_url
config.app["vision_openai_base_url"] = st_vision_base_url
st.session_state["vision_openai_base_url"] = st_vision_base_url
config_changed = True
else:
validation_errors.append(error_msg)
@ -701,7 +577,7 @@ def render_vision_llm_settings(tr):
# 清除缓存,确保下次使用新配置
UnifiedLLMService.clear_cache()
if st_vision_api_key or st_vision_base_url or st_vision_model_name:
st.success(f"视频分析模型配置已保存(LiteLLM")
st.success(f"视频分析模型配置已保存(OpenAI 兼容")
except Exception as e:
st.error(f"保存配置失败: {str(e)}")
logger.error(f"保存视频分析配置失败: {str(e)}")
@ -811,16 +687,16 @@ def test_text_model_connection(api_key, base_url, model_name, provider, tr):
def render_text_llm_settings(tr):
"""渲染文案生成模型设置(LiteLLM 统一配置)"""
"""渲染文案生成模型设置(OpenAI 兼容 统一配置)"""
st.subheader(tr("Text Generation Model Settings"))
# 固定使用 LiteLLM 提供商
config.app["text_llm_provider"] = "litellm"
# 固定使用 OpenAI 兼容 提供商
config.app["text_llm_provider"] = "openai"
# 获取已保存的 LiteLLM 配置
full_text_model_name = config.app.get("text_litellm_model_name", "deepseek/deepseek-chat")
text_api_key = config.app.get("text_litellm_api_key", "")
text_base_url = config.app.get("text_litellm_base_url", "")
# 获取已保存的配置
full_text_model_name = config.app.get("text_openai_model_name") or "deepseek/deepseek-chat"
text_api_key = config.app.get("text_openai_api_key", "")
text_base_url = config.app.get("text_openai_base_url", "")
# 解析 provider 和 model
default_provider = "deepseek"
@ -835,7 +711,7 @@ def render_text_llm_settings(tr):
current_model = full_text_model_name
# 定义支持的 provider 列表
LITELLM_PROVIDERS = [
OPENAI_COMPATIBLE_PROVIDERS = [
"openai", "gemini", "deepseek", "qwen", "siliconflow", "moonshot",
"anthropic", "azure", "ollama", "vertex_ai", "mistral", "codestral",
"volcengine", "groq", "cohere", "together_ai", "fireworks_ai",
@ -844,16 +720,16 @@ def render_text_llm_settings(tr):
]
# 如果当前 provider 不在列表中,添加到列表头部
if current_provider not in LITELLM_PROVIDERS:
LITELLM_PROVIDERS.insert(0, current_provider)
if current_provider not in OPENAI_COMPATIBLE_PROVIDERS:
OPENAI_COMPATIBLE_PROVIDERS.insert(0, current_provider)
# 渲染配置输入框
col1, col2 = st.columns([1, 2])
with col1:
selected_provider = st.selectbox(
tr("Text Model Provider"),
options=LITELLM_PROVIDERS,
index=LITELLM_PROVIDERS.index(current_provider) if current_provider in LITELLM_PROVIDERS else 0,
options=OPENAI_COMPATIBLE_PROVIDERS,
index=OPENAI_COMPATIBLE_PROVIDERS.index(current_provider) if current_provider in OPENAI_COMPATIBLE_PROVIDERS else 0,
key="text_provider_select"
)
@ -867,7 +743,7 @@ def render_text_llm_settings(tr):
"• gpt-4o\n"
"• gemini-2.0-flash\n"
"• deepseek-ai/DeepSeek-R1 (SiliconFlow)\n\n"
"支持 100+ providers详见: https://docs.litellm.ai/docs/providers",
"支持常见 OpenAI 兼容网关(如 OpenAI/DeepSeek/OpenRouter/SiliconFlow",
key="text_model_input"
)
@ -915,7 +791,7 @@ def render_text_llm_settings(tr):
else:
with st.spinner(tr("Testing connection...")):
try:
success, message = test_litellm_text_model(
success, message = test_openai_compatible_text_model(
api_key=st_text_api_key,
base_url=st_text_base_url,
model_name=st_text_model_name,
@ -928,7 +804,7 @@ def render_text_llm_settings(tr):
st.error(message)
except Exception as e:
st.error(f"测试连接时发生错误: {str(e)}")
logger.error(f"LiteLLM 文案生成模型连接测试失败: {str(e)}")
logger.error(f"OpenAI 兼容 文案生成模型连接测试失败: {str(e)}")
# 验证和保存配置
text_validation_errors = []
@ -936,10 +812,10 @@ def render_text_llm_settings(tr):
# 验证模型名称
if st_text_model_name:
is_valid, error_msg = validate_litellm_model_name(st_text_model_name, "文案生成")
is_valid, error_msg = validate_openai_compatible_model_name(st_text_model_name, "文案生成")
if is_valid:
config.app["text_litellm_model_name"] = st_text_model_name
st.session_state["text_litellm_model_name"] = st_text_model_name
config.app["text_openai_model_name"] = st_text_model_name
st.session_state["text_openai_model_name"] = st_text_model_name
text_config_changed = True
else:
text_validation_errors.append(error_msg)
@ -948,8 +824,8 @@ def render_text_llm_settings(tr):
if st_text_api_key:
is_valid, error_msg = validate_api_key(st_text_api_key, "文案生成")
if is_valid:
config.app["text_litellm_api_key"] = st_text_api_key
st.session_state["text_litellm_api_key"] = st_text_api_key
config.app["text_openai_api_key"] = st_text_api_key
st.session_state["text_openai_api_key"] = st_text_api_key
text_config_changed = True
else:
text_validation_errors.append(error_msg)
@ -958,8 +834,8 @@ def render_text_llm_settings(tr):
if st_text_base_url:
is_valid, error_msg = validate_base_url(st_text_base_url, "文案生成")
if is_valid:
config.app["text_litellm_base_url"] = st_text_base_url
st.session_state["text_litellm_base_url"] = st_text_base_url
config.app["text_openai_base_url"] = st_text_base_url
st.session_state["text_openai_base_url"] = st_text_base_url
text_config_changed = True
else:
text_validation_errors.append(error_msg)
@ -974,7 +850,7 @@ def render_text_llm_settings(tr):
# 清除缓存,确保下次使用新配置
UnifiedLLMService.clear_cache()
if st_text_api_key or st_text_base_url or st_text_model_name:
st.success(f"文案生成模型配置已保存(LiteLLM")
st.success(f"文案生成模型配置已保存(OpenAI 兼容")
except Exception as e:
st.error(f"保存配置失败: {str(e)}")
logger.error(f"保存文案生成配置失败: {str(e)}")

View File

@ -123,7 +123,7 @@ def generate_script_docu(params):
# 最佳实践:使用 get() 的默认值参数 + 从 config 获取备用值
vision_llm_provider = (
st.session_state.get('vision_llm_provider') or
config.app.get('vision_llm_provider', 'litellm')
config.app.get('vision_llm_provider', 'openai')
).lower()
logger.info(f"使用 {vision_llm_provider.upper()} 进行视觉分析")