mirror of
https://github.com/linyqh/NarratoAI.git
synced 2026-05-01 06:08:16 +00:00
feat: 移除 LiteLLM 依赖并迁移至 OpenAI 兼容接口
- 移除 LiteLLM 相关代码和依赖,改用原生 OpenAI 兼容接口 - 重构 LLM 服务提供商注册逻辑,仅支持 OpenAI 兼容接口 - 更新配置文件和文档,移除 LiteLLM 相关说明 - 添加新的测试用例验证 OpenAI 兼容接口集成 - 更新 WebUI 组件以适配新的 OpenAI 兼容接口
This commit is contained in:
parent
a6f2e0d815
commit
3396644593
@ -33,8 +33,9 @@ NarratoAI 是一个自动化影视解说工具,基于LLM实现文案撰写、
|
||||
本项目仅供学习和研究使用,不得商用。如需商业授权,请联系作者。
|
||||
|
||||
## 最新资讯
|
||||
- 2026.03.27 出于安全考虑,已移除 LiteLLM 依赖,统一使用 OpenAI 兼容请求链路
|
||||
- 2025.11.20 发布新版本 0.7.5, 新增 [IndexTTS2](https://github.com/index-tts/index-tts) 语音克隆支持
|
||||
- 2025.10.15 发布新版本 0.7.3, 使用 [LiteLLM](https://github.com/BerriAI/litellm) 管理模型供应商
|
||||
- 2025.10.15 发布新版本 0.7.3, 升级大模型供应商管理能力
|
||||
- 2025.09.10 发布新版本 0.7.2, 新增腾讯云tts
|
||||
- 2025.08.18 发布新版本 0.7.1,支持 **语音克隆** 和 最新大模型
|
||||
- 2025.05.11 发布新版本 0.6.0,支持 **短剧解说** 和 优化剪辑流程
|
||||
@ -176,4 +177,3 @@ streamlit run webui.py --server.maxUploadSize=2048
|
||||
## Star History
|
||||
|
||||
[](https://star-history.com/#linyqh/NarratoAI&Date)
|
||||
|
||||
|
||||
@ -12,8 +12,7 @@ NarratoAI 大模型服务模块
|
||||
- OutputValidator: 输出格式验证器
|
||||
|
||||
支持的供应商:
|
||||
视觉模型: Gemini, QwenVL, Siliconflow
|
||||
文本模型: OpenAI, DeepSeek, Gemini, Qwen, Moonshot, Siliconflow
|
||||
视觉模型/文本模型: OpenAI 兼容接口(可对接 OpenAI、DeepSeek、Gemini 网关、Qwen 网关等)
|
||||
"""
|
||||
|
||||
from .manager import LLMServiceManager
|
||||
|
||||
@ -68,7 +68,7 @@ class BaseLLMProvider(ABC):
|
||||
"""验证模型支持情况(宽松模式,仅记录警告)"""
|
||||
from loguru import logger
|
||||
|
||||
# LiteLLM 已提供统一的模型验证,传统 provider 使用宽松验证
|
||||
# OpenAI 兼容网关的模型数量较多,运行时由远端完成最终校验
|
||||
if self.model_name not in self.supported_models:
|
||||
logger.warning(
|
||||
f"模型 {self.model_name} 未在供应商 {self.provider_name} 的预定义支持列表中。"
|
||||
|
||||
@ -214,7 +214,7 @@ class LLMConfigValidator:
|
||||
"建议为每个提供商配置base_url以提高稳定性",
|
||||
"定期检查模型名称是否为最新版本",
|
||||
"建议配置多个提供商作为备用方案",
|
||||
"推荐使用 LiteLLM 作为统一接口,支持 100+ providers"
|
||||
"推荐使用 OpenAI 兼容接口,便于接入多家模型网关"
|
||||
]
|
||||
}
|
||||
|
||||
@ -257,8 +257,8 @@ class LLMConfigValidator:
|
||||
"text": ["gemini-2.5-flash", "gemini-2.0-flash", "gemini-1.5-pro"]
|
||||
},
|
||||
"openai": {
|
||||
"vision": [],
|
||||
"text": ["gpt-4o", "gpt-4o-mini", "gpt-4-turbo"]
|
||||
"vision": ["gpt-4o", "gemini-2.0-flash-lite", "Qwen/Qwen2.5-VL-32B-Instruct"],
|
||||
"text": ["gpt-4o-mini", "deepseek-chat", "zai-org/GLM-4.6"]
|
||||
},
|
||||
"qwen": {
|
||||
"vision": ["qwen2.5-vl-32b-instruct"],
|
||||
|
||||
@ -1,490 +0,0 @@
|
||||
"""
|
||||
LiteLLM 统一提供商实现
|
||||
|
||||
使用 LiteLLM 库提供统一的 LLM 接口,支持 100+ providers
|
||||
包括 OpenAI, Anthropic, Gemini, Qwen, DeepSeek, SiliconFlow 等
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from pathlib import Path
|
||||
import PIL.Image
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
import litellm
|
||||
from litellm import acompletion, completion
|
||||
from litellm.exceptions import (
|
||||
AuthenticationError as LiteLLMAuthError,
|
||||
RateLimitError as LiteLLMRateLimitError,
|
||||
BadRequestError as LiteLLMBadRequestError,
|
||||
APIError as LiteLLMAPIError
|
||||
)
|
||||
except ImportError:
|
||||
logger.error("LiteLLM 未安装。请运行: pip install litellm")
|
||||
raise
|
||||
|
||||
from .base import VisionModelProvider, TextModelProvider
|
||||
from .exceptions import (
|
||||
APICallError,
|
||||
AuthenticationError,
|
||||
RateLimitError,
|
||||
ContentFilterError
|
||||
)
|
||||
|
||||
|
||||
# 配置 LiteLLM 全局设置
|
||||
def configure_litellm():
|
||||
"""配置 LiteLLM 全局参数"""
|
||||
from app.config import config
|
||||
|
||||
# 设置重试次数
|
||||
litellm.num_retries = config.app.get('llm_max_retries', 3)
|
||||
|
||||
# 设置默认超时
|
||||
litellm.request_timeout = config.app.get('llm_text_timeout', 180)
|
||||
|
||||
# 启用详细日志(开发环境)
|
||||
# litellm.set_verbose = True
|
||||
|
||||
logger.info(f"LiteLLM 配置完成: retries={litellm.num_retries}, timeout={litellm.request_timeout}s")
|
||||
|
||||
|
||||
# 初始化配置
|
||||
configure_litellm()
|
||||
|
||||
|
||||
class LiteLLMVisionProvider(VisionModelProvider):
|
||||
"""使用 LiteLLM 的统一视觉模型提供商"""
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
# 从 model_name 中提取 provider 名称(如 "gemini/gemini-2.0-flash")
|
||||
if "/" in self.model_name:
|
||||
return self.model_name.split("/")[0]
|
||||
return "litellm"
|
||||
|
||||
@property
|
||||
def supported_models(self) -> List[str]:
|
||||
# LiteLLM 支持 100+ providers 和数百个模型,无法全部列举
|
||||
# 返回空列表表示跳过预定义列表检查,由 LiteLLM 在实际调用时验证
|
||||
return []
|
||||
|
||||
def _validate_model_support(self):
|
||||
"""
|
||||
重写模型验证逻辑
|
||||
|
||||
对于 LiteLLM,我们不做预定义列表检查,因为:
|
||||
1. LiteLLM 支持 100+ providers 和数百个模型,无法全部列举
|
||||
2. LiteLLM 会在实际调用时进行模型验证
|
||||
3. 如果模型不支持,LiteLLM 会返回清晰的错误信息
|
||||
|
||||
这里只做基本的格式验证(可选)
|
||||
"""
|
||||
from loguru import logger
|
||||
|
||||
# 可选:检查模型名称格式(provider/model)
|
||||
if "/" not in self.model_name:
|
||||
logger.debug(
|
||||
f"LiteLLM 模型名称 '{self.model_name}' 未包含 provider 前缀,"
|
||||
f"LiteLLM 将尝试自动推断。建议使用 'provider/model' 格式,如 'gemini/gemini-2.5-flash'"
|
||||
)
|
||||
|
||||
# 不抛出异常,让 LiteLLM 在实际调用时验证
|
||||
logger.debug(f"LiteLLM 视觉模型已配置: {self.model_name}")
|
||||
|
||||
def _initialize(self):
|
||||
"""初始化 LiteLLM 特定设置"""
|
||||
# 设置 API key 到环境变量(LiteLLM 会自动读取)
|
||||
import os
|
||||
|
||||
# 根据 model_name 确定需要设置哪个 API key
|
||||
provider = self.provider_name.lower()
|
||||
|
||||
# 映射 provider 到环境变量名
|
||||
env_key_mapping = {
|
||||
"gemini": "GEMINI_API_KEY",
|
||||
"google": "GEMINI_API_KEY",
|
||||
"openai": "OPENAI_API_KEY",
|
||||
"qwen": "QWEN_API_KEY",
|
||||
"dashscope": "DASHSCOPE_API_KEY",
|
||||
"siliconflow": "SILICONFLOW_API_KEY",
|
||||
"anthropic": "ANTHROPIC_API_KEY",
|
||||
"claude": "ANTHROPIC_API_KEY"
|
||||
}
|
||||
|
||||
env_var = env_key_mapping.get(provider, f"{provider.upper()}_API_KEY")
|
||||
|
||||
if self.api_key and env_var:
|
||||
os.environ[env_var] = self.api_key
|
||||
logger.debug(f"设置环境变量: {env_var}")
|
||||
|
||||
# 如果提供了 base_url,设置到 LiteLLM
|
||||
if self.base_url:
|
||||
# LiteLLM 支持通过 api_base 参数设置自定义 URL
|
||||
self._api_base = self.base_url
|
||||
logger.debug(f"使用自定义 API base URL: {self.base_url}")
|
||||
|
||||
async def analyze_images(self,
|
||||
images: List[Union[str, Path, PIL.Image.Image]],
|
||||
prompt: str,
|
||||
batch_size: int = 10,
|
||||
**kwargs) -> List[str]:
|
||||
"""
|
||||
使用 LiteLLM 分析图片
|
||||
|
||||
Args:
|
||||
images: 图片路径列表或PIL图片对象列表
|
||||
prompt: 分析提示词
|
||||
batch_size: 批处理大小
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
分析结果列表
|
||||
"""
|
||||
logger.info(f"开始使用 LiteLLM ({self.model_name}) 分析 {len(images)} 张图片")
|
||||
|
||||
# 预处理图片
|
||||
processed_images = self._prepare_images(images)
|
||||
|
||||
# 分批处理
|
||||
results = []
|
||||
for i in range(0, len(processed_images), batch_size):
|
||||
batch = processed_images[i:i + batch_size]
|
||||
logger.info(f"处理第 {i//batch_size + 1} 批,共 {len(batch)} 张图片")
|
||||
|
||||
try:
|
||||
result = await self._analyze_batch(batch, prompt, **kwargs)
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"批次 {i//batch_size + 1} 处理失败: {str(e)}")
|
||||
results.append(f"批次处理失败: {str(e)}")
|
||||
|
||||
return results
|
||||
|
||||
async def _analyze_batch(self, batch: List[PIL.Image.Image], prompt: str, **kwargs) -> str:
|
||||
"""分析一批图片"""
|
||||
# 构建 LiteLLM 格式的消息
|
||||
content = [{"type": "text", "text": prompt}]
|
||||
|
||||
# 添加图片(使用 base64 编码)
|
||||
for img in batch:
|
||||
base64_image = self._image_to_base64(img)
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}"
|
||||
}
|
||||
})
|
||||
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": content
|
||||
}]
|
||||
|
||||
# 调用 LiteLLM
|
||||
try:
|
||||
# 准备参数
|
||||
effective_model_name = self.model_name
|
||||
|
||||
# SiliconFlow 特殊处理
|
||||
if self.model_name.lower().startswith("siliconflow/"):
|
||||
# 替换 provider 为 openai
|
||||
if "/" in self.model_name:
|
||||
effective_model_name = f"openai/{self.model_name.split('/', 1)[1]}"
|
||||
else:
|
||||
effective_model_name = f"openai/{self.model_name}"
|
||||
|
||||
# 确保设置了 OPENAI_API_KEY (如果尚未设置)
|
||||
import os
|
||||
if not os.environ.get("OPENAI_API_KEY") and os.environ.get("SILICONFLOW_API_KEY"):
|
||||
os.environ["OPENAI_API_KEY"] = os.environ.get("SILICONFLOW_API_KEY")
|
||||
|
||||
# 确保设置了 base_url (如果尚未设置)
|
||||
if not hasattr(self, '_api_base'):
|
||||
self._api_base = "https://api.siliconflow.cn/v1"
|
||||
|
||||
completion_kwargs = {
|
||||
"model": effective_model_name,
|
||||
"messages": messages,
|
||||
"temperature": kwargs.get("temperature", 1.0),
|
||||
"max_tokens": kwargs.get("max_tokens", 4000)
|
||||
}
|
||||
|
||||
# 如果有自定义 base_url,添加 api_base 参数
|
||||
if hasattr(self, '_api_base'):
|
||||
completion_kwargs["api_base"] = self._api_base
|
||||
|
||||
# 支持动态传递 api_key 和 api_base
|
||||
if "api_key" in kwargs:
|
||||
completion_kwargs["api_key"] = kwargs["api_key"]
|
||||
if "api_base" in kwargs:
|
||||
completion_kwargs["api_base"] = kwargs["api_base"]
|
||||
|
||||
response = await acompletion(**completion_kwargs)
|
||||
|
||||
if response.choices and len(response.choices) > 0:
|
||||
content = response.choices[0].message.content
|
||||
logger.debug(f"LiteLLM 调用成功,消耗 tokens: {response.usage.total_tokens if response.usage else 'N/A'}")
|
||||
return content
|
||||
else:
|
||||
raise APICallError("LiteLLM 返回空响应")
|
||||
|
||||
except LiteLLMAuthError as e:
|
||||
logger.error(f"LiteLLM 认证失败: {str(e)}")
|
||||
raise AuthenticationError()
|
||||
except LiteLLMRateLimitError as e:
|
||||
logger.error(f"LiteLLM 速率限制: {str(e)}")
|
||||
raise RateLimitError()
|
||||
except LiteLLMBadRequestError as e:
|
||||
error_msg = str(e)
|
||||
if "SAFETY" in error_msg.upper() or "content_filter" in error_msg.lower():
|
||||
raise ContentFilterError(f"内容被安全过滤器阻止: {error_msg}")
|
||||
logger.error(f"LiteLLM 请求错误: {error_msg}")
|
||||
raise APICallError(f"请求错误: {error_msg}")
|
||||
except LiteLLMAPIError as e:
|
||||
logger.error(f"LiteLLM API 错误: {str(e)}")
|
||||
raise APICallError(f"API 错误: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"LiteLLM 调用失败: {str(e)}")
|
||||
raise APICallError(f"调用失败: {str(e)}")
|
||||
|
||||
def _image_to_base64(self, img: PIL.Image.Image) -> str:
|
||||
"""将PIL图片转换为base64编码"""
|
||||
img_buffer = io.BytesIO()
|
||||
img.save(img_buffer, format='JPEG', quality=85)
|
||||
img_bytes = img_buffer.getvalue()
|
||||
return base64.b64encode(img_bytes).decode('utf-8')
|
||||
|
||||
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""兼容基类接口(实际使用 LiteLLM SDK)"""
|
||||
pass
|
||||
|
||||
|
||||
class LiteLLMTextProvider(TextModelProvider):
|
||||
"""使用 LiteLLM 的统一文本生成提供商"""
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
# 从 model_name 中提取 provider 名称
|
||||
if "/" in self.model_name:
|
||||
return self.model_name.split("/")[0]
|
||||
# 尝试从模型名称推断 provider
|
||||
model_lower = self.model_name.lower()
|
||||
if "gpt" in model_lower:
|
||||
return "openai"
|
||||
elif "claude" in model_lower:
|
||||
return "anthropic"
|
||||
elif "gemini" in model_lower:
|
||||
return "gemini"
|
||||
elif "qwen" in model_lower:
|
||||
return "qwen"
|
||||
elif "deepseek" in model_lower:
|
||||
return "deepseek"
|
||||
return "litellm"
|
||||
|
||||
@property
|
||||
def supported_models(self) -> List[str]:
|
||||
# LiteLLM 支持 100+ providers 和数百个模型,无法全部列举
|
||||
# 返回空列表表示跳过预定义列表检查,由 LiteLLM 在实际调用时验证
|
||||
return []
|
||||
|
||||
def _validate_model_support(self):
|
||||
"""
|
||||
重写模型验证逻辑
|
||||
|
||||
对于 LiteLLM,我们不做预定义列表检查,因为:
|
||||
1. LiteLLM 支持 100+ providers 和数百个模型,无法全部列举
|
||||
2. LiteLLM 会在实际调用时进行模型验证
|
||||
3. 如果模型不支持,LiteLLM 会返回清晰的错误信息
|
||||
|
||||
这里只做基本的格式验证(可选)
|
||||
"""
|
||||
from loguru import logger
|
||||
|
||||
# 可选:检查模型名称格式(provider/model)
|
||||
if "/" not in self.model_name:
|
||||
logger.debug(
|
||||
f"LiteLLM 模型名称 '{self.model_name}' 未包含 provider 前缀,"
|
||||
f"LiteLLM 将尝试自动推断。建议使用 'provider/model' 格式,如 'gemini/gemini-2.5-flash'"
|
||||
)
|
||||
|
||||
# 不抛出异常,让 LiteLLM 在实际调用时验证
|
||||
logger.debug(f"LiteLLM 文本模型已配置: {self.model_name}")
|
||||
|
||||
def _initialize(self):
|
||||
"""初始化 LiteLLM 特定设置"""
|
||||
import os
|
||||
|
||||
# 根据 model_name 确定需要设置哪个 API key
|
||||
provider = self.provider_name.lower()
|
||||
|
||||
# 映射 provider 到环境变量名
|
||||
env_key_mapping = {
|
||||
"gemini": "GEMINI_API_KEY",
|
||||
"google": "GEMINI_API_KEY",
|
||||
"openai": "OPENAI_API_KEY",
|
||||
"qwen": "QWEN_API_KEY",
|
||||
"dashscope": "DASHSCOPE_API_KEY",
|
||||
"siliconflow": "SILICONFLOW_API_KEY",
|
||||
"deepseek": "DEEPSEEK_API_KEY",
|
||||
"anthropic": "ANTHROPIC_API_KEY",
|
||||
"claude": "ANTHROPIC_API_KEY",
|
||||
"moonshot": "MOONSHOT_API_KEY"
|
||||
}
|
||||
|
||||
env_var = env_key_mapping.get(provider, f"{provider.upper()}_API_KEY")
|
||||
|
||||
if self.api_key and env_var:
|
||||
os.environ[env_var] = self.api_key
|
||||
logger.debug(f"设置环境变量: {env_var}")
|
||||
|
||||
# 如果提供了 base_url,保存用于后续调用
|
||||
if self.base_url:
|
||||
self._api_base = self.base_url
|
||||
logger.debug(f"使用自定义 API base URL: {self.base_url}")
|
||||
|
||||
async def generate_text(self,
|
||||
prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
response_format: Optional[str] = None,
|
||||
**kwargs) -> str:
|
||||
"""
|
||||
使用 LiteLLM 生成文本
|
||||
|
||||
Args:
|
||||
prompt: 用户提示词
|
||||
system_prompt: 系统提示词
|
||||
temperature: 生成温度
|
||||
max_tokens: 最大token数
|
||||
response_format: 响应格式 ('json' 或 None)
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
生成的文本内容
|
||||
"""
|
||||
# 构建消息列表
|
||||
messages = self._build_messages(prompt, system_prompt)
|
||||
|
||||
# 准备参数
|
||||
effective_model_name = self.model_name
|
||||
|
||||
# SiliconFlow 特殊处理
|
||||
if self.model_name.lower().startswith("siliconflow/"):
|
||||
# 替换 provider 为 openai
|
||||
if "/" in self.model_name:
|
||||
effective_model_name = f"openai/{self.model_name.split('/', 1)[1]}"
|
||||
else:
|
||||
effective_model_name = f"openai/{self.model_name}"
|
||||
|
||||
# 确保设置了 OPENAI_API_KEY (如果尚未设置)
|
||||
import os
|
||||
if not os.environ.get("OPENAI_API_KEY") and os.environ.get("SILICONFLOW_API_KEY"):
|
||||
os.environ["OPENAI_API_KEY"] = os.environ.get("SILICONFLOW_API_KEY")
|
||||
|
||||
# 确保设置了 base_url (如果尚未设置)
|
||||
if not hasattr(self, '_api_base'):
|
||||
self._api_base = "https://api.siliconflow.cn/v1"
|
||||
|
||||
completion_kwargs = {
|
||||
"model": effective_model_name,
|
||||
"messages": messages,
|
||||
"temperature": temperature
|
||||
}
|
||||
|
||||
if max_tokens:
|
||||
completion_kwargs["max_tokens"] = max_tokens
|
||||
|
||||
# 处理 JSON 格式输出
|
||||
# LiteLLM 会自动处理不同 provider 的 JSON mode 差异
|
||||
if response_format == "json":
|
||||
try:
|
||||
completion_kwargs["response_format"] = {"type": "json_object"}
|
||||
except Exception as e:
|
||||
# 如果不支持,在提示词中添加约束
|
||||
logger.warning(f"模型可能不支持 response_format,将在提示词中添加 JSON 约束: {str(e)}")
|
||||
messages[-1]["content"] += "\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。"
|
||||
|
||||
# 如果有自定义 base_url,添加 api_base 参数
|
||||
if hasattr(self, '_api_base'):
|
||||
completion_kwargs["api_base"] = self._api_base
|
||||
|
||||
# 支持动态传递 api_key 和 api_base (修复认证问题)
|
||||
if "api_key" in kwargs:
|
||||
completion_kwargs["api_key"] = kwargs["api_key"]
|
||||
if "api_base" in kwargs:
|
||||
completion_kwargs["api_base"] = kwargs["api_base"]
|
||||
|
||||
try:
|
||||
# 调用 LiteLLM(自动重试)
|
||||
response = await acompletion(**completion_kwargs)
|
||||
|
||||
if response.choices and len(response.choices) > 0:
|
||||
content = response.choices[0].message.content
|
||||
|
||||
# 清理可能的 markdown 代码块(针对不支持 JSON mode 的模型)
|
||||
if response_format == "json" and "response_format" not in completion_kwargs:
|
||||
content = self._clean_json_output(content)
|
||||
|
||||
logger.debug(f"LiteLLM 调用成功,消耗 tokens: {response.usage.total_tokens if response.usage else 'N/A'}")
|
||||
return content
|
||||
else:
|
||||
raise APICallError("LiteLLM 返回空响应")
|
||||
|
||||
except LiteLLMAuthError as e:
|
||||
logger.error(f"LiteLLM 认证失败: {str(e)}")
|
||||
raise AuthenticationError()
|
||||
except LiteLLMRateLimitError as e:
|
||||
logger.error(f"LiteLLM 速率限制: {str(e)}")
|
||||
raise RateLimitError()
|
||||
except LiteLLMBadRequestError as e:
|
||||
error_msg = str(e)
|
||||
# 处理不支持 response_format 的情况
|
||||
if "response_format" in error_msg and response_format == "json":
|
||||
logger.warning(f"模型不支持 response_format,重试不带格式约束的请求")
|
||||
completion_kwargs.pop("response_format", None)
|
||||
messages[-1]["content"] += "\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。"
|
||||
|
||||
# 重试
|
||||
response = await acompletion(**completion_kwargs)
|
||||
if response.choices and len(response.choices) > 0:
|
||||
content = response.choices[0].message.content
|
||||
content = self._clean_json_output(content)
|
||||
return content
|
||||
else:
|
||||
raise APICallError("LiteLLM 返回空响应")
|
||||
|
||||
# 检查是否是安全过滤
|
||||
if "SAFETY" in error_msg.upper() or "content_filter" in error_msg.lower():
|
||||
raise ContentFilterError(f"内容被安全过滤器阻止: {error_msg}")
|
||||
|
||||
logger.error(f"LiteLLM 请求错误: {error_msg}")
|
||||
raise APICallError(f"请求错误: {error_msg}")
|
||||
except LiteLLMAPIError as e:
|
||||
logger.error(f"LiteLLM API 错误: {str(e)}")
|
||||
raise APICallError(f"API 错误: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"LiteLLM 调用失败: {str(e)}")
|
||||
raise APICallError(f"调用失败: {str(e)}")
|
||||
|
||||
def _clean_json_output(self, output: str) -> str:
|
||||
"""清理JSON输出,移除markdown标记等"""
|
||||
import re
|
||||
|
||||
# 移除可能的markdown代码块标记
|
||||
output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE)
|
||||
output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE)
|
||||
output = re.sub(r'^```.*$', '', output, flags=re.MULTILINE)
|
||||
|
||||
# 移除前后空白字符
|
||||
output = output.strip()
|
||||
|
||||
return output
|
||||
|
||||
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""兼容基类接口(实际使用 LiteLLM SDK)"""
|
||||
pass
|
||||
@ -4,7 +4,7 @@
|
||||
统一管理所有大模型服务提供商,提供简单的工厂方法来创建和获取服务实例
|
||||
"""
|
||||
|
||||
from typing import Dict, Type, Optional
|
||||
from typing import Dict, Type, Optional, Tuple
|
||||
from loguru import logger
|
||||
|
||||
from app.config import config
|
||||
@ -64,6 +64,29 @@ class LLMServiceManager:
|
||||
"vision_providers": list(cls._vision_providers.keys()),
|
||||
"text_providers": list(cls._text_providers.keys())
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _normalize_provider_name(cls, provider_name: str) -> str:
|
||||
"""规范化 provider 名称。"""
|
||||
return provider_name.lower()
|
||||
|
||||
@classmethod
|
||||
def _get_provider_config(
|
||||
cls,
|
||||
model_type: str,
|
||||
provider_name: str
|
||||
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
"""
|
||||
获取 provider 配置。
|
||||
|
||||
model_type: 'vision' 或 'text'
|
||||
"""
|
||||
config_prefix = f"{model_type}_{provider_name}"
|
||||
api_key = config.app.get(f"{config_prefix}_api_key")
|
||||
model_name = config.app.get(f"{config_prefix}_model_name")
|
||||
base_url = config.app.get(f"{config_prefix}_base_url")
|
||||
|
||||
return api_key, model_name, base_url
|
||||
|
||||
@classmethod
|
||||
def get_vision_provider(cls, provider_name: Optional[str] = None) -> VisionModelProvider:
|
||||
@ -89,9 +112,8 @@ class LLMServiceManager:
|
||||
|
||||
# 确定提供商名称
|
||||
if not provider_name:
|
||||
provider_name = config.app.get('vision_llm_provider', 'gemini').lower()
|
||||
else:
|
||||
provider_name = provider_name.lower()
|
||||
provider_name = config.app.get('vision_llm_provider', 'openai')
|
||||
provider_name = cls._normalize_provider_name(provider_name)
|
||||
|
||||
# 检查缓存
|
||||
cache_key = f"vision_{provider_name}"
|
||||
@ -104,9 +126,7 @@ class LLMServiceManager:
|
||||
|
||||
# 获取配置
|
||||
config_prefix = f"vision_{provider_name}"
|
||||
api_key = config.app.get(f'{config_prefix}_api_key')
|
||||
model_name = config.app.get(f'{config_prefix}_model_name')
|
||||
base_url = config.app.get(f'{config_prefix}_base_url')
|
||||
api_key, model_name, base_url = cls._get_provider_config("vision", provider_name)
|
||||
|
||||
if not api_key:
|
||||
raise ConfigurationError(f"缺少API密钥配置: {config_prefix}_api_key")
|
||||
@ -157,9 +177,8 @@ class LLMServiceManager:
|
||||
|
||||
# 确定提供商名称
|
||||
if not provider_name:
|
||||
provider_name = config.app.get('text_llm_provider', 'openai').lower()
|
||||
else:
|
||||
provider_name = provider_name.lower()
|
||||
provider_name = config.app.get('text_llm_provider', 'openai')
|
||||
provider_name = cls._normalize_provider_name(provider_name)
|
||||
|
||||
logger.debug(f"获取文本模型提供商: {provider_name}")
|
||||
logger.debug(f"已注册的文本提供商: {list(cls._text_providers.keys())}")
|
||||
@ -178,9 +197,7 @@ class LLMServiceManager:
|
||||
|
||||
# 获取配置
|
||||
config_prefix = f"text_{provider_name}"
|
||||
api_key = config.app.get(f'{config_prefix}_api_key')
|
||||
model_name = config.app.get(f'{config_prefix}_model_name')
|
||||
base_url = config.app.get(f'{config_prefix}_base_url')
|
||||
api_key, model_name, base_url = cls._get_provider_config("text", provider_name)
|
||||
|
||||
if not api_key:
|
||||
raise ConfigurationError(f"缺少API密钥配置: {config_prefix}_api_key")
|
||||
|
||||
276
app/services/llm/openai_compatible_provider.py
Normal file
276
app/services/llm/openai_compatible_provider.py
Normal file
@ -0,0 +1,276 @@
|
||||
"""
|
||||
OpenAI 兼容提供商实现
|
||||
|
||||
使用 OpenAI 官方 SDK 调用 OpenAI 兼容接口,支持文本和视觉模型。
|
||||
"""
|
||||
|
||||
import io
|
||||
import base64
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import PIL.Image
|
||||
from loguru import logger
|
||||
from openai import (
|
||||
APIError as OpenAIAPIError,
|
||||
AsyncOpenAI,
|
||||
AuthenticationError as OpenAIAuthError,
|
||||
BadRequestError as OpenAIBadRequestError,
|
||||
RateLimitError as OpenAIRateLimitError,
|
||||
)
|
||||
|
||||
from app.config import config
|
||||
from .base import TextModelProvider, VisionModelProvider
|
||||
from .exceptions import APICallError, AuthenticationError, ContentFilterError, RateLimitError
|
||||
|
||||
# 常见 OpenAI 兼容网关前缀。若使用 provider/model 格式,将剥离 provider 前缀。
|
||||
OPENAI_COMPATIBLE_PROVIDER_PREFIXES = {
|
||||
"openai",
|
||||
"gemini",
|
||||
"deepseek",
|
||||
"qwen",
|
||||
"siliconflow",
|
||||
"moonshot",
|
||||
"openrouter",
|
||||
"anthropic",
|
||||
"azure",
|
||||
"ollama",
|
||||
"mistral",
|
||||
"groq",
|
||||
"cohere",
|
||||
"together_ai",
|
||||
"fireworks_ai",
|
||||
"volcengine",
|
||||
"vertex_ai",
|
||||
"huggingface",
|
||||
"xai",
|
||||
"bedrock",
|
||||
"cloudflare",
|
||||
"vllm",
|
||||
"codestral",
|
||||
"replicate",
|
||||
"deepgram",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_model_name(model_name: str) -> str:
|
||||
"""兼容历史 provider/model 写法,必要时自动剥离 provider 前缀。"""
|
||||
if "/" not in model_name:
|
||||
return model_name
|
||||
|
||||
provider_prefix, raw_model = model_name.split("/", 1)
|
||||
if provider_prefix.lower() in OPENAI_COMPATIBLE_PROVIDER_PREFIXES and raw_model:
|
||||
return raw_model
|
||||
return model_name
|
||||
|
||||
|
||||
def _is_response_format_error(message: str) -> bool:
|
||||
return "response_format" in (message or "").lower()
|
||||
|
||||
|
||||
def _is_content_filter_error(message: str) -> bool:
|
||||
lowered = (message or "").lower()
|
||||
return "content_filter" in lowered or "safety" in lowered
|
||||
|
||||
|
||||
def _clean_json_output(output: str) -> str:
|
||||
"""清理 JSON 输出中的 markdown 包裹。"""
|
||||
output = re.sub(r"^```json\s*", "", output, flags=re.MULTILINE)
|
||||
output = re.sub(r"^```\s*$", "", output, flags=re.MULTILINE)
|
||||
output = re.sub(r"^```.*$", "", output, flags=re.MULTILINE)
|
||||
return output.strip()
|
||||
|
||||
|
||||
class _OpenAICompatibleBase:
|
||||
"""OpenAI 兼容 provider 共享逻辑。"""
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "openai"
|
||||
|
||||
@property
|
||||
def supported_models(self) -> List[str]:
|
||||
# 兼容网关模型数量很多,运行时校验由远端完成。
|
||||
return []
|
||||
|
||||
def _validate_model_support(self):
|
||||
logger.debug(f"OpenAI 兼容模型已配置: {self.model_name}")
|
||||
|
||||
def _initialize(self):
|
||||
# SDK client 按请求参数动态构建,这里无需初始化全局状态。
|
||||
pass
|
||||
|
||||
def _build_client(
|
||||
self,
|
||||
api_key_override: Optional[str] = None,
|
||||
base_url_override: Optional[str] = None,
|
||||
timeout_override: Optional[float] = None,
|
||||
) -> AsyncOpenAI:
|
||||
"""按请求构建 AsyncOpenAI 客户端,支持动态覆盖 api_key / base_url。"""
|
||||
api_key = api_key_override or self.api_key
|
||||
base_url = base_url_override or self.base_url or None
|
||||
|
||||
timeout_seconds: float = timeout_override or config.app.get("llm_text_timeout", 180)
|
||||
max_retries: int = config.app.get("llm_max_retries", 3)
|
||||
|
||||
return AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
timeout=timeout_seconds,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
|
||||
class OpenAICompatibleVisionProvider(_OpenAICompatibleBase, VisionModelProvider):
|
||||
"""OpenAI 兼容视觉模型提供商。"""
|
||||
|
||||
async def analyze_images(
|
||||
self,
|
||||
images: List[Union[str, Path, PIL.Image.Image]],
|
||||
prompt: str,
|
||||
batch_size: int = 10,
|
||||
**kwargs,
|
||||
) -> List[str]:
|
||||
logger.info(f"开始使用 OpenAI 兼容接口 ({self.model_name}) 分析 {len(images)} 张图片")
|
||||
|
||||
processed_images = self._prepare_images(images)
|
||||
results: List[str] = []
|
||||
|
||||
for i in range(0, len(processed_images), batch_size):
|
||||
batch = processed_images[i : i + batch_size]
|
||||
logger.info(f"处理第 {i // batch_size + 1} 批,共 {len(batch)} 张图片")
|
||||
try:
|
||||
result = await self._analyze_batch(batch, prompt, **kwargs)
|
||||
results.append(result)
|
||||
except Exception as exc:
|
||||
logger.error(f"批次 {i // batch_size + 1} 处理失败: {exc}")
|
||||
results.append(f"批次处理失败: {exc}")
|
||||
|
||||
return results
|
||||
|
||||
async def _analyze_batch(self, batch: List[PIL.Image.Image], prompt: str, **kwargs) -> str:
|
||||
content = [{"type": "text", "text": prompt}]
|
||||
for img in batch:
|
||||
content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{self._image_to_base64(img)}"},
|
||||
}
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": content}]
|
||||
model_name = _normalize_model_name(self.model_name)
|
||||
|
||||
client = self._build_client(
|
||||
api_key_override=kwargs.get("api_key"),
|
||||
base_url_override=kwargs.get("api_base"),
|
||||
timeout_override=config.app.get("llm_vision_timeout", 120),
|
||||
)
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
temperature=kwargs.get("temperature", 1.0),
|
||||
max_tokens=kwargs.get("max_tokens", 4000),
|
||||
)
|
||||
if response.choices and response.choices[0].message and response.choices[0].message.content:
|
||||
return response.choices[0].message.content
|
||||
raise APICallError("OpenAI 兼容接口返回空响应")
|
||||
except OpenAIAuthError as exc:
|
||||
logger.error(f"OpenAI 兼容接口认证失败: {exc}")
|
||||
raise AuthenticationError(str(exc))
|
||||
except OpenAIRateLimitError as exc:
|
||||
logger.error(f"OpenAI 兼容接口速率限制: {exc}")
|
||||
raise RateLimitError(str(exc))
|
||||
except OpenAIBadRequestError as exc:
|
||||
error_msg = str(exc)
|
||||
if _is_content_filter_error(error_msg):
|
||||
raise ContentFilterError(f"内容被安全过滤器阻止: {error_msg}")
|
||||
raise APICallError(f"请求错误: {error_msg}")
|
||||
except OpenAIAPIError as exc:
|
||||
logger.error(f"OpenAI 兼容接口 API 错误: {exc}")
|
||||
raise APICallError(f"API 错误: {exc}")
|
||||
except Exception as exc:
|
||||
logger.error(f"OpenAI 兼容接口调用失败: {exc}")
|
||||
raise APICallError(f"调用失败: {exc}")
|
||||
|
||||
def _image_to_base64(self, img: PIL.Image.Image) -> str:
|
||||
img_buffer = io.BytesIO()
|
||||
img.save(img_buffer, format="JPEG", quality=85)
|
||||
return base64.b64encode(img_buffer.getvalue()).decode("utf-8")
|
||||
|
||||
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return payload
|
||||
|
||||
|
||||
class OpenAICompatibleTextProvider(_OpenAICompatibleBase, TextModelProvider):
|
||||
"""OpenAI 兼容文本模型提供商。"""
|
||||
|
||||
async def generate_text(
|
||||
self,
|
||||
prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
response_format: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
messages = self._build_messages(prompt, system_prompt)
|
||||
model_name = _normalize_model_name(self.model_name)
|
||||
|
||||
client = self._build_client(
|
||||
api_key_override=kwargs.get("api_key"),
|
||||
base_url_override=kwargs.get("api_base"),
|
||||
timeout_override=config.app.get("llm_text_timeout", 180),
|
||||
)
|
||||
|
||||
completion_kwargs: Dict[str, Any] = {
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
}
|
||||
if max_tokens:
|
||||
completion_kwargs["max_tokens"] = max_tokens
|
||||
if response_format == "json":
|
||||
completion_kwargs["response_format"] = {"type": "json_object"}
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(**completion_kwargs)
|
||||
if response.choices and response.choices[0].message and response.choices[0].message.content:
|
||||
return response.choices[0].message.content
|
||||
raise APICallError("OpenAI 兼容接口返回空响应")
|
||||
|
||||
except OpenAIBadRequestError as exc:
|
||||
error_msg = str(exc)
|
||||
# 某些网关不支持 response_format,回退到提示词约束模式
|
||||
if response_format == "json" and _is_response_format_error(error_msg):
|
||||
logger.warning("目标网关不支持 response_format,回退为提示词约束 JSON 输出")
|
||||
completion_kwargs.pop("response_format", None)
|
||||
messages[-1]["content"] += "\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。"
|
||||
|
||||
retry_response = await client.chat.completions.create(**completion_kwargs)
|
||||
if retry_response.choices and retry_response.choices[0].message and retry_response.choices[0].message.content:
|
||||
return _clean_json_output(retry_response.choices[0].message.content)
|
||||
raise APICallError("OpenAI 兼容接口返回空响应")
|
||||
|
||||
if _is_content_filter_error(error_msg):
|
||||
raise ContentFilterError(f"内容被安全过滤器阻止: {error_msg}")
|
||||
raise APICallError(f"请求错误: {error_msg}")
|
||||
|
||||
except OpenAIAuthError as exc:
|
||||
logger.error(f"OpenAI 兼容接口认证失败: {exc}")
|
||||
raise AuthenticationError(str(exc))
|
||||
except OpenAIRateLimitError as exc:
|
||||
logger.error(f"OpenAI 兼容接口速率限制: {exc}")
|
||||
raise RateLimitError(str(exc))
|
||||
except OpenAIAPIError as exc:
|
||||
logger.error(f"OpenAI 兼容接口 API 错误: {exc}")
|
||||
raise APICallError(f"API 错误: {exc}")
|
||||
except Exception as exc:
|
||||
logger.error(f"OpenAI 兼容接口调用失败: {exc}")
|
||||
raise APICallError(f"调用失败: {exc}")
|
||||
|
||||
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return payload
|
||||
@ -2,7 +2,6 @@
|
||||
大模型服务提供商实现
|
||||
|
||||
包含各种大模型服务提供商的具体实现
|
||||
推荐使用 LiteLLM 统一接口(支持 100+ providers)
|
||||
"""
|
||||
|
||||
# 不在模块顶部导入 provider 类,避免循环依赖
|
||||
@ -13,25 +12,25 @@ def register_all_providers():
|
||||
"""
|
||||
注册所有提供商
|
||||
|
||||
v0.8.0 变更:只注册 LiteLLM 统一接口
|
||||
- 移除了旧的单独 provider 实现 (gemini, openai, qwen, deepseek, siliconflow)
|
||||
- LiteLLM 支持 100+ providers,无需单独实现
|
||||
当前实现:只注册 OpenAI 兼容统一接口
|
||||
"""
|
||||
# 在函数内部导入,避免循环依赖
|
||||
from ..manager import LLMServiceManager
|
||||
from loguru import logger
|
||||
|
||||
# 只导入 LiteLLM provider
|
||||
from ..litellm_provider import LiteLLMVisionProvider, LiteLLMTextProvider
|
||||
# 只导入 OpenAI 兼容 provider
|
||||
from ..openai_compatible_provider import (
|
||||
OpenAICompatibleVisionProvider,
|
||||
OpenAICompatibleTextProvider,
|
||||
)
|
||||
|
||||
logger.info("🔧 开始注册 LLM 提供商...")
|
||||
|
||||
# ===== 注册 LiteLLM 统一接口 =====
|
||||
# LiteLLM 支持 100+ providers(OpenAI, Gemini, Qwen, DeepSeek, SiliconFlow, 等)
|
||||
LLMServiceManager.register_vision_provider('litellm', LiteLLMVisionProvider)
|
||||
LLMServiceManager.register_text_provider('litellm', LiteLLMTextProvider)
|
||||
# ===== 注册 OpenAI 兼容统一接口 =====
|
||||
LLMServiceManager.register_vision_provider('openai', OpenAICompatibleVisionProvider)
|
||||
LLMServiceManager.register_text_provider('openai', OpenAICompatibleTextProvider)
|
||||
|
||||
logger.info("✅ LiteLLM 提供商注册完成(支持 100+ providers)")
|
||||
logger.info("✅ OpenAI 兼容提供商注册完成")
|
||||
|
||||
|
||||
# 导出注册函数
|
||||
|
||||
@ -1,228 +0,0 @@
|
||||
"""
|
||||
LiteLLM 集成测试脚本
|
||||
|
||||
测试 LiteLLM provider 是否正确集成到系统中
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
project_root = Path(__file__).parent.parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from loguru import logger
|
||||
from app.services.llm.manager import LLMServiceManager
|
||||
from app.services.llm.unified_service import UnifiedLLMService
|
||||
|
||||
|
||||
def test_provider_registration():
|
||||
"""测试 provider 是否正确注册"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("测试 1: Provider 注册检查")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 检查 LiteLLM provider 是否已注册
|
||||
vision_providers = LLMServiceManager.list_vision_providers()
|
||||
text_providers = LLMServiceManager.list_text_providers()
|
||||
|
||||
logger.info(f"已注册的视觉模型 providers: {vision_providers}")
|
||||
logger.info(f"已注册的文本模型 providers: {text_providers}")
|
||||
|
||||
assert 'litellm' in vision_providers, "❌ LiteLLM Vision Provider 未注册"
|
||||
assert 'litellm' in text_providers, "❌ LiteLLM Text Provider 未注册"
|
||||
|
||||
logger.success("✅ LiteLLM providers 已成功注册")
|
||||
|
||||
# 显示所有 provider 信息
|
||||
provider_info = LLMServiceManager.get_provider_info()
|
||||
logger.info("\n所有 Provider 信息:")
|
||||
logger.info(f" 视觉模型 providers: {list(provider_info['vision_providers'].keys())}")
|
||||
logger.info(f" 文本模型 providers: {list(provider_info['text_providers'].keys())}")
|
||||
|
||||
|
||||
def test_litellm_import():
|
||||
"""测试 LiteLLM 库是否正确安装"""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("测试 2: LiteLLM 库导入检查")
|
||||
logger.info("=" * 60)
|
||||
|
||||
try:
|
||||
import litellm
|
||||
logger.success(f"✅ LiteLLM 已安装,版本: {litellm.__version__}")
|
||||
return True
|
||||
except ImportError as e:
|
||||
logger.error(f"❌ LiteLLM 未安装: {str(e)}")
|
||||
logger.info("请运行: pip install litellm>=1.70.0")
|
||||
return False
|
||||
|
||||
|
||||
async def test_text_generation_mock():
|
||||
"""测试文本生成接口(模拟模式,不实际调用 API)"""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("测试 3: 文本生成接口(模拟)")
|
||||
logger.info("=" * 60)
|
||||
|
||||
try:
|
||||
# 这里只测试接口是否可调用,不实际发送 API 请求
|
||||
logger.info("接口测试通过:UnifiedLLMService.generate_text 可调用")
|
||||
logger.success("✅ 文本生成接口测试通过")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 文本生成接口测试失败: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def test_vision_analysis_mock():
|
||||
"""测试视觉分析接口(模拟模式)"""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("测试 4: 视觉分析接口(模拟)")
|
||||
logger.info("=" * 60)
|
||||
|
||||
try:
|
||||
# 这里只测试接口是否可调用
|
||||
logger.info("接口测试通过:UnifiedLLMService.analyze_images 可调用")
|
||||
logger.success("✅ 视觉分析接口测试通过")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 视觉分析接口测试失败: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def test_backward_compatibility():
|
||||
"""测试向后兼容性"""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("测试 5: 向后兼容性检查")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 检查旧的 provider 是否仍然可用
|
||||
old_providers = ['gemini', 'openai', 'qwen', 'deepseek', 'siliconflow']
|
||||
vision_providers = LLMServiceManager.list_vision_providers()
|
||||
text_providers = LLMServiceManager.list_text_providers()
|
||||
|
||||
logger.info("检查旧 provider 是否仍然可用:")
|
||||
for provider in old_providers:
|
||||
if provider in ['openai', 'deepseek']:
|
||||
# 这些只有 text provider
|
||||
if provider in text_providers:
|
||||
logger.info(f" ✅ {provider} (text)")
|
||||
else:
|
||||
logger.warning(f" ⚠️ {provider} (text) 未注册")
|
||||
else:
|
||||
# 这些有 vision 和 text provider
|
||||
vision_ok = provider in vision_providers or f"{provider}vl" in vision_providers
|
||||
text_ok = provider in text_providers
|
||||
|
||||
if vision_ok:
|
||||
logger.info(f" ✅ {provider} (vision)")
|
||||
if text_ok:
|
||||
logger.info(f" ✅ {provider} (text)")
|
||||
|
||||
logger.success("✅ 向后兼容性测试通过")
|
||||
|
||||
|
||||
def print_usage_guide():
|
||||
"""打印使用指南"""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("LiteLLM 使用指南")
|
||||
logger.info("=" * 60)
|
||||
|
||||
guide = """
|
||||
📚 如何使用 LiteLLM:
|
||||
|
||||
1. 在 config.toml 中配置:
|
||||
```toml
|
||||
[app]
|
||||
# 方式 1:直接使用 LiteLLM(推荐)
|
||||
vision_llm_provider = "litellm"
|
||||
vision_litellm_model_name = "gemini/gemini-2.0-flash-lite"
|
||||
vision_litellm_api_key = "your-api-key"
|
||||
|
||||
text_llm_provider = "litellm"
|
||||
text_litellm_model_name = "deepseek/deepseek-chat"
|
||||
text_litellm_api_key = "your-api-key"
|
||||
```
|
||||
|
||||
2. 支持的模型格式:
|
||||
- Gemini: gemini/gemini-2.0-flash
|
||||
- DeepSeek: deepseek/deepseek-chat
|
||||
- Qwen: qwen/qwen-plus
|
||||
- OpenAI: gpt-4o, gpt-4o-mini
|
||||
- SiliconFlow: siliconflow/deepseek-ai/DeepSeek-R1
|
||||
- 更多: 参考 https://docs.litellm.ai/docs/providers
|
||||
|
||||
3. 代码调用示例:
|
||||
```python
|
||||
from app.services.llm.unified_service import UnifiedLLMService
|
||||
|
||||
# 文本生成
|
||||
result = await UnifiedLLMService.generate_text(
|
||||
prompt="你好",
|
||||
provider="litellm"
|
||||
)
|
||||
|
||||
# 视觉分析
|
||||
results = await UnifiedLLMService.analyze_images(
|
||||
images=["path/to/image.jpg"],
|
||||
prompt="描述这张图片",
|
||||
provider="litellm"
|
||||
)
|
||||
```
|
||||
|
||||
4. 优势:
|
||||
✅ 减少 80% 代码量
|
||||
✅ 统一的错误处理
|
||||
✅ 自动重试机制
|
||||
✅ 支持 100+ providers
|
||||
✅ 自动成本追踪
|
||||
|
||||
5. 迁移建议:
|
||||
- 新项目:直接使用 LiteLLM
|
||||
- 旧项目:逐步迁移,旧的 provider 仍然可用
|
||||
- 测试充分后再切换生产环境
|
||||
"""
|
||||
print(guide)
|
||||
|
||||
|
||||
def main():
|
||||
"""运行所有测试"""
|
||||
logger.info("开始 LiteLLM 集成测试...\n")
|
||||
|
||||
try:
|
||||
# 测试 1: Provider 注册
|
||||
test_provider_registration()
|
||||
|
||||
# 测试 2: LiteLLM 库导入
|
||||
litellm_available = test_litellm_import()
|
||||
|
||||
if not litellm_available:
|
||||
logger.warning("\n⚠️ LiteLLM 未安装,跳过 API 测试")
|
||||
logger.info("请运行: pip install litellm>=1.70.0")
|
||||
else:
|
||||
# 测试 3-4: 接口测试(模拟)
|
||||
asyncio.run(test_text_generation_mock())
|
||||
asyncio.run(test_vision_analysis_mock())
|
||||
|
||||
# 测试 5: 向后兼容性
|
||||
test_backward_compatibility()
|
||||
|
||||
# 打印使用指南
|
||||
print_usage_guide()
|
||||
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.success("🎉 所有测试通过!")
|
||||
logger.info("=" * 60)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"\n❌ 测试失败: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
67
app/services/llm/test_openai_compat_unittest.py
Normal file
67
app/services/llm/test_openai_compat_unittest.py
Normal file
@ -0,0 +1,67 @@
|
||||
"""OpenAI 兼容 provider 的最小回归测试。"""
|
||||
|
||||
import unittest
|
||||
|
||||
from app.config import config
|
||||
from app.services.llm.base import TextModelProvider
|
||||
from app.services.llm.manager import LLMServiceManager
|
||||
from app.services.llm.providers import register_all_providers
|
||||
|
||||
|
||||
class DummyOpenAITextProvider(TextModelProvider):
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "openai"
|
||||
|
||||
@property
|
||||
def supported_models(self) -> list[str]:
|
||||
return []
|
||||
|
||||
async def generate_text(self, prompt: str, **kwargs) -> str:
|
||||
return prompt
|
||||
|
||||
async def _make_api_call(self, payload: dict) -> dict:
|
||||
return payload
|
||||
|
||||
|
||||
def _reset_manager_state():
|
||||
LLMServiceManager._vision_providers.clear()
|
||||
LLMServiceManager._text_providers.clear()
|
||||
LLMServiceManager._vision_instance_cache.clear()
|
||||
LLMServiceManager._text_instance_cache.clear()
|
||||
|
||||
|
||||
class OpenAICompatManagerTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
_reset_manager_state()
|
||||
self._original_app = dict(config.app)
|
||||
|
||||
def tearDown(self):
|
||||
_reset_manager_state()
|
||||
config.app.clear()
|
||||
config.app.update(self._original_app)
|
||||
|
||||
def test_register_all_providers_only_registers_openai_provider(self):
|
||||
register_all_providers()
|
||||
|
||||
self.assertEqual({"openai"}, set(LLMServiceManager.list_text_providers()))
|
||||
self.assertEqual({"openai"}, set(LLMServiceManager.list_vision_providers()))
|
||||
|
||||
def test_get_text_provider_uses_openai_keys(self):
|
||||
LLMServiceManager.register_text_provider("openai", DummyOpenAITextProvider)
|
||||
|
||||
config.app["text_llm_provider"] = "openai"
|
||||
config.app["text_openai_api_key"] = "new-key"
|
||||
config.app["text_openai_model_name"] = "new-model"
|
||||
config.app["text_openai_base_url"] = "https://new.example/v1"
|
||||
|
||||
provider = LLMServiceManager.get_text_provider()
|
||||
|
||||
self.assertIsInstance(provider, DummyOpenAITextProvider)
|
||||
self.assertEqual("new-key", provider.api_key)
|
||||
self.assertEqual("new-model", provider.model_name)
|
||||
self.assertEqual("https://new.example/v1", provider.base_url)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
35
app/services/llm/test_openai_compatible_integration.py
Normal file
35
app/services/llm/test_openai_compatible_integration.py
Normal file
@ -0,0 +1,35 @@
|
||||
"""
|
||||
OpenAI 兼容接口集成测试脚本
|
||||
|
||||
用于快速检查统一 LLM Provider 是否注册成功。
|
||||
"""
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from app.services.llm.manager import LLMServiceManager
|
||||
from app.services.llm.providers import register_all_providers
|
||||
|
||||
|
||||
def test_provider_registration() -> bool:
|
||||
"""检查 OpenAI 兼容 provider 是否注册成功。"""
|
||||
logger.info("测试:Provider 注册检查")
|
||||
register_all_providers()
|
||||
|
||||
vision_providers = LLMServiceManager.list_vision_providers()
|
||||
text_providers = LLMServiceManager.list_text_providers()
|
||||
|
||||
assert "openai" in vision_providers, "❌ OpenAI 兼容 Vision Provider 未注册"
|
||||
assert "openai" in text_providers, "❌ OpenAI 兼容 Text Provider 未注册"
|
||||
|
||||
logger.success("✅ OpenAI 兼容 providers 已成功注册")
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
ok = test_provider_registration()
|
||||
if ok:
|
||||
logger.success("\n🎉 集成检查通过")
|
||||
except Exception as exc:
|
||||
logger.error(f"\n❌ 集成检查失败: {exc}")
|
||||
raise
|
||||
@ -133,7 +133,7 @@ class ScriptGenerator:
|
||||
from app.services.llm.migration_adapter import create_vision_analyzer
|
||||
|
||||
# 获取配置
|
||||
text_provider = config.app.get('text_llm_provider', 'litellm').lower()
|
||||
text_provider = config.app.get('text_llm_provider', 'openai').lower()
|
||||
vision_api_key = config.app.get(f'vision_{vision_llm_provider}_api_key')
|
||||
vision_model = config.app.get(f'vision_{vision_llm_provider}_model_name')
|
||||
vision_base_url = config.app.get(f'vision_{vision_llm_provider}_base_url')
|
||||
@ -321,4 +321,4 @@ class ScriptGenerator:
|
||||
last_timestamp = format_timestamp(last_time)
|
||||
timestamp_range = f"{first_timestamp}-{last_timestamp}"
|
||||
|
||||
return first_timestamp, last_timestamp, timestamp_range
|
||||
return first_timestamp, last_timestamp, timestamp_range
|
||||
|
||||
@ -4,24 +4,16 @@
|
||||
# LLM API 超时配置(秒)
|
||||
llm_vision_timeout = 120 # 视觉模型基础超时时间
|
||||
llm_text_timeout = 180 # 文本模型基础超时时间(解说文案生成等复杂任务需要更长时间)
|
||||
llm_max_retries = 3 # API 重试次数(LiteLLM 会自动处理重试)
|
||||
llm_max_retries = 3 # API 重试次数
|
||||
|
||||
##########################################
|
||||
# 🚀 LLM 配置 - 使用 LiteLLM 统一接口
|
||||
# 🚀 LLM 配置 - 使用 OpenAI 兼容统一接口
|
||||
##########################################
|
||||
# LiteLLM 是统一的 LLM 接口库,支持 100+ providers
|
||||
# 优势:
|
||||
# ✅ 代码量减少 80%,统一的 API 接口
|
||||
# ✅ 自动重试和智能错误处理
|
||||
# ✅ 内置成本追踪和 token 统计
|
||||
# ✅ 支持更多 providers:OpenAI, Anthropic, Gemini, Qwen, DeepSeek,
|
||||
# Cohere, Together AI, Replicate, Groq, Mistral 等
|
||||
#
|
||||
# 文档:https://docs.litellm.ai/
|
||||
# 支持的模型:https://docs.litellm.ai/docs/providers
|
||||
# 统一使用 OpenAI 兼容协议(/v1/chat/completions)
|
||||
# 支持接入 OpenAI、DeepSeek、Gemini 兼容网关、Qwen 网关、SiliconFlow、OpenRouter 等。
|
||||
|
||||
# ===== 视觉模型配置 =====
|
||||
vision_llm_provider = "litellm"
|
||||
vision_llm_provider = "openai"
|
||||
|
||||
# 模型格式:provider/model_name
|
||||
# 常用视觉模型示例:
|
||||
@ -30,12 +22,12 @@
|
||||
# - OpenAI: gpt-4o, gpt-4o-mini
|
||||
# - Qwen: qwen/qwen2.5-vl-32b-instruct
|
||||
# - SiliconFlow: siliconflow/Qwen/Qwen2.5-VL-32B-Instruct
|
||||
vision_litellm_model_name = "gemini/gemini-2.0-flash-lite"
|
||||
vision_litellm_api_key = "" # 填入对应 provider 的 API key
|
||||
vision_litellm_base_url = "" # 可选:自定义 API base URL
|
||||
vision_openai_model_name = "gemini/gemini-2.0-flash-lite"
|
||||
vision_openai_api_key = "" # 填入对应 provider 的 API key
|
||||
vision_openai_base_url = "" # 可选:自定义 API base URL(官方 OpenAI 可留空)
|
||||
|
||||
# ===== 文本模型配置 =====
|
||||
text_llm_provider = "litellm"
|
||||
text_llm_provider = "openai"
|
||||
|
||||
# 常用文本模型示例:
|
||||
# - DeepSeek: deepseek/deepseek-chat (推荐,性价比高)
|
||||
@ -45,9 +37,9 @@
|
||||
# - Qwen: qwen/qwen-plus, qwen/qwen-turbo
|
||||
# - SiliconFlow: siliconflow/deepseek-ai/DeepSeek-R1
|
||||
# - Moonshot: moonshot/moonshot-v1-8k
|
||||
text_litellm_model_name = "deepseek/deepseek-chat"
|
||||
text_litellm_api_key = "" # 填入对应 provider 的 API key
|
||||
text_litellm_base_url = "" # 可选:自定义 API base URL
|
||||
text_openai_model_name = "deepseek/deepseek-chat"
|
||||
text_openai_api_key = "" # 填入对应 provider 的 API key
|
||||
text_openai_base_url = "" # 可选:自定义 API base URL(官方 OpenAI 可留空)
|
||||
|
||||
# ===== API Keys 参考 =====
|
||||
# 主流 LLM Providers API Key 获取地址:
|
||||
@ -69,21 +61,7 @@
|
||||
# WebUI 界面是否显示配置项
|
||||
hide_config = true
|
||||
|
||||
##########################################
|
||||
# 📚 传统配置示例(仅供参考,不推荐使用)
|
||||
##########################################
|
||||
# 如果需要使用传统的单独 provider 实现,可以参考以下配置
|
||||
# 但强烈推荐使用上面的 LiteLLM 配置
|
||||
#
|
||||
# 传统视觉模型配置示例:
|
||||
# vision_llm_provider = "gemini" # 可选:gemini, qwenvl, siliconflow
|
||||
# vision_gemini_api_key = ""
|
||||
# vision_gemini_model_name = "gemini-2.0-flash-lite"
|
||||
#
|
||||
# 传统文本模型配置示例:
|
||||
# text_llm_provider = "openai" # 可选:openai, gemini, qwen, deepseek, siliconflow, moonshot
|
||||
# text_openai_api_key = ""
|
||||
# text_openai_model_name = "gpt-4o-mini"
|
||||
# 官方 OpenAI 默认端点(可选):
|
||||
# text_openai_base_url = "https://api.openai.com/v1"
|
||||
|
||||
##########################################
|
||||
|
||||
@ -12,8 +12,7 @@ pysrt==1.1.2
|
||||
|
||||
# AI 服务依赖
|
||||
openai>=1.77.0
|
||||
litellm>=1.70.0 # 统一的 LLM 接口,支持 100+ providers
|
||||
google-generativeai>=0.8.5 # LiteLLM 会使用此库调用 Gemini
|
||||
google-generativeai>=0.8.5 # 原生 Gemini 场景依赖
|
||||
azure-cognitiveservices-speech>=1.37.0
|
||||
tencentcloud-sdk-python>=3.0.1200
|
||||
dashscope>=1.24.6
|
||||
|
||||
@ -72,8 +72,8 @@ def validate_model_name(model_name: str, provider: str) -> tuple[bool, str]:
|
||||
return True, ""
|
||||
|
||||
|
||||
def validate_litellm_model_name(model_name: str, model_type: str) -> tuple[bool, str]:
|
||||
"""验证 LiteLLM 模型名称格式
|
||||
def validate_openai_compatible_model_name(model_name: str, model_type: str) -> tuple[bool, str]:
|
||||
"""验证 OpenAI 兼容 模型名称格式
|
||||
|
||||
Args:
|
||||
model_name: 模型名称,应为 provider/model 格式
|
||||
@ -87,8 +87,8 @@ def validate_litellm_model_name(model_name: str, model_type: str) -> tuple[bool,
|
||||
|
||||
model_name = model_name.strip()
|
||||
|
||||
# LiteLLM 推荐格式:provider/model(如 gemini/gemini-2.0-flash-lite)
|
||||
# 但也支持直接的模型名称(如 gpt-4o,LiteLLM 会自动推断 provider)
|
||||
# OpenAI 兼容 推荐格式:provider/model(如 gemini/gemini-2.0-flash-lite)
|
||||
# 但也支持直接的模型名称(如 gpt-4o,OpenAI 兼容 会自动推断 provider)
|
||||
|
||||
# 检查是否包含 provider 前缀(推荐格式)
|
||||
if "/" in model_name:
|
||||
@ -101,9 +101,9 @@ def validate_litellm_model_name(model_name: str, model_type: str) -> tuple[bool,
|
||||
if not provider.replace("-", "").replace("_", "").isalnum():
|
||||
return False, f"{model_type} Provider 名称只能包含字母、数字、下划线和连字符"
|
||||
else:
|
||||
# 直接模型名称也是有效的(LiteLLM 会自动推断)
|
||||
# 直接模型名称也是有效的(OpenAI 兼容 会自动推断)
|
||||
# 但给出警告建议使用完整格式
|
||||
logger.debug(f"{model_type} 模型名称未包含 provider 前缀,LiteLLM 将自动推断")
|
||||
logger.debug(f"{model_type} 模型名称未包含 provider 前缀,OpenAI 兼容 将自动推断")
|
||||
|
||||
# 基本长度检查
|
||||
if len(model_name) < 3:
|
||||
@ -115,6 +115,13 @@ def validate_litellm_model_name(model_name: str, model_type: str) -> tuple[bool,
|
||||
return True, ""
|
||||
|
||||
|
||||
def normalize_openai_compatible_model_name(model_name: str) -> str:
|
||||
"""将 provider/model 格式转换为网关实际使用的模型名。"""
|
||||
if "/" not in model_name:
|
||||
return model_name
|
||||
return model_name.split("/", 1)[1]
|
||||
|
||||
|
||||
def show_config_validation_errors(errors: list):
|
||||
"""显示配置验证错误"""
|
||||
if errors:
|
||||
@ -312,243 +319,112 @@ def test_vision_model_connection(api_key, base_url, model_name, provider, tr):
|
||||
|
||||
|
||||
|
||||
def test_litellm_vision_model(api_key: str, base_url: str, model_name: str, tr) -> tuple[bool, str]:
|
||||
"""测试 LiteLLM 视觉模型连接
|
||||
|
||||
Args:
|
||||
api_key: API 密钥
|
||||
base_url: 基础 URL(可选)
|
||||
model_name: 模型名称(LiteLLM 格式:provider/model)
|
||||
tr: 翻译函数
|
||||
|
||||
Returns:
|
||||
(连接是否成功, 测试结果消息)
|
||||
"""
|
||||
def test_openai_compatible_vision_model(api_key: str, base_url: str, model_name: str, tr) -> tuple[bool, str]:
|
||||
"""测试 OpenAI 兼容视觉模型连接。"""
|
||||
try:
|
||||
import litellm
|
||||
import os
|
||||
import base64
|
||||
import io
|
||||
from openai import OpenAI
|
||||
from PIL import Image
|
||||
|
||||
logger.debug(f"LiteLLM 视觉模型连通性测试: model={model_name}, api_key={api_key[:10]}..., base_url={base_url}")
|
||||
|
||||
# 提取 provider 名称
|
||||
provider = model_name.split("/")[0] if "/" in model_name else "unknown"
|
||||
|
||||
# 设置 API key 到环境变量
|
||||
env_key_mapping = {
|
||||
"gemini": "GEMINI_API_KEY",
|
||||
"google": "GEMINI_API_KEY",
|
||||
"openai": "OPENAI_API_KEY",
|
||||
"qwen": "QWEN_API_KEY",
|
||||
"dashscope": "DASHSCOPE_API_KEY",
|
||||
"siliconflow": "SILICONFLOW_API_KEY",
|
||||
}
|
||||
env_var = env_key_mapping.get(provider.lower(), f"{provider.upper()}_API_KEY")
|
||||
old_key = os.environ.get(env_var)
|
||||
os.environ[env_var] = api_key
|
||||
|
||||
# SiliconFlow 特殊处理:使用 OpenAI 兼容模式
|
||||
test_model_name = model_name
|
||||
if provider.lower() == "siliconflow":
|
||||
# 替换 provider 为 openai
|
||||
if "/" in model_name:
|
||||
test_model_name = f"openai/{model_name.split('/', 1)[1]}"
|
||||
else:
|
||||
test_model_name = f"openai/{model_name}"
|
||||
|
||||
# 确保设置了 base_url
|
||||
if not base_url:
|
||||
base_url = "https://api.siliconflow.cn/v1"
|
||||
|
||||
# 设置 OPENAI_API_KEY (SiliconFlow 使用 OpenAI 协议)
|
||||
os.environ["OPENAI_API_KEY"] = api_key
|
||||
os.environ["OPENAI_API_BASE"] = base_url
|
||||
|
||||
try:
|
||||
# 创建测试图片(64x64 白色像素,避免某些模型对极小图片的限制)
|
||||
test_image = Image.new('RGB', (64, 64), color='white')
|
||||
img_buffer = io.BytesIO()
|
||||
test_image.save(img_buffer, format='JPEG')
|
||||
img_bytes = img_buffer.getvalue()
|
||||
base64_image = base64.b64encode(img_bytes).decode('utf-8')
|
||||
|
||||
# 构建测试请求
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "请直接回复'连接成功'"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}]
|
||||
|
||||
# 准备参数
|
||||
completion_kwargs = {
|
||||
"model": test_model_name,
|
||||
"messages": messages,
|
||||
"temperature": 0.1,
|
||||
"max_tokens": 50
|
||||
}
|
||||
|
||||
if base_url:
|
||||
completion_kwargs["api_base"] = base_url
|
||||
|
||||
# 调用 LiteLLM(同步调用用于测试)
|
||||
response = litellm.completion(**completion_kwargs)
|
||||
|
||||
if response and response.choices and len(response.choices) > 0:
|
||||
return True, f"LiteLLM 视觉模型连接成功 ({model_name})"
|
||||
else:
|
||||
return False, f"LiteLLM 视觉模型返回空响应"
|
||||
|
||||
finally:
|
||||
# 恢复原始环境变量
|
||||
if old_key:
|
||||
os.environ[env_var] = old_key
|
||||
else:
|
||||
os.environ.pop(env_var, None)
|
||||
|
||||
# 清理临时设置的 OpenAI 环境变量
|
||||
if provider.lower() == "siliconflow":
|
||||
os.environ.pop("OPENAI_API_KEY", None)
|
||||
os.environ.pop("OPENAI_API_BASE", None)
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"OpenAI 兼容视觉模型连通性测试: model={model_name}, api_key={api_key[:10]}..., base_url={base_url}"
|
||||
)
|
||||
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url or None,
|
||||
timeout=10.0,
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
# 创建测试图片(64x64 白色像素)
|
||||
test_image = Image.new("RGB", (64, 64), color="white")
|
||||
img_buffer = io.BytesIO()
|
||||
test_image.save(img_buffer, format="JPEG")
|
||||
base64_image = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=normalize_openai_compatible_model_name(model_name),
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "请直接回复'连接成功'"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
temperature=0.1,
|
||||
max_tokens=50,
|
||||
)
|
||||
|
||||
if response and response.choices and len(response.choices) > 0:
|
||||
return True, f"OpenAI 兼容视觉模型连接成功 ({model_name})"
|
||||
return False, "OpenAI 兼容视觉模型返回空响应"
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"LiteLLM 视觉模型测试失败: {error_msg}")
|
||||
|
||||
# 提供更友好的错误信息
|
||||
logger.error(f"OpenAI 兼容视觉模型测试失败: {error_msg}")
|
||||
if "authentication" in error_msg.lower() or "api_key" in error_msg.lower():
|
||||
return False, f"认证失败,请检查 API Key 是否正确"
|
||||
elif "not found" in error_msg.lower() or "404" in error_msg:
|
||||
return False, f"模型不存在,请检查模型名称是否正确"
|
||||
elif "rate limit" in error_msg.lower():
|
||||
return False, f"超出速率限制,请稍后重试"
|
||||
else:
|
||||
return False, f"连接失败: {error_msg}"
|
||||
return False, "认证失败,请检查 API Key 是否正确"
|
||||
if "not found" in error_msg.lower() or "404" in error_msg:
|
||||
return False, "模型不存在,请检查模型名称是否正确"
|
||||
if "rate limit" in error_msg.lower():
|
||||
return False, "超出速率限制,请稍后重试"
|
||||
return False, f"连接失败: {error_msg}"
|
||||
|
||||
|
||||
def test_litellm_text_model(api_key: str, base_url: str, model_name: str, tr) -> tuple[bool, str]:
|
||||
"""测试 LiteLLM 文本模型连接
|
||||
|
||||
Args:
|
||||
api_key: API 密钥
|
||||
base_url: 基础 URL(可选)
|
||||
model_name: 模型名称(LiteLLM 格式:provider/model)
|
||||
tr: 翻译函数
|
||||
|
||||
Returns:
|
||||
(连接是否成功, 测试结果消息)
|
||||
"""
|
||||
def test_openai_compatible_text_model(api_key: str, base_url: str, model_name: str, tr) -> tuple[bool, str]:
|
||||
"""测试 OpenAI 兼容文本模型连接。"""
|
||||
try:
|
||||
import litellm
|
||||
import os
|
||||
|
||||
logger.debug(f"LiteLLM 文本模型连通性测试: model={model_name}, api_key={api_key[:10]}..., base_url={base_url}")
|
||||
|
||||
# 提取 provider 名称
|
||||
provider = model_name.split("/")[0] if "/" in model_name else "unknown"
|
||||
|
||||
# 设置 API key 到环境变量
|
||||
env_key_mapping = {
|
||||
"gemini": "GEMINI_API_KEY",
|
||||
"google": "GEMINI_API_KEY",
|
||||
"openai": "OPENAI_API_KEY",
|
||||
"qwen": "QWEN_API_KEY",
|
||||
"dashscope": "DASHSCOPE_API_KEY",
|
||||
"siliconflow": "SILICONFLOW_API_KEY",
|
||||
"deepseek": "DEEPSEEK_API_KEY",
|
||||
"moonshot": "MOONSHOT_API_KEY",
|
||||
}
|
||||
env_var = env_key_mapping.get(provider.lower(), f"{provider.upper()}_API_KEY")
|
||||
old_key = os.environ.get(env_var)
|
||||
os.environ[env_var] = api_key
|
||||
|
||||
# SiliconFlow 特殊处理:使用 OpenAI 兼容模式
|
||||
test_model_name = model_name
|
||||
if provider.lower() == "siliconflow":
|
||||
# 替换 provider 为 openai
|
||||
if "/" in model_name:
|
||||
test_model_name = f"openai/{model_name.split('/', 1)[1]}"
|
||||
else:
|
||||
test_model_name = f"openai/{model_name}"
|
||||
|
||||
# 确保设置了 base_url
|
||||
if not base_url:
|
||||
base_url = "https://api.siliconflow.cn/v1"
|
||||
|
||||
# 设置 OPENAI_API_KEY (SiliconFlow 使用 OpenAI 协议)
|
||||
os.environ["OPENAI_API_KEY"] = api_key
|
||||
os.environ["OPENAI_API_BASE"] = base_url
|
||||
|
||||
try:
|
||||
# 构建测试请求
|
||||
messages = [
|
||||
{"role": "user", "content": "请直接回复'连接成功'"}
|
||||
]
|
||||
|
||||
# 准备参数
|
||||
completion_kwargs = {
|
||||
"model": test_model_name,
|
||||
"messages": messages,
|
||||
"temperature": 0.1,
|
||||
"max_tokens": 20
|
||||
}
|
||||
|
||||
if base_url:
|
||||
completion_kwargs["api_base"] = base_url
|
||||
|
||||
# 调用 LiteLLM(同步调用用于测试)
|
||||
response = litellm.completion(**completion_kwargs)
|
||||
|
||||
if response and response.choices and len(response.choices) > 0:
|
||||
return True, f"LiteLLM 文本模型连接成功 ({model_name})"
|
||||
else:
|
||||
return False, f"LiteLLM 文本模型返回空响应"
|
||||
|
||||
finally:
|
||||
# 恢复原始环境变量
|
||||
if old_key:
|
||||
os.environ[env_var] = old_key
|
||||
else:
|
||||
os.environ.pop(env_var, None)
|
||||
|
||||
# 清理临时设置的 OpenAI 环境变量
|
||||
if provider.lower() == "siliconflow":
|
||||
os.environ.pop("OPENAI_API_KEY", None)
|
||||
os.environ.pop("OPENAI_API_BASE", None)
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
logger.debug(
|
||||
f"OpenAI 兼容文本模型连通性测试: model={model_name}, api_key={api_key[:10]}..., base_url={base_url}"
|
||||
)
|
||||
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url or None,
|
||||
timeout=10.0,
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=normalize_openai_compatible_model_name(model_name),
|
||||
messages=[{"role": "user", "content": "请直接回复'连接成功'"}],
|
||||
temperature=0.1,
|
||||
max_tokens=20,
|
||||
)
|
||||
|
||||
if response and response.choices and len(response.choices) > 0:
|
||||
return True, f"OpenAI 兼容文本模型连接成功 ({model_name})"
|
||||
return False, "OpenAI 兼容文本模型返回空响应"
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"LiteLLM 文本模型测试失败: {error_msg}")
|
||||
|
||||
# 提供更友好的错误信息
|
||||
logger.error(f"OpenAI 兼容文本模型测试失败: {error_msg}")
|
||||
if "authentication" in error_msg.lower() or "api_key" in error_msg.lower():
|
||||
return False, f"认证失败,请检查 API Key 是否正确"
|
||||
elif "not found" in error_msg.lower() or "404" in error_msg:
|
||||
return False, f"模型不存在,请检查模型名称是否正确"
|
||||
elif "rate limit" in error_msg.lower():
|
||||
return False, f"超出速率限制,请稍后重试"
|
||||
else:
|
||||
return False, f"连接失败: {error_msg}"
|
||||
return False, "认证失败,请检查 API Key 是否正确"
|
||||
if "not found" in error_msg.lower() or "404" in error_msg:
|
||||
return False, "模型不存在,请检查模型名称是否正确"
|
||||
if "rate limit" in error_msg.lower():
|
||||
return False, "超出速率限制,请稍后重试"
|
||||
return False, f"连接失败: {error_msg}"
|
||||
|
||||
def render_vision_llm_settings(tr):
|
||||
"""渲染视频分析模型设置(LiteLLM 统一配置)"""
|
||||
"""渲染视频分析模型设置(OpenAI 兼容 统一配置)"""
|
||||
st.subheader(tr("Vision Model Settings"))
|
||||
|
||||
# 固定使用 LiteLLM 提供商
|
||||
config.app["vision_llm_provider"] = "litellm"
|
||||
# 固定使用 OpenAI 兼容 提供商
|
||||
config.app["vision_llm_provider"] = "openai"
|
||||
|
||||
# 获取已保存的 LiteLLM 配置
|
||||
full_vision_model_name = config.app.get("vision_litellm_model_name", "gemini/gemini-2.0-flash-lite")
|
||||
vision_api_key = config.app.get("vision_litellm_api_key", "")
|
||||
vision_base_url = config.app.get("vision_litellm_base_url", "")
|
||||
# 获取已保存的配置
|
||||
full_vision_model_name = config.app.get("vision_openai_model_name") or "gemini/gemini-2.0-flash-lite"
|
||||
vision_api_key = config.app.get("vision_openai_api_key", "")
|
||||
vision_base_url = config.app.get("vision_openai_base_url", "")
|
||||
|
||||
# 解析 provider 和 model
|
||||
default_provider = "gemini"
|
||||
@ -563,7 +439,7 @@ def render_vision_llm_settings(tr):
|
||||
current_model = full_vision_model_name
|
||||
|
||||
# 定义支持的 provider 列表
|
||||
LITELLM_PROVIDERS = [
|
||||
OPENAI_COMPATIBLE_PROVIDERS = [
|
||||
"openai", "gemini", "deepseek", "qwen", "siliconflow", "moonshot",
|
||||
"anthropic", "azure", "ollama", "vertex_ai", "mistral", "codestral",
|
||||
"volcengine", "groq", "cohere", "together_ai", "fireworks_ai",
|
||||
@ -572,16 +448,16 @@ def render_vision_llm_settings(tr):
|
||||
]
|
||||
|
||||
# 如果当前 provider 不在列表中,添加到列表头部
|
||||
if current_provider not in LITELLM_PROVIDERS:
|
||||
LITELLM_PROVIDERS.insert(0, current_provider)
|
||||
if current_provider not in OPENAI_COMPATIBLE_PROVIDERS:
|
||||
OPENAI_COMPATIBLE_PROVIDERS.insert(0, current_provider)
|
||||
|
||||
# 渲染配置输入框
|
||||
col1, col2 = st.columns([1, 2])
|
||||
with col1:
|
||||
selected_provider = st.selectbox(
|
||||
tr("Vision Model Provider"),
|
||||
options=LITELLM_PROVIDERS,
|
||||
index=LITELLM_PROVIDERS.index(current_provider) if current_provider in LITELLM_PROVIDERS else 0,
|
||||
options=OPENAI_COMPATIBLE_PROVIDERS,
|
||||
index=OPENAI_COMPATIBLE_PROVIDERS.index(current_provider) if current_provider in OPENAI_COMPATIBLE_PROVIDERS else 0,
|
||||
key="vision_provider_select"
|
||||
)
|
||||
|
||||
@ -595,7 +471,7 @@ def render_vision_llm_settings(tr):
|
||||
"• gpt-4o\n"
|
||||
"• qwen-vl-max\n"
|
||||
"• Qwen/Qwen2.5-VL-32B-Instruct (SiliconFlow)\n\n"
|
||||
"支持 100+ providers,详见: https://docs.litellm.ai/docs/providers",
|
||||
"支持常见 OpenAI 兼容网关(如 OpenAI/DeepSeek/OpenRouter/SiliconFlow)",
|
||||
key="vision_model_input"
|
||||
)
|
||||
|
||||
@ -641,7 +517,7 @@ def render_vision_llm_settings(tr):
|
||||
else:
|
||||
with st.spinner(tr("Testing connection...")):
|
||||
try:
|
||||
success, message = test_litellm_vision_model(
|
||||
success, message = test_openai_compatible_vision_model(
|
||||
api_key=st_vision_api_key,
|
||||
base_url=st_vision_base_url,
|
||||
model_name=st_vision_model_name,
|
||||
@ -654,7 +530,7 @@ def render_vision_llm_settings(tr):
|
||||
st.error(message)
|
||||
except Exception as e:
|
||||
st.error(f"测试连接时发生错误: {str(e)}")
|
||||
logger.error(f"LiteLLM 视频分析模型连接测试失败: {str(e)}")
|
||||
logger.error(f"OpenAI 兼容 视频分析模型连接测试失败: {str(e)}")
|
||||
|
||||
# 验证和保存配置
|
||||
validation_errors = []
|
||||
@ -663,10 +539,10 @@ def render_vision_llm_settings(tr):
|
||||
# 验证模型名称
|
||||
if st_vision_model_name:
|
||||
# 这里的验证逻辑可能需要微调,因为我们现在是自动组合的
|
||||
is_valid, error_msg = validate_litellm_model_name(st_vision_model_name, "视频分析")
|
||||
is_valid, error_msg = validate_openai_compatible_model_name(st_vision_model_name, "视频分析")
|
||||
if is_valid:
|
||||
config.app["vision_litellm_model_name"] = st_vision_model_name
|
||||
st.session_state["vision_litellm_model_name"] = st_vision_model_name
|
||||
config.app["vision_openai_model_name"] = st_vision_model_name
|
||||
st.session_state["vision_openai_model_name"] = st_vision_model_name
|
||||
config_changed = True
|
||||
else:
|
||||
validation_errors.append(error_msg)
|
||||
@ -675,8 +551,8 @@ def render_vision_llm_settings(tr):
|
||||
if st_vision_api_key:
|
||||
is_valid, error_msg = validate_api_key(st_vision_api_key, "视频分析")
|
||||
if is_valid:
|
||||
config.app["vision_litellm_api_key"] = st_vision_api_key
|
||||
st.session_state["vision_litellm_api_key"] = st_vision_api_key
|
||||
config.app["vision_openai_api_key"] = st_vision_api_key
|
||||
st.session_state["vision_openai_api_key"] = st_vision_api_key
|
||||
config_changed = True
|
||||
else:
|
||||
validation_errors.append(error_msg)
|
||||
@ -685,8 +561,8 @@ def render_vision_llm_settings(tr):
|
||||
if st_vision_base_url:
|
||||
is_valid, error_msg = validate_base_url(st_vision_base_url, "视频分析")
|
||||
if is_valid:
|
||||
config.app["vision_litellm_base_url"] = st_vision_base_url
|
||||
st.session_state["vision_litellm_base_url"] = st_vision_base_url
|
||||
config.app["vision_openai_base_url"] = st_vision_base_url
|
||||
st.session_state["vision_openai_base_url"] = st_vision_base_url
|
||||
config_changed = True
|
||||
else:
|
||||
validation_errors.append(error_msg)
|
||||
@ -701,7 +577,7 @@ def render_vision_llm_settings(tr):
|
||||
# 清除缓存,确保下次使用新配置
|
||||
UnifiedLLMService.clear_cache()
|
||||
if st_vision_api_key or st_vision_base_url or st_vision_model_name:
|
||||
st.success(f"视频分析模型配置已保存(LiteLLM)")
|
||||
st.success(f"视频分析模型配置已保存(OpenAI 兼容)")
|
||||
except Exception as e:
|
||||
st.error(f"保存配置失败: {str(e)}")
|
||||
logger.error(f"保存视频分析配置失败: {str(e)}")
|
||||
@ -811,16 +687,16 @@ def test_text_model_connection(api_key, base_url, model_name, provider, tr):
|
||||
|
||||
|
||||
def render_text_llm_settings(tr):
|
||||
"""渲染文案生成模型设置(LiteLLM 统一配置)"""
|
||||
"""渲染文案生成模型设置(OpenAI 兼容 统一配置)"""
|
||||
st.subheader(tr("Text Generation Model Settings"))
|
||||
|
||||
# 固定使用 LiteLLM 提供商
|
||||
config.app["text_llm_provider"] = "litellm"
|
||||
# 固定使用 OpenAI 兼容 提供商
|
||||
config.app["text_llm_provider"] = "openai"
|
||||
|
||||
# 获取已保存的 LiteLLM 配置
|
||||
full_text_model_name = config.app.get("text_litellm_model_name", "deepseek/deepseek-chat")
|
||||
text_api_key = config.app.get("text_litellm_api_key", "")
|
||||
text_base_url = config.app.get("text_litellm_base_url", "")
|
||||
# 获取已保存的配置
|
||||
full_text_model_name = config.app.get("text_openai_model_name") or "deepseek/deepseek-chat"
|
||||
text_api_key = config.app.get("text_openai_api_key", "")
|
||||
text_base_url = config.app.get("text_openai_base_url", "")
|
||||
|
||||
# 解析 provider 和 model
|
||||
default_provider = "deepseek"
|
||||
@ -835,7 +711,7 @@ def render_text_llm_settings(tr):
|
||||
current_model = full_text_model_name
|
||||
|
||||
# 定义支持的 provider 列表
|
||||
LITELLM_PROVIDERS = [
|
||||
OPENAI_COMPATIBLE_PROVIDERS = [
|
||||
"openai", "gemini", "deepseek", "qwen", "siliconflow", "moonshot",
|
||||
"anthropic", "azure", "ollama", "vertex_ai", "mistral", "codestral",
|
||||
"volcengine", "groq", "cohere", "together_ai", "fireworks_ai",
|
||||
@ -844,16 +720,16 @@ def render_text_llm_settings(tr):
|
||||
]
|
||||
|
||||
# 如果当前 provider 不在列表中,添加到列表头部
|
||||
if current_provider not in LITELLM_PROVIDERS:
|
||||
LITELLM_PROVIDERS.insert(0, current_provider)
|
||||
if current_provider not in OPENAI_COMPATIBLE_PROVIDERS:
|
||||
OPENAI_COMPATIBLE_PROVIDERS.insert(0, current_provider)
|
||||
|
||||
# 渲染配置输入框
|
||||
col1, col2 = st.columns([1, 2])
|
||||
with col1:
|
||||
selected_provider = st.selectbox(
|
||||
tr("Text Model Provider"),
|
||||
options=LITELLM_PROVIDERS,
|
||||
index=LITELLM_PROVIDERS.index(current_provider) if current_provider in LITELLM_PROVIDERS else 0,
|
||||
options=OPENAI_COMPATIBLE_PROVIDERS,
|
||||
index=OPENAI_COMPATIBLE_PROVIDERS.index(current_provider) if current_provider in OPENAI_COMPATIBLE_PROVIDERS else 0,
|
||||
key="text_provider_select"
|
||||
)
|
||||
|
||||
@ -867,7 +743,7 @@ def render_text_llm_settings(tr):
|
||||
"• gpt-4o\n"
|
||||
"• gemini-2.0-flash\n"
|
||||
"• deepseek-ai/DeepSeek-R1 (SiliconFlow)\n\n"
|
||||
"支持 100+ providers,详见: https://docs.litellm.ai/docs/providers",
|
||||
"支持常见 OpenAI 兼容网关(如 OpenAI/DeepSeek/OpenRouter/SiliconFlow)",
|
||||
key="text_model_input"
|
||||
)
|
||||
|
||||
@ -915,7 +791,7 @@ def render_text_llm_settings(tr):
|
||||
else:
|
||||
with st.spinner(tr("Testing connection...")):
|
||||
try:
|
||||
success, message = test_litellm_text_model(
|
||||
success, message = test_openai_compatible_text_model(
|
||||
api_key=st_text_api_key,
|
||||
base_url=st_text_base_url,
|
||||
model_name=st_text_model_name,
|
||||
@ -928,7 +804,7 @@ def render_text_llm_settings(tr):
|
||||
st.error(message)
|
||||
except Exception as e:
|
||||
st.error(f"测试连接时发生错误: {str(e)}")
|
||||
logger.error(f"LiteLLM 文案生成模型连接测试失败: {str(e)}")
|
||||
logger.error(f"OpenAI 兼容 文案生成模型连接测试失败: {str(e)}")
|
||||
|
||||
# 验证和保存配置
|
||||
text_validation_errors = []
|
||||
@ -936,10 +812,10 @@ def render_text_llm_settings(tr):
|
||||
|
||||
# 验证模型名称
|
||||
if st_text_model_name:
|
||||
is_valid, error_msg = validate_litellm_model_name(st_text_model_name, "文案生成")
|
||||
is_valid, error_msg = validate_openai_compatible_model_name(st_text_model_name, "文案生成")
|
||||
if is_valid:
|
||||
config.app["text_litellm_model_name"] = st_text_model_name
|
||||
st.session_state["text_litellm_model_name"] = st_text_model_name
|
||||
config.app["text_openai_model_name"] = st_text_model_name
|
||||
st.session_state["text_openai_model_name"] = st_text_model_name
|
||||
text_config_changed = True
|
||||
else:
|
||||
text_validation_errors.append(error_msg)
|
||||
@ -948,8 +824,8 @@ def render_text_llm_settings(tr):
|
||||
if st_text_api_key:
|
||||
is_valid, error_msg = validate_api_key(st_text_api_key, "文案生成")
|
||||
if is_valid:
|
||||
config.app["text_litellm_api_key"] = st_text_api_key
|
||||
st.session_state["text_litellm_api_key"] = st_text_api_key
|
||||
config.app["text_openai_api_key"] = st_text_api_key
|
||||
st.session_state["text_openai_api_key"] = st_text_api_key
|
||||
text_config_changed = True
|
||||
else:
|
||||
text_validation_errors.append(error_msg)
|
||||
@ -958,8 +834,8 @@ def render_text_llm_settings(tr):
|
||||
if st_text_base_url:
|
||||
is_valid, error_msg = validate_base_url(st_text_base_url, "文案生成")
|
||||
if is_valid:
|
||||
config.app["text_litellm_base_url"] = st_text_base_url
|
||||
st.session_state["text_litellm_base_url"] = st_text_base_url
|
||||
config.app["text_openai_base_url"] = st_text_base_url
|
||||
st.session_state["text_openai_base_url"] = st_text_base_url
|
||||
text_config_changed = True
|
||||
else:
|
||||
text_validation_errors.append(error_msg)
|
||||
@ -974,7 +850,7 @@ def render_text_llm_settings(tr):
|
||||
# 清除缓存,确保下次使用新配置
|
||||
UnifiedLLMService.clear_cache()
|
||||
if st_text_api_key or st_text_base_url or st_text_model_name:
|
||||
st.success(f"文案生成模型配置已保存(LiteLLM)")
|
||||
st.success(f"文案生成模型配置已保存(OpenAI 兼容)")
|
||||
except Exception as e:
|
||||
st.error(f"保存配置失败: {str(e)}")
|
||||
logger.error(f"保存文案生成配置失败: {str(e)}")
|
||||
|
||||
@ -123,7 +123,7 @@ def generate_script_docu(params):
|
||||
# 最佳实践:使用 get() 的默认值参数 + 从 config 获取备用值
|
||||
vision_llm_provider = (
|
||||
st.session_state.get('vision_llm_provider') or
|
||||
config.app.get('vision_llm_provider', 'litellm')
|
||||
config.app.get('vision_llm_provider', 'openai')
|
||||
).lower()
|
||||
|
||||
logger.info(f"使用 {vision_llm_provider.upper()} 进行视觉分析")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user