mirror of
https://github.com/linyqh/NarratoAI.git
synced 2025-12-10 18:02:51 +00:00
441 lines
16 KiB
Python
441 lines
16 KiB
Python
"""
|
||
LiteLLM 统一提供商实现
|
||
|
||
使用 LiteLLM 库提供统一的 LLM 接口,支持 100+ providers
|
||
包括 OpenAI, Anthropic, Gemini, Qwen, DeepSeek, SiliconFlow 等
|
||
"""
|
||
|
||
import asyncio
|
||
import base64
|
||
import io
|
||
from typing import List, Dict, Any, Optional, Union
|
||
from pathlib import Path
|
||
import PIL.Image
|
||
from loguru import logger
|
||
|
||
try:
|
||
import litellm
|
||
from litellm import acompletion, completion
|
||
from litellm.exceptions import (
|
||
AuthenticationError as LiteLLMAuthError,
|
||
RateLimitError as LiteLLMRateLimitError,
|
||
BadRequestError as LiteLLMBadRequestError,
|
||
APIError as LiteLLMAPIError
|
||
)
|
||
except ImportError:
|
||
logger.error("LiteLLM 未安装。请运行: pip install litellm")
|
||
raise
|
||
|
||
from .base import VisionModelProvider, TextModelProvider
|
||
from .exceptions import (
|
||
APICallError,
|
||
AuthenticationError,
|
||
RateLimitError,
|
||
ContentFilterError
|
||
)
|
||
|
||
|
||
# 配置 LiteLLM 全局设置
|
||
def configure_litellm():
|
||
"""配置 LiteLLM 全局参数"""
|
||
from app.config import config
|
||
|
||
# 设置重试次数
|
||
litellm.num_retries = config.app.get('llm_max_retries', 3)
|
||
|
||
# 设置默认超时
|
||
litellm.request_timeout = config.app.get('llm_text_timeout', 180)
|
||
|
||
# 启用详细日志(开发环境)
|
||
# litellm.set_verbose = True
|
||
|
||
logger.info(f"LiteLLM 配置完成: retries={litellm.num_retries}, timeout={litellm.request_timeout}s")
|
||
|
||
|
||
# 初始化配置
|
||
configure_litellm()
|
||
|
||
|
||
class LiteLLMVisionProvider(VisionModelProvider):
|
||
"""使用 LiteLLM 的统一视觉模型提供商"""
|
||
|
||
@property
|
||
def provider_name(self) -> str:
|
||
# 从 model_name 中提取 provider 名称(如 "gemini/gemini-2.0-flash")
|
||
if "/" in self.model_name:
|
||
return self.model_name.split("/")[0]
|
||
return "litellm"
|
||
|
||
@property
|
||
def supported_models(self) -> List[str]:
|
||
# LiteLLM 支持 100+ providers 和数百个模型,无法全部列举
|
||
# 返回空列表表示跳过预定义列表检查,由 LiteLLM 在实际调用时验证
|
||
return []
|
||
|
||
def _validate_model_support(self):
|
||
"""
|
||
重写模型验证逻辑
|
||
|
||
对于 LiteLLM,我们不做预定义列表检查,因为:
|
||
1. LiteLLM 支持 100+ providers 和数百个模型,无法全部列举
|
||
2. LiteLLM 会在实际调用时进行模型验证
|
||
3. 如果模型不支持,LiteLLM 会返回清晰的错误信息
|
||
|
||
这里只做基本的格式验证(可选)
|
||
"""
|
||
from loguru import logger
|
||
|
||
# 可选:检查模型名称格式(provider/model)
|
||
if "/" not in self.model_name:
|
||
logger.debug(
|
||
f"LiteLLM 模型名称 '{self.model_name}' 未包含 provider 前缀,"
|
||
f"LiteLLM 将尝试自动推断。建议使用 'provider/model' 格式,如 'gemini/gemini-2.5-flash'"
|
||
)
|
||
|
||
# 不抛出异常,让 LiteLLM 在实际调用时验证
|
||
logger.debug(f"LiteLLM 视觉模型已配置: {self.model_name}")
|
||
|
||
def _initialize(self):
|
||
"""初始化 LiteLLM 特定设置"""
|
||
# 设置 API key 到环境变量(LiteLLM 会自动读取)
|
||
import os
|
||
|
||
# 根据 model_name 确定需要设置哪个 API key
|
||
provider = self.provider_name.lower()
|
||
|
||
# 映射 provider 到环境变量名
|
||
env_key_mapping = {
|
||
"gemini": "GEMINI_API_KEY",
|
||
"google": "GEMINI_API_KEY",
|
||
"openai": "OPENAI_API_KEY",
|
||
"qwen": "QWEN_API_KEY",
|
||
"dashscope": "DASHSCOPE_API_KEY",
|
||
"siliconflow": "SILICONFLOW_API_KEY",
|
||
"anthropic": "ANTHROPIC_API_KEY",
|
||
"claude": "ANTHROPIC_API_KEY"
|
||
}
|
||
|
||
env_var = env_key_mapping.get(provider, f"{provider.upper()}_API_KEY")
|
||
|
||
if self.api_key and env_var:
|
||
os.environ[env_var] = self.api_key
|
||
logger.debug(f"设置环境变量: {env_var}")
|
||
|
||
# 如果提供了 base_url,设置到 LiteLLM
|
||
if self.base_url:
|
||
# LiteLLM 支持通过 api_base 参数设置自定义 URL
|
||
self._api_base = self.base_url
|
||
logger.debug(f"使用自定义 API base URL: {self.base_url}")
|
||
|
||
async def analyze_images(self,
|
||
images: List[Union[str, Path, PIL.Image.Image]],
|
||
prompt: str,
|
||
batch_size: int = 10,
|
||
**kwargs) -> List[str]:
|
||
"""
|
||
使用 LiteLLM 分析图片
|
||
|
||
Args:
|
||
images: 图片路径列表或PIL图片对象列表
|
||
prompt: 分析提示词
|
||
batch_size: 批处理大小
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
分析结果列表
|
||
"""
|
||
logger.info(f"开始使用 LiteLLM ({self.model_name}) 分析 {len(images)} 张图片")
|
||
|
||
# 预处理图片
|
||
processed_images = self._prepare_images(images)
|
||
|
||
# 分批处理
|
||
results = []
|
||
for i in range(0, len(processed_images), batch_size):
|
||
batch = processed_images[i:i + batch_size]
|
||
logger.info(f"处理第 {i//batch_size + 1} 批,共 {len(batch)} 张图片")
|
||
|
||
try:
|
||
result = await self._analyze_batch(batch, prompt, **kwargs)
|
||
results.append(result)
|
||
except Exception as e:
|
||
logger.error(f"批次 {i//batch_size + 1} 处理失败: {str(e)}")
|
||
results.append(f"批次处理失败: {str(e)}")
|
||
|
||
return results
|
||
|
||
async def _analyze_batch(self, batch: List[PIL.Image.Image], prompt: str, **kwargs) -> str:
|
||
"""分析一批图片"""
|
||
# 构建 LiteLLM 格式的消息
|
||
content = [{"type": "text", "text": prompt}]
|
||
|
||
# 添加图片(使用 base64 编码)
|
||
for img in batch:
|
||
base64_image = self._image_to_base64(img)
|
||
content.append({
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": f"data:image/jpeg;base64,{base64_image}"
|
||
}
|
||
})
|
||
|
||
messages = [{
|
||
"role": "user",
|
||
"content": content
|
||
}]
|
||
|
||
# 调用 LiteLLM
|
||
try:
|
||
# 准备参数
|
||
completion_kwargs = {
|
||
"model": self.model_name,
|
||
"messages": messages,
|
||
"temperature": kwargs.get("temperature", 1.0),
|
||
"max_tokens": kwargs.get("max_tokens", 4000)
|
||
}
|
||
|
||
# 如果有自定义 base_url,添加 api_base 参数
|
||
if hasattr(self, '_api_base'):
|
||
completion_kwargs["api_base"] = self._api_base
|
||
|
||
response = await acompletion(**completion_kwargs)
|
||
|
||
if response.choices and len(response.choices) > 0:
|
||
content = response.choices[0].message.content
|
||
logger.debug(f"LiteLLM 调用成功,消耗 tokens: {response.usage.total_tokens if response.usage else 'N/A'}")
|
||
return content
|
||
else:
|
||
raise APICallError("LiteLLM 返回空响应")
|
||
|
||
except LiteLLMAuthError as e:
|
||
logger.error(f"LiteLLM 认证失败: {str(e)}")
|
||
raise AuthenticationError()
|
||
except LiteLLMRateLimitError as e:
|
||
logger.error(f"LiteLLM 速率限制: {str(e)}")
|
||
raise RateLimitError()
|
||
except LiteLLMBadRequestError as e:
|
||
error_msg = str(e)
|
||
if "SAFETY" in error_msg.upper() or "content_filter" in error_msg.lower():
|
||
raise ContentFilterError(f"内容被安全过滤器阻止: {error_msg}")
|
||
logger.error(f"LiteLLM 请求错误: {error_msg}")
|
||
raise APICallError(f"请求错误: {error_msg}")
|
||
except LiteLLMAPIError as e:
|
||
logger.error(f"LiteLLM API 错误: {str(e)}")
|
||
raise APICallError(f"API 错误: {str(e)}")
|
||
except Exception as e:
|
||
logger.error(f"LiteLLM 调用失败: {str(e)}")
|
||
raise APICallError(f"调用失败: {str(e)}")
|
||
|
||
def _image_to_base64(self, img: PIL.Image.Image) -> str:
|
||
"""将PIL图片转换为base64编码"""
|
||
img_buffer = io.BytesIO()
|
||
img.save(img_buffer, format='JPEG', quality=85)
|
||
img_bytes = img_buffer.getvalue()
|
||
return base64.b64encode(img_bytes).decode('utf-8')
|
||
|
||
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""兼容基类接口(实际使用 LiteLLM SDK)"""
|
||
pass
|
||
|
||
|
||
class LiteLLMTextProvider(TextModelProvider):
|
||
"""使用 LiteLLM 的统一文本生成提供商"""
|
||
|
||
@property
|
||
def provider_name(self) -> str:
|
||
# 从 model_name 中提取 provider 名称
|
||
if "/" in self.model_name:
|
||
return self.model_name.split("/")[0]
|
||
# 尝试从模型名称推断 provider
|
||
model_lower = self.model_name.lower()
|
||
if "gpt" in model_lower:
|
||
return "openai"
|
||
elif "claude" in model_lower:
|
||
return "anthropic"
|
||
elif "gemini" in model_lower:
|
||
return "gemini"
|
||
elif "qwen" in model_lower:
|
||
return "qwen"
|
||
elif "deepseek" in model_lower:
|
||
return "deepseek"
|
||
return "litellm"
|
||
|
||
@property
|
||
def supported_models(self) -> List[str]:
|
||
# LiteLLM 支持 100+ providers 和数百个模型,无法全部列举
|
||
# 返回空列表表示跳过预定义列表检查,由 LiteLLM 在实际调用时验证
|
||
return []
|
||
|
||
def _validate_model_support(self):
|
||
"""
|
||
重写模型验证逻辑
|
||
|
||
对于 LiteLLM,我们不做预定义列表检查,因为:
|
||
1. LiteLLM 支持 100+ providers 和数百个模型,无法全部列举
|
||
2. LiteLLM 会在实际调用时进行模型验证
|
||
3. 如果模型不支持,LiteLLM 会返回清晰的错误信息
|
||
|
||
这里只做基本的格式验证(可选)
|
||
"""
|
||
from loguru import logger
|
||
|
||
# 可选:检查模型名称格式(provider/model)
|
||
if "/" not in self.model_name:
|
||
logger.debug(
|
||
f"LiteLLM 模型名称 '{self.model_name}' 未包含 provider 前缀,"
|
||
f"LiteLLM 将尝试自动推断。建议使用 'provider/model' 格式,如 'gemini/gemini-2.5-flash'"
|
||
)
|
||
|
||
# 不抛出异常,让 LiteLLM 在实际调用时验证
|
||
logger.debug(f"LiteLLM 文本模型已配置: {self.model_name}")
|
||
|
||
def _initialize(self):
|
||
"""初始化 LiteLLM 特定设置"""
|
||
import os
|
||
|
||
# 根据 model_name 确定需要设置哪个 API key
|
||
provider = self.provider_name.lower()
|
||
|
||
# 映射 provider 到环境变量名
|
||
env_key_mapping = {
|
||
"gemini": "GEMINI_API_KEY",
|
||
"google": "GEMINI_API_KEY",
|
||
"openai": "OPENAI_API_KEY",
|
||
"qwen": "QWEN_API_KEY",
|
||
"dashscope": "DASHSCOPE_API_KEY",
|
||
"siliconflow": "SILICONFLOW_API_KEY",
|
||
"deepseek": "DEEPSEEK_API_KEY",
|
||
"anthropic": "ANTHROPIC_API_KEY",
|
||
"claude": "ANTHROPIC_API_KEY",
|
||
"moonshot": "MOONSHOT_API_KEY"
|
||
}
|
||
|
||
env_var = env_key_mapping.get(provider, f"{provider.upper()}_API_KEY")
|
||
|
||
if self.api_key and env_var:
|
||
os.environ[env_var] = self.api_key
|
||
logger.debug(f"设置环境变量: {env_var}")
|
||
|
||
# 如果提供了 base_url,保存用于后续调用
|
||
if self.base_url:
|
||
self._api_base = self.base_url
|
||
logger.debug(f"使用自定义 API base URL: {self.base_url}")
|
||
|
||
async def generate_text(self,
|
||
prompt: str,
|
||
system_prompt: Optional[str] = None,
|
||
temperature: float = 1.0,
|
||
max_tokens: Optional[int] = None,
|
||
response_format: Optional[str] = None,
|
||
**kwargs) -> str:
|
||
"""
|
||
使用 LiteLLM 生成文本
|
||
|
||
Args:
|
||
prompt: 用户提示词
|
||
system_prompt: 系统提示词
|
||
temperature: 生成温度
|
||
max_tokens: 最大token数
|
||
response_format: 响应格式 ('json' 或 None)
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
生成的文本内容
|
||
"""
|
||
# 构建消息列表
|
||
messages = self._build_messages(prompt, system_prompt)
|
||
|
||
# 准备参数
|
||
completion_kwargs = {
|
||
"model": self.model_name,
|
||
"messages": messages,
|
||
"temperature": temperature
|
||
}
|
||
|
||
if max_tokens:
|
||
completion_kwargs["max_tokens"] = max_tokens
|
||
|
||
# 处理 JSON 格式输出
|
||
# LiteLLM 会自动处理不同 provider 的 JSON mode 差异
|
||
if response_format == "json":
|
||
try:
|
||
completion_kwargs["response_format"] = {"type": "json_object"}
|
||
except Exception as e:
|
||
# 如果不支持,在提示词中添加约束
|
||
logger.warning(f"模型可能不支持 response_format,将在提示词中添加 JSON 约束: {str(e)}")
|
||
messages[-1]["content"] += "\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。"
|
||
|
||
# 如果有自定义 base_url,添加 api_base 参数
|
||
if hasattr(self, '_api_base'):
|
||
completion_kwargs["api_base"] = self._api_base
|
||
|
||
try:
|
||
# 调用 LiteLLM(自动重试)
|
||
response = await acompletion(**completion_kwargs)
|
||
|
||
if response.choices and len(response.choices) > 0:
|
||
content = response.choices[0].message.content
|
||
|
||
# 清理可能的 markdown 代码块(针对不支持 JSON mode 的模型)
|
||
if response_format == "json" and "response_format" not in completion_kwargs:
|
||
content = self._clean_json_output(content)
|
||
|
||
logger.debug(f"LiteLLM 调用成功,消耗 tokens: {response.usage.total_tokens if response.usage else 'N/A'}")
|
||
return content
|
||
else:
|
||
raise APICallError("LiteLLM 返回空响应")
|
||
|
||
except LiteLLMAuthError as e:
|
||
logger.error(f"LiteLLM 认证失败: {str(e)}")
|
||
raise AuthenticationError()
|
||
except LiteLLMRateLimitError as e:
|
||
logger.error(f"LiteLLM 速率限制: {str(e)}")
|
||
raise RateLimitError()
|
||
except LiteLLMBadRequestError as e:
|
||
error_msg = str(e)
|
||
# 处理不支持 response_format 的情况
|
||
if "response_format" in error_msg and response_format == "json":
|
||
logger.warning(f"模型不支持 response_format,重试不带格式约束的请求")
|
||
completion_kwargs.pop("response_format", None)
|
||
messages[-1]["content"] += "\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。"
|
||
|
||
# 重试
|
||
response = await acompletion(**completion_kwargs)
|
||
if response.choices and len(response.choices) > 0:
|
||
content = response.choices[0].message.content
|
||
content = self._clean_json_output(content)
|
||
return content
|
||
else:
|
||
raise APICallError("LiteLLM 返回空响应")
|
||
|
||
# 检查是否是安全过滤
|
||
if "SAFETY" in error_msg.upper() or "content_filter" in error_msg.lower():
|
||
raise ContentFilterError(f"内容被安全过滤器阻止: {error_msg}")
|
||
|
||
logger.error(f"LiteLLM 请求错误: {error_msg}")
|
||
raise APICallError(f"请求错误: {error_msg}")
|
||
except LiteLLMAPIError as e:
|
||
logger.error(f"LiteLLM API 错误: {str(e)}")
|
||
raise APICallError(f"API 错误: {str(e)}")
|
||
except Exception as e:
|
||
logger.error(f"LiteLLM 调用失败: {str(e)}")
|
||
raise APICallError(f"调用失败: {str(e)}")
|
||
|
||
def _clean_json_output(self, output: str) -> str:
|
||
"""清理JSON输出,移除markdown标记等"""
|
||
import re
|
||
|
||
# 移除可能的markdown代码块标记
|
||
output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE)
|
||
output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE)
|
||
output = re.sub(r'^```.*$', '', output, flags=re.MULTILINE)
|
||
|
||
# 移除前后空白字符
|
||
output = output.strip()
|
||
|
||
return output
|
||
|
||
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""兼容基类接口(实际使用 LiteLLM SDK)"""
|
||
pass
|