mirror of
https://github.com/linyqh/NarratoAI.git
synced 2025-12-10 18:02:51 +00:00
feat(llm): 重构解说文案生成和视觉分析器,支持新的LLM服务架构
更新generate_narration_script.py、base.py和generate_short_summary.py文件,重构解说文案生成和视觉分析器的实现,优先使用新的LLM服务架构。添加回退机制以确保兼容性,增强系统的稳定性和用户体验。
This commit is contained in:
parent
dd59d5295d
commit
7309208282
@ -11,9 +11,14 @@
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
import asyncio
|
||||
from openai import OpenAI
|
||||
from loguru import logger
|
||||
|
||||
# 导入新的LLM服务模块 - 确保提供商被注册
|
||||
import app.services.llm # 这会触发提供商注册
|
||||
from app.services.llm.migration_adapter import generate_narration as generate_narration_new
|
||||
|
||||
|
||||
def parse_frame_analysis_to_markdown(json_file_path):
|
||||
"""
|
||||
@ -79,11 +84,34 @@ def parse_frame_analysis_to_markdown(json_file_path):
|
||||
|
||||
def generate_narration(markdown_content, api_key, base_url, model):
|
||||
"""
|
||||
调用OpenAI API根据视频帧分析的Markdown内容生成解说文案
|
||||
|
||||
调用大模型API根据视频帧分析的Markdown内容生成解说文案 - 已重构为使用新的LLM服务架构
|
||||
|
||||
:param markdown_content: Markdown格式的视频帧分析内容
|
||||
:param api_key: OpenAI API密钥
|
||||
:param base_url: API基础URL,如果使用非官方API
|
||||
:param api_key: API密钥
|
||||
:param base_url: API基础URL
|
||||
:param model: 使用的模型名称
|
||||
:return: 生成的解说文案
|
||||
"""
|
||||
try:
|
||||
# 优先使用新的LLM服务架构
|
||||
logger.info("使用新的LLM服务架构生成解说文案")
|
||||
result = generate_narration_new(markdown_content, api_key, base_url, model)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"使用新LLM服务失败,回退到旧实现: {str(e)}")
|
||||
|
||||
# 回退到旧的实现以确保兼容性
|
||||
return _generate_narration_legacy(markdown_content, api_key, base_url, model)
|
||||
|
||||
|
||||
def _generate_narration_legacy(markdown_content, api_key, base_url, model):
|
||||
"""
|
||||
旧的解说文案生成实现 - 保留作为备用方案
|
||||
|
||||
:param markdown_content: Markdown格式的视频帧分析内容
|
||||
:param api_key: API密钥
|
||||
:param base_url: API基础URL
|
||||
:param model: 使用的模型名称
|
||||
:return: 生成的解说文案
|
||||
"""
|
||||
|
||||
52
app/services/llm/__init__.py
Normal file
52
app/services/llm/__init__.py
Normal file
@ -0,0 +1,52 @@
|
||||
"""
|
||||
NarratoAI 大模型服务模块
|
||||
|
||||
统一的大模型服务抽象层,支持多供应商切换和严格的输出格式验证
|
||||
包含视觉模型和文本生成模型的统一接口
|
||||
|
||||
主要组件:
|
||||
- BaseLLMProvider: 大模型服务提供商基类
|
||||
- VisionModelProvider: 视觉模型提供商基类
|
||||
- TextModelProvider: 文本模型提供商基类
|
||||
- LLMServiceManager: 大模型服务管理器
|
||||
- OutputValidator: 输出格式验证器
|
||||
|
||||
支持的供应商:
|
||||
视觉模型: Gemini, QwenVL, Siliconflow
|
||||
文本模型: OpenAI, DeepSeek, Gemini, Qwen, Moonshot, Siliconflow
|
||||
"""
|
||||
|
||||
from .manager import LLMServiceManager
|
||||
from .base import BaseLLMProvider, VisionModelProvider, TextModelProvider
|
||||
from .validators import OutputValidator, ValidationError
|
||||
from .exceptions import LLMServiceError, ProviderNotFoundError, ConfigurationError
|
||||
|
||||
# 确保提供商在模块导入时被注册
|
||||
def _ensure_providers_registered():
|
||||
"""确保所有提供商都已注册"""
|
||||
try:
|
||||
# 导入providers模块会自动执行注册
|
||||
from . import providers
|
||||
from loguru import logger
|
||||
logger.debug("LLM服务提供商注册完成")
|
||||
except Exception as e:
|
||||
from loguru import logger
|
||||
logger.error(f"LLM服务提供商注册失败: {str(e)}")
|
||||
|
||||
# 自动注册提供商
|
||||
_ensure_providers_registered()
|
||||
|
||||
__all__ = [
|
||||
'LLMServiceManager',
|
||||
'BaseLLMProvider',
|
||||
'VisionModelProvider',
|
||||
'TextModelProvider',
|
||||
'OutputValidator',
|
||||
'ValidationError',
|
||||
'LLMServiceError',
|
||||
'ProviderNotFoundError',
|
||||
'ConfigurationError'
|
||||
]
|
||||
|
||||
# 版本信息
|
||||
__version__ = '1.0.0'
|
||||
175
app/services/llm/base.py
Normal file
175
app/services/llm/base.py
Normal file
@ -0,0 +1,175 @@
|
||||
"""
|
||||
大模型服务提供商基类定义
|
||||
|
||||
定义了统一的大模型服务接口,包括视觉模型和文本生成模型的抽象基类
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from pathlib import Path
|
||||
import PIL.Image
|
||||
from loguru import logger
|
||||
|
||||
from .exceptions import LLMServiceError, ConfigurationError
|
||||
|
||||
|
||||
class BaseLLMProvider(ABC):
|
||||
"""大模型服务提供商基类"""
|
||||
|
||||
def __init__(self,
|
||||
api_key: str,
|
||||
model_name: str,
|
||||
base_url: Optional[str] = None,
|
||||
**kwargs):
|
||||
"""
|
||||
初始化大模型服务提供商
|
||||
|
||||
Args:
|
||||
api_key: API密钥
|
||||
model_name: 模型名称
|
||||
base_url: API基础URL
|
||||
**kwargs: 其他配置参数
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.config = kwargs
|
||||
|
||||
# 验证必要配置
|
||||
self._validate_config()
|
||||
|
||||
# 初始化提供商特定设置
|
||||
self._initialize()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def provider_name(self) -> str:
|
||||
"""供应商名称"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_models(self) -> List[str]:
|
||||
"""支持的模型列表"""
|
||||
pass
|
||||
|
||||
def _validate_config(self):
|
||||
"""验证配置参数"""
|
||||
if not self.api_key:
|
||||
raise ConfigurationError("API密钥不能为空", "api_key")
|
||||
|
||||
if not self.model_name:
|
||||
raise ConfigurationError("模型名称不能为空", "model_name")
|
||||
|
||||
if self.model_name not in self.supported_models:
|
||||
from .exceptions import ModelNotSupportedError
|
||||
raise ModelNotSupportedError(self.model_name, self.provider_name)
|
||||
|
||||
def _initialize(self):
|
||||
"""初始化提供商特定设置,子类可重写"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行API调用,子类必须实现"""
|
||||
pass
|
||||
|
||||
def _handle_api_error(self, status_code: int, response_text: str) -> LLMServiceError:
|
||||
"""处理API错误,返回适当的异常"""
|
||||
from .exceptions import APICallError, RateLimitError, AuthenticationError
|
||||
|
||||
if status_code == 401:
|
||||
return AuthenticationError()
|
||||
elif status_code == 429:
|
||||
return RateLimitError()
|
||||
else:
|
||||
return APICallError(f"HTTP {status_code}", status_code, response_text)
|
||||
|
||||
|
||||
class VisionModelProvider(BaseLLMProvider):
|
||||
"""视觉模型提供商基类"""
|
||||
|
||||
@abstractmethod
|
||||
async def analyze_images(self,
|
||||
images: List[Union[str, Path, PIL.Image.Image]],
|
||||
prompt: str,
|
||||
batch_size: int = 10,
|
||||
**kwargs) -> List[str]:
|
||||
"""
|
||||
分析图片并返回结果
|
||||
|
||||
Args:
|
||||
images: 图片路径列表或PIL图片对象列表
|
||||
prompt: 分析提示词
|
||||
batch_size: 批处理大小
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
分析结果列表
|
||||
"""
|
||||
pass
|
||||
|
||||
def _prepare_images(self, images: List[Union[str, Path, PIL.Image.Image]]) -> List[PIL.Image.Image]:
|
||||
"""预处理图片,统一转换为PIL.Image对象"""
|
||||
processed_images = []
|
||||
|
||||
for img in images:
|
||||
try:
|
||||
if isinstance(img, (str, Path)):
|
||||
pil_img = PIL.Image.open(img)
|
||||
elif isinstance(img, PIL.Image.Image):
|
||||
pil_img = img
|
||||
else:
|
||||
logger.warning(f"不支持的图片类型: {type(img)}")
|
||||
continue
|
||||
|
||||
# 调整图片大小以优化性能
|
||||
if pil_img.size[0] > 1024 or pil_img.size[1] > 1024:
|
||||
pil_img.thumbnail((1024, 1024), PIL.Image.Resampling.LANCZOS)
|
||||
|
||||
processed_images.append(pil_img)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载图片失败 {img}: {str(e)}")
|
||||
continue
|
||||
|
||||
return processed_images
|
||||
|
||||
|
||||
class TextModelProvider(BaseLLMProvider):
|
||||
"""文本生成模型提供商基类"""
|
||||
|
||||
@abstractmethod
|
||||
async def generate_text(self,
|
||||
prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
response_format: Optional[str] = None,
|
||||
**kwargs) -> str:
|
||||
"""
|
||||
生成文本内容
|
||||
|
||||
Args:
|
||||
prompt: 用户提示词
|
||||
system_prompt: 系统提示词
|
||||
temperature: 生成温度
|
||||
max_tokens: 最大token数
|
||||
response_format: 响应格式 ('json' 或 None)
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
生成的文本内容
|
||||
"""
|
||||
pass
|
||||
|
||||
def _build_messages(self, prompt: str, system_prompt: Optional[str] = None) -> List[Dict[str, str]]:
|
||||
"""构建消息列表"""
|
||||
messages = []
|
||||
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
return messages
|
||||
307
app/services/llm/config_validator.py
Normal file
307
app/services/llm/config_validator.py
Normal file
@ -0,0 +1,307 @@
|
||||
"""
|
||||
LLM服务配置验证器
|
||||
|
||||
验证大模型服务的配置是否正确,并提供配置建议
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any, Optional
|
||||
from loguru import logger
|
||||
|
||||
from app.config import config
|
||||
from .manager import LLMServiceManager
|
||||
from .exceptions import ConfigurationError
|
||||
|
||||
|
||||
class LLMConfigValidator:
|
||||
"""LLM服务配置验证器"""
|
||||
|
||||
@staticmethod
|
||||
def validate_all_configs() -> Dict[str, Any]:
|
||||
"""
|
||||
验证所有LLM服务配置
|
||||
|
||||
Returns:
|
||||
验证结果字典
|
||||
"""
|
||||
results = {
|
||||
"vision_providers": {},
|
||||
"text_providers": {},
|
||||
"summary": {
|
||||
"total_vision_providers": 0,
|
||||
"valid_vision_providers": 0,
|
||||
"total_text_providers": 0,
|
||||
"valid_text_providers": 0,
|
||||
"errors": [],
|
||||
"warnings": []
|
||||
}
|
||||
}
|
||||
|
||||
# 验证视觉模型提供商
|
||||
vision_providers = LLMServiceManager.list_vision_providers()
|
||||
results["summary"]["total_vision_providers"] = len(vision_providers)
|
||||
|
||||
for provider in vision_providers:
|
||||
try:
|
||||
validation_result = LLMConfigValidator.validate_vision_provider(provider)
|
||||
results["vision_providers"][provider] = validation_result
|
||||
|
||||
if validation_result["is_valid"]:
|
||||
results["summary"]["valid_vision_providers"] += 1
|
||||
else:
|
||||
results["summary"]["errors"].extend(validation_result["errors"])
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"验证视觉模型提供商 {provider} 时发生错误: {str(e)}"
|
||||
results["vision_providers"][provider] = {
|
||||
"is_valid": False,
|
||||
"errors": [error_msg],
|
||||
"warnings": []
|
||||
}
|
||||
results["summary"]["errors"].append(error_msg)
|
||||
|
||||
# 验证文本模型提供商
|
||||
text_providers = LLMServiceManager.list_text_providers()
|
||||
results["summary"]["total_text_providers"] = len(text_providers)
|
||||
|
||||
for provider in text_providers:
|
||||
try:
|
||||
validation_result = LLMConfigValidator.validate_text_provider(provider)
|
||||
results["text_providers"][provider] = validation_result
|
||||
|
||||
if validation_result["is_valid"]:
|
||||
results["summary"]["valid_text_providers"] += 1
|
||||
else:
|
||||
results["summary"]["errors"].extend(validation_result["errors"])
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"验证文本模型提供商 {provider} 时发生错误: {str(e)}"
|
||||
results["text_providers"][provider] = {
|
||||
"is_valid": False,
|
||||
"errors": [error_msg],
|
||||
"warnings": []
|
||||
}
|
||||
results["summary"]["errors"].append(error_msg)
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def validate_vision_provider(provider_name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
验证视觉模型提供商配置
|
||||
|
||||
Args:
|
||||
provider_name: 提供商名称
|
||||
|
||||
Returns:
|
||||
验证结果字典
|
||||
"""
|
||||
result = {
|
||||
"is_valid": False,
|
||||
"errors": [],
|
||||
"warnings": [],
|
||||
"config": {}
|
||||
}
|
||||
|
||||
try:
|
||||
# 获取配置
|
||||
config_prefix = f"vision_{provider_name}"
|
||||
api_key = config.app.get(f'{config_prefix}_api_key')
|
||||
model_name = config.app.get(f'{config_prefix}_model_name')
|
||||
base_url = config.app.get(f'{config_prefix}_base_url')
|
||||
|
||||
result["config"] = {
|
||||
"api_key": "***" if api_key else None,
|
||||
"model_name": model_name,
|
||||
"base_url": base_url
|
||||
}
|
||||
|
||||
# 验证必需配置
|
||||
if not api_key:
|
||||
result["errors"].append(f"缺少API密钥配置: {config_prefix}_api_key")
|
||||
|
||||
if not model_name:
|
||||
result["errors"].append(f"缺少模型名称配置: {config_prefix}_model_name")
|
||||
|
||||
# 尝试创建提供商实例
|
||||
if api_key and model_name:
|
||||
try:
|
||||
provider_instance = LLMServiceManager.get_vision_provider(provider_name)
|
||||
result["is_valid"] = True
|
||||
logger.debug(f"视觉模型提供商 {provider_name} 配置验证成功")
|
||||
|
||||
except Exception as e:
|
||||
result["errors"].append(f"创建提供商实例失败: {str(e)}")
|
||||
|
||||
# 添加警告
|
||||
if not base_url:
|
||||
result["warnings"].append(f"未配置base_url,将使用默认值")
|
||||
|
||||
except Exception as e:
|
||||
result["errors"].append(f"配置验证过程中发生错误: {str(e)}")
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def validate_text_provider(provider_name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
验证文本模型提供商配置
|
||||
|
||||
Args:
|
||||
provider_name: 提供商名称
|
||||
|
||||
Returns:
|
||||
验证结果字典
|
||||
"""
|
||||
result = {
|
||||
"is_valid": False,
|
||||
"errors": [],
|
||||
"warnings": [],
|
||||
"config": {}
|
||||
}
|
||||
|
||||
try:
|
||||
# 获取配置
|
||||
config_prefix = f"text_{provider_name}"
|
||||
api_key = config.app.get(f'{config_prefix}_api_key')
|
||||
model_name = config.app.get(f'{config_prefix}_model_name')
|
||||
base_url = config.app.get(f'{config_prefix}_base_url')
|
||||
|
||||
result["config"] = {
|
||||
"api_key": "***" if api_key else None,
|
||||
"model_name": model_name,
|
||||
"base_url": base_url
|
||||
}
|
||||
|
||||
# 验证必需配置
|
||||
if not api_key:
|
||||
result["errors"].append(f"缺少API密钥配置: {config_prefix}_api_key")
|
||||
|
||||
if not model_name:
|
||||
result["errors"].append(f"缺少模型名称配置: {config_prefix}_model_name")
|
||||
|
||||
# 尝试创建提供商实例
|
||||
if api_key and model_name:
|
||||
try:
|
||||
provider_instance = LLMServiceManager.get_text_provider(provider_name)
|
||||
result["is_valid"] = True
|
||||
logger.debug(f"文本模型提供商 {provider_name} 配置验证成功")
|
||||
|
||||
except Exception as e:
|
||||
result["errors"].append(f"创建提供商实例失败: {str(e)}")
|
||||
|
||||
# 添加警告
|
||||
if not base_url:
|
||||
result["warnings"].append(f"未配置base_url,将使用默认值")
|
||||
|
||||
except Exception as e:
|
||||
result["errors"].append(f"配置验证过程中发生错误: {str(e)}")
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_config_suggestions() -> Dict[str, Any]:
|
||||
"""
|
||||
获取配置建议
|
||||
|
||||
Returns:
|
||||
配置建议字典
|
||||
"""
|
||||
suggestions = {
|
||||
"vision_providers": {},
|
||||
"text_providers": {},
|
||||
"general_tips": [
|
||||
"确保所有API密钥都已正确配置",
|
||||
"建议为每个提供商配置base_url以提高稳定性",
|
||||
"定期检查模型名称是否为最新版本",
|
||||
"建议配置多个提供商作为备用方案"
|
||||
]
|
||||
}
|
||||
|
||||
# 为每个视觉模型提供商提供建议
|
||||
vision_providers = LLMServiceManager.list_vision_providers()
|
||||
for provider in vision_providers:
|
||||
suggestions["vision_providers"][provider] = {
|
||||
"required_configs": [
|
||||
f"vision_{provider}_api_key",
|
||||
f"vision_{provider}_model_name"
|
||||
],
|
||||
"optional_configs": [
|
||||
f"vision_{provider}_base_url"
|
||||
],
|
||||
"example_models": LLMConfigValidator._get_example_models(provider, "vision")
|
||||
}
|
||||
|
||||
# 为每个文本模型提供商提供建议
|
||||
text_providers = LLMServiceManager.list_text_providers()
|
||||
for provider in text_providers:
|
||||
suggestions["text_providers"][provider] = {
|
||||
"required_configs": [
|
||||
f"text_{provider}_api_key",
|
||||
f"text_{provider}_model_name"
|
||||
],
|
||||
"optional_configs": [
|
||||
f"text_{provider}_base_url"
|
||||
],
|
||||
"example_models": LLMConfigValidator._get_example_models(provider, "text")
|
||||
}
|
||||
|
||||
return suggestions
|
||||
|
||||
@staticmethod
|
||||
def _get_example_models(provider_name: str, model_type: str) -> List[str]:
|
||||
"""获取示例模型名称"""
|
||||
examples = {
|
||||
"gemini": {
|
||||
"vision": ["gemini-2.0-flash-lite", "gemini-2.0-flash"],
|
||||
"text": ["gemini-2.0-flash", "gemini-1.5-pro"]
|
||||
},
|
||||
"openai": {
|
||||
"vision": [],
|
||||
"text": ["gpt-4o", "gpt-4o-mini", "gpt-4-turbo"]
|
||||
},
|
||||
"qwen": {
|
||||
"vision": ["qwen2.5-vl-32b-instruct"],
|
||||
"text": ["qwen-plus-1127", "qwen-turbo"]
|
||||
},
|
||||
"deepseek": {
|
||||
"vision": [],
|
||||
"text": ["deepseek-chat", "deepseek-reasoner"]
|
||||
},
|
||||
"siliconflow": {
|
||||
"vision": ["Qwen/Qwen2.5-VL-32B-Instruct"],
|
||||
"text": ["deepseek-ai/DeepSeek-R1", "Qwen/Qwen2.5-72B-Instruct"]
|
||||
}
|
||||
}
|
||||
|
||||
return examples.get(provider_name, {}).get(model_type, [])
|
||||
|
||||
@staticmethod
|
||||
def print_validation_report(validation_results: Dict[str, Any]):
|
||||
"""
|
||||
打印验证报告
|
||||
|
||||
Args:
|
||||
validation_results: 验证结果
|
||||
"""
|
||||
summary = validation_results["summary"]
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("LLM服务配置验证报告")
|
||||
print("="*60)
|
||||
|
||||
print(f"\n📊 总体统计:")
|
||||
print(f" 视觉模型提供商: {summary['valid_vision_providers']}/{summary['total_vision_providers']} 有效")
|
||||
print(f" 文本模型提供商: {summary['valid_text_providers']}/{summary['total_text_providers']} 有效")
|
||||
|
||||
if summary["errors"]:
|
||||
print(f"\n❌ 错误 ({len(summary['errors'])}):")
|
||||
for error in summary["errors"]:
|
||||
print(f" - {error}")
|
||||
|
||||
if summary["warnings"]:
|
||||
print(f"\n⚠️ 警告 ({len(summary['warnings'])}):")
|
||||
for warning in summary["warnings"]:
|
||||
print(f" - {warning}")
|
||||
|
||||
print(f"\n✅ 配置验证完成")
|
||||
print("="*60)
|
||||
118
app/services/llm/exceptions.py
Normal file
118
app/services/llm/exceptions.py
Normal file
@ -0,0 +1,118 @@
|
||||
"""
|
||||
大模型服务异常类定义
|
||||
|
||||
定义了大模型服务中可能出现的各种异常类型,
|
||||
提供统一的错误处理机制
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
class LLMServiceError(Exception):
|
||||
"""大模型服务基础异常类"""
|
||||
|
||||
def __init__(self, message: str, error_code: Optional[str] = None, details: Optional[Dict[str, Any]] = None):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.error_code = error_code
|
||||
self.details = details or {}
|
||||
|
||||
def __str__(self):
|
||||
if self.error_code:
|
||||
return f"[{self.error_code}] {self.message}"
|
||||
return self.message
|
||||
|
||||
|
||||
class ProviderNotFoundError(LLMServiceError):
|
||||
"""供应商未找到异常"""
|
||||
|
||||
def __init__(self, provider_name: str):
|
||||
super().__init__(
|
||||
message=f"未找到大模型供应商: {provider_name}",
|
||||
error_code="PROVIDER_NOT_FOUND",
|
||||
details={"provider_name": provider_name}
|
||||
)
|
||||
|
||||
|
||||
class ConfigurationError(LLMServiceError):
|
||||
"""配置错误异常"""
|
||||
|
||||
def __init__(self, message: str, config_key: Optional[str] = None):
|
||||
super().__init__(
|
||||
message=f"配置错误: {message}",
|
||||
error_code="CONFIGURATION_ERROR",
|
||||
details={"config_key": config_key} if config_key else {}
|
||||
)
|
||||
|
||||
|
||||
class APICallError(LLMServiceError):
|
||||
"""API调用错误异常"""
|
||||
|
||||
def __init__(self, message: str, status_code: Optional[int] = None, response_text: Optional[str] = None):
|
||||
super().__init__(
|
||||
message=f"API调用失败: {message}",
|
||||
error_code="API_CALL_ERROR",
|
||||
details={
|
||||
"status_code": status_code,
|
||||
"response_text": response_text
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ValidationError(LLMServiceError):
|
||||
"""输出验证错误异常"""
|
||||
|
||||
def __init__(self, message: str, validation_type: Optional[str] = None, invalid_data: Optional[Any] = None):
|
||||
super().__init__(
|
||||
message=f"输出验证失败: {message}",
|
||||
error_code="VALIDATION_ERROR",
|
||||
details={
|
||||
"validation_type": validation_type,
|
||||
"invalid_data": str(invalid_data) if invalid_data else None
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ModelNotSupportedError(LLMServiceError):
|
||||
"""模型不支持异常"""
|
||||
|
||||
def __init__(self, model_name: str, provider_name: str):
|
||||
super().__init__(
|
||||
message=f"供应商 {provider_name} 不支持模型 {model_name}",
|
||||
error_code="MODEL_NOT_SUPPORTED",
|
||||
details={
|
||||
"model_name": model_name,
|
||||
"provider_name": provider_name
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class RateLimitError(LLMServiceError):
|
||||
"""API速率限制异常"""
|
||||
|
||||
def __init__(self, message: str = "API调用频率超限", retry_after: Optional[int] = None):
|
||||
super().__init__(
|
||||
message=message,
|
||||
error_code="RATE_LIMIT_ERROR",
|
||||
details={"retry_after": retry_after}
|
||||
)
|
||||
|
||||
|
||||
class AuthenticationError(LLMServiceError):
|
||||
"""认证错误异常"""
|
||||
|
||||
def __init__(self, message: str = "API密钥无效或权限不足"):
|
||||
super().__init__(
|
||||
message=message,
|
||||
error_code="AUTHENTICATION_ERROR"
|
||||
)
|
||||
|
||||
|
||||
class ContentFilterError(LLMServiceError):
|
||||
"""内容过滤异常"""
|
||||
|
||||
def __init__(self, message: str = "内容被安全过滤器阻止"):
|
||||
super().__init__(
|
||||
message=message,
|
||||
error_code="CONTENT_FILTER_ERROR"
|
||||
)
|
||||
214
app/services/llm/manager.py
Normal file
214
app/services/llm/manager.py
Normal file
@ -0,0 +1,214 @@
|
||||
"""
|
||||
大模型服务管理器
|
||||
|
||||
统一管理所有大模型服务提供商,提供简单的工厂方法来创建和获取服务实例
|
||||
"""
|
||||
|
||||
from typing import Dict, Type, Optional
|
||||
from loguru import logger
|
||||
|
||||
from app.config import config
|
||||
from .base import VisionModelProvider, TextModelProvider
|
||||
from .exceptions import ProviderNotFoundError, ConfigurationError
|
||||
|
||||
|
||||
class LLMServiceManager:
|
||||
"""大模型服务管理器"""
|
||||
|
||||
# 注册的视觉模型提供商
|
||||
_vision_providers: Dict[str, Type[VisionModelProvider]] = {}
|
||||
|
||||
# 注册的文本模型提供商
|
||||
_text_providers: Dict[str, Type[TextModelProvider]] = {}
|
||||
|
||||
# 缓存的提供商实例
|
||||
_vision_instance_cache: Dict[str, VisionModelProvider] = {}
|
||||
_text_instance_cache: Dict[str, TextModelProvider] = {}
|
||||
|
||||
@classmethod
|
||||
def register_vision_provider(cls, name: str, provider_class: Type[VisionModelProvider]):
|
||||
"""注册视觉模型提供商"""
|
||||
cls._vision_providers[name.lower()] = provider_class
|
||||
logger.debug(f"注册视觉模型提供商: {name}")
|
||||
|
||||
@classmethod
|
||||
def register_text_provider(cls, name: str, provider_class: Type[TextModelProvider]):
|
||||
"""注册文本模型提供商"""
|
||||
cls._text_providers[name.lower()] = provider_class
|
||||
logger.debug(f"注册文本模型提供商: {name}")
|
||||
|
||||
@classmethod
|
||||
def _ensure_providers_registered(cls):
|
||||
"""确保提供商已注册"""
|
||||
try:
|
||||
# 如果没有注册的提供商,强制导入providers模块
|
||||
if not cls._vision_providers or not cls._text_providers:
|
||||
from . import providers
|
||||
logger.debug("LLMServiceManager强制注册提供商")
|
||||
except Exception as e:
|
||||
logger.error(f"LLMServiceManager确保提供商注册时发生错误: {str(e)}")
|
||||
|
||||
@classmethod
|
||||
def get_vision_provider(cls, provider_name: Optional[str] = None) -> VisionModelProvider:
|
||||
"""
|
||||
获取视觉模型提供商实例
|
||||
|
||||
Args:
|
||||
provider_name: 提供商名称,如果不指定则从配置中获取
|
||||
|
||||
Returns:
|
||||
视觉模型提供商实例
|
||||
|
||||
Raises:
|
||||
ProviderNotFoundError: 提供商未找到
|
||||
ConfigurationError: 配置错误
|
||||
"""
|
||||
# 确保提供商已注册
|
||||
cls._ensure_providers_registered()
|
||||
|
||||
# 确定提供商名称
|
||||
if not provider_name:
|
||||
provider_name = config.app.get('vision_llm_provider', 'gemini').lower()
|
||||
else:
|
||||
provider_name = provider_name.lower()
|
||||
|
||||
# 检查缓存
|
||||
cache_key = f"vision_{provider_name}"
|
||||
if cache_key in cls._vision_instance_cache:
|
||||
return cls._vision_instance_cache[cache_key]
|
||||
|
||||
# 检查提供商是否已注册
|
||||
if provider_name not in cls._vision_providers:
|
||||
raise ProviderNotFoundError(provider_name)
|
||||
|
||||
# 获取配置
|
||||
config_prefix = f"vision_{provider_name}"
|
||||
api_key = config.app.get(f'{config_prefix}_api_key')
|
||||
model_name = config.app.get(f'{config_prefix}_model_name')
|
||||
base_url = config.app.get(f'{config_prefix}_base_url')
|
||||
|
||||
if not api_key:
|
||||
raise ConfigurationError(f"缺少API密钥配置: {config_prefix}_api_key")
|
||||
|
||||
if not model_name:
|
||||
raise ConfigurationError(f"缺少模型名称配置: {config_prefix}_model_name")
|
||||
|
||||
# 创建提供商实例
|
||||
provider_class = cls._vision_providers[provider_name]
|
||||
try:
|
||||
instance = provider_class(
|
||||
api_key=api_key,
|
||||
model_name=model_name,
|
||||
base_url=base_url
|
||||
)
|
||||
|
||||
# 缓存实例
|
||||
cls._vision_instance_cache[cache_key] = instance
|
||||
|
||||
logger.info(f"创建视觉模型提供商实例: {provider_name} - {model_name}")
|
||||
return instance
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建视觉模型提供商实例失败: {provider_name} - {str(e)}")
|
||||
raise ConfigurationError(f"创建提供商实例失败: {str(e)}")
|
||||
|
||||
@classmethod
|
||||
def get_text_provider(cls, provider_name: Optional[str] = None) -> TextModelProvider:
|
||||
"""
|
||||
获取文本模型提供商实例
|
||||
|
||||
Args:
|
||||
provider_name: 提供商名称,如果不指定则从配置中获取
|
||||
|
||||
Returns:
|
||||
文本模型提供商实例
|
||||
|
||||
Raises:
|
||||
ProviderNotFoundError: 提供商未找到
|
||||
ConfigurationError: 配置错误
|
||||
"""
|
||||
# 确保提供商已注册
|
||||
cls._ensure_providers_registered()
|
||||
|
||||
# 确定提供商名称
|
||||
if not provider_name:
|
||||
provider_name = config.app.get('text_llm_provider', 'openai').lower()
|
||||
else:
|
||||
provider_name = provider_name.lower()
|
||||
|
||||
# 检查缓存
|
||||
cache_key = f"text_{provider_name}"
|
||||
if cache_key in cls._text_instance_cache:
|
||||
return cls._text_instance_cache[cache_key]
|
||||
|
||||
# 检查提供商是否已注册
|
||||
if provider_name not in cls._text_providers:
|
||||
raise ProviderNotFoundError(provider_name)
|
||||
|
||||
# 获取配置
|
||||
config_prefix = f"text_{provider_name}"
|
||||
api_key = config.app.get(f'{config_prefix}_api_key')
|
||||
model_name = config.app.get(f'{config_prefix}_model_name')
|
||||
base_url = config.app.get(f'{config_prefix}_base_url')
|
||||
|
||||
if not api_key:
|
||||
raise ConfigurationError(f"缺少API密钥配置: {config_prefix}_api_key")
|
||||
|
||||
if not model_name:
|
||||
raise ConfigurationError(f"缺少模型名称配置: {config_prefix}_model_name")
|
||||
|
||||
# 创建提供商实例
|
||||
provider_class = cls._text_providers[provider_name]
|
||||
try:
|
||||
instance = provider_class(
|
||||
api_key=api_key,
|
||||
model_name=model_name,
|
||||
base_url=base_url
|
||||
)
|
||||
|
||||
# 缓存实例
|
||||
cls._text_instance_cache[cache_key] = instance
|
||||
|
||||
logger.info(f"创建文本模型提供商实例: {provider_name} - {model_name}")
|
||||
return instance
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建文本模型提供商实例失败: {provider_name} - {str(e)}")
|
||||
raise ConfigurationError(f"创建提供商实例失败: {str(e)}")
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls):
|
||||
"""清空提供商实例缓存"""
|
||||
cls._vision_instance_cache.clear()
|
||||
cls._text_instance_cache.clear()
|
||||
logger.info("已清空提供商实例缓存")
|
||||
|
||||
@classmethod
|
||||
def list_vision_providers(cls) -> list[str]:
|
||||
"""列出所有已注册的视觉模型提供商"""
|
||||
return list(cls._vision_providers.keys())
|
||||
|
||||
@classmethod
|
||||
def list_text_providers(cls) -> list[str]:
|
||||
"""列出所有已注册的文本模型提供商"""
|
||||
return list(cls._text_providers.keys())
|
||||
|
||||
@classmethod
|
||||
def get_provider_info(cls) -> Dict[str, Dict[str, any]]:
|
||||
"""获取所有提供商信息"""
|
||||
return {
|
||||
"vision_providers": {
|
||||
name: {
|
||||
"class": provider_class.__name__,
|
||||
"module": provider_class.__module__
|
||||
}
|
||||
for name, provider_class in cls._vision_providers.items()
|
||||
},
|
||||
"text_providers": {
|
||||
name: {
|
||||
"class": provider_class.__name__,
|
||||
"module": provider_class.__module__
|
||||
}
|
||||
for name, provider_class in cls._text_providers.items()
|
||||
}
|
||||
}
|
||||
322
app/services/llm/migration_adapter.py
Normal file
322
app/services/llm/migration_adapter.py
Normal file
@ -0,0 +1,322 @@
|
||||
"""
|
||||
迁移适配器
|
||||
|
||||
为现有代码提供向后兼容的接口,方便逐步迁移到新的LLM服务架构
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from pathlib import Path
|
||||
import PIL.Image
|
||||
from loguru import logger
|
||||
|
||||
from .unified_service import UnifiedLLMService
|
||||
from .exceptions import LLMServiceError
|
||||
|
||||
# 确保提供商已注册
|
||||
def _ensure_providers_registered():
|
||||
"""确保所有提供商都已注册"""
|
||||
try:
|
||||
from .manager import LLMServiceManager
|
||||
# 检查是否有已注册的提供商
|
||||
if not LLMServiceManager.list_text_providers() or not LLMServiceManager.list_vision_providers():
|
||||
# 如果没有注册的提供商,强制导入providers模块
|
||||
from . import providers
|
||||
logger.debug("迁移适配器强制注册LLM服务提供商")
|
||||
except Exception as e:
|
||||
logger.error(f"迁移适配器确保LLM服务提供商注册时发生错误: {str(e)}")
|
||||
|
||||
# 在模块加载时确保提供商已注册
|
||||
_ensure_providers_registered()
|
||||
|
||||
|
||||
def _run_async_safely(coro_func, *args, **kwargs):
|
||||
"""
|
||||
安全地运行异步协程,处理各种事件循环情况
|
||||
|
||||
Args:
|
||||
coro_func: 协程函数(不是协程对象)
|
||||
*args: 协程函数的位置参数
|
||||
**kwargs: 协程函数的关键字参数
|
||||
|
||||
Returns:
|
||||
协程的执行结果
|
||||
"""
|
||||
def run_in_new_loop():
|
||||
"""在新的事件循环中运行协程"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(coro_func(*args, **kwargs))
|
||||
finally:
|
||||
loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
try:
|
||||
# 尝试获取当前事件循环
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
# 如果有运行中的事件循环,使用线程池执行
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
return future.result()
|
||||
except RuntimeError:
|
||||
# 没有运行中的事件循环,直接运行
|
||||
return run_in_new_loop()
|
||||
except Exception as e:
|
||||
logger.error(f"异步执行失败: {str(e)}")
|
||||
raise LLMServiceError(f"异步执行失败: {str(e)}")
|
||||
|
||||
|
||||
class LegacyLLMAdapter:
|
||||
"""传统LLM接口适配器"""
|
||||
|
||||
@staticmethod
|
||||
def create_vision_analyzer(provider: str, api_key: str, model: str, base_url: str = None):
|
||||
"""
|
||||
创建视觉分析器实例 - 兼容原有接口
|
||||
|
||||
Args:
|
||||
provider: 提供商名称
|
||||
api_key: API密钥
|
||||
model: 模型名称
|
||||
base_url: API基础URL
|
||||
|
||||
Returns:
|
||||
适配器实例
|
||||
"""
|
||||
return VisionAnalyzerAdapter(provider, api_key, model, base_url)
|
||||
|
||||
@staticmethod
|
||||
def generate_narration(markdown_content: str, api_key: str, base_url: str, model: str) -> str:
|
||||
"""
|
||||
生成解说文案 - 兼容原有接口
|
||||
|
||||
Args:
|
||||
markdown_content: Markdown格式的视频帧分析内容
|
||||
api_key: API密钥
|
||||
base_url: API基础URL
|
||||
model: 模型名称
|
||||
|
||||
Returns:
|
||||
生成的解说文案JSON字符串
|
||||
"""
|
||||
try:
|
||||
# 构建提示词
|
||||
prompt = f"""
|
||||
我是一名荒野建造解说的博主,以下是一些同行的对标文案,请你深度学习并总结这些文案的风格特点跟内容特点:
|
||||
|
||||
<video_frame_description>
|
||||
{markdown_content}
|
||||
</video_frame_description>
|
||||
|
||||
请根据以上视频帧描述,生成引人入胜的解说文案。
|
||||
|
||||
<output>
|
||||
{{
|
||||
"items": [
|
||||
{{
|
||||
"_id": 1,
|
||||
"timestamp": "00:00:05,390-00:00:10,430",
|
||||
"picture": "画面描述",
|
||||
"narration": "解说文案",
|
||||
}}
|
||||
]
|
||||
}}
|
||||
</output>
|
||||
|
||||
<restriction>
|
||||
1. 只输出 json 内容,不要输出其他任何说明性的文字
|
||||
2. 解说文案的语言使用 简体中文
|
||||
3. 严禁虚构画面,所有画面只能从 <video_frame_description> 中摘取
|
||||
</restriction>
|
||||
"""
|
||||
|
||||
# 使用统一服务生成文案
|
||||
result = _run_async_safely(
|
||||
UnifiedLLMService.generate_text,
|
||||
prompt=prompt,
|
||||
system_prompt="你是一名专业的短视频解说文案撰写专家。",
|
||||
temperature=1.5,
|
||||
response_format="json"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成解说文案失败: {str(e)}")
|
||||
return f"生成解说文案失败: {str(e)}"
|
||||
|
||||
|
||||
class VisionAnalyzerAdapter:
|
||||
"""视觉分析器适配器"""
|
||||
|
||||
def __init__(self, provider: str, api_key: str, model: str, base_url: str = None):
|
||||
self.provider = provider
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
|
||||
async def analyze_images(self,
|
||||
images: List[Union[str, Path, PIL.Image.Image]],
|
||||
prompt: str,
|
||||
batch_size: int = 10) -> List[str]:
|
||||
"""
|
||||
分析图片 - 兼容原有接口
|
||||
|
||||
Args:
|
||||
images: 图片列表
|
||||
prompt: 分析提示词
|
||||
batch_size: 批处理大小
|
||||
|
||||
Returns:
|
||||
分析结果列表
|
||||
"""
|
||||
try:
|
||||
# 使用统一服务分析图片
|
||||
results = await UnifiedLLMService.analyze_images(
|
||||
images=images,
|
||||
prompt=prompt,
|
||||
provider=self.provider,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"图片分析失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
class SubtitleAnalyzerAdapter:
|
||||
"""字幕分析器适配器"""
|
||||
|
||||
def __init__(self, api_key: str, model: str, base_url: str, provider: str = None):
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
self.provider = provider or "openai"
|
||||
|
||||
def _run_async_safely(self, coro_func, *args, **kwargs):
|
||||
"""安全地运行异步协程"""
|
||||
return _run_async_safely(coro_func, *args, **kwargs)
|
||||
|
||||
def _clean_json_output(self, output: str) -> str:
|
||||
"""清理JSON输出,移除markdown标记等"""
|
||||
import re
|
||||
|
||||
# 移除可能的markdown代码块标记
|
||||
output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE)
|
||||
output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE)
|
||||
output = re.sub(r'^```.*$', '', output, flags=re.MULTILINE)
|
||||
|
||||
# 移除开头和结尾的```标记
|
||||
output = re.sub(r'^```', '', output)
|
||||
output = re.sub(r'```$', '', output)
|
||||
|
||||
# 移除前后空白字符
|
||||
output = output.strip()
|
||||
|
||||
return output
|
||||
|
||||
def analyze_subtitle(self, subtitle_content: str) -> Dict[str, Any]:
|
||||
"""
|
||||
分析字幕内容 - 兼容原有接口
|
||||
|
||||
Args:
|
||||
subtitle_content: 字幕内容
|
||||
|
||||
Returns:
|
||||
分析结果字典
|
||||
"""
|
||||
try:
|
||||
# 使用统一服务分析字幕
|
||||
result = self._run_async_safely(
|
||||
UnifiedLLMService.analyze_subtitle,
|
||||
subtitle_content=subtitle_content,
|
||||
provider=self.provider,
|
||||
temperature=1.0
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"analysis": result,
|
||||
"model": self.model,
|
||||
"temperature": 1.0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"字幕分析失败: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e),
|
||||
"temperature": 1.0
|
||||
}
|
||||
|
||||
def generate_narration_script(self, short_name: str, plot_analysis: str, temperature: float = 0.7) -> Dict[str, Any]:
|
||||
"""
|
||||
生成解说文案 - 兼容原有接口
|
||||
|
||||
Args:
|
||||
short_name: 短剧名称
|
||||
plot_analysis: 剧情分析内容
|
||||
temperature: 生成温度
|
||||
|
||||
Returns:
|
||||
生成结果字典
|
||||
"""
|
||||
try:
|
||||
# 构建提示词
|
||||
prompt = f"""
|
||||
根据以下剧情分析,为短剧《{short_name}》生成引人入胜的解说文案:
|
||||
|
||||
{plot_analysis}
|
||||
|
||||
请生成JSON格式的解说文案,包含以下字段:
|
||||
- narration_script: 解说文案内容
|
||||
|
||||
输出格式:
|
||||
{{
|
||||
"narration_script": "解说文案内容"
|
||||
}}
|
||||
"""
|
||||
|
||||
# 使用统一服务生成文案
|
||||
result = self._run_async_safely(
|
||||
UnifiedLLMService.generate_text,
|
||||
prompt=prompt,
|
||||
system_prompt="你是一位专业的短视频解说脚本撰写专家。",
|
||||
provider=self.provider,
|
||||
temperature=temperature,
|
||||
response_format="json"
|
||||
)
|
||||
|
||||
# 清理JSON输出
|
||||
cleaned_result = self._clean_json_output(result)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"narration_script": cleaned_result,
|
||||
"model": self.model,
|
||||
"temperature": temperature
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解说文案生成失败: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e),
|
||||
"temperature": temperature
|
||||
}
|
||||
|
||||
|
||||
# 为了向后兼容,提供一些全局函数
|
||||
def create_vision_analyzer(provider: str, api_key: str, model: str, base_url: str = None):
|
||||
"""创建视觉分析器 - 全局函数"""
|
||||
return LegacyLLMAdapter.create_vision_analyzer(provider, api_key, model, base_url)
|
||||
|
||||
|
||||
def generate_narration(markdown_content: str, api_key: str, base_url: str, model: str) -> str:
|
||||
"""生成解说文案 - 全局函数"""
|
||||
return LegacyLLMAdapter.generate_narration(markdown_content, api_key, base_url, model)
|
||||
47
app/services/llm/providers/__init__.py
Normal file
47
app/services/llm/providers/__init__.py
Normal file
@ -0,0 +1,47 @@
|
||||
"""
|
||||
大模型服务提供商实现
|
||||
|
||||
包含各种大模型服务提供商的具体实现
|
||||
"""
|
||||
|
||||
from .gemini_provider import GeminiVisionProvider, GeminiTextProvider
|
||||
from .gemini_openai_provider import GeminiOpenAIVisionProvider, GeminiOpenAITextProvider
|
||||
from .openai_provider import OpenAITextProvider
|
||||
from .qwen_provider import QwenVisionProvider, QwenTextProvider
|
||||
from .deepseek_provider import DeepSeekTextProvider
|
||||
from .siliconflow_provider import SiliconflowVisionProvider, SiliconflowTextProvider
|
||||
|
||||
# 自动注册所有提供商
|
||||
from ..manager import LLMServiceManager
|
||||
|
||||
def register_all_providers():
|
||||
"""注册所有提供商"""
|
||||
# 注册视觉模型提供商
|
||||
LLMServiceManager.register_vision_provider('gemini', GeminiVisionProvider)
|
||||
LLMServiceManager.register_vision_provider('gemini(openai)', GeminiOpenAIVisionProvider)
|
||||
LLMServiceManager.register_vision_provider('qwenvl', QwenVisionProvider)
|
||||
LLMServiceManager.register_vision_provider('siliconflow', SiliconflowVisionProvider)
|
||||
|
||||
# 注册文本模型提供商
|
||||
LLMServiceManager.register_text_provider('gemini', GeminiTextProvider)
|
||||
LLMServiceManager.register_text_provider('gemini(openai)', GeminiOpenAITextProvider)
|
||||
LLMServiceManager.register_text_provider('openai', OpenAITextProvider)
|
||||
LLMServiceManager.register_text_provider('qwen', QwenTextProvider)
|
||||
LLMServiceManager.register_text_provider('deepseek', DeepSeekTextProvider)
|
||||
LLMServiceManager.register_text_provider('siliconflow', SiliconflowTextProvider)
|
||||
|
||||
# 自动注册
|
||||
register_all_providers()
|
||||
|
||||
__all__ = [
|
||||
'GeminiVisionProvider',
|
||||
'GeminiTextProvider',
|
||||
'GeminiOpenAIVisionProvider',
|
||||
'GeminiOpenAITextProvider',
|
||||
'OpenAITextProvider',
|
||||
'QwenVisionProvider',
|
||||
'QwenTextProvider',
|
||||
'DeepSeekTextProvider',
|
||||
'SiliconflowVisionProvider',
|
||||
'SiliconflowTextProvider'
|
||||
]
|
||||
157
app/services/llm/providers/deepseek_provider.py
Normal file
157
app/services/llm/providers/deepseek_provider.py
Normal file
@ -0,0 +1,157 @@
|
||||
"""
|
||||
DeepSeek API提供商实现
|
||||
|
||||
支持DeepSeek的文本生成模型
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional
|
||||
from openai import OpenAI, BadRequestError
|
||||
from loguru import logger
|
||||
|
||||
from ..base import TextModelProvider
|
||||
from ..exceptions import APICallError
|
||||
|
||||
|
||||
class DeepSeekTextProvider(TextModelProvider):
|
||||
"""DeepSeek文本生成提供商"""
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "deepseek"
|
||||
|
||||
@property
|
||||
def supported_models(self) -> List[str]:
|
||||
return [
|
||||
"deepseek-chat",
|
||||
"deepseek-reasoner",
|
||||
"deepseek-r1",
|
||||
"deepseek-v3"
|
||||
]
|
||||
|
||||
def _initialize(self):
|
||||
"""初始化DeepSeek客户端"""
|
||||
if not self.base_url:
|
||||
self.base_url = "https://api.deepseek.com"
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
|
||||
async def generate_text(self,
|
||||
prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
response_format: Optional[str] = None,
|
||||
**kwargs) -> str:
|
||||
"""
|
||||
使用DeepSeek API生成文本
|
||||
|
||||
Args:
|
||||
prompt: 用户提示词
|
||||
system_prompt: 系统提示词
|
||||
temperature: 生成温度
|
||||
max_tokens: 最大token数
|
||||
response_format: 响应格式 ('json' 或 None)
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
生成的文本内容
|
||||
"""
|
||||
# 构建消息列表
|
||||
messages = self._build_messages(prompt, system_prompt)
|
||||
|
||||
# 构建请求参数
|
||||
request_params = {
|
||||
"model": self.model_name,
|
||||
"messages": messages,
|
||||
"temperature": temperature
|
||||
}
|
||||
|
||||
if max_tokens:
|
||||
request_params["max_tokens"] = max_tokens
|
||||
|
||||
# 处理JSON格式输出
|
||||
# DeepSeek R1 和 V3 不支持 response_format=json_object
|
||||
if response_format == "json":
|
||||
if self._supports_response_format():
|
||||
request_params["response_format"] = {"type": "json_object"}
|
||||
else:
|
||||
# 对于不支持response_format的模型,在提示词中添加约束
|
||||
messages[-1]["content"] += "\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。"
|
||||
|
||||
try:
|
||||
# 发送API请求
|
||||
response = await asyncio.to_thread(
|
||||
self.client.chat.completions.create,
|
||||
**request_params
|
||||
)
|
||||
|
||||
# 提取生成的内容
|
||||
if response.choices and len(response.choices) > 0:
|
||||
content = response.choices[0].message.content
|
||||
|
||||
# 对于不支持response_format的模型,清理输出
|
||||
if response_format == "json" and not self._supports_response_format():
|
||||
content = self._clean_json_output(content)
|
||||
|
||||
logger.debug(f"DeepSeek API调用成功,消耗tokens: {response.usage.total_tokens if response.usage else 'N/A'}")
|
||||
return content
|
||||
else:
|
||||
raise APICallError("DeepSeek API返回空响应")
|
||||
|
||||
except BadRequestError as e:
|
||||
# 处理不支持response_format的情况
|
||||
if "response_format" in str(e) and response_format == "json":
|
||||
logger.warning(f"DeepSeek模型 {self.model_name} 不支持response_format,重试不带格式约束的请求")
|
||||
request_params.pop("response_format", None)
|
||||
messages[-1]["content"] += "\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。"
|
||||
|
||||
response = await asyncio.to_thread(
|
||||
self.client.chat.completions.create,
|
||||
**request_params
|
||||
)
|
||||
|
||||
if response.choices and len(response.choices) > 0:
|
||||
content = response.choices[0].message.content
|
||||
content = self._clean_json_output(content)
|
||||
return content
|
||||
else:
|
||||
raise APICallError("DeepSeek API返回空响应")
|
||||
else:
|
||||
raise APICallError(f"DeepSeek API请求失败: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"DeepSeek API调用失败: {str(e)}")
|
||||
raise APICallError(f"DeepSeek API调用失败: {str(e)}")
|
||||
|
||||
def _supports_response_format(self) -> bool:
|
||||
"""检查模型是否支持response_format参数"""
|
||||
# DeepSeek R1 和 V3 不支持 response_format=json_object
|
||||
unsupported_models = [
|
||||
"deepseek-reasoner",
|
||||
"deepseek-r1",
|
||||
"deepseek-v3"
|
||||
]
|
||||
|
||||
return not any(unsupported in self.model_name.lower() for unsupported in unsupported_models)
|
||||
|
||||
def _clean_json_output(self, output: str) -> str:
|
||||
"""清理JSON输出,移除markdown标记等"""
|
||||
import re
|
||||
|
||||
# 移除可能的markdown代码块标记
|
||||
output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE)
|
||||
output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE)
|
||||
output = re.sub(r'^```.*$', '', output, flags=re.MULTILINE)
|
||||
|
||||
# 移除前后空白字符
|
||||
output = output.strip()
|
||||
|
||||
return output
|
||||
|
||||
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行API调用 - 由于使用OpenAI SDK,这个方法主要用于兼容基类"""
|
||||
pass
|
||||
235
app/services/llm/providers/gemini_openai_provider.py
Normal file
235
app/services/llm/providers/gemini_openai_provider.py
Normal file
@ -0,0 +1,235 @@
|
||||
"""
|
||||
OpenAI兼容的Gemini API提供商实现
|
||||
|
||||
使用OpenAI兼容接口调用Gemini服务,支持视觉分析和文本生成
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from pathlib import Path
|
||||
import PIL.Image
|
||||
from openai import OpenAI
|
||||
from loguru import logger
|
||||
|
||||
from ..base import VisionModelProvider, TextModelProvider
|
||||
from ..exceptions import APICallError
|
||||
|
||||
|
||||
class GeminiOpenAIVisionProvider(VisionModelProvider):
|
||||
"""OpenAI兼容的Gemini视觉模型提供商"""
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "gemini(openai)"
|
||||
|
||||
@property
|
||||
def supported_models(self) -> List[str]:
|
||||
return [
|
||||
"gemini-2.0-flash-lite",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-1.5-pro",
|
||||
"gemini-1.5-flash"
|
||||
]
|
||||
|
||||
def _initialize(self):
|
||||
"""初始化OpenAI兼容的Gemini客户端"""
|
||||
if not self.base_url:
|
||||
self.base_url = "https://generativelanguage.googleapis.com/v1beta/openai"
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
|
||||
async def analyze_images(self,
|
||||
images: List[Union[str, Path, PIL.Image.Image]],
|
||||
prompt: str,
|
||||
batch_size: int = 10,
|
||||
**kwargs) -> List[str]:
|
||||
"""
|
||||
使用OpenAI兼容的Gemini API分析图片
|
||||
|
||||
Args:
|
||||
images: 图片列表
|
||||
prompt: 分析提示词
|
||||
batch_size: 批处理大小
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
分析结果列表
|
||||
"""
|
||||
logger.info(f"开始分析 {len(images)} 张图片,使用OpenAI兼容Gemini代理")
|
||||
|
||||
# 预处理图片
|
||||
processed_images = self._prepare_images(images)
|
||||
|
||||
# 分批处理
|
||||
results = []
|
||||
for i in range(0, len(processed_images), batch_size):
|
||||
batch = processed_images[i:i + batch_size]
|
||||
logger.info(f"处理第 {i//batch_size + 1} 批,共 {len(batch)} 张图片")
|
||||
|
||||
try:
|
||||
result = await self._analyze_batch(batch, prompt)
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"批次 {i//batch_size + 1} 处理失败: {str(e)}")
|
||||
results.append(f"批次处理失败: {str(e)}")
|
||||
|
||||
return results
|
||||
|
||||
async def _analyze_batch(self, batch: List[PIL.Image.Image], prompt: str) -> str:
|
||||
"""分析一批图片"""
|
||||
# 构建OpenAI格式的消息内容
|
||||
content = [{"type": "text", "text": prompt}]
|
||||
|
||||
# 添加图片
|
||||
for img in batch:
|
||||
base64_image = self._image_to_base64(img)
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}"
|
||||
}
|
||||
})
|
||||
|
||||
# 构建消息
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": content
|
||||
}]
|
||||
|
||||
# 调用API
|
||||
response = await asyncio.to_thread(
|
||||
self.client.chat.completions.create,
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
max_tokens=4000,
|
||||
temperature=1.0
|
||||
)
|
||||
|
||||
if response.choices and len(response.choices) > 0:
|
||||
return response.choices[0].message.content
|
||||
else:
|
||||
raise APICallError("OpenAI兼容Gemini API返回空响应")
|
||||
|
||||
def _image_to_base64(self, img: PIL.Image.Image) -> str:
|
||||
"""将PIL图片转换为base64编码"""
|
||||
img_buffer = io.BytesIO()
|
||||
img.save(img_buffer, format='JPEG', quality=85)
|
||||
img_bytes = img_buffer.getvalue()
|
||||
return base64.b64encode(img_bytes).decode('utf-8')
|
||||
|
||||
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行API调用 - 由于使用OpenAI SDK,这个方法主要用于兼容基类"""
|
||||
pass
|
||||
|
||||
|
||||
class GeminiOpenAITextProvider(TextModelProvider):
|
||||
"""OpenAI兼容的Gemini文本生成提供商"""
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "gemini(openai)"
|
||||
|
||||
@property
|
||||
def supported_models(self) -> List[str]:
|
||||
return [
|
||||
"gemini-2.0-flash-lite",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-1.5-pro",
|
||||
"gemini-1.5-flash"
|
||||
]
|
||||
|
||||
def _initialize(self):
|
||||
"""初始化OpenAI兼容的Gemini客户端"""
|
||||
if not self.base_url:
|
||||
self.base_url = "https://generativelanguage.googleapis.com/v1beta/openai"
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
|
||||
async def generate_text(self,
|
||||
prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
response_format: Optional[str] = None,
|
||||
**kwargs) -> str:
|
||||
"""
|
||||
使用OpenAI兼容的Gemini API生成文本
|
||||
|
||||
Args:
|
||||
prompt: 用户提示词
|
||||
system_prompt: 系统提示词
|
||||
temperature: 生成温度
|
||||
max_tokens: 最大token数
|
||||
response_format: 响应格式 ('json' 或 None)
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
生成的文本内容
|
||||
"""
|
||||
# 构建消息列表
|
||||
messages = self._build_messages(prompt, system_prompt)
|
||||
|
||||
# 构建请求参数
|
||||
request_params = {
|
||||
"model": self.model_name,
|
||||
"messages": messages,
|
||||
"temperature": temperature
|
||||
}
|
||||
|
||||
if max_tokens:
|
||||
request_params["max_tokens"] = max_tokens
|
||||
|
||||
# 处理JSON格式输出 - Gemini通过OpenAI接口可能不完全支持response_format
|
||||
if response_format == "json":
|
||||
# 在提示词中添加JSON格式约束
|
||||
messages[-1]["content"] += "\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。"
|
||||
|
||||
try:
|
||||
# 发送API请求
|
||||
response = await asyncio.to_thread(
|
||||
self.client.chat.completions.create,
|
||||
**request_params
|
||||
)
|
||||
|
||||
# 提取生成的内容
|
||||
if response.choices and len(response.choices) > 0:
|
||||
content = response.choices[0].message.content
|
||||
|
||||
# 对于JSON格式,清理输出
|
||||
if response_format == "json":
|
||||
content = self._clean_json_output(content)
|
||||
|
||||
logger.debug(f"OpenAI兼容Gemini API调用成功,消耗tokens: {response.usage.total_tokens if response.usage else 'N/A'}")
|
||||
return content
|
||||
else:
|
||||
raise APICallError("OpenAI兼容Gemini API返回空响应")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI兼容Gemini API调用失败: {str(e)}")
|
||||
raise APICallError(f"OpenAI兼容Gemini API调用失败: {str(e)}")
|
||||
|
||||
def _clean_json_output(self, output: str) -> str:
|
||||
"""清理JSON输出,移除markdown标记等"""
|
||||
import re
|
||||
|
||||
# 移除可能的markdown代码块标记
|
||||
output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE)
|
||||
output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE)
|
||||
output = re.sub(r'^```.*$', '', output, flags=re.MULTILINE)
|
||||
|
||||
# 移除前后空白字符
|
||||
output = output.strip()
|
||||
|
||||
return output
|
||||
|
||||
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行API调用 - 由于使用OpenAI SDK,这个方法主要用于兼容基类"""
|
||||
pass
|
||||
346
app/services/llm/providers/gemini_provider.py
Normal file
346
app/services/llm/providers/gemini_provider.py
Normal file
@ -0,0 +1,346 @@
|
||||
"""
|
||||
原生Gemini API提供商实现
|
||||
|
||||
使用Google原生Gemini API进行视觉分析和文本生成
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import requests
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from pathlib import Path
|
||||
import PIL.Image
|
||||
from loguru import logger
|
||||
|
||||
from ..base import VisionModelProvider, TextModelProvider
|
||||
from ..exceptions import APICallError, ContentFilterError
|
||||
|
||||
|
||||
class GeminiVisionProvider(VisionModelProvider):
|
||||
"""原生Gemini视觉模型提供商"""
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "gemini"
|
||||
|
||||
@property
|
||||
def supported_models(self) -> List[str]:
|
||||
return [
|
||||
"gemini-2.0-flash-lite",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-1.5-pro",
|
||||
"gemini-1.5-flash"
|
||||
]
|
||||
|
||||
def _initialize(self):
|
||||
"""初始化Gemini特定设置"""
|
||||
if not self.base_url:
|
||||
self.base_url = "https://generativelanguage.googleapis.com/v1beta"
|
||||
|
||||
async def analyze_images(self,
|
||||
images: List[Union[str, Path, PIL.Image.Image]],
|
||||
prompt: str,
|
||||
batch_size: int = 10,
|
||||
**kwargs) -> List[str]:
|
||||
"""
|
||||
使用原生Gemini API分析图片
|
||||
|
||||
Args:
|
||||
images: 图片列表
|
||||
prompt: 分析提示词
|
||||
batch_size: 批处理大小
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
分析结果列表
|
||||
"""
|
||||
logger.info(f"开始分析 {len(images)} 张图片,使用原生Gemini API")
|
||||
|
||||
# 预处理图片
|
||||
processed_images = self._prepare_images(images)
|
||||
|
||||
# 分批处理
|
||||
results = []
|
||||
for i in range(0, len(processed_images), batch_size):
|
||||
batch = processed_images[i:i + batch_size]
|
||||
logger.info(f"处理第 {i//batch_size + 1} 批,共 {len(batch)} 张图片")
|
||||
|
||||
try:
|
||||
result = await self._analyze_batch(batch, prompt)
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"批次 {i//batch_size + 1} 处理失败: {str(e)}")
|
||||
results.append(f"批次处理失败: {str(e)}")
|
||||
|
||||
return results
|
||||
|
||||
async def _analyze_batch(self, batch: List[PIL.Image.Image], prompt: str) -> str:
|
||||
"""分析一批图片"""
|
||||
# 构建请求数据
|
||||
parts = [{"text": prompt}]
|
||||
|
||||
# 添加图片数据
|
||||
for img in batch:
|
||||
img_data = self._image_to_base64(img)
|
||||
parts.append({
|
||||
"inline_data": {
|
||||
"mime_type": "image/jpeg",
|
||||
"data": img_data
|
||||
}
|
||||
})
|
||||
|
||||
payload = {
|
||||
"systemInstruction": {
|
||||
"parts": [{"text": "你是一位专业的视觉内容分析师,请仔细分析图片内容并提供详细描述。"}]
|
||||
},
|
||||
"contents": [{"parts": parts}],
|
||||
"generationConfig": {
|
||||
"temperature": 1.0,
|
||||
"topK": 40,
|
||||
"topP": 0.95,
|
||||
"maxOutputTokens": 4000,
|
||||
"candidateCount": 1
|
||||
},
|
||||
"safetySettings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# 发送API请求
|
||||
response_data = await self._make_api_call(payload)
|
||||
|
||||
# 解析响应
|
||||
return self._parse_vision_response(response_data)
|
||||
|
||||
def _image_to_base64(self, img: PIL.Image.Image) -> str:
|
||||
"""将PIL图片转换为base64编码"""
|
||||
img_buffer = io.BytesIO()
|
||||
img.save(img_buffer, format='JPEG', quality=85)
|
||||
img_bytes = img_buffer.getvalue()
|
||||
return base64.b64encode(img_bytes).decode('utf-8')
|
||||
|
||||
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行原生Gemini API调用"""
|
||||
url = f"{self.base_url}/models/{self.model_name}:generateContent?key={self.api_key}"
|
||||
|
||||
response = await asyncio.to_thread(
|
||||
requests.post,
|
||||
url,
|
||||
json=payload,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "NarratoAI/1.0"
|
||||
},
|
||||
timeout=120
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error = self._handle_api_error(response.status_code, response.text)
|
||||
raise error
|
||||
|
||||
return response.json()
|
||||
|
||||
def _parse_vision_response(self, response_data: Dict[str, Any]) -> str:
|
||||
"""解析视觉分析响应"""
|
||||
if "candidates" not in response_data or not response_data["candidates"]:
|
||||
raise APICallError("原生Gemini API返回无效响应")
|
||||
|
||||
candidate = response_data["candidates"][0]
|
||||
|
||||
# 检查是否被安全过滤阻止
|
||||
if "finishReason" in candidate and candidate["finishReason"] == "SAFETY":
|
||||
raise ContentFilterError("内容被Gemini安全过滤器阻止")
|
||||
|
||||
if "content" not in candidate or "parts" not in candidate["content"]:
|
||||
raise APICallError("原生Gemini API返回内容格式错误")
|
||||
|
||||
# 提取文本内容
|
||||
result = ""
|
||||
for part in candidate["content"]["parts"]:
|
||||
if "text" in part:
|
||||
result += part["text"]
|
||||
|
||||
if not result.strip():
|
||||
raise APICallError("原生Gemini API返回空内容")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class GeminiTextProvider(TextModelProvider):
|
||||
"""原生Gemini文本生成提供商"""
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "gemini"
|
||||
|
||||
@property
|
||||
def supported_models(self) -> List[str]:
|
||||
return [
|
||||
"gemini-2.0-flash-lite",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-1.5-pro",
|
||||
"gemini-1.5-flash"
|
||||
]
|
||||
|
||||
def _initialize(self):
|
||||
"""初始化Gemini特定设置"""
|
||||
if not self.base_url:
|
||||
self.base_url = "https://generativelanguage.googleapis.com/v1beta"
|
||||
|
||||
async def generate_text(self,
|
||||
prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
response_format: Optional[str] = None,
|
||||
**kwargs) -> str:
|
||||
"""
|
||||
使用原生Gemini API生成文本
|
||||
|
||||
Args:
|
||||
prompt: 用户提示词
|
||||
system_prompt: 系统提示词
|
||||
temperature: 生成温度
|
||||
max_tokens: 最大token数
|
||||
response_format: 响应格式
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
生成的文本内容
|
||||
"""
|
||||
# 构建请求数据
|
||||
payload = {
|
||||
"contents": [{"parts": [{"text": prompt}]}],
|
||||
"generationConfig": {
|
||||
"temperature": temperature,
|
||||
"topK": 40,
|
||||
"topP": 0.95,
|
||||
"maxOutputTokens": max_tokens or 4000,
|
||||
"candidateCount": 1
|
||||
},
|
||||
"safetySettings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# 添加系统提示词
|
||||
if system_prompt:
|
||||
payload["systemInstruction"] = {
|
||||
"parts": [{"text": system_prompt}]
|
||||
}
|
||||
|
||||
# 如果需要JSON格式,调整提示词和配置
|
||||
if response_format == "json":
|
||||
# 使用更温和的JSON格式约束
|
||||
enhanced_prompt = f"{prompt}\n\n请以JSON格式输出结果。"
|
||||
payload["contents"][0]["parts"][0]["text"] = enhanced_prompt
|
||||
# 移除可能导致问题的stopSequences
|
||||
# payload["generationConfig"]["stopSequences"] = ["```", "注意", "说明"]
|
||||
|
||||
# 记录请求信息
|
||||
logger.debug(f"Gemini文本生成请求: {payload}")
|
||||
|
||||
# 发送API请求
|
||||
response_data = await self._make_api_call(payload)
|
||||
|
||||
# 解析响应
|
||||
return self._parse_text_response(response_data)
|
||||
|
||||
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行原生Gemini API调用"""
|
||||
url = f"{self.base_url}/models/{self.model_name}:generateContent?key={self.api_key}"
|
||||
|
||||
response = await asyncio.to_thread(
|
||||
requests.post,
|
||||
url,
|
||||
json=payload,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "NarratoAI/1.0"
|
||||
},
|
||||
timeout=120
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error = self._handle_api_error(response.status_code, response.text)
|
||||
raise error
|
||||
|
||||
return response.json()
|
||||
|
||||
def _parse_text_response(self, response_data: Dict[str, Any]) -> str:
|
||||
"""解析文本生成响应"""
|
||||
logger.debug(f"Gemini API响应数据: {response_data}")
|
||||
|
||||
if "candidates" not in response_data or not response_data["candidates"]:
|
||||
logger.error(f"Gemini API返回无效响应结构: {response_data}")
|
||||
raise APICallError("原生Gemini API返回无效响应")
|
||||
|
||||
candidate = response_data["candidates"][0]
|
||||
logger.debug(f"Gemini候选响应: {candidate}")
|
||||
|
||||
# 检查完成原因
|
||||
finish_reason = candidate.get("finishReason", "UNKNOWN")
|
||||
logger.debug(f"Gemini完成原因: {finish_reason}")
|
||||
|
||||
# 检查是否被安全过滤阻止
|
||||
if finish_reason == "SAFETY":
|
||||
safety_ratings = candidate.get("safetyRatings", [])
|
||||
logger.warning(f"内容被Gemini安全过滤器阻止,安全评级: {safety_ratings}")
|
||||
raise ContentFilterError("内容被Gemini安全过滤器阻止")
|
||||
|
||||
# 检查是否因为其他原因停止
|
||||
if finish_reason in ["RECITATION", "OTHER"]:
|
||||
logger.warning(f"Gemini因为{finish_reason}原因停止生成")
|
||||
raise APICallError(f"Gemini因为{finish_reason}原因停止生成")
|
||||
|
||||
if "content" not in candidate:
|
||||
logger.error(f"Gemini候选响应中缺少content字段: {candidate}")
|
||||
raise APICallError("原生Gemini API返回内容格式错误")
|
||||
|
||||
if "parts" not in candidate["content"]:
|
||||
logger.error(f"Gemini内容中缺少parts字段: {candidate['content']}")
|
||||
raise APICallError("原生Gemini API返回内容格式错误")
|
||||
|
||||
# 提取文本内容
|
||||
result = ""
|
||||
for part in candidate["content"]["parts"]:
|
||||
if "text" in part:
|
||||
result += part["text"]
|
||||
|
||||
if not result.strip():
|
||||
logger.error(f"Gemini API返回空文本内容,完整响应: {response_data}")
|
||||
raise APICallError("原生Gemini API返回空内容")
|
||||
|
||||
logger.debug(f"Gemini成功生成内容,长度: {len(result)}")
|
||||
return result
|
||||
168
app/services/llm/providers/openai_provider.py
Normal file
168
app/services/llm/providers/openai_provider.py
Normal file
@ -0,0 +1,168 @@
|
||||
"""
|
||||
OpenAI API提供商实现
|
||||
|
||||
使用OpenAI API进行文本生成,也支持OpenAI兼容的其他服务
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional
|
||||
from openai import OpenAI, BadRequestError
|
||||
from loguru import logger
|
||||
|
||||
from ..base import TextModelProvider
|
||||
from ..exceptions import APICallError, RateLimitError, AuthenticationError
|
||||
|
||||
|
||||
class OpenAITextProvider(TextModelProvider):
|
||||
"""OpenAI文本生成提供商"""
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "openai"
|
||||
|
||||
@property
|
||||
def supported_models(self) -> List[str]:
|
||||
return [
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
# 支持其他OpenAI兼容模型
|
||||
"deepseek-chat",
|
||||
"deepseek-reasoner",
|
||||
"qwen-plus",
|
||||
"qwen-turbo",
|
||||
"moonshot-v1-8k",
|
||||
"moonshot-v1-32k",
|
||||
"moonshot-v1-128k"
|
||||
]
|
||||
|
||||
def _initialize(self):
|
||||
"""初始化OpenAI客户端"""
|
||||
if not self.base_url:
|
||||
self.base_url = "https://api.openai.com/v1"
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
|
||||
async def generate_text(self,
|
||||
prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
response_format: Optional[str] = None,
|
||||
**kwargs) -> str:
|
||||
"""
|
||||
使用OpenAI API生成文本
|
||||
|
||||
Args:
|
||||
prompt: 用户提示词
|
||||
system_prompt: 系统提示词
|
||||
temperature: 生成温度
|
||||
max_tokens: 最大token数
|
||||
response_format: 响应格式 ('json' 或 None)
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
生成的文本内容
|
||||
"""
|
||||
# 构建消息列表
|
||||
messages = self._build_messages(prompt, system_prompt)
|
||||
|
||||
# 构建请求参数
|
||||
request_params = {
|
||||
"model": self.model_name,
|
||||
"messages": messages,
|
||||
"temperature": temperature
|
||||
}
|
||||
|
||||
if max_tokens:
|
||||
request_params["max_tokens"] = max_tokens
|
||||
|
||||
# 处理JSON格式输出
|
||||
if response_format == "json":
|
||||
# 检查模型是否支持response_format
|
||||
if self._supports_response_format():
|
||||
request_params["response_format"] = {"type": "json_object"}
|
||||
else:
|
||||
# 对于不支持response_format的模型,在提示词中添加约束
|
||||
messages[-1]["content"] += "\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。"
|
||||
|
||||
try:
|
||||
# 发送API请求
|
||||
response = await asyncio.to_thread(
|
||||
self.client.chat.completions.create,
|
||||
**request_params
|
||||
)
|
||||
|
||||
# 提取生成的内容
|
||||
if response.choices and len(response.choices) > 0:
|
||||
content = response.choices[0].message.content
|
||||
|
||||
# 对于不支持response_format的模型,清理输出
|
||||
if response_format == "json" and not self._supports_response_format():
|
||||
content = self._clean_json_output(content)
|
||||
|
||||
logger.debug(f"OpenAI API调用成功,消耗tokens: {response.usage.total_tokens if response.usage else 'N/A'}")
|
||||
return content
|
||||
else:
|
||||
raise APICallError("OpenAI API返回空响应")
|
||||
|
||||
except BadRequestError as e:
|
||||
# 处理不支持response_format的情况
|
||||
if "response_format" in str(e) and response_format == "json":
|
||||
logger.warning(f"模型 {self.model_name} 不支持response_format,重试不带格式约束的请求")
|
||||
request_params.pop("response_format", None)
|
||||
messages[-1]["content"] += "\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。"
|
||||
|
||||
response = await asyncio.to_thread(
|
||||
self.client.chat.completions.create,
|
||||
**request_params
|
||||
)
|
||||
|
||||
if response.choices and len(response.choices) > 0:
|
||||
content = response.choices[0].message.content
|
||||
content = self._clean_json_output(content)
|
||||
return content
|
||||
else:
|
||||
raise APICallError("OpenAI API返回空响应")
|
||||
else:
|
||||
raise APICallError(f"OpenAI API请求失败: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API调用失败: {str(e)}")
|
||||
raise APICallError(f"OpenAI API调用失败: {str(e)}")
|
||||
|
||||
def _supports_response_format(self) -> bool:
|
||||
"""检查模型是否支持response_format参数"""
|
||||
# 已知不支持response_format的模型
|
||||
unsupported_models = [
|
||||
"deepseek-reasoner",
|
||||
"deepseek-r1"
|
||||
]
|
||||
|
||||
return not any(unsupported in self.model_name.lower() for unsupported in unsupported_models)
|
||||
|
||||
def _clean_json_output(self, output: str) -> str:
|
||||
"""清理JSON输出,移除markdown标记等"""
|
||||
import re
|
||||
|
||||
# 移除可能的markdown代码块标记
|
||||
output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE)
|
||||
output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE)
|
||||
output = re.sub(r'^```.*$', '', output, flags=re.MULTILINE)
|
||||
|
||||
# 移除前后空白字符
|
||||
output = output.strip()
|
||||
|
||||
return output
|
||||
|
||||
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行API调用 - 由于使用OpenAI SDK,这个方法主要用于兼容基类"""
|
||||
# 这个方法在OpenAI提供商中不直接使用,因为我们使用OpenAI SDK
|
||||
# 但为了兼容基类接口,保留此方法
|
||||
pass
|
||||
247
app/services/llm/providers/qwen_provider.py
Normal file
247
app/services/llm/providers/qwen_provider.py
Normal file
@ -0,0 +1,247 @@
|
||||
"""
|
||||
通义千问API提供商实现
|
||||
|
||||
支持通义千问的视觉模型和文本生成模型
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from pathlib import Path
|
||||
import PIL.Image
|
||||
from openai import OpenAI
|
||||
from loguru import logger
|
||||
|
||||
from ..base import VisionModelProvider, TextModelProvider
|
||||
from ..exceptions import APICallError
|
||||
|
||||
|
||||
class QwenVisionProvider(VisionModelProvider):
|
||||
"""通义千问视觉模型提供商"""
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "qwenvl"
|
||||
|
||||
@property
|
||||
def supported_models(self) -> List[str]:
|
||||
return [
|
||||
"qwen2.5-vl-32b-instruct",
|
||||
"qwen2-vl-72b-instruct",
|
||||
"qwen-vl-max",
|
||||
"qwen-vl-plus"
|
||||
]
|
||||
|
||||
def _initialize(self):
|
||||
"""初始化通义千问客户端"""
|
||||
if not self.base_url:
|
||||
self.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
|
||||
async def analyze_images(self,
|
||||
images: List[Union[str, Path, PIL.Image.Image]],
|
||||
prompt: str,
|
||||
batch_size: int = 10,
|
||||
**kwargs) -> List[str]:
|
||||
"""
|
||||
使用通义千问VL分析图片
|
||||
|
||||
Args:
|
||||
images: 图片列表
|
||||
prompt: 分析提示词
|
||||
batch_size: 批处理大小
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
分析结果列表
|
||||
"""
|
||||
logger.info(f"开始分析 {len(images)} 张图片,使用通义千问VL")
|
||||
|
||||
# 预处理图片
|
||||
processed_images = self._prepare_images(images)
|
||||
|
||||
# 分批处理
|
||||
results = []
|
||||
for i in range(0, len(processed_images), batch_size):
|
||||
batch = processed_images[i:i + batch_size]
|
||||
logger.info(f"处理第 {i//batch_size + 1} 批,共 {len(batch)} 张图片")
|
||||
|
||||
try:
|
||||
result = await self._analyze_batch(batch, prompt)
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"批次 {i//batch_size + 1} 处理失败: {str(e)}")
|
||||
results.append(f"批次处理失败: {str(e)}")
|
||||
|
||||
return results
|
||||
|
||||
async def _analyze_batch(self, batch: List[PIL.Image.Image], prompt: str) -> str:
|
||||
"""分析一批图片"""
|
||||
# 构建消息内容
|
||||
content = []
|
||||
|
||||
# 添加图片
|
||||
for img in batch:
|
||||
base64_image = self._image_to_base64(img)
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}"
|
||||
}
|
||||
})
|
||||
|
||||
# 添加文本提示,使用占位符来引用图片数量
|
||||
content.append({
|
||||
"type": "text",
|
||||
"text": prompt % (len(batch), len(batch), len(batch))
|
||||
})
|
||||
|
||||
# 构建消息
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": content
|
||||
}]
|
||||
|
||||
# 调用API
|
||||
response = await asyncio.to_thread(
|
||||
self.client.chat.completions.create,
|
||||
model=self.model_name,
|
||||
messages=messages
|
||||
)
|
||||
|
||||
if response.choices and len(response.choices) > 0:
|
||||
return response.choices[0].message.content
|
||||
else:
|
||||
raise APICallError("通义千问VL API返回空响应")
|
||||
|
||||
def _image_to_base64(self, img: PIL.Image.Image) -> str:
|
||||
"""将PIL图片转换为base64编码"""
|
||||
img_buffer = io.BytesIO()
|
||||
img.save(img_buffer, format='JPEG', quality=85)
|
||||
img_bytes = img_buffer.getvalue()
|
||||
return base64.b64encode(img_bytes).decode('utf-8')
|
||||
|
||||
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行API调用 - 由于使用OpenAI SDK,这个方法主要用于兼容基类"""
|
||||
pass
|
||||
|
||||
|
||||
class QwenTextProvider(TextModelProvider):
|
||||
"""通义千问文本生成提供商"""
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "qwen"
|
||||
|
||||
@property
|
||||
def supported_models(self) -> List[str]:
|
||||
return [
|
||||
"qwen-plus-1127",
|
||||
"qwen-plus",
|
||||
"qwen-turbo",
|
||||
"qwen-max",
|
||||
"qwen2.5-72b-instruct",
|
||||
"qwen2.5-32b-instruct",
|
||||
"qwen2.5-14b-instruct",
|
||||
"qwen2.5-7b-instruct"
|
||||
]
|
||||
|
||||
def _initialize(self):
|
||||
"""初始化通义千问客户端"""
|
||||
if not self.base_url:
|
||||
self.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
|
||||
async def generate_text(self,
|
||||
prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
response_format: Optional[str] = None,
|
||||
**kwargs) -> str:
|
||||
"""
|
||||
使用通义千问API生成文本
|
||||
|
||||
Args:
|
||||
prompt: 用户提示词
|
||||
system_prompt: 系统提示词
|
||||
temperature: 生成温度
|
||||
max_tokens: 最大token数
|
||||
response_format: 响应格式 ('json' 或 None)
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
生成的文本内容
|
||||
"""
|
||||
# 构建消息列表
|
||||
messages = self._build_messages(prompt, system_prompt)
|
||||
|
||||
# 构建请求参数
|
||||
request_params = {
|
||||
"model": self.model_name,
|
||||
"messages": messages,
|
||||
"temperature": temperature
|
||||
}
|
||||
|
||||
if max_tokens:
|
||||
request_params["max_tokens"] = max_tokens
|
||||
|
||||
# 处理JSON格式输出
|
||||
if response_format == "json":
|
||||
# 通义千问支持response_format
|
||||
try:
|
||||
request_params["response_format"] = {"type": "json_object"}
|
||||
except:
|
||||
# 如果不支持,在提示词中添加约束
|
||||
messages[-1]["content"] += "\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。"
|
||||
|
||||
try:
|
||||
# 发送API请求
|
||||
response = await asyncio.to_thread(
|
||||
self.client.chat.completions.create,
|
||||
**request_params
|
||||
)
|
||||
|
||||
# 提取生成的内容
|
||||
if response.choices and len(response.choices) > 0:
|
||||
content = response.choices[0].message.content
|
||||
|
||||
# 对于JSON格式,清理输出
|
||||
if response_format == "json" and "response_format" not in request_params:
|
||||
content = self._clean_json_output(content)
|
||||
|
||||
logger.debug(f"通义千问API调用成功,消耗tokens: {response.usage.total_tokens if response.usage else 'N/A'}")
|
||||
return content
|
||||
else:
|
||||
raise APICallError("通义千问API返回空响应")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"通义千问API调用失败: {str(e)}")
|
||||
raise APICallError(f"通义千问API调用失败: {str(e)}")
|
||||
|
||||
def _clean_json_output(self, output: str) -> str:
|
||||
"""清理JSON输出,移除markdown标记等"""
|
||||
import re
|
||||
|
||||
# 移除可能的markdown代码块标记
|
||||
output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE)
|
||||
output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE)
|
||||
output = re.sub(r'^```.*$', '', output, flags=re.MULTILINE)
|
||||
|
||||
# 移除前后空白字符
|
||||
output = output.strip()
|
||||
|
||||
return output
|
||||
|
||||
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行API调用 - 由于使用OpenAI SDK,这个方法主要用于兼容基类"""
|
||||
pass
|
||||
251
app/services/llm/providers/siliconflow_provider.py
Normal file
251
app/services/llm/providers/siliconflow_provider.py
Normal file
@ -0,0 +1,251 @@
|
||||
"""
|
||||
硅基流动API提供商实现
|
||||
|
||||
支持硅基流动的视觉模型和文本生成模型
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from pathlib import Path
|
||||
import PIL.Image
|
||||
from openai import OpenAI
|
||||
from loguru import logger
|
||||
|
||||
from ..base import VisionModelProvider, TextModelProvider
|
||||
from ..exceptions import APICallError
|
||||
|
||||
|
||||
class SiliconflowVisionProvider(VisionModelProvider):
|
||||
"""硅基流动视觉模型提供商"""
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "siliconflow"
|
||||
|
||||
@property
|
||||
def supported_models(self) -> List[str]:
|
||||
return [
|
||||
"Qwen/Qwen2.5-VL-32B-Instruct",
|
||||
"Qwen/Qwen2-VL-72B-Instruct",
|
||||
"deepseek-ai/deepseek-vl2",
|
||||
"OpenGVLab/InternVL2-26B"
|
||||
]
|
||||
|
||||
def _initialize(self):
|
||||
"""初始化硅基流动客户端"""
|
||||
if not self.base_url:
|
||||
self.base_url = "https://api.siliconflow.cn/v1"
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
|
||||
async def analyze_images(self,
|
||||
images: List[Union[str, Path, PIL.Image.Image]],
|
||||
prompt: str,
|
||||
batch_size: int = 10,
|
||||
**kwargs) -> List[str]:
|
||||
"""
|
||||
使用硅基流动API分析图片
|
||||
|
||||
Args:
|
||||
images: 图片列表
|
||||
prompt: 分析提示词
|
||||
batch_size: 批处理大小
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
分析结果列表
|
||||
"""
|
||||
logger.info(f"开始分析 {len(images)} 张图片,使用硅基流动")
|
||||
|
||||
# 预处理图片
|
||||
processed_images = self._prepare_images(images)
|
||||
|
||||
# 分批处理
|
||||
results = []
|
||||
for i in range(0, len(processed_images), batch_size):
|
||||
batch = processed_images[i:i + batch_size]
|
||||
logger.info(f"处理第 {i//batch_size + 1} 批,共 {len(batch)} 张图片")
|
||||
|
||||
try:
|
||||
result = await self._analyze_batch(batch, prompt)
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"批次 {i//batch_size + 1} 处理失败: {str(e)}")
|
||||
results.append(f"批次处理失败: {str(e)}")
|
||||
|
||||
return results
|
||||
|
||||
async def _analyze_batch(self, batch: List[PIL.Image.Image], prompt: str) -> str:
|
||||
"""分析一批图片"""
|
||||
# 构建消息内容
|
||||
content = [{"type": "text", "text": prompt}]
|
||||
|
||||
# 添加图片
|
||||
for img in batch:
|
||||
base64_image = self._image_to_base64(img)
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}"
|
||||
}
|
||||
})
|
||||
|
||||
# 构建消息
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": content
|
||||
}]
|
||||
|
||||
# 调用API
|
||||
response = await asyncio.to_thread(
|
||||
self.client.chat.completions.create,
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
max_tokens=4000,
|
||||
temperature=1.0
|
||||
)
|
||||
|
||||
if response.choices and len(response.choices) > 0:
|
||||
return response.choices[0].message.content
|
||||
else:
|
||||
raise APICallError("硅基流动API返回空响应")
|
||||
|
||||
def _image_to_base64(self, img: PIL.Image.Image) -> str:
|
||||
"""将PIL图片转换为base64编码"""
|
||||
img_buffer = io.BytesIO()
|
||||
img.save(img_buffer, format='JPEG', quality=85)
|
||||
img_bytes = img_buffer.getvalue()
|
||||
return base64.b64encode(img_bytes).decode('utf-8')
|
||||
|
||||
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行API调用 - 由于使用OpenAI SDK,这个方法主要用于兼容基类"""
|
||||
pass
|
||||
|
||||
|
||||
class SiliconflowTextProvider(TextModelProvider):
|
||||
"""硅基流动文本生成提供商"""
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "siliconflow"
|
||||
|
||||
@property
|
||||
def supported_models(self) -> List[str]:
|
||||
return [
|
||||
"deepseek-ai/DeepSeek-R1",
|
||||
"deepseek-ai/DeepSeek-V3",
|
||||
"Qwen/Qwen2.5-72B-Instruct",
|
||||
"Qwen/Qwen2.5-32B-Instruct",
|
||||
"meta-llama/Llama-3.1-70B-Instruct",
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
"01-ai/Yi-1.5-34B-Chat"
|
||||
]
|
||||
|
||||
def _initialize(self):
|
||||
"""初始化硅基流动客户端"""
|
||||
if not self.base_url:
|
||||
self.base_url = "https://api.siliconflow.cn/v1"
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
|
||||
async def generate_text(self,
|
||||
prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
response_format: Optional[str] = None,
|
||||
**kwargs) -> str:
|
||||
"""
|
||||
使用硅基流动API生成文本
|
||||
|
||||
Args:
|
||||
prompt: 用户提示词
|
||||
system_prompt: 系统提示词
|
||||
temperature: 生成温度
|
||||
max_tokens: 最大token数
|
||||
response_format: 响应格式 ('json' 或 None)
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
生成的文本内容
|
||||
"""
|
||||
# 构建消息列表
|
||||
messages = self._build_messages(prompt, system_prompt)
|
||||
|
||||
# 构建请求参数
|
||||
request_params = {
|
||||
"model": self.model_name,
|
||||
"messages": messages,
|
||||
"temperature": temperature
|
||||
}
|
||||
|
||||
if max_tokens:
|
||||
request_params["max_tokens"] = max_tokens
|
||||
|
||||
# 处理JSON格式输出
|
||||
if response_format == "json":
|
||||
if self._supports_response_format():
|
||||
request_params["response_format"] = {"type": "json_object"}
|
||||
else:
|
||||
# 对于不支持response_format的模型,在提示词中添加约束
|
||||
messages[-1]["content"] += "\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。"
|
||||
|
||||
try:
|
||||
# 发送API请求
|
||||
response = await asyncio.to_thread(
|
||||
self.client.chat.completions.create,
|
||||
**request_params
|
||||
)
|
||||
|
||||
# 提取生成的内容
|
||||
if response.choices and len(response.choices) > 0:
|
||||
content = response.choices[0].message.content
|
||||
|
||||
# 对于不支持response_format的模型,清理输出
|
||||
if response_format == "json" and not self._supports_response_format():
|
||||
content = self._clean_json_output(content)
|
||||
|
||||
logger.debug(f"硅基流动API调用成功,消耗tokens: {response.usage.total_tokens if response.usage else 'N/A'}")
|
||||
return content
|
||||
else:
|
||||
raise APICallError("硅基流动API返回空响应")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"硅基流动API调用失败: {str(e)}")
|
||||
raise APICallError(f"硅基流动API调用失败: {str(e)}")
|
||||
|
||||
def _supports_response_format(self) -> bool:
|
||||
"""检查模型是否支持response_format参数"""
|
||||
# DeepSeek R1 和 V3 不支持 response_format=json_object
|
||||
unsupported_models = [
|
||||
"deepseek-ai/deepseek-r1",
|
||||
"deepseek-ai/deepseek-v3"
|
||||
]
|
||||
|
||||
return not any(unsupported in self.model_name.lower() for unsupported in unsupported_models)
|
||||
|
||||
def _clean_json_output(self, output: str) -> str:
|
||||
"""清理JSON输出,移除markdown标记等"""
|
||||
import re
|
||||
|
||||
# 移除可能的markdown代码块标记
|
||||
output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE)
|
||||
output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE)
|
||||
output = re.sub(r'^```.*$', '', output, flags=re.MULTILINE)
|
||||
|
||||
# 移除前后空白字符
|
||||
output = output.strip()
|
||||
|
||||
return output
|
||||
|
||||
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行API调用 - 由于使用OpenAI SDK,这个方法主要用于兼容基类"""
|
||||
pass
|
||||
263
app/services/llm/test_llm_service.py
Normal file
263
app/services/llm/test_llm_service.py
Normal file
@ -0,0 +1,263 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
|
||||
"""
|
||||
LLM服务测试脚本
|
||||
|
||||
测试新的LLM服务架构是否正常工作
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
project_root = Path(__file__).parent.parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from app.services.llm.config_validator import LLMConfigValidator
|
||||
from app.services.llm.unified_service import UnifiedLLMService
|
||||
from app.services.llm.exceptions import LLMServiceError
|
||||
|
||||
|
||||
async def test_text_generation():
|
||||
"""测试文本生成功能"""
|
||||
print("\n🔤 测试文本生成功能...")
|
||||
|
||||
try:
|
||||
# 简单的文本生成测试
|
||||
prompt = "请用一句话介绍人工智能。"
|
||||
|
||||
result = await UnifiedLLMService.generate_text(
|
||||
prompt=prompt,
|
||||
system_prompt="你是一个专业的AI助手。",
|
||||
temperature=0.7
|
||||
)
|
||||
|
||||
print(f"✅ 文本生成成功:")
|
||||
print(f" 提示词: {prompt}")
|
||||
print(f" 生成结果: {result[:100]}...")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 文本生成失败: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def test_json_generation():
|
||||
"""测试JSON格式生成功能"""
|
||||
print("\n📄 测试JSON格式生成功能...")
|
||||
|
||||
try:
|
||||
prompt = """
|
||||
请生成一个简单的解说文案示例,包含以下字段:
|
||||
- title: 标题
|
||||
- content: 内容
|
||||
- duration: 时长(秒)
|
||||
|
||||
输出JSON格式。
|
||||
"""
|
||||
|
||||
result = await UnifiedLLMService.generate_text(
|
||||
prompt=prompt,
|
||||
system_prompt="你是一个专业的文案撰写专家。",
|
||||
temperature=0.7,
|
||||
response_format="json"
|
||||
)
|
||||
|
||||
# 尝试解析JSON
|
||||
import json
|
||||
parsed_result = json.loads(result)
|
||||
|
||||
print(f"✅ JSON生成成功:")
|
||||
print(f" 生成结果: {json.dumps(parsed_result, ensure_ascii=False, indent=2)}")
|
||||
|
||||
return True
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"❌ JSON解析失败: {str(e)}")
|
||||
print(f" 原始结果: {result}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ JSON生成失败: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def test_narration_script_generation():
|
||||
"""测试解说文案生成功能"""
|
||||
print("\n🎬 测试解说文案生成功能...")
|
||||
|
||||
try:
|
||||
prompt = """
|
||||
根据以下视频描述生成解说文案:
|
||||
|
||||
视频内容:一个人在森林中建造木屋,首先挖掘地基,然后搭建墙壁,最后安装屋顶。
|
||||
|
||||
请生成JSON格式的解说文案,包含items数组,每个item包含:
|
||||
- _id: 序号
|
||||
- timestamp: 时间戳(格式:HH:MM:SS,mmm-HH:MM:SS,mmm)
|
||||
- picture: 画面描述
|
||||
- narration: 解说文案
|
||||
"""
|
||||
|
||||
result = await UnifiedLLMService.generate_narration_script(
|
||||
prompt=prompt,
|
||||
temperature=0.8,
|
||||
validate_output=True
|
||||
)
|
||||
|
||||
print(f"✅ 解说文案生成成功:")
|
||||
print(f" 生成了 {len(result)} 个片段")
|
||||
for item in result[:2]: # 只显示前2个
|
||||
print(f" - {item.get('timestamp', 'N/A')}: {item.get('narration', 'N/A')[:50]}...")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 解说文案生成失败: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def test_subtitle_analysis():
|
||||
"""测试字幕分析功能"""
|
||||
print("\n📝 测试字幕分析功能...")
|
||||
|
||||
try:
|
||||
subtitle_content = """
|
||||
1
|
||||
00:00:01,000 --> 00:00:05,000
|
||||
大家好,欢迎来到我的频道。
|
||||
|
||||
2
|
||||
00:00:05,000 --> 00:00:10,000
|
||||
今天我们要学习如何使用人工智能。
|
||||
|
||||
3
|
||||
00:00:10,000 --> 00:00:15,000
|
||||
人工智能是一项非常有趣的技术。
|
||||
"""
|
||||
|
||||
result = await UnifiedLLMService.analyze_subtitle(
|
||||
subtitle_content=subtitle_content,
|
||||
temperature=0.7,
|
||||
validate_output=True
|
||||
)
|
||||
|
||||
print(f"✅ 字幕分析成功:")
|
||||
print(f" 分析结果: {result[:100]}...")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 字幕分析失败: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def test_config_validation():
|
||||
"""测试配置验证功能"""
|
||||
print("\n⚙️ 测试配置验证功能...")
|
||||
|
||||
try:
|
||||
# 验证所有配置
|
||||
validation_results = LLMConfigValidator.validate_all_configs()
|
||||
|
||||
summary = validation_results["summary"]
|
||||
print(f"✅ 配置验证完成:")
|
||||
print(f" 视觉模型提供商: {summary['valid_vision_providers']}/{summary['total_vision_providers']} 有效")
|
||||
print(f" 文本模型提供商: {summary['valid_text_providers']}/{summary['total_text_providers']} 有效")
|
||||
|
||||
if summary["errors"]:
|
||||
print(f" 发现 {len(summary['errors'])} 个错误")
|
||||
for error in summary["errors"][:3]: # 只显示前3个错误
|
||||
print(f" - {error}")
|
||||
|
||||
return summary['valid_text_providers'] > 0
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 配置验证失败: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def test_provider_info():
|
||||
"""测试提供商信息获取"""
|
||||
print("\n📋 测试提供商信息获取...")
|
||||
|
||||
try:
|
||||
provider_info = UnifiedLLMService.get_provider_info()
|
||||
|
||||
vision_providers = list(provider_info["vision_providers"].keys())
|
||||
text_providers = list(provider_info["text_providers"].keys())
|
||||
|
||||
print(f"✅ 提供商信息获取成功:")
|
||||
print(f" 视觉模型提供商: {', '.join(vision_providers)}")
|
||||
print(f" 文本模型提供商: {', '.join(text_providers)}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 提供商信息获取失败: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def run_all_tests():
|
||||
"""运行所有测试"""
|
||||
print("🚀 开始LLM服务测试...")
|
||||
print("="*60)
|
||||
|
||||
# 测试结果统计
|
||||
test_results = []
|
||||
|
||||
# 1. 测试配置验证
|
||||
test_results.append(("配置验证", test_config_validation()))
|
||||
|
||||
# 2. 测试提供商信息
|
||||
test_results.append(("提供商信息", test_provider_info()))
|
||||
|
||||
# 3. 测试文本生成
|
||||
test_results.append(("文本生成", await test_text_generation()))
|
||||
|
||||
# 4. 测试JSON生成
|
||||
test_results.append(("JSON生成", await test_json_generation()))
|
||||
|
||||
# 5. 测试字幕分析
|
||||
test_results.append(("字幕分析", await test_subtitle_analysis()))
|
||||
|
||||
# 6. 测试解说文案生成
|
||||
test_results.append(("解说文案生成", await test_narration_script_generation()))
|
||||
|
||||
# 输出测试结果
|
||||
print("\n" + "="*60)
|
||||
print("📊 测试结果汇总:")
|
||||
print("="*60)
|
||||
|
||||
passed = 0
|
||||
total = len(test_results)
|
||||
|
||||
for test_name, result in test_results:
|
||||
status = "✅ 通过" if result else "❌ 失败"
|
||||
print(f" {test_name:<15} {status}")
|
||||
if result:
|
||||
passed += 1
|
||||
|
||||
print(f"\n总计: {passed}/{total} 个测试通过")
|
||||
|
||||
if passed == total:
|
||||
print("🎉 所有测试通过!LLM服务工作正常。")
|
||||
elif passed > 0:
|
||||
print("⚠️ 部分测试通过,请检查失败的测试项。")
|
||||
else:
|
||||
print("💥 所有测试失败,请检查配置和网络连接。")
|
||||
|
||||
print("="*60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 设置日志级别
|
||||
logger.remove()
|
||||
logger.add(sys.stderr, level="INFO")
|
||||
|
||||
# 运行测试
|
||||
asyncio.run(run_all_tests())
|
||||
274
app/services/llm/unified_service.py
Normal file
274
app/services/llm/unified_service.py
Normal file
@ -0,0 +1,274 @@
|
||||
"""
|
||||
统一的大模型服务接口
|
||||
|
||||
提供简化的API接口,方便现有代码迁移到新的架构
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from pathlib import Path
|
||||
import PIL.Image
|
||||
from loguru import logger
|
||||
|
||||
from .manager import LLMServiceManager
|
||||
from .validators import OutputValidator
|
||||
from .exceptions import LLMServiceError
|
||||
|
||||
# 确保提供商已注册
|
||||
def _ensure_providers_registered():
|
||||
"""确保所有提供商都已注册"""
|
||||
try:
|
||||
# 检查是否有已注册的提供商
|
||||
if not LLMServiceManager.list_text_providers() or not LLMServiceManager.list_vision_providers():
|
||||
# 如果没有注册的提供商,强制导入providers模块
|
||||
from . import providers
|
||||
logger.debug("强制注册LLM服务提供商")
|
||||
except Exception as e:
|
||||
logger.error(f"确保LLM服务提供商注册时发生错误: {str(e)}")
|
||||
|
||||
# 在模块加载时确保提供商已注册
|
||||
_ensure_providers_registered()
|
||||
|
||||
|
||||
class UnifiedLLMService:
|
||||
"""统一的大模型服务接口"""
|
||||
|
||||
@staticmethod
|
||||
async def analyze_images(images: List[Union[str, Path, PIL.Image.Image]],
|
||||
prompt: str,
|
||||
provider: Optional[str] = None,
|
||||
batch_size: int = 10,
|
||||
**kwargs) -> List[str]:
|
||||
"""
|
||||
分析图片内容
|
||||
|
||||
Args:
|
||||
images: 图片路径列表或PIL图片对象列表
|
||||
prompt: 分析提示词
|
||||
provider: 视觉模型提供商名称,如果不指定则使用配置中的默认值
|
||||
batch_size: 批处理大小
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
分析结果列表
|
||||
|
||||
Raises:
|
||||
LLMServiceError: 服务调用失败时抛出
|
||||
"""
|
||||
try:
|
||||
# 获取视觉模型提供商
|
||||
vision_provider = LLMServiceManager.get_vision_provider(provider)
|
||||
|
||||
# 执行图片分析
|
||||
results = await vision_provider.analyze_images(
|
||||
images=images,
|
||||
prompt=prompt,
|
||||
batch_size=batch_size,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
logger.info(f"图片分析完成,共处理 {len(images)} 张图片,生成 {len(results)} 个结果")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"图片分析失败: {str(e)}")
|
||||
raise LLMServiceError(f"图片分析失败: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
async def generate_text(prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
response_format: Optional[str] = None,
|
||||
**kwargs) -> str:
|
||||
"""
|
||||
生成文本内容
|
||||
|
||||
Args:
|
||||
prompt: 用户提示词
|
||||
system_prompt: 系统提示词
|
||||
provider: 文本模型提供商名称,如果不指定则使用配置中的默认值
|
||||
temperature: 生成温度
|
||||
max_tokens: 最大token数
|
||||
response_format: 响应格式 ('json' 或 None)
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
生成的文本内容
|
||||
|
||||
Raises:
|
||||
LLMServiceError: 服务调用失败时抛出
|
||||
"""
|
||||
try:
|
||||
# 获取文本模型提供商
|
||||
text_provider = LLMServiceManager.get_text_provider(provider)
|
||||
|
||||
# 执行文本生成
|
||||
result = await text_provider.generate_text(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
logger.info(f"文本生成完成,生成内容长度: {len(result)} 字符")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"文本生成失败: {str(e)}")
|
||||
raise LLMServiceError(f"文本生成失败: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
async def generate_narration_script(prompt: str,
|
||||
provider: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
validate_output: bool = True,
|
||||
**kwargs) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
生成解说文案
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
provider: 文本模型提供商名称
|
||||
temperature: 生成温度
|
||||
validate_output: 是否验证输出格式
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
解说文案列表
|
||||
|
||||
Raises:
|
||||
LLMServiceError: 服务调用失败时抛出
|
||||
"""
|
||||
try:
|
||||
# 生成文本
|
||||
result = await UnifiedLLMService.generate_text(
|
||||
prompt=prompt,
|
||||
provider=provider,
|
||||
temperature=temperature,
|
||||
response_format="json",
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# 验证输出格式
|
||||
if validate_output:
|
||||
narration_items = OutputValidator.validate_narration_script(result)
|
||||
logger.info(f"解说文案生成并验证完成,共 {len(narration_items)} 个片段")
|
||||
return narration_items
|
||||
else:
|
||||
# 简单的JSON解析
|
||||
import json
|
||||
parsed_result = json.loads(result)
|
||||
if "items" in parsed_result:
|
||||
return parsed_result["items"]
|
||||
else:
|
||||
return parsed_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解说文案生成失败: {str(e)}")
|
||||
raise LLMServiceError(f"解说文案生成失败: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
async def analyze_subtitle(subtitle_content: str,
|
||||
provider: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
validate_output: bool = True,
|
||||
**kwargs) -> str:
|
||||
"""
|
||||
分析字幕内容
|
||||
|
||||
Args:
|
||||
subtitle_content: 字幕内容
|
||||
provider: 文本模型提供商名称
|
||||
temperature: 生成温度
|
||||
validate_output: 是否验证输出格式
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
分析结果
|
||||
|
||||
Raises:
|
||||
LLMServiceError: 服务调用失败时抛出
|
||||
"""
|
||||
try:
|
||||
# 构建分析提示词
|
||||
system_prompt = "你是一位专业的剧本分析师和剧情概括助手。请仔细分析字幕内容,提取关键剧情信息。"
|
||||
|
||||
# 生成分析结果
|
||||
result = await UnifiedLLMService.generate_text(
|
||||
prompt=subtitle_content,
|
||||
system_prompt=system_prompt,
|
||||
provider=provider,
|
||||
temperature=temperature,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# 验证输出格式
|
||||
if validate_output:
|
||||
validated_result = OutputValidator.validate_subtitle_analysis(result)
|
||||
logger.info("字幕分析完成并验证通过")
|
||||
return validated_result
|
||||
else:
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"字幕分析失败: {str(e)}")
|
||||
raise LLMServiceError(f"字幕分析失败: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def get_provider_info() -> Dict[str, Any]:
|
||||
"""
|
||||
获取所有提供商信息
|
||||
|
||||
Returns:
|
||||
提供商信息字典
|
||||
"""
|
||||
return LLMServiceManager.get_provider_info()
|
||||
|
||||
@staticmethod
|
||||
def list_vision_providers() -> List[str]:
|
||||
"""
|
||||
列出所有视觉模型提供商
|
||||
|
||||
Returns:
|
||||
提供商名称列表
|
||||
"""
|
||||
return LLMServiceManager.list_vision_providers()
|
||||
|
||||
@staticmethod
|
||||
def list_text_providers() -> List[str]:
|
||||
"""
|
||||
列出所有文本模型提供商
|
||||
|
||||
Returns:
|
||||
提供商名称列表
|
||||
"""
|
||||
return LLMServiceManager.list_text_providers()
|
||||
|
||||
@staticmethod
|
||||
def clear_cache():
|
||||
"""清空提供商实例缓存"""
|
||||
LLMServiceManager.clear_cache()
|
||||
logger.info("已清空大模型服务缓存")
|
||||
|
||||
|
||||
# 为了向后兼容,提供一些便捷函数
|
||||
async def analyze_images_unified(images: List[Union[str, Path, PIL.Image.Image]],
|
||||
prompt: str,
|
||||
provider: Optional[str] = None,
|
||||
batch_size: int = 10) -> List[str]:
|
||||
"""便捷的图片分析函数"""
|
||||
return await UnifiedLLMService.analyze_images(images, prompt, provider, batch_size)
|
||||
|
||||
|
||||
async def generate_text_unified(prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
response_format: Optional[str] = None) -> str:
|
||||
"""便捷的文本生成函数"""
|
||||
return await UnifiedLLMService.generate_text(
|
||||
prompt, system_prompt, provider, temperature, response_format=response_format
|
||||
)
|
||||
200
app/services/llm/validators.py
Normal file
200
app/services/llm/validators.py
Normal file
@ -0,0 +1,200 @@
|
||||
"""
|
||||
输出格式验证器
|
||||
|
||||
提供严格的输出格式验证机制,确保大模型输出符合预期格式
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from loguru import logger
|
||||
|
||||
from .exceptions import ValidationError
|
||||
|
||||
|
||||
class OutputValidator:
|
||||
"""输出格式验证器"""
|
||||
|
||||
@staticmethod
|
||||
def validate_json_output(output: str, schema: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
验证JSON输出格式
|
||||
|
||||
Args:
|
||||
output: 待验证的输出字符串
|
||||
schema: JSON Schema (可选)
|
||||
|
||||
Returns:
|
||||
解析后的JSON对象
|
||||
|
||||
Raises:
|
||||
ValidationError: 验证失败时抛出
|
||||
"""
|
||||
try:
|
||||
# 清理输出字符串,移除可能的markdown代码块标记
|
||||
cleaned_output = OutputValidator._clean_json_output(output)
|
||||
|
||||
# 解析JSON
|
||||
parsed_json = json.loads(cleaned_output)
|
||||
|
||||
# 如果提供了schema,进行schema验证
|
||||
if schema:
|
||||
OutputValidator._validate_json_schema(parsed_json, schema)
|
||||
|
||||
return parsed_json
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析失败: {str(e)}")
|
||||
logger.error(f"原始输出: {output}")
|
||||
raise ValidationError(f"JSON格式无效: {str(e)}", "json_parse", output)
|
||||
except Exception as e:
|
||||
logger.error(f"JSON验证失败: {str(e)}")
|
||||
raise ValidationError(f"JSON验证失败: {str(e)}", "json_validation", output)
|
||||
|
||||
@staticmethod
|
||||
def _clean_json_output(output: str) -> str:
|
||||
"""清理JSON输出,移除markdown标记等"""
|
||||
# 移除可能的markdown代码块标记
|
||||
output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE)
|
||||
output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE)
|
||||
output = re.sub(r'^```.*$', '', output, flags=re.MULTILINE)
|
||||
|
||||
# 移除开头和结尾的```标记
|
||||
output = re.sub(r'^```', '', output)
|
||||
output = re.sub(r'```$', '', output)
|
||||
|
||||
# 移除前后空白字符
|
||||
output = output.strip()
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def _validate_json_schema(data: Dict[str, Any], schema: Dict[str, Any]):
|
||||
"""验证JSON Schema (简化版本)"""
|
||||
# 这里可以集成jsonschema库进行更严格的验证
|
||||
# 目前实现基础的类型检查
|
||||
|
||||
if "type" in schema:
|
||||
expected_type = schema["type"]
|
||||
if expected_type == "object" and not isinstance(data, dict):
|
||||
raise ValidationError(f"期望对象类型,实际为 {type(data)}", "schema_type")
|
||||
elif expected_type == "array" and not isinstance(data, list):
|
||||
raise ValidationError(f"期望数组类型,实际为 {type(data)}", "schema_type")
|
||||
|
||||
if "required" in schema and isinstance(data, dict):
|
||||
for required_field in schema["required"]:
|
||||
if required_field not in data:
|
||||
raise ValidationError(f"缺少必需字段: {required_field}", "schema_required")
|
||||
|
||||
@staticmethod
|
||||
def validate_narration_script(output: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
验证解说文案输出格式
|
||||
|
||||
Args:
|
||||
output: 待验证的解说文案输出
|
||||
|
||||
Returns:
|
||||
解析后的解说文案列表
|
||||
|
||||
Raises:
|
||||
ValidationError: 验证失败时抛出
|
||||
"""
|
||||
try:
|
||||
# 定义解说文案的JSON Schema
|
||||
narration_schema = {
|
||||
"type": "object",
|
||||
"required": ["items"],
|
||||
"properties": {
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": ["_id", "timestamp", "picture", "narration"],
|
||||
"properties": {
|
||||
"_id": {"type": "number"},
|
||||
"timestamp": {"type": "string"},
|
||||
"picture": {"type": "string"},
|
||||
"narration": {"type": "string"},
|
||||
"OST": {"type": "number"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# 验证JSON格式
|
||||
parsed_data = OutputValidator.validate_json_output(output, narration_schema)
|
||||
|
||||
# 提取items数组
|
||||
items = parsed_data.get("items", [])
|
||||
|
||||
# 验证每个item的具体内容
|
||||
for i, item in enumerate(items):
|
||||
OutputValidator._validate_narration_item(item, i)
|
||||
|
||||
logger.info(f"解说文案验证成功,共 {len(items)} 个片段")
|
||||
return items
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"解说文案验证失败: {str(e)}")
|
||||
raise ValidationError(f"解说文案验证失败: {str(e)}", "narration_validation", output)
|
||||
|
||||
@staticmethod
|
||||
def _validate_narration_item(item: Dict[str, Any], index: int):
|
||||
"""验证单个解说文案项目"""
|
||||
# 验证时间戳格式
|
||||
timestamp = item.get("timestamp", "")
|
||||
if not re.match(r'\d{2}:\d{2}:\d{2},\d{3}-\d{2}:\d{2}:\d{2},\d{3}', timestamp):
|
||||
raise ValidationError(f"第{index+1}项时间戳格式无效: {timestamp}", "timestamp_format")
|
||||
|
||||
# 验证内容不为空
|
||||
if not item.get("picture", "").strip():
|
||||
raise ValidationError(f"第{index+1}项画面描述不能为空", "empty_picture")
|
||||
|
||||
if not item.get("narration", "").strip():
|
||||
raise ValidationError(f"第{index+1}项解说文案不能为空", "empty_narration")
|
||||
|
||||
# 验证ID为正整数
|
||||
item_id = item.get("_id")
|
||||
if not isinstance(item_id, (int, float)) or item_id <= 0:
|
||||
raise ValidationError(f"第{index+1}项ID必须为正整数: {item_id}", "invalid_id")
|
||||
|
||||
@staticmethod
|
||||
def validate_subtitle_analysis(output: str) -> str:
|
||||
"""
|
||||
验证字幕分析输出格式
|
||||
|
||||
Args:
|
||||
output: 待验证的字幕分析输出
|
||||
|
||||
Returns:
|
||||
验证后的分析内容
|
||||
|
||||
Raises:
|
||||
ValidationError: 验证失败时抛出
|
||||
"""
|
||||
try:
|
||||
# 基础验证:内容不能为空
|
||||
if not output or not output.strip():
|
||||
raise ValidationError("字幕分析结果不能为空", "empty_analysis")
|
||||
|
||||
# 验证内容长度合理
|
||||
if len(output.strip()) < 50:
|
||||
raise ValidationError("字幕分析结果过短,可能不完整", "analysis_too_short")
|
||||
|
||||
# 验证是否包含基本的分析要素(可根据需要调整)
|
||||
analysis_keywords = ["剧情", "情节", "角色", "故事", "内容"]
|
||||
if not any(keyword in output for keyword in analysis_keywords):
|
||||
logger.warning("字幕分析结果可能缺少关键分析要素")
|
||||
|
||||
logger.info("字幕分析验证成功")
|
||||
return output.strip()
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"字幕分析验证失败: {str(e)}")
|
||||
raise ValidationError(f"字幕分析验证失败: {str(e)}", "analysis_validation", output)
|
||||
367
docs/LLM_MIGRATION_GUIDE.md
Normal file
367
docs/LLM_MIGRATION_GUIDE.md
Normal file
@ -0,0 +1,367 @@
|
||||
# NarratoAI 大模型服务迁移指南
|
||||
|
||||
## 📋 概述
|
||||
|
||||
本指南帮助开发者将现有代码从旧的大模型调用方式迁移到新的统一LLM服务架构。新架构提供了更好的模块化、错误处理和配置管理。
|
||||
|
||||
## 🔄 迁移对比
|
||||
|
||||
### 旧的调用方式 vs 新的调用方式
|
||||
|
||||
#### 1. 视觉分析器创建
|
||||
|
||||
**旧方式:**
|
||||
```python
|
||||
from app.utils import gemini_analyzer, qwenvl_analyzer
|
||||
|
||||
if provider == 'gemini':
|
||||
analyzer = gemini_analyzer.VisionAnalyzer(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
base_url=base_url
|
||||
)
|
||||
elif provider == 'qwenvl':
|
||||
analyzer = qwenvl_analyzer.QwenAnalyzer(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
base_url=base_url
|
||||
)
|
||||
```
|
||||
|
||||
**新方式:**
|
||||
```python
|
||||
from app.services.llm.unified_service import UnifiedLLMService
|
||||
|
||||
# 方式1: 直接使用统一服务
|
||||
results = await UnifiedLLMService.analyze_images(
|
||||
images=images,
|
||||
prompt=prompt,
|
||||
provider=provider # 可选,使用配置中的默认值
|
||||
)
|
||||
|
||||
# 方式2: 使用迁移适配器(向后兼容)
|
||||
from app.services.llm.migration_adapter import create_vision_analyzer
|
||||
analyzer = create_vision_analyzer(provider, api_key, model, base_url)
|
||||
results = await analyzer.analyze_images(images, prompt)
|
||||
```
|
||||
|
||||
#### 2. 文本生成
|
||||
|
||||
**旧方式:**
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=temperature,
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
result = response.choices[0].message.content
|
||||
```
|
||||
|
||||
**新方式:**
|
||||
```python
|
||||
from app.services.llm.unified_service import UnifiedLLMService
|
||||
|
||||
result = await UnifiedLLMService.generate_text(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
temperature=temperature,
|
||||
response_format="json"
|
||||
)
|
||||
```
|
||||
|
||||
#### 3. 解说文案生成
|
||||
|
||||
**旧方式:**
|
||||
```python
|
||||
from app.services.generate_narration_script import generate_narration
|
||||
|
||||
narration = generate_narration(
|
||||
markdown_content,
|
||||
api_key,
|
||||
base_url=base_url,
|
||||
model=model
|
||||
)
|
||||
# 手动解析JSON和验证格式
|
||||
import json
|
||||
narration_dict = json.loads(narration)['items']
|
||||
```
|
||||
|
||||
**新方式:**
|
||||
```python
|
||||
from app.services.llm.unified_service import UnifiedLLMService
|
||||
|
||||
# 自动验证输出格式
|
||||
narration_items = await UnifiedLLMService.generate_narration_script(
|
||||
prompt=prompt,
|
||||
validate_output=True # 自动验证JSON格式和字段
|
||||
)
|
||||
```
|
||||
|
||||
## 📝 具体迁移步骤
|
||||
|
||||
### 步骤1: 更新配置文件
|
||||
|
||||
**旧配置格式:**
|
||||
```toml
|
||||
[app]
|
||||
llm_provider = "openai"
|
||||
openai_api_key = "sk-xxx"
|
||||
openai_model_name = "gpt-4"
|
||||
|
||||
vision_llm_provider = "gemini"
|
||||
gemini_api_key = "xxx"
|
||||
gemini_model_name = "gemini-1.5-pro"
|
||||
```
|
||||
|
||||
**新配置格式:**
|
||||
```toml
|
||||
[app]
|
||||
# 视觉模型配置
|
||||
vision_llm_provider = "gemini"
|
||||
vision_gemini_api_key = "xxx"
|
||||
vision_gemini_model_name = "gemini-2.0-flash-lite"
|
||||
vision_gemini_base_url = "https://generativelanguage.googleapis.com/v1beta"
|
||||
|
||||
# 文本模型配置
|
||||
text_llm_provider = "openai"
|
||||
text_openai_api_key = "sk-xxx"
|
||||
text_openai_model_name = "gpt-4o-mini"
|
||||
text_openai_base_url = "https://api.openai.com/v1"
|
||||
```
|
||||
|
||||
### 步骤2: 更新导入语句
|
||||
|
||||
**旧导入:**
|
||||
```python
|
||||
from app.utils import gemini_analyzer, qwenvl_analyzer
|
||||
from app.services.generate_narration_script import generate_narration
|
||||
from app.services.SDE.short_drama_explanation import analyze_subtitle
|
||||
```
|
||||
|
||||
**新导入:**
|
||||
```python
|
||||
from app.services.llm.unified_service import UnifiedLLMService
|
||||
from app.services.llm.migration_adapter import (
|
||||
create_vision_analyzer,
|
||||
SubtitleAnalyzerAdapter
|
||||
)
|
||||
```
|
||||
|
||||
### 步骤3: 更新函数调用
|
||||
|
||||
#### 图片分析迁移
|
||||
|
||||
**旧代码:**
|
||||
```python
|
||||
def analyze_images_old(provider, api_key, model, base_url, images, prompt):
|
||||
if provider == 'gemini':
|
||||
analyzer = gemini_analyzer.VisionAnalyzer(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
base_url=base_url
|
||||
)
|
||||
else:
|
||||
analyzer = qwenvl_analyzer.QwenAnalyzer(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
base_url=base_url
|
||||
)
|
||||
|
||||
# 同步调用
|
||||
results = []
|
||||
for batch in batches:
|
||||
result = analyzer.analyze_batch(batch, prompt)
|
||||
results.append(result)
|
||||
return results
|
||||
```
|
||||
|
||||
**新代码:**
|
||||
```python
|
||||
async def analyze_images_new(images, prompt, provider=None):
|
||||
# 异步调用,自动批处理
|
||||
results = await UnifiedLLMService.analyze_images(
|
||||
images=images,
|
||||
prompt=prompt,
|
||||
provider=provider,
|
||||
batch_size=10
|
||||
)
|
||||
return results
|
||||
```
|
||||
|
||||
#### 字幕分析迁移
|
||||
|
||||
**旧代码:**
|
||||
```python
|
||||
from app.services.SDE.short_drama_explanation import analyze_subtitle
|
||||
|
||||
result = analyze_subtitle(
|
||||
subtitle_file_path=subtitle_path,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url,
|
||||
provider=provider
|
||||
)
|
||||
```
|
||||
|
||||
**新代码:**
|
||||
```python
|
||||
# 方式1: 使用统一服务
|
||||
with open(subtitle_path, 'r', encoding='utf-8') as f:
|
||||
subtitle_content = f.read()
|
||||
|
||||
result = await UnifiedLLMService.analyze_subtitle(
|
||||
subtitle_content=subtitle_content,
|
||||
provider=provider,
|
||||
validate_output=True
|
||||
)
|
||||
|
||||
# 方式2: 使用适配器
|
||||
from app.services.llm.migration_adapter import SubtitleAnalyzerAdapter
|
||||
|
||||
analyzer = SubtitleAnalyzerAdapter(api_key, model, base_url, provider)
|
||||
result = analyzer.analyze_subtitle(subtitle_content)
|
||||
```
|
||||
|
||||
## 🔧 常见迁移问题
|
||||
|
||||
### 1. 同步 vs 异步调用
|
||||
|
||||
**问题:** 新架构使用异步调用,旧代码是同步的。
|
||||
|
||||
**解决方案:**
|
||||
```python
|
||||
# 在同步函数中调用异步函数
|
||||
import asyncio
|
||||
|
||||
def sync_function():
|
||||
result = asyncio.run(UnifiedLLMService.generate_text(prompt))
|
||||
return result
|
||||
|
||||
# 或者将整个函数改为异步
|
||||
async def async_function():
|
||||
result = await UnifiedLLMService.generate_text(prompt)
|
||||
return result
|
||||
```
|
||||
|
||||
### 2. 配置获取方式变化
|
||||
|
||||
**问题:** 配置键名发生变化。
|
||||
|
||||
**解决方案:**
|
||||
```python
|
||||
# 旧方式
|
||||
api_key = config.app.get('openai_api_key')
|
||||
model = config.app.get('openai_model_name')
|
||||
|
||||
# 新方式
|
||||
provider = config.app.get('text_llm_provider', 'openai')
|
||||
api_key = config.app.get(f'text_{provider}_api_key')
|
||||
model = config.app.get(f'text_{provider}_model_name')
|
||||
```
|
||||
|
||||
### 3. 错误处理更新
|
||||
|
||||
**旧方式:**
|
||||
```python
|
||||
try:
|
||||
result = some_llm_call()
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
```
|
||||
|
||||
**新方式:**
|
||||
```python
|
||||
from app.services.llm.exceptions import LLMServiceError, ValidationError
|
||||
|
||||
try:
|
||||
result = await UnifiedLLMService.generate_text(prompt)
|
||||
except ValidationError as e:
|
||||
print(f"输出验证失败: {e.message}")
|
||||
except LLMServiceError as e:
|
||||
print(f"LLM服务错误: {e.message}")
|
||||
except Exception as e:
|
||||
print(f"未知错误: {e}")
|
||||
```
|
||||
|
||||
## ✅ 迁移检查清单
|
||||
|
||||
### 配置迁移
|
||||
- [ ] 更新配置文件格式
|
||||
- [ ] 验证所有API密钥配置正确
|
||||
- [ ] 运行配置验证器检查
|
||||
|
||||
### 代码迁移
|
||||
- [ ] 更新导入语句
|
||||
- [ ] 将同步调用改为异步调用
|
||||
- [ ] 更新错误处理机制
|
||||
- [ ] 使用新的统一接口
|
||||
|
||||
### 测试验证
|
||||
- [ ] 运行LLM服务测试脚本
|
||||
- [ ] 测试所有功能模块
|
||||
- [ ] 验证输出格式正确
|
||||
- [ ] 检查性能和稳定性
|
||||
|
||||
### 清理工作
|
||||
- [ ] 移除未使用的旧代码
|
||||
- [ ] 更新文档和注释
|
||||
- [ ] 清理过时的依赖
|
||||
|
||||
## 🚀 迁移最佳实践
|
||||
|
||||
### 1. 渐进式迁移
|
||||
- 先迁移一个模块,测试通过后再迁移其他模块
|
||||
- 保留旧代码作为备用方案
|
||||
- 使用迁移适配器确保向后兼容
|
||||
|
||||
### 2. 充分测试
|
||||
- 在每个迁移步骤后运行测试
|
||||
- 比较新旧实现的输出结果
|
||||
- 测试边界情况和错误处理
|
||||
|
||||
### 3. 监控和日志
|
||||
- 启用详细日志记录
|
||||
- 监控API调用成功率
|
||||
- 跟踪性能指标
|
||||
|
||||
### 4. 文档更新
|
||||
- 更新代码注释
|
||||
- 更新API文档
|
||||
- 记录迁移过程中的问题和解决方案
|
||||
|
||||
## 📞 获取帮助
|
||||
|
||||
如果在迁移过程中遇到问题:
|
||||
|
||||
1. **查看测试脚本输出**:
|
||||
```bash
|
||||
python app/services/llm/test_llm_service.py
|
||||
```
|
||||
|
||||
2. **验证配置**:
|
||||
```python
|
||||
from app.services.llm.config_validator import LLMConfigValidator
|
||||
results = LLMConfigValidator.validate_all_configs()
|
||||
LLMConfigValidator.print_validation_report(results)
|
||||
```
|
||||
|
||||
3. **查看详细日志**:
|
||||
```python
|
||||
from loguru import logger
|
||||
logger.add("migration.log", level="DEBUG")
|
||||
```
|
||||
|
||||
4. **参考示例代码**:
|
||||
- 查看 `app/services/llm/test_llm_service.py` 中的使用示例
|
||||
- 参考已迁移的文件如 `webui/tools/base.py`
|
||||
|
||||
---
|
||||
|
||||
*最后更新: 2025-01-07*
|
||||
294
docs/LLM_SERVICE_GUIDE.md
Normal file
294
docs/LLM_SERVICE_GUIDE.md
Normal file
@ -0,0 +1,294 @@
|
||||
# NarratoAI 大模型服务使用指南
|
||||
|
||||
## 📖 概述
|
||||
|
||||
NarratoAI 项目已完成大模型服务的全面重构,提供了统一、模块化、可扩展的大模型集成架构。新架构支持多种大模型供应商,具有严格的输出格式验证和完善的错误处理机制。
|
||||
|
||||
## 🏗️ 架构概览
|
||||
|
||||
### 核心组件
|
||||
|
||||
```
|
||||
app/services/llm/
|
||||
├── __init__.py # 模块入口
|
||||
├── base.py # 抽象基类
|
||||
├── manager.py # 服务管理器
|
||||
├── unified_service.py # 统一服务接口
|
||||
├── validators.py # 输出格式验证器
|
||||
├── exceptions.py # 异常类定义
|
||||
├── migration_adapter.py # 迁移适配器
|
||||
├── config_validator.py # 配置验证器
|
||||
├── test_llm_service.py # 测试脚本
|
||||
└── providers/ # 提供商实现
|
||||
├── __init__.py
|
||||
├── gemini_provider.py
|
||||
├── gemini_openai_provider.py
|
||||
├── openai_provider.py
|
||||
├── qwen_provider.py
|
||||
├── deepseek_provider.py
|
||||
└── siliconflow_provider.py
|
||||
```
|
||||
|
||||
### 支持的供应商
|
||||
|
||||
#### 视觉模型供应商
|
||||
- **Gemini** (原生API + OpenAI兼容)
|
||||
- **QwenVL** (通义千问视觉)
|
||||
- **Siliconflow** (硅基流动)
|
||||
|
||||
#### 文本生成模型供应商
|
||||
- **OpenAI** (标准OpenAI API)
|
||||
- **Gemini** (原生API + OpenAI兼容)
|
||||
- **DeepSeek** (深度求索)
|
||||
- **Qwen** (通义千问)
|
||||
- **Siliconflow** (硅基流动)
|
||||
|
||||
## ⚙️ 配置说明
|
||||
|
||||
### 配置文件格式
|
||||
|
||||
在 `config.toml` 中配置大模型服务:
|
||||
|
||||
```toml
|
||||
[app]
|
||||
# 视觉模型提供商配置
|
||||
vision_llm_provider = "gemini"
|
||||
|
||||
# Gemini 视觉模型
|
||||
vision_gemini_api_key = "your_gemini_api_key"
|
||||
vision_gemini_model_name = "gemini-2.0-flash-lite"
|
||||
vision_gemini_base_url = "https://generativelanguage.googleapis.com/v1beta"
|
||||
|
||||
# QwenVL 视觉模型
|
||||
vision_qwenvl_api_key = "your_qwen_api_key"
|
||||
vision_qwenvl_model_name = "qwen2.5-vl-32b-instruct"
|
||||
vision_qwenvl_base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
|
||||
# 文本模型提供商配置
|
||||
text_llm_provider = "openai"
|
||||
|
||||
# OpenAI 文本模型
|
||||
text_openai_api_key = "your_openai_api_key"
|
||||
text_openai_model_name = "gpt-4o-mini"
|
||||
text_openai_base_url = "https://api.openai.com/v1"
|
||||
|
||||
# DeepSeek 文本模型
|
||||
text_deepseek_api_key = "your_deepseek_api_key"
|
||||
text_deepseek_model_name = "deepseek-chat"
|
||||
text_deepseek_base_url = "https://api.deepseek.com"
|
||||
```
|
||||
|
||||
### 配置验证
|
||||
|
||||
使用配置验证器检查配置是否正确:
|
||||
|
||||
```python
|
||||
from app.services.llm.config_validator import LLMConfigValidator
|
||||
|
||||
# 验证所有配置
|
||||
results = LLMConfigValidator.validate_all_configs()
|
||||
|
||||
# 打印验证报告
|
||||
LLMConfigValidator.print_validation_report(results)
|
||||
|
||||
# 获取配置建议
|
||||
suggestions = LLMConfigValidator.get_config_suggestions()
|
||||
```
|
||||
|
||||
## 🚀 使用方法
|
||||
|
||||
### 1. 统一服务接口(推荐)
|
||||
|
||||
```python
|
||||
from app.services.llm.unified_service import UnifiedLLMService
|
||||
|
||||
# 图片分析
|
||||
results = await UnifiedLLMService.analyze_images(
|
||||
images=["path/to/image1.jpg", "path/to/image2.jpg"],
|
||||
prompt="请描述这些图片的内容",
|
||||
provider="gemini", # 可选,不指定则使用配置中的默认值
|
||||
batch_size=10
|
||||
)
|
||||
|
||||
# 文本生成
|
||||
text = await UnifiedLLMService.generate_text(
|
||||
prompt="请介绍人工智能的发展历史",
|
||||
system_prompt="你是一个专业的AI专家",
|
||||
provider="openai", # 可选
|
||||
temperature=0.7,
|
||||
response_format="json" # 可选,支持JSON格式输出
|
||||
)
|
||||
|
||||
# 解说文案生成(带验证)
|
||||
narration_items = await UnifiedLLMService.generate_narration_script(
|
||||
prompt="根据视频内容生成解说文案...",
|
||||
validate_output=True # 自动验证输出格式
|
||||
)
|
||||
|
||||
# 字幕分析
|
||||
analysis = await UnifiedLLMService.analyze_subtitle(
|
||||
subtitle_content="字幕内容...",
|
||||
validate_output=True
|
||||
)
|
||||
```
|
||||
|
||||
### 2. 直接使用服务管理器
|
||||
|
||||
```python
|
||||
from app.services.llm.manager import LLMServiceManager
|
||||
|
||||
# 获取视觉模型提供商
|
||||
vision_provider = LLMServiceManager.get_vision_provider("gemini")
|
||||
results = await vision_provider.analyze_images(images, prompt)
|
||||
|
||||
# 获取文本模型提供商
|
||||
text_provider = LLMServiceManager.get_text_provider("openai")
|
||||
text = await text_provider.generate_text(prompt)
|
||||
```
|
||||
|
||||
### 3. 迁移适配器(向后兼容)
|
||||
|
||||
```python
|
||||
from app.services.llm.migration_adapter import create_vision_analyzer
|
||||
|
||||
# 兼容旧的接口
|
||||
analyzer = create_vision_analyzer("gemini", api_key, model, base_url)
|
||||
results = await analyzer.analyze_images(images, prompt)
|
||||
```
|
||||
|
||||
## 🔍 输出格式验证
|
||||
|
||||
### 解说文案验证
|
||||
|
||||
```python
|
||||
from app.services.llm.validators import OutputValidator
|
||||
|
||||
# 验证解说文案格式
|
||||
try:
|
||||
narration_items = OutputValidator.validate_narration_script(output)
|
||||
print(f"验证成功,共 {len(narration_items)} 个片段")
|
||||
except ValidationError as e:
|
||||
print(f"验证失败: {e.message}")
|
||||
```
|
||||
|
||||
### JSON输出验证
|
||||
|
||||
```python
|
||||
# 验证JSON格式
|
||||
try:
|
||||
data = OutputValidator.validate_json_output(output)
|
||||
print("JSON格式验证成功")
|
||||
except ValidationError as e:
|
||||
print(f"JSON验证失败: {e.message}")
|
||||
```
|
||||
|
||||
## 🧪 测试和调试
|
||||
|
||||
### 运行测试脚本
|
||||
|
||||
```bash
|
||||
# 运行完整的LLM服务测试
|
||||
python app/services/llm/test_llm_service.py
|
||||
```
|
||||
|
||||
测试脚本会验证:
|
||||
- 配置有效性
|
||||
- 提供商信息获取
|
||||
- 文本生成功能
|
||||
- JSON格式生成
|
||||
- 字幕分析功能
|
||||
- 解说文案生成功能
|
||||
|
||||
### 调试技巧
|
||||
|
||||
1. **启用详细日志**:
|
||||
```python
|
||||
from loguru import logger
|
||||
logger.add("llm_service.log", level="DEBUG")
|
||||
```
|
||||
|
||||
2. **清空提供商缓存**:
|
||||
```python
|
||||
UnifiedLLMService.clear_cache()
|
||||
```
|
||||
|
||||
3. **检查提供商信息**:
|
||||
```python
|
||||
info = UnifiedLLMService.get_provider_info()
|
||||
print(info)
|
||||
```
|
||||
|
||||
## ⚠️ 注意事项
|
||||
|
||||
### 1. API密钥安全
|
||||
- 不要在代码中硬编码API密钥
|
||||
- 使用环境变量或配置文件管理密钥
|
||||
- 定期轮换API密钥
|
||||
|
||||
### 2. 错误处理
|
||||
- 所有LLM服务调用都应该包装在try-catch中
|
||||
- 使用适当的异常类型进行错误处理
|
||||
- 实现重试机制处理临时性错误
|
||||
|
||||
### 3. 性能优化
|
||||
- 合理设置批处理大小
|
||||
- 使用缓存避免重复调用
|
||||
- 监控API调用频率和成本
|
||||
|
||||
### 4. 模型选择
|
||||
- 根据任务类型选择合适的模型
|
||||
- 考虑成本和性能的平衡
|
||||
- 定期更新到最新的模型版本
|
||||
|
||||
## 🔧 扩展新供应商
|
||||
|
||||
### 1. 创建提供商类
|
||||
|
||||
```python
|
||||
# app/services/llm/providers/new_provider.py
|
||||
from ..base import TextModelProvider
|
||||
|
||||
class NewTextProvider(TextModelProvider):
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "new_provider"
|
||||
|
||||
@property
|
||||
def supported_models(self) -> List[str]:
|
||||
return ["model-1", "model-2"]
|
||||
|
||||
async def generate_text(self, prompt: str, **kwargs) -> str:
|
||||
# 实现具体的API调用逻辑
|
||||
pass
|
||||
```
|
||||
|
||||
### 2. 注册提供商
|
||||
|
||||
```python
|
||||
# app/services/llm/providers/__init__.py
|
||||
from .new_provider import NewTextProvider
|
||||
|
||||
LLMServiceManager.register_text_provider('new_provider', NewTextProvider)
|
||||
```
|
||||
|
||||
### 3. 添加配置支持
|
||||
|
||||
```toml
|
||||
# config.toml
|
||||
text_new_provider_api_key = "your_api_key"
|
||||
text_new_provider_model_name = "model-1"
|
||||
text_new_provider_base_url = "https://api.newprovider.com/v1"
|
||||
```
|
||||
|
||||
## 📞 技术支持
|
||||
|
||||
如果在使用过程中遇到问题:
|
||||
|
||||
1. 首先运行测试脚本检查配置
|
||||
2. 查看日志文件了解详细错误信息
|
||||
3. 检查API密钥和网络连接
|
||||
4. 参考本文档的故障排除部分
|
||||
|
||||
---
|
||||
|
||||
*最后更新: 2025-01-07*
|
||||
@ -6,34 +6,45 @@ from requests.adapters import HTTPAdapter
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
from app.config import config
|
||||
# 导入新的LLM服务模块 - 确保提供商被注册
|
||||
import app.services.llm # 这会触发提供商注册
|
||||
from app.services.llm.migration_adapter import create_vision_analyzer as create_vision_analyzer_new
|
||||
# 保留旧的导入以确保向后兼容
|
||||
from app.utils import gemini_analyzer, qwenvl_analyzer
|
||||
|
||||
|
||||
def create_vision_analyzer(provider, api_key, model, base_url):
|
||||
"""
|
||||
创建视觉分析器实例
|
||||
|
||||
创建视觉分析器实例 - 已重构为使用新的LLM服务架构
|
||||
|
||||
Args:
|
||||
provider: 提供商名称 ('gemini' 或 'qwenvl')
|
||||
provider: 提供商名称 ('gemini', 'gemini(openai)', 'qwenvl', 'siliconflow')
|
||||
api_key: API密钥
|
||||
model: 模型名称
|
||||
base_url: API基础URL
|
||||
|
||||
|
||||
Returns:
|
||||
VisionAnalyzer 或 QwenAnalyzer 实例
|
||||
视觉分析器实例
|
||||
"""
|
||||
if provider == 'gemini':
|
||||
return gemini_analyzer.VisionAnalyzer(model_name=model, api_key=api_key, base_url=base_url)
|
||||
elif provider == 'gemini(openai)':
|
||||
from app.utils.gemini_openai_analyzer import GeminiOpenAIAnalyzer
|
||||
return GeminiOpenAIAnalyzer(model_name=model, api_key=api_key, base_url=base_url)
|
||||
else:
|
||||
# 只传入必要的参数
|
||||
return qwenvl_analyzer.QwenAnalyzer(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
base_url=base_url
|
||||
)
|
||||
try:
|
||||
# 优先使用新的LLM服务架构
|
||||
return create_vision_analyzer_new(provider, api_key, model, base_url)
|
||||
except Exception as e:
|
||||
logger.warning(f"使用新LLM服务失败,回退到旧实现: {str(e)}")
|
||||
|
||||
# 回退到旧的实现以确保兼容性
|
||||
if provider == 'gemini':
|
||||
return gemini_analyzer.VisionAnalyzer(model_name=model, api_key=api_key, base_url=base_url)
|
||||
elif provider == 'gemini(openai)':
|
||||
from app.utils.gemini_openai_analyzer import GeminiOpenAIAnalyzer
|
||||
return GeminiOpenAIAnalyzer(model_name=model, api_key=api_key, base_url=base_url)
|
||||
else:
|
||||
# 只传入必要的参数
|
||||
return qwenvl_analyzer.QwenAnalyzer(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
base_url=base_url
|
||||
)
|
||||
|
||||
|
||||
def get_batch_timestamps(batch_files, prev_batch_files=None):
|
||||
|
||||
@ -16,6 +16,9 @@ from loguru import logger
|
||||
|
||||
from app.config import config
|
||||
from app.services.SDE.short_drama_explanation import analyze_subtitle, generate_narration_script
|
||||
# 导入新的LLM服务模块 - 确保提供商被注册
|
||||
import app.services.llm # 这会触发提供商注册
|
||||
from app.services.llm.migration_adapter import SubtitleAnalyzerAdapter
|
||||
import re
|
||||
|
||||
|
||||
@ -132,32 +135,29 @@ def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperatu
|
||||
return
|
||||
|
||||
"""
|
||||
2. 分析字幕总结剧情
|
||||
2. 分析字幕总结剧情 - 使用新的LLM服务架构
|
||||
"""
|
||||
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
|
||||
text_api_key = config.app.get(f'text_{text_provider}_api_key')
|
||||
text_model = config.app.get(f'text_{text_provider}_model_name')
|
||||
text_base_url = config.app.get(f'text_{text_provider}_base_url')
|
||||
analysis_result = analyze_subtitle(
|
||||
subtitle_file_path=subtitle_path,
|
||||
api_key=text_api_key,
|
||||
model=text_model,
|
||||
base_url=text_base_url,
|
||||
save_result=True,
|
||||
temperature=temperature,
|
||||
provider=text_provider
|
||||
)
|
||||
"""
|
||||
3. 根据剧情生成解说文案
|
||||
"""
|
||||
if analysis_result["status"] == "success":
|
||||
logger.info("字幕分析成功!")
|
||||
update_progress(60, "正在生成文案...")
|
||||
|
||||
# 根据剧情生成解说文案
|
||||
narration_result = generate_narration_script(
|
||||
short_name=video_theme,
|
||||
plot_analysis=analysis_result["analysis"],
|
||||
try:
|
||||
# 优先使用新的LLM服务架构
|
||||
logger.info("使用新的LLM服务架构进行字幕分析")
|
||||
analyzer = SubtitleAnalyzerAdapter(text_api_key, text_model, text_base_url, text_provider)
|
||||
|
||||
# 读取字幕文件
|
||||
with open(subtitle_path, 'r', encoding='utf-8') as f:
|
||||
subtitle_content = f.read()
|
||||
|
||||
analysis_result = analyzer.analyze_subtitle(subtitle_content)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"使用新LLM服务失败,回退到旧实现: {str(e)}")
|
||||
# 回退到旧的实现
|
||||
analysis_result = analyze_subtitle(
|
||||
subtitle_file_path=subtitle_path,
|
||||
api_key=text_api_key,
|
||||
model=text_model,
|
||||
base_url=text_base_url,
|
||||
@ -165,6 +165,35 @@ def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperatu
|
||||
temperature=temperature,
|
||||
provider=text_provider
|
||||
)
|
||||
"""
|
||||
3. 根据剧情生成解说文案
|
||||
"""
|
||||
if analysis_result["status"] == "success":
|
||||
logger.info("字幕分析成功!")
|
||||
update_progress(60, "正在生成文案...")
|
||||
|
||||
# 根据剧情生成解说文案 - 使用新的LLM服务架构
|
||||
try:
|
||||
# 优先使用新的LLM服务架构
|
||||
logger.info("使用新的LLM服务架构生成解说文案")
|
||||
narration_result = analyzer.generate_narration_script(
|
||||
short_name=video_theme,
|
||||
plot_analysis=analysis_result["analysis"],
|
||||
temperature=temperature
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"使用新LLM服务失败,回退到旧实现: {str(e)}")
|
||||
# 回退到旧的实现
|
||||
narration_result = generate_narration_script(
|
||||
short_name=video_theme,
|
||||
plot_analysis=analysis_result["analysis"],
|
||||
api_key=text_api_key,
|
||||
model=text_model,
|
||||
base_url=text_base_url,
|
||||
save_result=True,
|
||||
temperature=temperature,
|
||||
provider=text_provider
|
||||
)
|
||||
|
||||
if narration_result["status"] == "success":
|
||||
logger.info("\n解说文案生成成功!")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user