Merge pull request #157 from linyqh/dev066

Dev067 重构 LLM 管理器和提示词管理器
This commit is contained in:
viccy 2025-07-07 18:58:44 +08:00 committed by GitHub
commit 63375883c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
56 changed files with 7698 additions and 473 deletions

View File

@ -4,7 +4,7 @@
'''
@Project: NarratoAI
@File : audio_config
@Author : 小林同学
@Author : Viccy同学
@Date : 2025/1/7
@Description: 音频配置管理
'''

View File

@ -15,41 +15,60 @@ from typing import Dict, Any, Optional
from loguru import logger
from app.config import config
from app.utils.utils import get_uuid, storage_dir
from app.services.SDE.prompt import subtitle_plot_analysis_v1, plot_writing
# 导入新的提示词管理系统
from app.services.prompts import PromptManager
class SubtitleAnalyzer:
"""字幕剧情分析器,负责分析字幕内容并提取关键剧情段落"""
def __init__(
self,
self,
api_key: Optional[str] = None,
model: Optional[str] = None,
base_url: Optional[str] = None,
custom_prompt: Optional[str] = None,
temperature: Optional[float] = 1.0,
provider: Optional[str] = None,
):
"""
初始化字幕分析器
Args:
api_key: API密钥如果不提供则从配置中读取
model: 模型名称如果不提供则从配置中读取
base_url: API基础URL如果不提供则从配置中读取或使用默认值
custom_prompt: 自定义提示词如果不提供则使用默认值
temperature: 模型温度
provider: 提供商类型用于确定API调用格式
"""
# 使用传入的参数或从配置中获取
self.api_key = api_key
self.model = model
self.base_url = base_url
self.temperature = temperature
self.provider = provider or self._detect_provider()
# 设置提示词模板
self.prompt_template = custom_prompt or subtitle_plot_analysis_v1
if custom_prompt:
self.prompt_template = custom_prompt
else:
# 使用新的提示词管理系统
self.prompt_template = PromptManager.get_prompt(
category="short_drama_narration",
name="plot_analysis",
parameters={}
)
# 根据提供商类型确定是否为原生Gemini
self.is_native_gemini = self.provider.lower() == 'gemini'
# 初始化HTTP请求所需的头信息
self._init_headers()
def _detect_provider(self):
"""根据配置自动检测提供商类型"""
return config.app.get('text_llm_provider', 'gemini').lower()
def _init_headers(self):
"""初始化HTTP请求头"""
@ -67,18 +86,152 @@ class SubtitleAnalyzer:
def analyze_subtitle(self, subtitle_content: str) -> Dict[str, Any]:
"""
分析字幕内容
Args:
subtitle_content: 字幕内容文本
Returns:
Dict[str, Any]: 包含分析结果的字典
"""
try:
# 构建完整提示词
prompt = f"{self.prompt_template}\n\n{subtitle_content}"
# 构建请求体数据
if self.is_native_gemini:
# 使用原生Gemini API格式
return self._call_native_gemini_api(prompt)
else:
# 使用OpenAI兼容格式
return self._call_openai_compatible_api(prompt)
except Exception as e:
logger.error(f"字幕分析过程中发生错误: {str(e)}")
return {
"status": "error",
"message": str(e),
"temperature": self.temperature
}
def _call_native_gemini_api(self, prompt: str) -> Dict[str, Any]:
"""调用原生Gemini API"""
try:
# 构建原生Gemini API请求数据
payload = {
"systemInstruction": {
"parts": [{"text": "你是一位专业的剧本分析师和剧情概括助手。请严格按照要求的格式输出分析结果。"}]
},
"contents": [{
"parts": [{"text": prompt}]
}],
"generationConfig": {
"temperature": self.temperature,
"topK": 40,
"topP": 0.95,
"maxOutputTokens": 4000,
"candidateCount": 1
},
"safetySettings": [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
}
]
}
# 构建请求URL
url = f"{self.base_url}/models/{self.model}:generateContent?key={self.api_key}"
# 发送请求
response = requests.post(
url,
json=payload,
headers={"Content-Type": "application/json", "User-Agent": "NarratoAI/1.0"},
timeout=120
)
if response.status_code == 200:
response_data = response.json()
# 检查响应格式
if "candidates" not in response_data or not response_data["candidates"]:
return {
"status": "error",
"message": "原生Gemini API返回无效响应可能触发了安全过滤",
"temperature": self.temperature
}
candidate = response_data["candidates"][0]
# 检查是否被安全过滤阻止
if "finishReason" in candidate and candidate["finishReason"] == "SAFETY":
return {
"status": "error",
"message": "内容被Gemini安全过滤器阻止",
"temperature": self.temperature
}
if "content" not in candidate or "parts" not in candidate["content"]:
return {
"status": "error",
"message": "原生Gemini API返回内容格式错误",
"temperature": self.temperature
}
# 提取文本内容
analysis_result = ""
for part in candidate["content"]["parts"]:
if "text" in part:
analysis_result += part["text"]
if not analysis_result.strip():
return {
"status": "error",
"message": "原生Gemini API返回空内容",
"temperature": self.temperature
}
logger.debug(f"原生Gemini字幕分析完成")
return {
"status": "success",
"analysis": analysis_result,
"tokens_used": response_data.get("usage", {}).get("total_tokens", 0),
"model": self.model,
"temperature": self.temperature
}
else:
error_msg = f"原生Gemini API请求失败状态码: {response.status_code}, 响应: {response.text}"
logger.error(error_msg)
return {
"status": "error",
"message": error_msg,
"temperature": self.temperature
}
except Exception as e:
logger.error(f"原生Gemini API调用失败: {str(e)}")
return {
"status": "error",
"message": f"原生Gemini API调用失败: {str(e)}",
"temperature": self.temperature
}
def _call_openai_compatible_api(self, prompt: str) -> Dict[str, Any]:
"""调用OpenAI兼容的API"""
try:
# 构建OpenAI格式的请求数据
payload = {
"model": self.model,
"messages": [
@ -87,22 +240,22 @@ class SubtitleAnalyzer:
],
"temperature": self.temperature
}
# 构建请求地址
url = f"{self.base_url}/chat/completions"
# 发送HTTP请求
response = requests.post(url, headers=self.headers, json=payload)
response = requests.post(url, headers=self.headers, json=payload, timeout=120)
# 解析响应
if response.status_code == 200:
response_data = response.json()
# 提取响应内容
if "choices" in response_data and len(response_data["choices"]) > 0:
analysis_result = response_data["choices"][0]["message"]["content"]
logger.debug(f"字幕分析完成消耗的tokens: {response_data.get('usage', {}).get('total_tokens', 0)}")
logger.debug(f"OpenAI兼容API字幕分析完成消耗的tokens: {response_data.get('usage', {}).get('total_tokens', 0)}")
# 返回结果
return {
"status": "success",
@ -112,26 +265,26 @@ class SubtitleAnalyzer:
"temperature": self.temperature
}
else:
logger.error("字幕分析失败: 未获取到有效响应")
logger.error("OpenAI兼容API字幕分析失败: 未获取到有效响应")
return {
"status": "error",
"message": "未获取到有效响应",
"temperature": self.temperature
}
else:
error_msg = f"请求失败,状态码: {response.status_code}, 响应: {response.text}"
error_msg = f"OpenAI兼容API请求失败,状态码: {response.status_code}, 响应: {response.text}"
logger.error(error_msg)
return {
"status": "error",
"message": error_msg,
"temperature": self.temperature
}
except Exception as e:
logger.error(f"字幕分析过程中发生错误: {str(e)}")
logger.error(f"OpenAI兼容API调用失败: {str(e)}")
return {
"status": "error",
"message": str(e),
"message": f"OpenAI兼容API调用失败: {str(e)}",
"temperature": self.temperature
}
@ -206,20 +359,165 @@ class SubtitleAnalyzer:
def generate_narration_script(self, short_name:str, plot_analysis: str, temperature: float = 0.7) -> Dict[str, Any]:
"""
根据剧情分析生成解说文案
Args:
short_name: 短剧名称
plot_analysis: 剧情分析内容
temperature: 生成温度控制创造性默认0.7
Returns:
Dict[str, Any]: 包含生成结果的字典
"""
try:
# 构建完整提示词
prompt = plot_writing % (short_name, plot_analysis)
# 使用新的提示词管理系统构建提示词
prompt = PromptManager.get_prompt(
category="short_drama_narration",
name="script_generation",
parameters={
"drama_name": short_name,
"plot_analysis": plot_analysis
}
)
# 构建请求体数据
if self.is_native_gemini:
# 使用原生Gemini API格式
return self._generate_narration_with_native_gemini(prompt, temperature)
else:
# 使用OpenAI兼容格式
return self._generate_narration_with_openai_compatible(prompt, temperature)
except Exception as e:
logger.error(f"解说文案生成过程中发生错误: {str(e)}")
return {
"status": "error",
"message": str(e),
"temperature": self.temperature
}
def _generate_narration_with_native_gemini(self, prompt: str, temperature: float) -> Dict[str, Any]:
"""使用原生Gemini API生成解说文案"""
try:
# 构建原生Gemini API请求数据
# 为了确保JSON输出在提示词中添加更强的约束
enhanced_prompt = f"{prompt}\n\n请确保输出严格的JSON格式不要包含任何其他文字或标记。"
payload = {
"systemInstruction": {
"parts": [{"text": "你是一位专业的短视频解说脚本撰写专家。你必须严格按照JSON格式输出不能包含任何其他文字、说明或代码块标记。"}]
},
"contents": [{
"parts": [{"text": enhanced_prompt}]
}],
"generationConfig": {
"temperature": temperature,
"topK": 40,
"topP": 0.95,
"maxOutputTokens": 4000,
"candidateCount": 1,
"stopSequences": ["```", "注意", "说明"]
},
"safetySettings": [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
}
]
}
# 构建请求URL
url = f"{self.base_url}/models/{self.model}:generateContent?key={self.api_key}"
# 发送请求
response = requests.post(
url,
json=payload,
headers={"Content-Type": "application/json", "User-Agent": "NarratoAI/1.0"},
timeout=120
)
if response.status_code == 200:
response_data = response.json()
# 检查响应格式
if "candidates" not in response_data or not response_data["candidates"]:
return {
"status": "error",
"message": "原生Gemini API返回无效响应可能触发了安全过滤",
"temperature": temperature
}
candidate = response_data["candidates"][0]
# 检查是否被安全过滤阻止
if "finishReason" in candidate and candidate["finishReason"] == "SAFETY":
return {
"status": "error",
"message": "内容被Gemini安全过滤器阻止",
"temperature": temperature
}
if "content" not in candidate or "parts" not in candidate["content"]:
return {
"status": "error",
"message": "原生Gemini API返回内容格式错误",
"temperature": temperature
}
# 提取文本内容
narration_script = ""
for part in candidate["content"]["parts"]:
if "text" in part:
narration_script += part["text"]
if not narration_script.strip():
return {
"status": "error",
"message": "原生Gemini API返回空内容",
"temperature": temperature
}
logger.debug(f"原生Gemini解说文案生成完成")
return {
"status": "success",
"narration_script": narration_script,
"tokens_used": response_data.get("usage", {}).get("total_tokens", 0),
"model": self.model,
"temperature": temperature
}
else:
error_msg = f"原生Gemini API请求失败状态码: {response.status_code}, 响应: {response.text}"
logger.error(error_msg)
return {
"status": "error",
"message": error_msg,
"temperature": temperature
}
except Exception as e:
logger.error(f"原生Gemini API解说文案生成失败: {str(e)}")
return {
"status": "error",
"message": f"原生Gemini API解说文案生成失败: {str(e)}",
"temperature": temperature
}
def _generate_narration_with_openai_compatible(self, prompt: str, temperature: float) -> Dict[str, Any]:
"""使用OpenAI兼容API生成解说文案"""
try:
# 构建OpenAI格式的请求数据
payload = {
"model": self.model,
"messages": [
@ -228,56 +526,56 @@ class SubtitleAnalyzer:
],
"temperature": temperature
}
# 对特定模型添加响应格式设置
if self.model not in ["deepseek-reasoner"]:
payload["response_format"] = {"type": "json_object"}
# 构建请求地址
url = f"{self.base_url}/chat/completions"
# 发送HTTP请求
response = requests.post(url, headers=self.headers, json=payload)
response = requests.post(url, headers=self.headers, json=payload, timeout=120)
# 解析响应
if response.status_code == 200:
response_data = response.json()
# 提取响应内容
if "choices" in response_data and len(response_data["choices"]) > 0:
narration_script = response_data["choices"][0]["message"]["content"]
logger.debug(f"解说文案生成完成消耗的tokens: {response_data.get('usage', {}).get('total_tokens', 0)}")
logger.debug(f"OpenAI兼容API解说文案生成完成消耗的tokens: {response_data.get('usage', {}).get('total_tokens', 0)}")
# 返回结果
return {
"status": "success",
"narration_script": narration_script,
"tokens_used": response_data.get("usage", {}).get("total_tokens", 0),
"model": self.model,
"temperature": self.temperature
"temperature": temperature
}
else:
logger.error("解说文案生成失败: 未获取到有效响应")
logger.error("OpenAI兼容API解说文案生成失败: 未获取到有效响应")
return {
"status": "error",
"message": "未获取到有效响应",
"temperature": self.temperature
"temperature": temperature
}
else:
error_msg = f"请求失败,状态码: {response.status_code}, 响应: {response.text}"
error_msg = f"OpenAI兼容API请求失败,状态码: {response.status_code}, 响应: {response.text}"
logger.error(error_msg)
return {
"status": "error",
"message": error_msg,
"temperature": self.temperature
"temperature": temperature
}
except Exception as e:
logger.error(f"解说文案生成过程中发生错误: {str(e)}")
logger.error(f"OpenAI兼容API解说文案生成失败: {str(e)}")
return {
"status": "error",
"message": str(e),
"temperature": self.temperature
"message": f"OpenAI兼容API解说文案生成失败: {str(e)}",
"temperature": temperature
}
def save_narration_script(self, narration_result: Dict[str, Any], output_path: Optional[str] = None) -> str:
@ -324,11 +622,12 @@ def analyze_subtitle(
custom_prompt: Optional[str] = None,
temperature: float = 1.0,
save_result: bool = False,
output_path: Optional[str] = None
output_path: Optional[str] = None,
provider: Optional[str] = None
) -> Dict[str, Any]:
"""
分析字幕内容的便捷函数
Args:
subtitle_content: 字幕内容文本
subtitle_file_path: 字幕文件路径
@ -339,7 +638,8 @@ def analyze_subtitle(
temperature: 模型温度
save_result: 是否保存结果到文件
output_path: 输出文件路径
provider: 提供商类型
Returns:
Dict[str, Any]: 包含分析结果的字典
"""
@ -349,7 +649,8 @@ def analyze_subtitle(
api_key=api_key,
model=model,
base_url=base_url,
custom_prompt=custom_prompt
custom_prompt=custom_prompt,
provider=provider
)
logger.debug(f"使用模型: {analyzer.model} 开始分析, 温度: {analyzer.temperature}")
# 分析字幕
@ -379,11 +680,12 @@ def generate_narration_script(
base_url: Optional[str] = None,
temperature: float = 1.0,
save_result: bool = False,
output_path: Optional[str] = None
output_path: Optional[str] = None,
provider: Optional[str] = None
) -> Dict[str, Any]:
"""
根据剧情分析生成解说文案的便捷函数
Args:
short_name: 短剧名称
plot_analysis: 剧情分析内容直接提供
@ -393,7 +695,8 @@ def generate_narration_script(
temperature: 生成温度控制创造性
save_result: 是否保存结果到文件
output_path: 输出文件路径
provider: 提供商类型
Returns:
Dict[str, Any]: 包含生成结果的字典
"""
@ -402,7 +705,8 @@ def generate_narration_script(
temperature=temperature,
api_key=api_key,
model=model,
base_url=base_url
base_url=base_url,
provider=provider
)
# 生成解说文案

View File

@ -6,12 +6,17 @@ from .utils.step1_subtitle_analyzer_openai import analyze_subtitle
from .utils.step5_merge_script import merge_script
def generate_script(srt_path: str, api_key: str, model_name: str, output_path: str, base_url: str = None, custom_clips: int = 5):
def generate_script(srt_path: str, api_key: str, model_name: str, output_path: str, base_url: str = None, custom_clips: int = 5, provider: str = None):
"""生成视频混剪脚本
Args:
srt_path: 字幕文件路径
api_key: API密钥
model_name: 模型名称
output_path: 输出文件路径可选
base_url: API基础URL
custom_clips: 自定义片段数量
provider: LLM服务提供商
Returns:
str: 生成的脚本内容
@ -27,7 +32,8 @@ def generate_script(srt_path: str, api_key: str, model_name: str, output_path: s
api_key=api_key,
model_name=model_name,
base_url=base_url,
custom_clips=custom_clips
custom_clips=custom_clips,
provider=provider
)
# 合并生成最终脚本

View File

@ -1,12 +1,18 @@
"""
使用OpenAI API分析字幕文件返回剧情梗概和爆点
使用统一LLM服务分析字幕文件返回剧情梗概和爆点
"""
import traceback
from openai import OpenAI, BadRequestError
import os
import json
import asyncio
from loguru import logger
from .utils import load_srt
# 导入新的提示词管理系统
from app.services.prompts import PromptManager
# 导入统一LLM服务
from app.services.llm.unified_service import UnifiedLLMService
# 导入安全的异步执行函数
from app.services.llm.migration_adapter import _run_async_safely
def analyze_subtitle(
@ -14,15 +20,18 @@ def analyze_subtitle(
model_name: str,
api_key: str = None,
base_url: str = None,
custom_clips: int = 5
custom_clips: int = 5,
provider: str = None
) -> dict:
"""分析字幕内容,返回完整的分析结果
Args:
srt_path (str): SRT字幕文件路径
model_name (str): 大模型名称
api_key (str, optional): 大模型API密钥. Defaults to None.
model_name (str, optional): 大模型名称. Defaults to "gpt-4o-2024-11-20".
base_url (str, optional): 大模型API基础URL. Defaults to None.
custom_clips (int): 需要提取的片段数量. Defaults to 5.
provider (str, optional): LLM服务提供商. Defaults to None.
Returns:
dict: 包含剧情梗概和结构化的时间段分析的字典
@ -32,126 +41,103 @@ def analyze_subtitle(
subtitles = load_srt(srt_path)
subtitle_content = "\n".join([f"{sub['timestamp']}\n{sub['text']}" for sub in subtitles])
# 初始化客户端
global client
if "deepseek" in model_name.lower():
client = OpenAI(
api_key=api_key or os.getenv('DeepSeek_API_KEY'),
base_url="https://api.siliconflow.cn/v1" # 使用第三方 硅基流动 API
)
else:
client = OpenAI(
api_key=api_key or os.getenv('OPENAI_API_KEY'),
base_url=base_url
)
# 初始化统一LLM服务
llm_service = UnifiedLLMService()
messages = [
{
"role": "system",
"content": """你是一名经验丰富的短剧编剧,擅长根据字幕内容按照先后顺序分析关键剧情,并找出 %s 个关键片段。
请返回一个JSON对象包含以下字段
{
"summary": "整体剧情梗概",
"plot_titles": [
"关键剧情1",
"关键剧情2",
"关键剧情3",
"关键剧情4",
"关键剧情5",
"..."
]
}
请确保返回的是合法的JSON格式, 请确保返回的是 %s 个片段
""" % (custom_clips, custom_clips)
},
{
"role": "user",
"content": f"srt字幕如下{subtitle_content}"
# 如果没有指定provider根据model_name推断
if not provider:
if "deepseek" in model_name.lower():
provider = "deepseek"
elif "gpt" in model_name.lower():
provider = "openai"
elif "gemini" in model_name.lower():
provider = "gemini"
else:
provider = "openai" # 默认使用openai
logger.info(f"使用LLM服务分析字幕提供商: {provider}, 模型: {model_name}")
# 使用新的提示词管理系统
subtitle_analysis_prompt = PromptManager.get_prompt(
category="short_drama_editing",
name="subtitle_analysis",
parameters={
"subtitle_content": subtitle_content,
"custom_clips": custom_clips
}
]
# DeepSeek R1 和 V3 不支持 response_format=json_object
try:
completion = client.chat.completions.create(
model=model_name,
messages=messages,
response_format={"type": "json_object"}
)
summary_data = json.loads(completion.choices[0].message.content)
except BadRequestError as e:
completion = client.chat.completions.create(
model=model_name,
messages=messages
)
# 去除 completion 字符串前的 ```json 和 结尾的 ```
completion = completion.choices[0].message.content.replace("```json", "").replace("```", "")
summary_data = json.loads(completion)
except Exception as e:
raise Exception(f"大模型解析发生错误:{str(e)}\n{traceback.format_exc()}")
)
# 使用统一LLM服务生成文本
logger.info("开始分析字幕内容...")
response = _run_async_safely(
UnifiedLLMService.generate_text,
prompt=subtitle_analysis_prompt,
provider=provider,
model=model_name,
api_key=api_key,
base_url=base_url,
temperature=0.1, # 使用较低的温度以获得更稳定的结果
max_tokens=4000
)
# 解析JSON响应
from webui.tools.generate_short_summary import parse_and_fix_json
summary_data = parse_and_fix_json(response)
if not summary_data:
raise Exception("无法解析LLM返回的JSON数据")
logger.info(f"字幕分析完成,找到 {len(summary_data.get('plot_titles', []))} 个关键情节")
print(json.dumps(summary_data, indent=4, ensure_ascii=False))
# 获取爆点时间段分析
prompt = f"""剧情梗概:
{summary_data['summary']}
需要定位的爆点内容
"""
# 构建爆点标题列表
plot_titles_text = ""
print(f"找到 {len(summary_data['plot_titles'])} 个片段")
for i, point in enumerate(summary_data['plot_titles'], 1):
prompt += f"{i}. {point}\n"
plot_titles_text += f"{i}. {point}\n"
messages = [
{
"role": "system",
"content": """你是一名短剧编剧,非常擅长根据字幕中分析视频中关键剧情出现的具体时间段。
请仔细阅读剧情梗概和爆点内容然后在字幕中找出每个爆点发生的具体时间段和爆点前后的详细剧情
请返回一个JSON对象包含一个名为"plot_points"的数组数组中包含多个对象每个对象都要包含以下字段
{
"plot_points": [
{
"timestamp": "时间段格式为xx:xx:xx,xxx-xx:xx:xx,xxx",
"title": "关键剧情的主题",
"picture": "关键剧情前后的详细剧情描述"
}
]
}
请确保返回的是合法的JSON格式"""
},
{
"role": "user",
"content": f"""字幕内容:
{subtitle_content}
{prompt}"""
# 使用新的提示词管理系统
plot_extraction_prompt = PromptManager.get_prompt(
category="short_drama_editing",
name="plot_extraction",
parameters={
"subtitle_content": subtitle_content,
"plot_summary": summary_data['summary'],
"plot_titles": plot_titles_text
}
]
# DeepSeek R1 和 V3 不支持 response_format=json_object
try:
completion = client.chat.completions.create(
model=model_name,
messages=messages,
response_format={"type": "json_object"}
)
plot_points_data = json.loads(completion.choices[0].message.content)
except BadRequestError as e:
completion = client.chat.completions.create(
model=model_name,
messages=messages
)
# 去除 completion 字符串前的 ```json 和 结尾的 ```
completion = completion.choices[0].message.content.replace("```json", "").replace("```", "")
plot_points_data = json.loads(completion)
except Exception as e:
raise Exception(f"大模型解析错误:{str(e)}\n{traceback.format_exc()}")
)
print(json.dumps(plot_points_data, indent=4, ensure_ascii=False))
# 使用统一LLM服务进行爆点时间段分析
logger.info("开始分析爆点时间段...")
response = _run_async_safely(
UnifiedLLMService.generate_text,
prompt=plot_extraction_prompt,
provider=provider,
model=model_name,
api_key=api_key,
base_url=base_url,
temperature=0.1,
max_tokens=4000
)
# 解析JSON响应
plot_data = parse_and_fix_json(response)
if not plot_data:
raise Exception("无法解析爆点分析的JSON数据")
logger.info(f"爆点分析完成,找到 {len(plot_data.get('plot_points', []))} 个时间段")
# 合并结果
return {
"plot_summary": summary_data,
"plot_points": plot_points_data["plot_points"]
result = {
"summary": summary_data.get("summary", ""),
"plot_titles": summary_data.get("plot_titles", []),
"plot_points": plot_data.get("plot_points", [])
}
return result
except Exception as e:
logger.error(f"分析字幕时发生错误: {str(e)}")
raise Exception(f"分析字幕时发生错误:{str(e)}\n{traceback.format_exc()}")

View File

@ -4,7 +4,7 @@
'''
@Project: NarratoAI
@File : audio_normalizer
@Author : 小林同学
@Author : Viccy同学
@Date : 2025/1/7
@Description: 音频响度分析和标准化工具
'''

View File

@ -4,7 +4,7 @@
'''
@Project: NarratoAI
@File : clip_video
@Author : 小林同学
@Author : Viccy同学
@Date : 2025/5/6 下午6:14
'''

View File

@ -4,16 +4,23 @@
'''
@Project: NarratoAI
@File : 生成介绍文案
@Author : 小林同学
@Author : Viccy同学
@Date : 2025/5/8 上午11:33
'''
import json
import os
import traceback
import asyncio
from openai import OpenAI
from loguru import logger
# 导入新的LLM服务模块 - 确保提供商被注册
import app.services.llm # 这会触发提供商注册
from app.services.llm.migration_adapter import generate_narration as generate_narration_new
# 导入新的提示词管理系统
from app.services.prompts import PromptManager
def parse_frame_analysis_to_markdown(json_file_path):
"""
@ -79,104 +86,52 @@ def parse_frame_analysis_to_markdown(json_file_path):
def generate_narration(markdown_content, api_key, base_url, model):
"""
调用OpenAI API根据视频帧分析的Markdown内容生成解说文案
调用大模型API根据视频帧分析的Markdown内容生成解说文案 - 已重构为使用新的LLM服务架构
:param markdown_content: Markdown格式的视频帧分析内容
:param api_key: OpenAI API密钥
:param base_url: API基础URL如果使用非官方API
:param api_key: API密钥
:param base_url: API基础URL
:param model: 使用的模型名称
:return: 生成的解说文案
"""
try:
# 构建提示词
prompt = """
我是一名荒野建造解说的博主以下是一些同行的对标文案请你深度学习并总结这些文案的风格特点跟内容特点
# 优先使用新的LLM服务架构
logger.info("使用新的LLM服务架构生成解说文案")
result = generate_narration_new(markdown_content, api_key, base_url, model)
return result
<example_text_1>
解压助眠的天花板就是荒野建造沉浸丝滑的搭建过程可以说每一帧都是极致享受我保证强迫症来了都找不出一丁点毛病更别说全屋严丝合缝的拼接工艺还能轻松抵御零下二十度气温让你居住的每一天都温暖如春
在家闲不住的西姆今天也打算来一次野外建造行走没多久他就发现许多倒塌的树任由它们自生自灭不如将其利用起来想到这他就开始挥舞铲子要把地基挖掘出来虽然每次只能挖一点点但架不住他体能惊人没多长时间一个 2x3 的深坑就赫然出现这深度住他一人绰绰有余
随后他去附近收集来原木这些都是搭建墙壁的最好材料而在投入使用前自然要把表皮刮掉防止森林中的白蚁蛀虫处理好一大堆后西姆还在两端打孔使用木钉固定在一起这可不是用来做墙壁的而是做庇护所的承重柱只要木头间的缝隙足够紧密那搭建出的木屋就能足够坚固
每向上搭建一层他都会在中间塞入苔藓防寒保证不会泄露一丝热量其他几面也是用相同方法很快西姆就做好了三面墙壁每一根木头都极其工整保证强迫症来了都要点个赞再走
在继续搭建墙壁前西姆决定将壁炉制作出来毕竟森林夜晚的气温会很低保暖措施可是重中之重完成后他找来一块大树皮用来充当庇护所的大门而上面刮掉的木屑还能作为壁炉的引火物可以说再完美不过
测试了排烟没问题后他才开始搭建最后一面墙壁这一面要预留门和窗所以在搭建到一半后还需要在原木中间开出卡口让自己劈砍时能轻松许多此时只需将另外一根如法炮制两端拼接在一起后就是一扇大小适中的窗户而随着随后一层苔藓铺好最后一根原木落位这个庇护所的雏形就算完成
大门的安装他没选择用合页而是在底端雕刻出榫头门框上则雕刻出榫眼只能说西姆的眼就是一把尺这完全就是严丝合缝此时他才开始搭建屋顶这里西姆用的方法不同他先把最外围的原木固定好随后将原木平铺在上面就能得到完美的斜面屋顶等他将四周的围栏也装好后工整的屋顶看起来十分舒服西姆躺上去都不想动
稍作休息后他利用剩余的苔藓对屋顶的缝隙处密封可这样西姆觉得不够保险于是他找来一些黏土再次对原本的缝隙二次加工保管这庇护所冬天也暖和最后只需要平铺上枯叶以及挖掘出的泥土整个屋顶就算完成
考虑到庇护所的美观性自然少不了覆盖上苔藓翠绿的颜色看起来十分舒服就连门口的庭院旁他都移植了许多小树做点缀让这木屋与周边环境融为一体西姆才刚完成好这件事一场大雨就骤然降临好在此时的他已经不用淋雨更别说这屋顶防水十分不错室内没一点雨水渗透进来
等待温度回升的过程西姆利用墙壁本身的凹槽把床框镶嵌在上面只需要铺上苔藓以及自带的床单枕头一张完美的单人床就做好辛苦劳作一整天西姆可不会亏待自己他将自带的牛肉腌制好后直接放到壁炉中烤只需要等待三十分钟就能享受这美味的一顿
在辛苦建造一星期后他终于可以在自己搭建的庇护所中享受最纯正的野外露营后面西姆回家补给了一堆物资再次回来时森林已经大雪纷飞让他原本翠绿的小屋更换上了冬季限定皮肤好在内部设施没受什么影响和他离开时一样整洁
就是房间中已经没多少柴火让西姆今天又得劈柴寒冷干燥的天气让木头劈起来十分轻松没多久他就收集到一大堆这些足够燃烧好几天虽然此时外面大雪纷飞但小屋中却开始逐渐温暖这次他除了带来一些食物外还有几瓶调味料以及一整套被褥让自己的居住舒适度提高一大截
而秋天他有收集干草的缘故只需要塞入枕套中密封起来就能作为靠垫用就这居住条件比一般人在家过的还要奢侈趁着壁炉木头变木炭的过程西姆则开始不紧不慢的处理食物他取出一块牛排改好花刀以后撒上一堆调料腌制起来接着用锡纸包裹好放到壁炉中直接炭烤搭配上自带的红酒是一个非常好的选择
随着时间来到第二天外面的积雪融化了不少西姆简单做顿煎蛋补充体力后决定制作一个室外篝火堆用来晚上驱散周边野兽搭建这玩意没什么技巧只需要找到一大堆木棍利用大树的夹缝将其掰弯然后将其堆积在一起就是一个简易版的篝火堆看这外形有点像帐篷好在西姆没想那么多
等待天色暗淡下来后他才来到室外将其点燃顺便处理下多余的废料只可惜这场景没朋友陪在身边对西姆来说可能是个遗憾而哪怕森林只有他一个人都依旧做了好几个小时等到里面的篝火彻底燃尽后西姆还找来雪球覆盖到上面将火熄灭这防火意识可谓十分好最后在室内二十五度的高温下裹着被子睡觉
</example_text_1>
except Exception as e:
logger.warning(f"使用新LLM服务失败回退到旧实现: {str(e)}")
# 回退到旧的实现以确保兼容性
return _generate_narration_legacy(markdown_content, api_key, base_url, model)
def _generate_narration_legacy(markdown_content, api_key, base_url, model):
"""
旧的解说文案生成实现 - 保留作为备用方案
:param markdown_content: Markdown格式的视频帧分析内容
:param api_key: API密钥
:param base_url: API基础URL
:param model: 使用的模型名称
:return: 生成的解说文案
"""
try:
# 使用新的提示词管理系统构建提示词
prompt = PromptManager.get_prompt(
category="documentary",
name="narration_generation",
parameters={
"video_frame_description": markdown_content
}
)
<example_text_2>
解压助眠的天花板就是荒野建造沉浸丝滑的搭建过程每一帧都是极致享受全屋严丝合缝的拼接工艺能轻松抵御零下二十度气温居住体验温暖如春
在家闲不住的西姆开启野外建造他发现倒塌的树决定加以利用先挖掘出 2x3 的深坑作为地基接着收集原木刮掉表皮防白蚁蛀虫打孔用木钉固定制作承重柱搭建墙壁时每一层都塞入苔藓防寒很快做好三面墙
为应对森林夜晚低温西姆制作壁炉用大树皮当大门刮下的木屑做引火物搭建最后一面墙时预留门窗通过在原木中间开口拼接做出窗户大门采用榫卯结构安装严丝合缝
搭建屋顶时先固定外围原木再平铺原木形成斜面屋顶之后用苔藓黏土密封缝隙铺上枯叶和泥土为美观在木屋覆盖苔藓移植小树点缀完工时遇大雨木屋防水良好
西姆利用墙壁凹槽镶嵌床框铺上苔藓床单枕头做成床劳作一天后他用壁炉烤牛肉享用建造一星期后他开始野外露营
后来西姆回家补给物资回来时森林大雪纷飞他劈柴储备带回食物调味料和被褥提高居住舒适度还用干草做靠垫他用壁炉烤牛排搭配红酒
第二天积雪融化西姆制作室外篝火堆防野兽用大树夹缝掰弯木棍堆积而成晚上点燃处理废料结束后用雪球灭火最后在室内二十五度的环境中裹被入睡
</example_text_2>
<example_text_3>
如果战争到来这个深埋地下十几米的庇护所绝对是 bug 般的存在即使被敌人发现还能通过快速通道一秒逃出里面不仅有竹子地暖地下水井还自制抽水机在解决用水问题的同时甚至自研无土栽培技术过上完全自给自足的生活
阿伟的老婆美如花但阿伟从来不回家来到野外他乐哈哈一言不合就开挖众所周知当战争来临时地下堡垒的安全性是最高的阿伟苦苦研习两载半只为练就一身挖洞本领在这双逆天麒麟臂的加持下如此坚硬的泥土都只能当做炮灰
得到了充足的空间后他便开始对这些边缘进行打磨随后阿伟将细线捆在木棍上以此描绘出圆柱的轮廓接着再一点点铲掉多余的部分虽然是由泥土一体式打造但这样的桌子保准用上千年都不成问题
考虑到十几米的深度进出非常不方便于是阿伟找来两根长达 66.6 米的木头打算为庇护所打造一条快速通道只见他将木桩牢牢地插入地下并顺着洞口的方向延伸出去直到贯穿整个山洞接着在每个木桩的连接处钉入铁钉确保轨道不能有一毫米的偏差完成后再制作一个木质框架从而达到前后滑动的效果
不得不说阿伟这手艺简直就是大钢管子杵青蛙在上面放上一个木制的车斗还能加快搬运泥土的速度没多久庇护所的内部就已经初见雏形为了住起来更加舒适还需要为自己打造一张床虽然深处的泥土同样很坚固但好处就是不用担心垮塌的风险
阿伟不仅设计了更加符合人体工学的拱形并且还在一旁雕刻处壁龛就是这氛围怎么看着有点不太吉利别看阿伟一身腱子肉但这身体里的艺术细菌可不少每个边缘的地方他都做了精雕细琢瞬间让整个卧室的颜值提升一大截
住在地下的好处就是房子面积全靠挖每平方消耗两个半馒头不仅没有了房贷的压力就连买墓地的钱也省了阿伟将中间的墙壁挖空从而得到取暖的壁炉当然最重要的还有排烟问题要想从上往下打通十几米的山体是件极其困难的事好在阿伟年轻时报过忆坤年的古墓派补习班这打洞技术堪比隔壁学校的土拨鼠专业虽然深度长达十几米但排烟效果却一点不受影响一个字专业
随后阿伟继续对壁炉底部雕刻打通了底部放柴火的空间并制作出放锅的灶头完成后阿伟从侧面将壁炉打通并制作出一条导热的通道以此连接到床铺的位置毕竟住在这么一个风湿宝地不注意保暖除湿很容易得老寒腿
阿伟在床面上挖出一条条管道以便于温度能传输到床的每个角落接下来就可以根据这些通道的长度裁切出同样长短的竹子根据竹筒的大小凿出相互连接的孔洞最后再将竹筒内部打通以达到温度传送的效果
而后阿伟将这些管道安装到凹槽内在他严谨的制作工艺下每根竹子刚好都能镶嵌进去在铺设床面之前还需要用木塞把圆孔堵住防止泥土掉落进管道泥土虽然不能隔绝湿气但却是十分优良的导热材料等他把床面都压平后就可以小心的将这些木塞拔出来最后再用黏土把剩余的管道也遮盖起来直到整个墙面恢复原样
接下来还需要测试一下加热效果当他把火点起来后温度很快就传送到了管道内把火力一点点加大直到热气流淌到更远的床面随着小孔里的青烟冒出也预示着阿伟的地暖可以投入使用而后阿伟制作了一些竹条并用细绳将它们喜结连理
千里之行始于足下美好的家园要靠自己双手打造明明可以靠才艺吃饭的阿伟偏偏要用八块腹肌征服大家就问这样的男人哪个野生婆娘不喜欢完成后阿伟还用自己 35 码的大腚感受了一下真烫
随后阿伟来到野区找到一根上好的雷击木他当即就把木头咔嚓成两段并取下两节较为完整的带了回去刚好能和圆桌配套另外一个在里面凿出凹槽并插入木棍连接得到一个夯土的木锤住过农村的小伙伴都知道这样夯出来的地面堪比水泥地不仅坚硬耐磨还不用担心脚底打滑忙碌了一天的阿伟已经饥渴难耐拿出野生小烤肠安安心心住新房光脚爬上大热炕一觉能睡到天亮
第二天阿伟打算将房间扩宽毕竟吃住的地方有了还要解决个人卫生的问题阿伟在另一侧增加了一个房间他打算将这里打造成洗澡的地方为了防止泥土垮塌他将顶部做成圆弧形等挖出足够的空间后旁边的泥土已经堆成了小山
为了方便清理这些泥土阿伟在之前的轨道增加了转弯交接处依然是用铁钉固定一直延伸到房间的最里面有了运输车的帮助这些成吨的泥土也能轻松的运送出去并且还能体验过山车的感觉很快他就完成了清理工作
为了更方便的在里面洗澡他将底部一点点挖空这么大的浴缸看来阿伟并不打算一个人住完成后他将墙面雕刻的凹凸有致让这里看起来更加豪华接着用洛阳铲挖出排水口并用一根相同大小的竹筒作为开关
由于四周都是泥土还不能防水阿伟特意找了一些白蚁巢用来制作可以防水的野生水泥现在就可以将里里外外能接触到水的地方都涂抹一遍细心的阿伟还找来这种 500 克一斤的鹅卵石对池子表面进行装饰
没错水源问题阿伟早已经考虑在内他打算直接在旁边挖个水井毕竟已经挖了这么深再向下挖一挖应该就能到达地下水的深度经过几日的奋战能看得出阿伟已经消瘦了不少但一想到马上就能拥有的豪宅他直接化身为无情的挖土机器很快就挖到了好几米的深度
考虑到自己的弹跳力有限阿伟在一旁定入木桩然后通过绳子爬上爬下随着深度越来越深井底已经开始渗出水来这也预示着打井成功没多久这里面将渗满泉水仅凭一次就能挖到水源看来这里还真是块风湿宝地
随后阿伟在井口四周挖出凹槽以便于井盖的安置这一量才知道井的深度已经达到了足足的 5 阿伟把木板组合在一起再沿着标记切掉多余部分他甚至还给井盖做了把手可是如何从这么深的井里打水还是个问题但从阿伟坚定的眼神来看他应该想到了解决办法
只见他将树桩锯成两半然后用凿子把里面一点点掏空另外一半也是如法炮制接着还要在底部挖出圆孔要想成功将水从 5 米深的地方抽上来那就不得不提到大家熟知的勾股定理没错这跟勾股定理没什么关系
阿伟给竹筒做了一个木塞并在里面打上安装连接轴的孔为了增加密闭性阿伟不得不牺牲了自己的 AJ剪出与木塞相同的大小后再用木钉固定住随后他收集了一些树胶并放到火上加热融化接下来就可以涂在木塞上增加使用寿命
现在将竹筒组装完成就可以利用虹吸原理将水抽上来完成后就可以把井盖盖上去再用泥土在上面覆盖现在就不用担心失足掉下去了
接下来阿伟去采集了一些大漆将它涂抹在木桶接缝处就能将其二合为一完了再接入旁边浴缸的入水口每个连接的地方都要做好密封不然后面很容易漏水随后就可以安装上活塞并用一根木桩作为省力杠杆根据空气压强的原理将井水抽上来
经过半小时的来回拉扯硕大的浴缸终于被灌满阿伟也是忍不住洗了把脸接下来还需要解决排水的问题阿伟在地上挖出沟渠一直贯穿到屋外然后再用竹筒从出水口连接每个接口处都要抹上胶水就连门外的出水口他都做了隐藏
在野外最重要的就是庇护所水源还有食物既然已经完成了前二者那么阿伟还需要拥有可持续发展的食物来源他先是在地上挖了两排地洞然后在每根竹筒的表面都打上无数孔洞这就是他打算用来种植的载体在此之前还需要用大火对竹筒进行杀菌消毒
趁着这时候他去搬了一麻袋的木屑先用芭蕉叶覆盖在上面再铺上厚厚的黏土隔绝温度在火焰的温度下能让里面的木屑达到生长条件
等到第二天所有材料都晾凉后阿伟才将竹筒内部掏空并将木屑一点点地塞入竹筒一切准备就绪就可以将竹筒插入提前挖好的地洞最后再往竹筒里塞入种子依靠房间内的湿度和温度就能达到大棚种植的效果稍加时日这些种子就会慢慢发芽
虽然暂时还吃不上自己培养的食物但好在阿伟从表哥贺强那里学到不少钓鱼本领哪怕只有一根小小的竹竿也能让他钓上两斤半的大鲶鱼新鲜的食材那肯定是少不了高温消毒的过程趁着鱼没熟阿伟直接爬进浴缸冰凉的井水瞬间洗去了身上的疲惫这一刻的阿伟是无比的享受
不久后鱼也烤得差不多了阿伟的生活现在可以说是有滋有味住在十几米的地下不仅能安全感满满哪怕遇到危险还能通过轨道快速逃生
<example_text_3>
<video_frame_description>
%s
</video_frame_description>
我正在尝试做这个内容的解说纪录片视频我需要你以 <video_frame_description> </video_frame_description> 中的内容为解说目标根据我刚才提供给你的对标文案 <example_text> 特点以及你总结的特点帮我生成一段关于荒野建造的解说文案文案需要符合平台受欢迎的解说风格请使用 json 格式进行输出使用 <output> 中的输出格式
<output>
{
"items": [
{
"_id": 1, # 唯一递增id
"timestamp": "00:00:05,390-00:00:10,430",
"picture": "画面描述",
"narration": "解说文案",
}
}
</output>
<restriction>
1. 只输出 json 内容不要输出其他任何说明性的文字
2. 解说文案的语言使用 简体中文
3. 严禁虚构画面所有画面只能从 <video_frame_description> 中摘取
</restriction>
""" % (markdown_content)
# 使用OpenAI SDK初始化客户端
client = OpenAI(

View File

@ -4,7 +4,7 @@
'''
@Project: NarratoAI
@File : generate_video
@Author : 小林同学
@Author : Viccy同学
@Date : 2025/5/7 上午11:55
'''

View File

@ -0,0 +1,52 @@
"""
NarratoAI 大模型服务模块
统一的大模型服务抽象层支持多供应商切换和严格的输出格式验证
包含视觉模型和文本生成模型的统一接口
主要组件:
- BaseLLMProvider: 大模型服务提供商基类
- VisionModelProvider: 视觉模型提供商基类
- TextModelProvider: 文本模型提供商基类
- LLMServiceManager: 大模型服务管理器
- OutputValidator: 输出格式验证器
支持的供应商:
视觉模型: Gemini, QwenVL, Siliconflow
文本模型: OpenAI, DeepSeek, Gemini, Qwen, Moonshot, Siliconflow
"""
from .manager import LLMServiceManager
from .base import BaseLLMProvider, VisionModelProvider, TextModelProvider
from .validators import OutputValidator, ValidationError
from .exceptions import LLMServiceError, ProviderNotFoundError, ConfigurationError
# 确保提供商在模块导入时被注册
def _ensure_providers_registered():
"""确保所有提供商都已注册"""
try:
# 导入providers模块会自动执行注册
from . import providers
from loguru import logger
logger.debug("LLM服务提供商注册完成")
except Exception as e:
from loguru import logger
logger.error(f"LLM服务提供商注册失败: {str(e)}")
# 自动注册提供商
_ensure_providers_registered()
__all__ = [
'LLMServiceManager',
'BaseLLMProvider',
'VisionModelProvider',
'TextModelProvider',
'OutputValidator',
'ValidationError',
'LLMServiceError',
'ProviderNotFoundError',
'ConfigurationError'
]
# 版本信息
__version__ = '1.0.0'

175
app/services/llm/base.py Normal file
View File

@ -0,0 +1,175 @@
"""
大模型服务提供商基类定义
定义了统一的大模型服务接口包括视觉模型和文本生成模型的抽象基类
"""
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Union
from pathlib import Path
import PIL.Image
from loguru import logger
from .exceptions import LLMServiceError, ConfigurationError
class BaseLLMProvider(ABC):
"""大模型服务提供商基类"""
def __init__(self,
api_key: str,
model_name: str,
base_url: Optional[str] = None,
**kwargs):
"""
初始化大模型服务提供商
Args:
api_key: API密钥
model_name: 模型名称
base_url: API基础URL
**kwargs: 其他配置参数
"""
self.api_key = api_key
self.model_name = model_name
self.base_url = base_url
self.config = kwargs
# 验证必要配置
self._validate_config()
# 初始化提供商特定设置
self._initialize()
@property
@abstractmethod
def provider_name(self) -> str:
"""供应商名称"""
pass
@property
@abstractmethod
def supported_models(self) -> List[str]:
"""支持的模型列表"""
pass
def _validate_config(self):
"""验证配置参数"""
if not self.api_key:
raise ConfigurationError("API密钥不能为空", "api_key")
if not self.model_name:
raise ConfigurationError("模型名称不能为空", "model_name")
if self.model_name not in self.supported_models:
from .exceptions import ModelNotSupportedError
raise ModelNotSupportedError(self.model_name, self.provider_name)
def _initialize(self):
"""初始化提供商特定设置,子类可重写"""
pass
@abstractmethod
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""执行API调用子类必须实现"""
pass
def _handle_api_error(self, status_code: int, response_text: str) -> LLMServiceError:
"""处理API错误返回适当的异常"""
from .exceptions import APICallError, RateLimitError, AuthenticationError
if status_code == 401:
return AuthenticationError()
elif status_code == 429:
return RateLimitError()
else:
return APICallError(f"HTTP {status_code}", status_code, response_text)
class VisionModelProvider(BaseLLMProvider):
"""视觉模型提供商基类"""
@abstractmethod
async def analyze_images(self,
images: List[Union[str, Path, PIL.Image.Image]],
prompt: str,
batch_size: int = 10,
**kwargs) -> List[str]:
"""
分析图片并返回结果
Args:
images: 图片路径列表或PIL图片对象列表
prompt: 分析提示词
batch_size: 批处理大小
**kwargs: 其他参数
Returns:
分析结果列表
"""
pass
def _prepare_images(self, images: List[Union[str, Path, PIL.Image.Image]]) -> List[PIL.Image.Image]:
"""预处理图片统一转换为PIL.Image对象"""
processed_images = []
for img in images:
try:
if isinstance(img, (str, Path)):
pil_img = PIL.Image.open(img)
elif isinstance(img, PIL.Image.Image):
pil_img = img
else:
logger.warning(f"不支持的图片类型: {type(img)}")
continue
# 调整图片大小以优化性能
if pil_img.size[0] > 1024 or pil_img.size[1] > 1024:
pil_img.thumbnail((1024, 1024), PIL.Image.Resampling.LANCZOS)
processed_images.append(pil_img)
except Exception as e:
logger.error(f"加载图片失败 {img}: {str(e)}")
continue
return processed_images
class TextModelProvider(BaseLLMProvider):
"""文本生成模型提供商基类"""
@abstractmethod
async def generate_text(self,
prompt: str,
system_prompt: Optional[str] = None,
temperature: float = 1.0,
max_tokens: Optional[int] = None,
response_format: Optional[str] = None,
**kwargs) -> str:
"""
生成文本内容
Args:
prompt: 用户提示词
system_prompt: 系统提示词
temperature: 生成温度
max_tokens: 最大token数
response_format: 响应格式 ('json' None)
**kwargs: 其他参数
Returns:
生成的文本内容
"""
pass
def _build_messages(self, prompt: str, system_prompt: Optional[str] = None) -> List[Dict[str, str]]:
"""构建消息列表"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
return messages

View File

@ -0,0 +1,307 @@
"""
LLM服务配置验证器
验证大模型服务的配置是否正确并提供配置建议
"""
from typing import Dict, List, Any, Optional
from loguru import logger
from app.config import config
from .manager import LLMServiceManager
from .exceptions import ConfigurationError
class LLMConfigValidator:
"""LLM服务配置验证器"""
@staticmethod
def validate_all_configs() -> Dict[str, Any]:
"""
验证所有LLM服务配置
Returns:
验证结果字典
"""
results = {
"vision_providers": {},
"text_providers": {},
"summary": {
"total_vision_providers": 0,
"valid_vision_providers": 0,
"total_text_providers": 0,
"valid_text_providers": 0,
"errors": [],
"warnings": []
}
}
# 验证视觉模型提供商
vision_providers = LLMServiceManager.list_vision_providers()
results["summary"]["total_vision_providers"] = len(vision_providers)
for provider in vision_providers:
try:
validation_result = LLMConfigValidator.validate_vision_provider(provider)
results["vision_providers"][provider] = validation_result
if validation_result["is_valid"]:
results["summary"]["valid_vision_providers"] += 1
else:
results["summary"]["errors"].extend(validation_result["errors"])
except Exception as e:
error_msg = f"验证视觉模型提供商 {provider} 时发生错误: {str(e)}"
results["vision_providers"][provider] = {
"is_valid": False,
"errors": [error_msg],
"warnings": []
}
results["summary"]["errors"].append(error_msg)
# 验证文本模型提供商
text_providers = LLMServiceManager.list_text_providers()
results["summary"]["total_text_providers"] = len(text_providers)
for provider in text_providers:
try:
validation_result = LLMConfigValidator.validate_text_provider(provider)
results["text_providers"][provider] = validation_result
if validation_result["is_valid"]:
results["summary"]["valid_text_providers"] += 1
else:
results["summary"]["errors"].extend(validation_result["errors"])
except Exception as e:
error_msg = f"验证文本模型提供商 {provider} 时发生错误: {str(e)}"
results["text_providers"][provider] = {
"is_valid": False,
"errors": [error_msg],
"warnings": []
}
results["summary"]["errors"].append(error_msg)
return results
@staticmethod
def validate_vision_provider(provider_name: str) -> Dict[str, Any]:
"""
验证视觉模型提供商配置
Args:
provider_name: 提供商名称
Returns:
验证结果字典
"""
result = {
"is_valid": False,
"errors": [],
"warnings": [],
"config": {}
}
try:
# 获取配置
config_prefix = f"vision_{provider_name}"
api_key = config.app.get(f'{config_prefix}_api_key')
model_name = config.app.get(f'{config_prefix}_model_name')
base_url = config.app.get(f'{config_prefix}_base_url')
result["config"] = {
"api_key": "***" if api_key else None,
"model_name": model_name,
"base_url": base_url
}
# 验证必需配置
if not api_key:
result["errors"].append(f"缺少API密钥配置: {config_prefix}_api_key")
if not model_name:
result["errors"].append(f"缺少模型名称配置: {config_prefix}_model_name")
# 尝试创建提供商实例
if api_key and model_name:
try:
provider_instance = LLMServiceManager.get_vision_provider(provider_name)
result["is_valid"] = True
logger.debug(f"视觉模型提供商 {provider_name} 配置验证成功")
except Exception as e:
result["errors"].append(f"创建提供商实例失败: {str(e)}")
# 添加警告
if not base_url:
result["warnings"].append(f"未配置base_url将使用默认值")
except Exception as e:
result["errors"].append(f"配置验证过程中发生错误: {str(e)}")
return result
@staticmethod
def validate_text_provider(provider_name: str) -> Dict[str, Any]:
"""
验证文本模型提供商配置
Args:
provider_name: 提供商名称
Returns:
验证结果字典
"""
result = {
"is_valid": False,
"errors": [],
"warnings": [],
"config": {}
}
try:
# 获取配置
config_prefix = f"text_{provider_name}"
api_key = config.app.get(f'{config_prefix}_api_key')
model_name = config.app.get(f'{config_prefix}_model_name')
base_url = config.app.get(f'{config_prefix}_base_url')
result["config"] = {
"api_key": "***" if api_key else None,
"model_name": model_name,
"base_url": base_url
}
# 验证必需配置
if not api_key:
result["errors"].append(f"缺少API密钥配置: {config_prefix}_api_key")
if not model_name:
result["errors"].append(f"缺少模型名称配置: {config_prefix}_model_name")
# 尝试创建提供商实例
if api_key and model_name:
try:
provider_instance = LLMServiceManager.get_text_provider(provider_name)
result["is_valid"] = True
logger.debug(f"文本模型提供商 {provider_name} 配置验证成功")
except Exception as e:
result["errors"].append(f"创建提供商实例失败: {str(e)}")
# 添加警告
if not base_url:
result["warnings"].append(f"未配置base_url将使用默认值")
except Exception as e:
result["errors"].append(f"配置验证过程中发生错误: {str(e)}")
return result
@staticmethod
def get_config_suggestions() -> Dict[str, Any]:
"""
获取配置建议
Returns:
配置建议字典
"""
suggestions = {
"vision_providers": {},
"text_providers": {},
"general_tips": [
"确保所有API密钥都已正确配置",
"建议为每个提供商配置base_url以提高稳定性",
"定期检查模型名称是否为最新版本",
"建议配置多个提供商作为备用方案"
]
}
# 为每个视觉模型提供商提供建议
vision_providers = LLMServiceManager.list_vision_providers()
for provider in vision_providers:
suggestions["vision_providers"][provider] = {
"required_configs": [
f"vision_{provider}_api_key",
f"vision_{provider}_model_name"
],
"optional_configs": [
f"vision_{provider}_base_url"
],
"example_models": LLMConfigValidator._get_example_models(provider, "vision")
}
# 为每个文本模型提供商提供建议
text_providers = LLMServiceManager.list_text_providers()
for provider in text_providers:
suggestions["text_providers"][provider] = {
"required_configs": [
f"text_{provider}_api_key",
f"text_{provider}_model_name"
],
"optional_configs": [
f"text_{provider}_base_url"
],
"example_models": LLMConfigValidator._get_example_models(provider, "text")
}
return suggestions
@staticmethod
def _get_example_models(provider_name: str, model_type: str) -> List[str]:
"""获取示例模型名称"""
examples = {
"gemini": {
"vision": ["gemini-2.0-flash-lite", "gemini-2.0-flash"],
"text": ["gemini-2.0-flash", "gemini-1.5-pro"]
},
"openai": {
"vision": [],
"text": ["gpt-4o", "gpt-4o-mini", "gpt-4-turbo"]
},
"qwen": {
"vision": ["qwen2.5-vl-32b-instruct"],
"text": ["qwen-plus-1127", "qwen-turbo"]
},
"deepseek": {
"vision": [],
"text": ["deepseek-chat", "deepseek-reasoner"]
},
"siliconflow": {
"vision": ["Qwen/Qwen2.5-VL-32B-Instruct"],
"text": ["deepseek-ai/DeepSeek-R1", "Qwen/Qwen2.5-72B-Instruct"]
}
}
return examples.get(provider_name, {}).get(model_type, [])
@staticmethod
def print_validation_report(validation_results: Dict[str, Any]):
"""
打印验证报告
Args:
validation_results: 验证结果
"""
summary = validation_results["summary"]
print("\n" + "="*60)
print("LLM服务配置验证报告")
print("="*60)
print(f"\n📊 总体统计:")
print(f" 视觉模型提供商: {summary['valid_vision_providers']}/{summary['total_vision_providers']} 有效")
print(f" 文本模型提供商: {summary['valid_text_providers']}/{summary['total_text_providers']} 有效")
if summary["errors"]:
print(f"\n❌ 错误 ({len(summary['errors'])}):")
for error in summary["errors"]:
print(f" - {error}")
if summary["warnings"]:
print(f"\n⚠️ 警告 ({len(summary['warnings'])}):")
for warning in summary["warnings"]:
print(f" - {warning}")
print(f"\n✅ 配置验证完成")
print("="*60)

View File

@ -0,0 +1,118 @@
"""
大模型服务异常类定义
定义了大模型服务中可能出现的各种异常类型
提供统一的错误处理机制
"""
from typing import Optional, Dict, Any
class LLMServiceError(Exception):
"""大模型服务基础异常类"""
def __init__(self, message: str, error_code: Optional[str] = None, details: Optional[Dict[str, Any]] = None):
super().__init__(message)
self.message = message
self.error_code = error_code
self.details = details or {}
def __str__(self):
if self.error_code:
return f"[{self.error_code}] {self.message}"
return self.message
class ProviderNotFoundError(LLMServiceError):
"""供应商未找到异常"""
def __init__(self, provider_name: str):
super().__init__(
message=f"未找到大模型供应商: {provider_name}",
error_code="PROVIDER_NOT_FOUND",
details={"provider_name": provider_name}
)
class ConfigurationError(LLMServiceError):
"""配置错误异常"""
def __init__(self, message: str, config_key: Optional[str] = None):
super().__init__(
message=f"配置错误: {message}",
error_code="CONFIGURATION_ERROR",
details={"config_key": config_key} if config_key else {}
)
class APICallError(LLMServiceError):
"""API调用错误异常"""
def __init__(self, message: str, status_code: Optional[int] = None, response_text: Optional[str] = None):
super().__init__(
message=f"API调用失败: {message}",
error_code="API_CALL_ERROR",
details={
"status_code": status_code,
"response_text": response_text
}
)
class ValidationError(LLMServiceError):
"""输出验证错误异常"""
def __init__(self, message: str, validation_type: Optional[str] = None, invalid_data: Optional[Any] = None):
super().__init__(
message=f"输出验证失败: {message}",
error_code="VALIDATION_ERROR",
details={
"validation_type": validation_type,
"invalid_data": str(invalid_data) if invalid_data else None
}
)
class ModelNotSupportedError(LLMServiceError):
"""模型不支持异常"""
def __init__(self, model_name: str, provider_name: str):
super().__init__(
message=f"供应商 {provider_name} 不支持模型 {model_name}",
error_code="MODEL_NOT_SUPPORTED",
details={
"model_name": model_name,
"provider_name": provider_name
}
)
class RateLimitError(LLMServiceError):
"""API速率限制异常"""
def __init__(self, message: str = "API调用频率超限", retry_after: Optional[int] = None):
super().__init__(
message=message,
error_code="RATE_LIMIT_ERROR",
details={"retry_after": retry_after}
)
class AuthenticationError(LLMServiceError):
"""认证错误异常"""
def __init__(self, message: str = "API密钥无效或权限不足"):
super().__init__(
message=message,
error_code="AUTHENTICATION_ERROR"
)
class ContentFilterError(LLMServiceError):
"""内容过滤异常"""
def __init__(self, message: str = "内容被安全过滤器阻止"):
super().__init__(
message=message,
error_code="CONTENT_FILTER_ERROR"
)

214
app/services/llm/manager.py Normal file
View File

@ -0,0 +1,214 @@
"""
大模型服务管理器
统一管理所有大模型服务提供商提供简单的工厂方法来创建和获取服务实例
"""
from typing import Dict, Type, Optional
from loguru import logger
from app.config import config
from .base import VisionModelProvider, TextModelProvider
from .exceptions import ProviderNotFoundError, ConfigurationError
class LLMServiceManager:
"""大模型服务管理器"""
# 注册的视觉模型提供商
_vision_providers: Dict[str, Type[VisionModelProvider]] = {}
# 注册的文本模型提供商
_text_providers: Dict[str, Type[TextModelProvider]] = {}
# 缓存的提供商实例
_vision_instance_cache: Dict[str, VisionModelProvider] = {}
_text_instance_cache: Dict[str, TextModelProvider] = {}
@classmethod
def register_vision_provider(cls, name: str, provider_class: Type[VisionModelProvider]):
"""注册视觉模型提供商"""
cls._vision_providers[name.lower()] = provider_class
logger.debug(f"注册视觉模型提供商: {name}")
@classmethod
def register_text_provider(cls, name: str, provider_class: Type[TextModelProvider]):
"""注册文本模型提供商"""
cls._text_providers[name.lower()] = provider_class
logger.debug(f"注册文本模型提供商: {name}")
@classmethod
def _ensure_providers_registered(cls):
"""确保提供商已注册"""
try:
# 如果没有注册的提供商强制导入providers模块
if not cls._vision_providers or not cls._text_providers:
from . import providers
logger.debug("LLMServiceManager强制注册提供商")
except Exception as e:
logger.error(f"LLMServiceManager确保提供商注册时发生错误: {str(e)}")
@classmethod
def get_vision_provider(cls, provider_name: Optional[str] = None) -> VisionModelProvider:
"""
获取视觉模型提供商实例
Args:
provider_name: 提供商名称如果不指定则从配置中获取
Returns:
视觉模型提供商实例
Raises:
ProviderNotFoundError: 提供商未找到
ConfigurationError: 配置错误
"""
# 确保提供商已注册
cls._ensure_providers_registered()
# 确定提供商名称
if not provider_name:
provider_name = config.app.get('vision_llm_provider', 'gemini').lower()
else:
provider_name = provider_name.lower()
# 检查缓存
cache_key = f"vision_{provider_name}"
if cache_key in cls._vision_instance_cache:
return cls._vision_instance_cache[cache_key]
# 检查提供商是否已注册
if provider_name not in cls._vision_providers:
raise ProviderNotFoundError(provider_name)
# 获取配置
config_prefix = f"vision_{provider_name}"
api_key = config.app.get(f'{config_prefix}_api_key')
model_name = config.app.get(f'{config_prefix}_model_name')
base_url = config.app.get(f'{config_prefix}_base_url')
if not api_key:
raise ConfigurationError(f"缺少API密钥配置: {config_prefix}_api_key")
if not model_name:
raise ConfigurationError(f"缺少模型名称配置: {config_prefix}_model_name")
# 创建提供商实例
provider_class = cls._vision_providers[provider_name]
try:
instance = provider_class(
api_key=api_key,
model_name=model_name,
base_url=base_url
)
# 缓存实例
cls._vision_instance_cache[cache_key] = instance
logger.info(f"创建视觉模型提供商实例: {provider_name} - {model_name}")
return instance
except Exception as e:
logger.error(f"创建视觉模型提供商实例失败: {provider_name} - {str(e)}")
raise ConfigurationError(f"创建提供商实例失败: {str(e)}")
@classmethod
def get_text_provider(cls, provider_name: Optional[str] = None) -> TextModelProvider:
"""
获取文本模型提供商实例
Args:
provider_name: 提供商名称如果不指定则从配置中获取
Returns:
文本模型提供商实例
Raises:
ProviderNotFoundError: 提供商未找到
ConfigurationError: 配置错误
"""
# 确保提供商已注册
cls._ensure_providers_registered()
# 确定提供商名称
if not provider_name:
provider_name = config.app.get('text_llm_provider', 'openai').lower()
else:
provider_name = provider_name.lower()
# 检查缓存
cache_key = f"text_{provider_name}"
if cache_key in cls._text_instance_cache:
return cls._text_instance_cache[cache_key]
# 检查提供商是否已注册
if provider_name not in cls._text_providers:
raise ProviderNotFoundError(provider_name)
# 获取配置
config_prefix = f"text_{provider_name}"
api_key = config.app.get(f'{config_prefix}_api_key')
model_name = config.app.get(f'{config_prefix}_model_name')
base_url = config.app.get(f'{config_prefix}_base_url')
if not api_key:
raise ConfigurationError(f"缺少API密钥配置: {config_prefix}_api_key")
if not model_name:
raise ConfigurationError(f"缺少模型名称配置: {config_prefix}_model_name")
# 创建提供商实例
provider_class = cls._text_providers[provider_name]
try:
instance = provider_class(
api_key=api_key,
model_name=model_name,
base_url=base_url
)
# 缓存实例
cls._text_instance_cache[cache_key] = instance
logger.info(f"创建文本模型提供商实例: {provider_name} - {model_name}")
return instance
except Exception as e:
logger.error(f"创建文本模型提供商实例失败: {provider_name} - {str(e)}")
raise ConfigurationError(f"创建提供商实例失败: {str(e)}")
@classmethod
def clear_cache(cls):
"""清空提供商实例缓存"""
cls._vision_instance_cache.clear()
cls._text_instance_cache.clear()
logger.info("已清空提供商实例缓存")
@classmethod
def list_vision_providers(cls) -> list[str]:
"""列出所有已注册的视觉模型提供商"""
return list(cls._vision_providers.keys())
@classmethod
def list_text_providers(cls) -> list[str]:
"""列出所有已注册的文本模型提供商"""
return list(cls._text_providers.keys())
@classmethod
def get_provider_info(cls) -> Dict[str, Dict[str, any]]:
"""获取所有提供商信息"""
return {
"vision_providers": {
name: {
"class": provider_class.__name__,
"module": provider_class.__module__
}
for name, provider_class in cls._vision_providers.items()
},
"text_providers": {
name: {
"class": provider_class.__name__,
"module": provider_class.__module__
}
for name, provider_class in cls._text_providers.items()
}
}

View File

@ -0,0 +1,348 @@
"""
迁移适配器
为现有代码提供向后兼容的接口方便逐步迁移到新的LLM服务架构
"""
import asyncio
import json
from typing import List, Dict, Any, Optional, Union
from pathlib import Path
import PIL.Image
from loguru import logger
from .unified_service import UnifiedLLMService
from .exceptions import LLMServiceError
# 导入新的提示词管理系统
from app.services.prompts import PromptManager
# 确保提供商已注册
def _ensure_providers_registered():
"""确保所有提供商都已注册"""
try:
from .manager import LLMServiceManager
# 检查是否有已注册的提供商
if not LLMServiceManager.list_text_providers() or not LLMServiceManager.list_vision_providers():
# 如果没有注册的提供商强制导入providers模块
from . import providers
logger.debug("迁移适配器强制注册LLM服务提供商")
except Exception as e:
logger.error(f"迁移适配器确保LLM服务提供商注册时发生错误: {str(e)}")
# 在模块加载时确保提供商已注册
_ensure_providers_registered()
def _run_async_safely(coro_func, *args, **kwargs):
"""
安全地运行异步协程处理各种事件循环情况
Args:
coro_func: 协程函数不是协程对象
*args: 协程函数的位置参数
**kwargs: 协程函数的关键字参数
Returns:
协程的执行结果
"""
def run_in_new_loop():
"""在新的事件循环中运行协程"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(coro_func(*args, **kwargs))
finally:
loop.close()
asyncio.set_event_loop(None)
try:
# 尝试获取当前事件循环
try:
loop = asyncio.get_running_loop()
# 如果有运行中的事件循环,使用线程池执行
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_new_loop)
return future.result()
except RuntimeError:
# 没有运行中的事件循环,直接运行
return run_in_new_loop()
except Exception as e:
logger.error(f"异步执行失败: {str(e)}")
raise LLMServiceError(f"异步执行失败: {str(e)}")
class LegacyLLMAdapter:
"""传统LLM接口适配器"""
@staticmethod
def create_vision_analyzer(provider: str, api_key: str, model: str, base_url: str = None):
"""
创建视觉分析器实例 - 兼容原有接口
Args:
provider: 提供商名称
api_key: API密钥
model: 模型名称
base_url: API基础URL
Returns:
适配器实例
"""
return VisionAnalyzerAdapter(provider, api_key, model, base_url)
@staticmethod
def generate_narration(markdown_content: str, api_key: str, base_url: str, model: str) -> str:
"""
生成解说文案 - 兼容原有接口
Args:
markdown_content: Markdown格式的视频帧分析内容
api_key: API密钥
base_url: API基础URL
model: 模型名称
Returns:
生成的解说文案JSON字符串
"""
try:
# 使用新的提示词管理系统
prompt = PromptManager.get_prompt(
category="documentary",
name="narration_generation",
parameters={
"video_frame_description": markdown_content
}
)
# 使用统一服务生成文案
result = _run_async_safely(
UnifiedLLMService.generate_text,
prompt=prompt,
system_prompt="你是一名专业的短视频解说文案撰写专家。",
temperature=1.5,
response_format="json"
)
# 使用增强的JSON解析器
from webui.tools.generate_short_summary import parse_and_fix_json
parsed_result = parse_and_fix_json(result)
if not parsed_result:
logger.error("无法解析LLM返回的JSON数据")
# 返回一个基本的JSON结构而不是错误字符串
return json.dumps({
"items": [
{
"_id": 1,
"timestamp": "00:00:00-00:00:10",
"picture": "解析失败请检查LLM输出",
"narration": "解说文案生成失败,请重试"
}
]
}, ensure_ascii=False)
# 确保返回的是JSON字符串
return json.dumps(parsed_result, ensure_ascii=False)
except Exception as e:
logger.error(f"生成解说文案失败: {str(e)}")
# 返回一个基本的JSON结构而不是错误字符串
return json.dumps({
"items": [
{
"_id": 1,
"timestamp": "00:00:00-00:00:10",
"picture": "生成失败",
"narration": f"解说文案生成失败: {str(e)}"
}
]
}, ensure_ascii=False)
class VisionAnalyzerAdapter:
"""视觉分析器适配器"""
def __init__(self, provider: str, api_key: str, model: str, base_url: str = None):
self.provider = provider
self.api_key = api_key
self.model = model
self.base_url = base_url
async def analyze_images(self,
images: List[Union[str, Path, PIL.Image.Image]],
prompt: str,
batch_size: int = 10) -> List[Dict[str, Any]]:
"""
分析图片 - 兼容原有接口
Args:
images: 图片列表
prompt: 分析提示词
batch_size: 批处理大小
Returns:
分析结果列表格式与旧实现兼容
"""
try:
# 使用统一服务分析图片
results = await UnifiedLLMService.analyze_images(
images=images,
prompt=prompt,
provider=self.provider,
batch_size=batch_size
)
# 转换为旧格式以保持向后兼容性
# 新实现返回 List[str],需要转换为 List[Dict]
compatible_results = []
for i, result in enumerate(results):
# 计算这个批次处理的图片数量
start_idx = i * batch_size
end_idx = min(start_idx + batch_size, len(images))
images_processed = end_idx - start_idx
compatible_results.append({
'batch_index': i,
'images_processed': images_processed,
'response': result,
'model_used': self.model
})
logger.info(f"图片分析完成,共处理 {len(images)} 张图片,生成 {len(compatible_results)} 个批次结果")
return compatible_results
except Exception as e:
logger.error(f"图片分析失败: {str(e)}")
raise
class SubtitleAnalyzerAdapter:
"""字幕分析器适配器"""
def __init__(self, api_key: str, model: str, base_url: str, provider: str = None):
self.api_key = api_key
self.model = model
self.base_url = base_url
self.provider = provider or "openai"
def _run_async_safely(self, coro_func, *args, **kwargs):
"""安全地运行异步协程"""
return _run_async_safely(coro_func, *args, **kwargs)
def _clean_json_output(self, output: str) -> str:
"""清理JSON输出移除markdown标记等"""
import re
# 移除可能的markdown代码块标记
output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE)
output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE)
output = re.sub(r'^```.*$', '', output, flags=re.MULTILINE)
# 移除开头和结尾的```标记
output = re.sub(r'^```', '', output)
output = re.sub(r'```$', '', output)
# 移除前后空白字符
output = output.strip()
return output
def analyze_subtitle(self, subtitle_content: str) -> Dict[str, Any]:
"""
分析字幕内容 - 兼容原有接口
Args:
subtitle_content: 字幕内容
Returns:
分析结果字典
"""
try:
# 使用统一服务分析字幕
result = self._run_async_safely(
UnifiedLLMService.analyze_subtitle,
subtitle_content=subtitle_content,
provider=self.provider,
temperature=1.0
)
return {
"status": "success",
"analysis": result,
"model": self.model,
"temperature": 1.0
}
except Exception as e:
logger.error(f"字幕分析失败: {str(e)}")
return {
"status": "error",
"message": str(e),
"temperature": 1.0
}
def generate_narration_script(self, short_name: str, plot_analysis: str, temperature: float = 0.7) -> Dict[str, Any]:
"""
生成解说文案 - 兼容原有接口
Args:
short_name: 短剧名称
plot_analysis: 剧情分析内容
temperature: 生成温度
Returns:
生成结果字典
"""
try:
# 使用新的提示词管理系统构建提示词
prompt = PromptManager.get_prompt(
category="short_drama_narration",
name="script_generation",
parameters={
"drama_name": short_name,
"plot_analysis": plot_analysis
}
)
# 使用统一服务生成文案
result = self._run_async_safely(
UnifiedLLMService.generate_text,
prompt=prompt,
system_prompt="你是一位专业的短视频解说脚本撰写专家。",
provider=self.provider,
temperature=temperature,
response_format="json"
)
# 清理JSON输出
cleaned_result = self._clean_json_output(result)
# 新的提示词系统返回的是包含items数组的JSON格式
# 为了保持向后兼容我们需要直接返回这个JSON字符串
# 调用方会期望这是一个包含items数组的JSON字符串
return {
"status": "success",
"narration_script": cleaned_result,
"model": self.model,
"temperature": temperature
}
except Exception as e:
logger.error(f"解说文案生成失败: {str(e)}")
return {
"status": "error",
"message": str(e),
"temperature": temperature
}
# 为了向后兼容,提供一些全局函数
def create_vision_analyzer(provider: str, api_key: str, model: str, base_url: str = None):
"""创建视觉分析器 - 全局函数"""
return LegacyLLMAdapter.create_vision_analyzer(provider, api_key, model, base_url)
def generate_narration(markdown_content: str, api_key: str, base_url: str, model: str) -> str:
"""生成解说文案 - 全局函数"""
return LegacyLLMAdapter.generate_narration(markdown_content, api_key, base_url, model)

View File

@ -0,0 +1,47 @@
"""
大模型服务提供商实现
包含各种大模型服务提供商的具体实现
"""
from .gemini_provider import GeminiVisionProvider, GeminiTextProvider
from .gemini_openai_provider import GeminiOpenAIVisionProvider, GeminiOpenAITextProvider
from .openai_provider import OpenAITextProvider
from .qwen_provider import QwenVisionProvider, QwenTextProvider
from .deepseek_provider import DeepSeekTextProvider
from .siliconflow_provider import SiliconflowVisionProvider, SiliconflowTextProvider
# 自动注册所有提供商
from ..manager import LLMServiceManager
def register_all_providers():
"""注册所有提供商"""
# 注册视觉模型提供商
LLMServiceManager.register_vision_provider('gemini', GeminiVisionProvider)
LLMServiceManager.register_vision_provider('gemini(openai)', GeminiOpenAIVisionProvider)
LLMServiceManager.register_vision_provider('qwenvl', QwenVisionProvider)
LLMServiceManager.register_vision_provider('siliconflow', SiliconflowVisionProvider)
# 注册文本模型提供商
LLMServiceManager.register_text_provider('gemini', GeminiTextProvider)
LLMServiceManager.register_text_provider('gemini(openai)', GeminiOpenAITextProvider)
LLMServiceManager.register_text_provider('openai', OpenAITextProvider)
LLMServiceManager.register_text_provider('qwen', QwenTextProvider)
LLMServiceManager.register_text_provider('deepseek', DeepSeekTextProvider)
LLMServiceManager.register_text_provider('siliconflow', SiliconflowTextProvider)
# 自动注册
register_all_providers()
__all__ = [
'GeminiVisionProvider',
'GeminiTextProvider',
'GeminiOpenAIVisionProvider',
'GeminiOpenAITextProvider',
'OpenAITextProvider',
'QwenVisionProvider',
'QwenTextProvider',
'DeepSeekTextProvider',
'SiliconflowVisionProvider',
'SiliconflowTextProvider'
]

View File

@ -0,0 +1,157 @@
"""
DeepSeek API提供商实现
支持DeepSeek的文本生成模型
"""
import asyncio
from typing import List, Dict, Any, Optional
from openai import OpenAI, BadRequestError
from loguru import logger
from ..base import TextModelProvider
from ..exceptions import APICallError
class DeepSeekTextProvider(TextModelProvider):
"""DeepSeek文本生成提供商"""
@property
def provider_name(self) -> str:
return "deepseek"
@property
def supported_models(self) -> List[str]:
return [
"deepseek-chat",
"deepseek-reasoner",
"deepseek-r1",
"deepseek-v3"
]
def _initialize(self):
"""初始化DeepSeek客户端"""
if not self.base_url:
self.base_url = "https://api.deepseek.com"
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url
)
async def generate_text(self,
prompt: str,
system_prompt: Optional[str] = None,
temperature: float = 1.0,
max_tokens: Optional[int] = None,
response_format: Optional[str] = None,
**kwargs) -> str:
"""
使用DeepSeek API生成文本
Args:
prompt: 用户提示词
system_prompt: 系统提示词
temperature: 生成温度
max_tokens: 最大token数
response_format: 响应格式 ('json' None)
**kwargs: 其他参数
Returns:
生成的文本内容
"""
# 构建消息列表
messages = self._build_messages(prompt, system_prompt)
# 构建请求参数
request_params = {
"model": self.model_name,
"messages": messages,
"temperature": temperature
}
if max_tokens:
request_params["max_tokens"] = max_tokens
# 处理JSON格式输出
# DeepSeek R1 和 V3 不支持 response_format=json_object
if response_format == "json":
if self._supports_response_format():
request_params["response_format"] = {"type": "json_object"}
else:
# 对于不支持response_format的模型在提示词中添加约束
messages[-1]["content"] += "\n\n请确保输出严格的JSON格式不要包含任何其他文字或标记。"
try:
# 发送API请求
response = await asyncio.to_thread(
self.client.chat.completions.create,
**request_params
)
# 提取生成的内容
if response.choices and len(response.choices) > 0:
content = response.choices[0].message.content
# 对于不支持response_format的模型清理输出
if response_format == "json" and not self._supports_response_format():
content = self._clean_json_output(content)
logger.debug(f"DeepSeek API调用成功消耗tokens: {response.usage.total_tokens if response.usage else 'N/A'}")
return content
else:
raise APICallError("DeepSeek API返回空响应")
except BadRequestError as e:
# 处理不支持response_format的情况
if "response_format" in str(e) and response_format == "json":
logger.warning(f"DeepSeek模型 {self.model_name} 不支持response_format重试不带格式约束的请求")
request_params.pop("response_format", None)
messages[-1]["content"] += "\n\n请确保输出严格的JSON格式不要包含任何其他文字或标记。"
response = await asyncio.to_thread(
self.client.chat.completions.create,
**request_params
)
if response.choices and len(response.choices) > 0:
content = response.choices[0].message.content
content = self._clean_json_output(content)
return content
else:
raise APICallError("DeepSeek API返回空响应")
else:
raise APICallError(f"DeepSeek API请求失败: {str(e)}")
except Exception as e:
logger.error(f"DeepSeek API调用失败: {str(e)}")
raise APICallError(f"DeepSeek API调用失败: {str(e)}")
def _supports_response_format(self) -> bool:
"""检查模型是否支持response_format参数"""
# DeepSeek R1 和 V3 不支持 response_format=json_object
unsupported_models = [
"deepseek-reasoner",
"deepseek-r1",
"deepseek-v3"
]
return not any(unsupported in self.model_name.lower() for unsupported in unsupported_models)
def _clean_json_output(self, output: str) -> str:
"""清理JSON输出移除markdown标记等"""
import re
# 移除可能的markdown代码块标记
output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE)
output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE)
output = re.sub(r'^```.*$', '', output, flags=re.MULTILINE)
# 移除前后空白字符
output = output.strip()
return output
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""执行API调用 - 由于使用OpenAI SDK这个方法主要用于兼容基类"""
pass

View File

@ -0,0 +1,235 @@
"""
OpenAI兼容的Gemini API提供商实现
使用OpenAI兼容接口调用Gemini服务支持视觉分析和文本生成
"""
import asyncio
import base64
import io
from typing import List, Dict, Any, Optional, Union
from pathlib import Path
import PIL.Image
from openai import OpenAI
from loguru import logger
from ..base import VisionModelProvider, TextModelProvider
from ..exceptions import APICallError
class GeminiOpenAIVisionProvider(VisionModelProvider):
"""OpenAI兼容的Gemini视觉模型提供商"""
@property
def provider_name(self) -> str:
return "gemini(openai)"
@property
def supported_models(self) -> List[str]:
return [
"gemini-2.0-flash-lite",
"gemini-2.0-flash",
"gemini-1.5-pro",
"gemini-1.5-flash"
]
def _initialize(self):
"""初始化OpenAI兼容的Gemini客户端"""
if not self.base_url:
self.base_url = "https://generativelanguage.googleapis.com/v1beta/openai"
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url
)
async def analyze_images(self,
images: List[Union[str, Path, PIL.Image.Image]],
prompt: str,
batch_size: int = 10,
**kwargs) -> List[str]:
"""
使用OpenAI兼容的Gemini API分析图片
Args:
images: 图片列表
prompt: 分析提示词
batch_size: 批处理大小
**kwargs: 其他参数
Returns:
分析结果列表
"""
logger.info(f"开始分析 {len(images)} 张图片使用OpenAI兼容Gemini代理")
# 预处理图片
processed_images = self._prepare_images(images)
# 分批处理
results = []
for i in range(0, len(processed_images), batch_size):
batch = processed_images[i:i + batch_size]
logger.info(f"处理第 {i//batch_size + 1} 批,共 {len(batch)} 张图片")
try:
result = await self._analyze_batch(batch, prompt)
results.append(result)
except Exception as e:
logger.error(f"批次 {i//batch_size + 1} 处理失败: {str(e)}")
results.append(f"批次处理失败: {str(e)}")
return results
async def _analyze_batch(self, batch: List[PIL.Image.Image], prompt: str) -> str:
"""分析一批图片"""
# 构建OpenAI格式的消息内容
content = [{"type": "text", "text": prompt}]
# 添加图片
for img in batch:
base64_image = self._image_to_base64(img)
content.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
})
# 构建消息
messages = [{
"role": "user",
"content": content
}]
# 调用API
response = await asyncio.to_thread(
self.client.chat.completions.create,
model=self.model_name,
messages=messages,
max_tokens=4000,
temperature=1.0
)
if response.choices and len(response.choices) > 0:
return response.choices[0].message.content
else:
raise APICallError("OpenAI兼容Gemini API返回空响应")
def _image_to_base64(self, img: PIL.Image.Image) -> str:
"""将PIL图片转换为base64编码"""
img_buffer = io.BytesIO()
img.save(img_buffer, format='JPEG', quality=85)
img_bytes = img_buffer.getvalue()
return base64.b64encode(img_bytes).decode('utf-8')
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""执行API调用 - 由于使用OpenAI SDK这个方法主要用于兼容基类"""
pass
class GeminiOpenAITextProvider(TextModelProvider):
"""OpenAI兼容的Gemini文本生成提供商"""
@property
def provider_name(self) -> str:
return "gemini(openai)"
@property
def supported_models(self) -> List[str]:
return [
"gemini-2.0-flash-lite",
"gemini-2.0-flash",
"gemini-1.5-pro",
"gemini-1.5-flash"
]
def _initialize(self):
"""初始化OpenAI兼容的Gemini客户端"""
if not self.base_url:
self.base_url = "https://generativelanguage.googleapis.com/v1beta/openai"
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url
)
async def generate_text(self,
prompt: str,
system_prompt: Optional[str] = None,
temperature: float = 1.0,
max_tokens: Optional[int] = None,
response_format: Optional[str] = None,
**kwargs) -> str:
"""
使用OpenAI兼容的Gemini API生成文本
Args:
prompt: 用户提示词
system_prompt: 系统提示词
temperature: 生成温度
max_tokens: 最大token数
response_format: 响应格式 ('json' None)
**kwargs: 其他参数
Returns:
生成的文本内容
"""
# 构建消息列表
messages = self._build_messages(prompt, system_prompt)
# 构建请求参数
request_params = {
"model": self.model_name,
"messages": messages,
"temperature": temperature
}
if max_tokens:
request_params["max_tokens"] = max_tokens
# 处理JSON格式输出 - Gemini通过OpenAI接口可能不完全支持response_format
if response_format == "json":
# 在提示词中添加JSON格式约束
messages[-1]["content"] += "\n\n请确保输出严格的JSON格式不要包含任何其他文字或标记。"
try:
# 发送API请求
response = await asyncio.to_thread(
self.client.chat.completions.create,
**request_params
)
# 提取生成的内容
if response.choices and len(response.choices) > 0:
content = response.choices[0].message.content
# 对于JSON格式清理输出
if response_format == "json":
content = self._clean_json_output(content)
logger.debug(f"OpenAI兼容Gemini API调用成功消耗tokens: {response.usage.total_tokens if response.usage else 'N/A'}")
return content
else:
raise APICallError("OpenAI兼容Gemini API返回空响应")
except Exception as e:
logger.error(f"OpenAI兼容Gemini API调用失败: {str(e)}")
raise APICallError(f"OpenAI兼容Gemini API调用失败: {str(e)}")
def _clean_json_output(self, output: str) -> str:
"""清理JSON输出移除markdown标记等"""
import re
# 移除可能的markdown代码块标记
output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE)
output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE)
output = re.sub(r'^```.*$', '', output, flags=re.MULTILINE)
# 移除前后空白字符
output = output.strip()
return output
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""执行API调用 - 由于使用OpenAI SDK这个方法主要用于兼容基类"""
pass

View File

@ -0,0 +1,346 @@
"""
原生Gemini API提供商实现
使用Google原生Gemini API进行视觉分析和文本生成
"""
import asyncio
import base64
import io
import requests
from typing import List, Dict, Any, Optional, Union
from pathlib import Path
import PIL.Image
from loguru import logger
from ..base import VisionModelProvider, TextModelProvider
from ..exceptions import APICallError, ContentFilterError
class GeminiVisionProvider(VisionModelProvider):
"""原生Gemini视觉模型提供商"""
@property
def provider_name(self) -> str:
return "gemini"
@property
def supported_models(self) -> List[str]:
return [
"gemini-2.0-flash-lite",
"gemini-2.0-flash",
"gemini-1.5-pro",
"gemini-1.5-flash"
]
def _initialize(self):
"""初始化Gemini特定设置"""
if not self.base_url:
self.base_url = "https://generativelanguage.googleapis.com/v1beta"
async def analyze_images(self,
images: List[Union[str, Path, PIL.Image.Image]],
prompt: str,
batch_size: int = 10,
**kwargs) -> List[str]:
"""
使用原生Gemini API分析图片
Args:
images: 图片列表
prompt: 分析提示词
batch_size: 批处理大小
**kwargs: 其他参数
Returns:
分析结果列表
"""
logger.info(f"开始分析 {len(images)} 张图片使用原生Gemini API")
# 预处理图片
processed_images = self._prepare_images(images)
# 分批处理
results = []
for i in range(0, len(processed_images), batch_size):
batch = processed_images[i:i + batch_size]
logger.info(f"处理第 {i//batch_size + 1} 批,共 {len(batch)} 张图片")
try:
result = await self._analyze_batch(batch, prompt)
results.append(result)
except Exception as e:
logger.error(f"批次 {i//batch_size + 1} 处理失败: {str(e)}")
results.append(f"批次处理失败: {str(e)}")
return results
async def _analyze_batch(self, batch: List[PIL.Image.Image], prompt: str) -> str:
"""分析一批图片"""
# 构建请求数据
parts = [{"text": prompt}]
# 添加图片数据
for img in batch:
img_data = self._image_to_base64(img)
parts.append({
"inline_data": {
"mime_type": "image/jpeg",
"data": img_data
}
})
payload = {
"systemInstruction": {
"parts": [{"text": "你是一位专业的视觉内容分析师,请仔细分析图片内容并提供详细描述。"}]
},
"contents": [{"parts": parts}],
"generationConfig": {
"temperature": 1.0,
"topK": 40,
"topP": 0.95,
"maxOutputTokens": 4000,
"candidateCount": 1
},
"safetySettings": [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
}
]
}
# 发送API请求
response_data = await self._make_api_call(payload)
# 解析响应
return self._parse_vision_response(response_data)
def _image_to_base64(self, img: PIL.Image.Image) -> str:
"""将PIL图片转换为base64编码"""
img_buffer = io.BytesIO()
img.save(img_buffer, format='JPEG', quality=85)
img_bytes = img_buffer.getvalue()
return base64.b64encode(img_bytes).decode('utf-8')
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""执行原生Gemini API调用"""
url = f"{self.base_url}/models/{self.model_name}:generateContent?key={self.api_key}"
response = await asyncio.to_thread(
requests.post,
url,
json=payload,
headers={
"Content-Type": "application/json",
"User-Agent": "NarratoAI/1.0"
},
timeout=120
)
if response.status_code != 200:
error = self._handle_api_error(response.status_code, response.text)
raise error
return response.json()
def _parse_vision_response(self, response_data: Dict[str, Any]) -> str:
"""解析视觉分析响应"""
if "candidates" not in response_data or not response_data["candidates"]:
raise APICallError("原生Gemini API返回无效响应")
candidate = response_data["candidates"][0]
# 检查是否被安全过滤阻止
if "finishReason" in candidate and candidate["finishReason"] == "SAFETY":
raise ContentFilterError("内容被Gemini安全过滤器阻止")
if "content" not in candidate or "parts" not in candidate["content"]:
raise APICallError("原生Gemini API返回内容格式错误")
# 提取文本内容
result = ""
for part in candidate["content"]["parts"]:
if "text" in part:
result += part["text"]
if not result.strip():
raise APICallError("原生Gemini API返回空内容")
return result
class GeminiTextProvider(TextModelProvider):
"""原生Gemini文本生成提供商"""
@property
def provider_name(self) -> str:
return "gemini"
@property
def supported_models(self) -> List[str]:
return [
"gemini-2.0-flash-lite",
"gemini-2.0-flash",
"gemini-1.5-pro",
"gemini-1.5-flash"
]
def _initialize(self):
"""初始化Gemini特定设置"""
if not self.base_url:
self.base_url = "https://generativelanguage.googleapis.com/v1beta"
async def generate_text(self,
prompt: str,
system_prompt: Optional[str] = None,
temperature: float = 1.0,
max_tokens: Optional[int] = None,
response_format: Optional[str] = None,
**kwargs) -> str:
"""
使用原生Gemini API生成文本
Args:
prompt: 用户提示词
system_prompt: 系统提示词
temperature: 生成温度
max_tokens: 最大token数
response_format: 响应格式
**kwargs: 其他参数
Returns:
生成的文本内容
"""
# 构建请求数据
payload = {
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"temperature": temperature,
"topK": 40,
"topP": 0.95,
"maxOutputTokens": max_tokens or 4000,
"candidateCount": 1
},
"safetySettings": [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
}
]
}
# 添加系统提示词
if system_prompt:
payload["systemInstruction"] = {
"parts": [{"text": system_prompt}]
}
# 如果需要JSON格式调整提示词和配置
if response_format == "json":
# 使用更温和的JSON格式约束
enhanced_prompt = f"{prompt}\n\n请以JSON格式输出结果。"
payload["contents"][0]["parts"][0]["text"] = enhanced_prompt
# 移除可能导致问题的stopSequences
# payload["generationConfig"]["stopSequences"] = ["```", "注意", "说明"]
# 记录请求信息
logger.debug(f"Gemini文本生成请求: {payload}")
# 发送API请求
response_data = await self._make_api_call(payload)
# 解析响应
return self._parse_text_response(response_data)
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""执行原生Gemini API调用"""
url = f"{self.base_url}/models/{self.model_name}:generateContent?key={self.api_key}"
response = await asyncio.to_thread(
requests.post,
url,
json=payload,
headers={
"Content-Type": "application/json",
"User-Agent": "NarratoAI/1.0"
},
timeout=120
)
if response.status_code != 200:
error = self._handle_api_error(response.status_code, response.text)
raise error
return response.json()
def _parse_text_response(self, response_data: Dict[str, Any]) -> str:
"""解析文本生成响应"""
logger.debug(f"Gemini API响应数据: {response_data}")
if "candidates" not in response_data or not response_data["candidates"]:
logger.error(f"Gemini API返回无效响应结构: {response_data}")
raise APICallError("原生Gemini API返回无效响应")
candidate = response_data["candidates"][0]
logger.debug(f"Gemini候选响应: {candidate}")
# 检查完成原因
finish_reason = candidate.get("finishReason", "UNKNOWN")
logger.debug(f"Gemini完成原因: {finish_reason}")
# 检查是否被安全过滤阻止
if finish_reason == "SAFETY":
safety_ratings = candidate.get("safetyRatings", [])
logger.warning(f"内容被Gemini安全过滤器阻止安全评级: {safety_ratings}")
raise ContentFilterError("内容被Gemini安全过滤器阻止")
# 检查是否因为其他原因停止
if finish_reason in ["RECITATION", "OTHER"]:
logger.warning(f"Gemini因为{finish_reason}原因停止生成")
raise APICallError(f"Gemini因为{finish_reason}原因停止生成")
if "content" not in candidate:
logger.error(f"Gemini候选响应中缺少content字段: {candidate}")
raise APICallError("原生Gemini API返回内容格式错误")
if "parts" not in candidate["content"]:
logger.error(f"Gemini内容中缺少parts字段: {candidate['content']}")
raise APICallError("原生Gemini API返回内容格式错误")
# 提取文本内容
result = ""
for part in candidate["content"]["parts"]:
if "text" in part:
result += part["text"]
if not result.strip():
logger.error(f"Gemini API返回空文本内容完整响应: {response_data}")
raise APICallError("原生Gemini API返回空内容")
logger.debug(f"Gemini成功生成内容长度: {len(result)}")
return result

View File

@ -0,0 +1,168 @@
"""
OpenAI API提供商实现
使用OpenAI API进行文本生成也支持OpenAI兼容的其他服务
"""
import asyncio
from typing import List, Dict, Any, Optional
from openai import OpenAI, BadRequestError
from loguru import logger
from ..base import TextModelProvider
from ..exceptions import APICallError, RateLimitError, AuthenticationError
class OpenAITextProvider(TextModelProvider):
"""OpenAI文本生成提供商"""
@property
def provider_name(self) -> str:
return "openai"
@property
def supported_models(self) -> List[str]:
return [
"gpt-4o",
"gpt-4o-mini",
"gpt-4-turbo",
"gpt-4",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
# 支持其他OpenAI兼容模型
"deepseek-chat",
"deepseek-reasoner",
"qwen-plus",
"qwen-turbo",
"moonshot-v1-8k",
"moonshot-v1-32k",
"moonshot-v1-128k"
]
def _initialize(self):
"""初始化OpenAI客户端"""
if not self.base_url:
self.base_url = "https://api.openai.com/v1"
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url
)
async def generate_text(self,
prompt: str,
system_prompt: Optional[str] = None,
temperature: float = 1.0,
max_tokens: Optional[int] = None,
response_format: Optional[str] = None,
**kwargs) -> str:
"""
使用OpenAI API生成文本
Args:
prompt: 用户提示词
system_prompt: 系统提示词
temperature: 生成温度
max_tokens: 最大token数
response_format: 响应格式 ('json' None)
**kwargs: 其他参数
Returns:
生成的文本内容
"""
# 构建消息列表
messages = self._build_messages(prompt, system_prompt)
# 构建请求参数
request_params = {
"model": self.model_name,
"messages": messages,
"temperature": temperature
}
if max_tokens:
request_params["max_tokens"] = max_tokens
# 处理JSON格式输出
if response_format == "json":
# 检查模型是否支持response_format
if self._supports_response_format():
request_params["response_format"] = {"type": "json_object"}
else:
# 对于不支持response_format的模型在提示词中添加约束
messages[-1]["content"] += "\n\n请确保输出严格的JSON格式不要包含任何其他文字或标记。"
try:
# 发送API请求
response = await asyncio.to_thread(
self.client.chat.completions.create,
**request_params
)
# 提取生成的内容
if response.choices and len(response.choices) > 0:
content = response.choices[0].message.content
# 对于不支持response_format的模型清理输出
if response_format == "json" and not self._supports_response_format():
content = self._clean_json_output(content)
logger.debug(f"OpenAI API调用成功消耗tokens: {response.usage.total_tokens if response.usage else 'N/A'}")
return content
else:
raise APICallError("OpenAI API返回空响应")
except BadRequestError as e:
# 处理不支持response_format的情况
if "response_format" in str(e) and response_format == "json":
logger.warning(f"模型 {self.model_name} 不支持response_format重试不带格式约束的请求")
request_params.pop("response_format", None)
messages[-1]["content"] += "\n\n请确保输出严格的JSON格式不要包含任何其他文字或标记。"
response = await asyncio.to_thread(
self.client.chat.completions.create,
**request_params
)
if response.choices and len(response.choices) > 0:
content = response.choices[0].message.content
content = self._clean_json_output(content)
return content
else:
raise APICallError("OpenAI API返回空响应")
else:
raise APICallError(f"OpenAI API请求失败: {str(e)}")
except Exception as e:
logger.error(f"OpenAI API调用失败: {str(e)}")
raise APICallError(f"OpenAI API调用失败: {str(e)}")
def _supports_response_format(self) -> bool:
"""检查模型是否支持response_format参数"""
# 已知不支持response_format的模型
unsupported_models = [
"deepseek-reasoner",
"deepseek-r1"
]
return not any(unsupported in self.model_name.lower() for unsupported in unsupported_models)
def _clean_json_output(self, output: str) -> str:
"""清理JSON输出移除markdown标记等"""
import re
# 移除可能的markdown代码块标记
output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE)
output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE)
output = re.sub(r'^```.*$', '', output, flags=re.MULTILINE)
# 移除前后空白字符
output = output.strip()
return output
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""执行API调用 - 由于使用OpenAI SDK这个方法主要用于兼容基类"""
# 这个方法在OpenAI提供商中不直接使用因为我们使用OpenAI SDK
# 但为了兼容基类接口,保留此方法
pass

View File

@ -0,0 +1,247 @@
"""
通义千问API提供商实现
支持通义千问的视觉模型和文本生成模型
"""
import asyncio
import base64
import io
from typing import List, Dict, Any, Optional, Union
from pathlib import Path
import PIL.Image
from openai import OpenAI
from loguru import logger
from ..base import VisionModelProvider, TextModelProvider
from ..exceptions import APICallError
class QwenVisionProvider(VisionModelProvider):
"""通义千问视觉模型提供商"""
@property
def provider_name(self) -> str:
return "qwenvl"
@property
def supported_models(self) -> List[str]:
return [
"qwen2.5-vl-32b-instruct",
"qwen2-vl-72b-instruct",
"qwen-vl-max",
"qwen-vl-plus"
]
def _initialize(self):
"""初始化通义千问客户端"""
if not self.base_url:
self.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url
)
async def analyze_images(self,
images: List[Union[str, Path, PIL.Image.Image]],
prompt: str,
batch_size: int = 10,
**kwargs) -> List[str]:
"""
使用通义千问VL分析图片
Args:
images: 图片列表
prompt: 分析提示词
batch_size: 批处理大小
**kwargs: 其他参数
Returns:
分析结果列表
"""
logger.info(f"开始分析 {len(images)} 张图片使用通义千问VL")
# 预处理图片
processed_images = self._prepare_images(images)
# 分批处理
results = []
for i in range(0, len(processed_images), batch_size):
batch = processed_images[i:i + batch_size]
logger.info(f"处理第 {i//batch_size + 1} 批,共 {len(batch)} 张图片")
try:
result = await self._analyze_batch(batch, prompt)
results.append(result)
except Exception as e:
logger.error(f"批次 {i//batch_size + 1} 处理失败: {str(e)}")
results.append(f"批次处理失败: {str(e)}")
return results
async def _analyze_batch(self, batch: List[PIL.Image.Image], prompt: str) -> str:
"""分析一批图片"""
# 构建消息内容
content = []
# 添加图片
for img in batch:
base64_image = self._image_to_base64(img)
content.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
})
# 添加文本提示,使用占位符来引用图片数量
content.append({
"type": "text",
"text": prompt % (len(batch), len(batch), len(batch))
})
# 构建消息
messages = [{
"role": "user",
"content": content
}]
# 调用API
response = await asyncio.to_thread(
self.client.chat.completions.create,
model=self.model_name,
messages=messages
)
if response.choices and len(response.choices) > 0:
return response.choices[0].message.content
else:
raise APICallError("通义千问VL API返回空响应")
def _image_to_base64(self, img: PIL.Image.Image) -> str:
"""将PIL图片转换为base64编码"""
img_buffer = io.BytesIO()
img.save(img_buffer, format='JPEG', quality=85)
img_bytes = img_buffer.getvalue()
return base64.b64encode(img_bytes).decode('utf-8')
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""执行API调用 - 由于使用OpenAI SDK这个方法主要用于兼容基类"""
pass
class QwenTextProvider(TextModelProvider):
"""通义千问文本生成提供商"""
@property
def provider_name(self) -> str:
return "qwen"
@property
def supported_models(self) -> List[str]:
return [
"qwen-plus-1127",
"qwen-plus",
"qwen-turbo",
"qwen-max",
"qwen2.5-72b-instruct",
"qwen2.5-32b-instruct",
"qwen2.5-14b-instruct",
"qwen2.5-7b-instruct"
]
def _initialize(self):
"""初始化通义千问客户端"""
if not self.base_url:
self.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url
)
async def generate_text(self,
prompt: str,
system_prompt: Optional[str] = None,
temperature: float = 1.0,
max_tokens: Optional[int] = None,
response_format: Optional[str] = None,
**kwargs) -> str:
"""
使用通义千问API生成文本
Args:
prompt: 用户提示词
system_prompt: 系统提示词
temperature: 生成温度
max_tokens: 最大token数
response_format: 响应格式 ('json' None)
**kwargs: 其他参数
Returns:
生成的文本内容
"""
# 构建消息列表
messages = self._build_messages(prompt, system_prompt)
# 构建请求参数
request_params = {
"model": self.model_name,
"messages": messages,
"temperature": temperature
}
if max_tokens:
request_params["max_tokens"] = max_tokens
# 处理JSON格式输出
if response_format == "json":
# 通义千问支持response_format
try:
request_params["response_format"] = {"type": "json_object"}
except:
# 如果不支持,在提示词中添加约束
messages[-1]["content"] += "\n\n请确保输出严格的JSON格式不要包含任何其他文字或标记。"
try:
# 发送API请求
response = await asyncio.to_thread(
self.client.chat.completions.create,
**request_params
)
# 提取生成的内容
if response.choices and len(response.choices) > 0:
content = response.choices[0].message.content
# 对于JSON格式清理输出
if response_format == "json" and "response_format" not in request_params:
content = self._clean_json_output(content)
logger.debug(f"通义千问API调用成功消耗tokens: {response.usage.total_tokens if response.usage else 'N/A'}")
return content
else:
raise APICallError("通义千问API返回空响应")
except Exception as e:
logger.error(f"通义千问API调用失败: {str(e)}")
raise APICallError(f"通义千问API调用失败: {str(e)}")
def _clean_json_output(self, output: str) -> str:
"""清理JSON输出移除markdown标记等"""
import re
# 移除可能的markdown代码块标记
output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE)
output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE)
output = re.sub(r'^```.*$', '', output, flags=re.MULTILINE)
# 移除前后空白字符
output = output.strip()
return output
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""执行API调用 - 由于使用OpenAI SDK这个方法主要用于兼容基类"""
pass

View File

@ -0,0 +1,251 @@
"""
硅基流动API提供商实现
支持硅基流动的视觉模型和文本生成模型
"""
import asyncio
import base64
import io
from typing import List, Dict, Any, Optional, Union
from pathlib import Path
import PIL.Image
from openai import OpenAI
from loguru import logger
from ..base import VisionModelProvider, TextModelProvider
from ..exceptions import APICallError
class SiliconflowVisionProvider(VisionModelProvider):
"""硅基流动视觉模型提供商"""
@property
def provider_name(self) -> str:
return "siliconflow"
@property
def supported_models(self) -> List[str]:
return [
"Qwen/Qwen2.5-VL-32B-Instruct",
"Qwen/Qwen2-VL-72B-Instruct",
"deepseek-ai/deepseek-vl2",
"OpenGVLab/InternVL2-26B"
]
def _initialize(self):
"""初始化硅基流动客户端"""
if not self.base_url:
self.base_url = "https://api.siliconflow.cn/v1"
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url
)
async def analyze_images(self,
images: List[Union[str, Path, PIL.Image.Image]],
prompt: str,
batch_size: int = 10,
**kwargs) -> List[str]:
"""
使用硅基流动API分析图片
Args:
images: 图片列表
prompt: 分析提示词
batch_size: 批处理大小
**kwargs: 其他参数
Returns:
分析结果列表
"""
logger.info(f"开始分析 {len(images)} 张图片,使用硅基流动")
# 预处理图片
processed_images = self._prepare_images(images)
# 分批处理
results = []
for i in range(0, len(processed_images), batch_size):
batch = processed_images[i:i + batch_size]
logger.info(f"处理第 {i//batch_size + 1} 批,共 {len(batch)} 张图片")
try:
result = await self._analyze_batch(batch, prompt)
results.append(result)
except Exception as e:
logger.error(f"批次 {i//batch_size + 1} 处理失败: {str(e)}")
results.append(f"批次处理失败: {str(e)}")
return results
async def _analyze_batch(self, batch: List[PIL.Image.Image], prompt: str) -> str:
"""分析一批图片"""
# 构建消息内容
content = [{"type": "text", "text": prompt}]
# 添加图片
for img in batch:
base64_image = self._image_to_base64(img)
content.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
})
# 构建消息
messages = [{
"role": "user",
"content": content
}]
# 调用API
response = await asyncio.to_thread(
self.client.chat.completions.create,
model=self.model_name,
messages=messages,
max_tokens=4000,
temperature=1.0
)
if response.choices and len(response.choices) > 0:
return response.choices[0].message.content
else:
raise APICallError("硅基流动API返回空响应")
def _image_to_base64(self, img: PIL.Image.Image) -> str:
"""将PIL图片转换为base64编码"""
img_buffer = io.BytesIO()
img.save(img_buffer, format='JPEG', quality=85)
img_bytes = img_buffer.getvalue()
return base64.b64encode(img_bytes).decode('utf-8')
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""执行API调用 - 由于使用OpenAI SDK这个方法主要用于兼容基类"""
pass
class SiliconflowTextProvider(TextModelProvider):
"""硅基流动文本生成提供商"""
@property
def provider_name(self) -> str:
return "siliconflow"
@property
def supported_models(self) -> List[str]:
return [
"deepseek-ai/DeepSeek-R1",
"deepseek-ai/DeepSeek-V3",
"Qwen/Qwen2.5-72B-Instruct",
"Qwen/Qwen2.5-32B-Instruct",
"meta-llama/Llama-3.1-70B-Instruct",
"meta-llama/Llama-3.1-8B-Instruct",
"01-ai/Yi-1.5-34B-Chat"
]
def _initialize(self):
"""初始化硅基流动客户端"""
if not self.base_url:
self.base_url = "https://api.siliconflow.cn/v1"
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url
)
async def generate_text(self,
prompt: str,
system_prompt: Optional[str] = None,
temperature: float = 1.0,
max_tokens: Optional[int] = None,
response_format: Optional[str] = None,
**kwargs) -> str:
"""
使用硅基流动API生成文本
Args:
prompt: 用户提示词
system_prompt: 系统提示词
temperature: 生成温度
max_tokens: 最大token数
response_format: 响应格式 ('json' None)
**kwargs: 其他参数
Returns:
生成的文本内容
"""
# 构建消息列表
messages = self._build_messages(prompt, system_prompt)
# 构建请求参数
request_params = {
"model": self.model_name,
"messages": messages,
"temperature": temperature
}
if max_tokens:
request_params["max_tokens"] = max_tokens
# 处理JSON格式输出
if response_format == "json":
if self._supports_response_format():
request_params["response_format"] = {"type": "json_object"}
else:
# 对于不支持response_format的模型在提示词中添加约束
messages[-1]["content"] += "\n\n请确保输出严格的JSON格式不要包含任何其他文字或标记。"
try:
# 发送API请求
response = await asyncio.to_thread(
self.client.chat.completions.create,
**request_params
)
# 提取生成的内容
if response.choices and len(response.choices) > 0:
content = response.choices[0].message.content
# 对于不支持response_format的模型清理输出
if response_format == "json" and not self._supports_response_format():
content = self._clean_json_output(content)
logger.debug(f"硅基流动API调用成功消耗tokens: {response.usage.total_tokens if response.usage else 'N/A'}")
return content
else:
raise APICallError("硅基流动API返回空响应")
except Exception as e:
logger.error(f"硅基流动API调用失败: {str(e)}")
raise APICallError(f"硅基流动API调用失败: {str(e)}")
def _supports_response_format(self) -> bool:
"""检查模型是否支持response_format参数"""
# DeepSeek R1 和 V3 不支持 response_format=json_object
unsupported_models = [
"deepseek-ai/deepseek-r1",
"deepseek-ai/deepseek-v3"
]
return not any(unsupported in self.model_name.lower() for unsupported in unsupported_models)
def _clean_json_output(self, output: str) -> str:
"""清理JSON输出移除markdown标记等"""
import re
# 移除可能的markdown代码块标记
output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE)
output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE)
output = re.sub(r'^```.*$', '', output, flags=re.MULTILINE)
# 移除前后空白字符
output = output.strip()
return output
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""执行API调用 - 由于使用OpenAI SDK这个方法主要用于兼容基类"""
pass

View File

@ -0,0 +1,263 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
LLM服务测试脚本
测试新的LLM服务架构是否正常工作
"""
import asyncio
import sys
import os
from pathlib import Path
from loguru import logger
# 添加项目根目录到Python路径
project_root = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(project_root))
from app.services.llm.config_validator import LLMConfigValidator
from app.services.llm.unified_service import UnifiedLLMService
from app.services.llm.exceptions import LLMServiceError
async def test_text_generation():
"""测试文本生成功能"""
print("\n🔤 测试文本生成功能...")
try:
# 简单的文本生成测试
prompt = "请用一句话介绍人工智能。"
result = await UnifiedLLMService.generate_text(
prompt=prompt,
system_prompt="你是一个专业的AI助手。",
temperature=0.7
)
print(f"✅ 文本生成成功:")
print(f" 提示词: {prompt}")
print(f" 生成结果: {result[:100]}...")
return True
except Exception as e:
print(f"❌ 文本生成失败: {str(e)}")
return False
async def test_json_generation():
"""测试JSON格式生成功能"""
print("\n📄 测试JSON格式生成功能...")
try:
prompt = """
请生成一个简单的解说文案示例包含以下字段
- title: 标题
- content: 内容
- duration: 时长
输出JSON格式
"""
result = await UnifiedLLMService.generate_text(
prompt=prompt,
system_prompt="你是一个专业的文案撰写专家。",
temperature=0.7,
response_format="json"
)
# 尝试解析JSON
import json
parsed_result = json.loads(result)
print(f"✅ JSON生成成功:")
print(f" 生成结果: {json.dumps(parsed_result, ensure_ascii=False, indent=2)}")
return True
except json.JSONDecodeError as e:
print(f"❌ JSON解析失败: {str(e)}")
print(f" 原始结果: {result}")
return False
except Exception as e:
print(f"❌ JSON生成失败: {str(e)}")
return False
async def test_narration_script_generation():
"""测试解说文案生成功能"""
print("\n🎬 测试解说文案生成功能...")
try:
prompt = """
根据以下视频描述生成解说文案
视频内容一个人在森林中建造木屋首先挖掘地基然后搭建墙壁最后安装屋顶
请生成JSON格式的解说文案包含items数组每个item包含
- _id: 序号
- timestamp: 时间戳格式HH:MM:SS,mmm-HH:MM:SS,mmm
- picture: 画面描述
- narration: 解说文案
"""
result = await UnifiedLLMService.generate_narration_script(
prompt=prompt,
temperature=0.8,
validate_output=True
)
print(f"✅ 解说文案生成成功:")
print(f" 生成了 {len(result)} 个片段")
for item in result[:2]: # 只显示前2个
print(f" - {item.get('timestamp', 'N/A')}: {item.get('narration', 'N/A')[:50]}...")
return True
except Exception as e:
print(f"❌ 解说文案生成失败: {str(e)}")
return False
async def test_subtitle_analysis():
"""测试字幕分析功能"""
print("\n📝 测试字幕分析功能...")
try:
subtitle_content = """
1
00:00:01,000 --> 00:00:05,000
大家好欢迎来到我的频道
2
00:00:05,000 --> 00:00:10,000
今天我们要学习如何使用人工智能
3
00:00:10,000 --> 00:00:15,000
人工智能是一项非常有趣的技术
"""
result = await UnifiedLLMService.analyze_subtitle(
subtitle_content=subtitle_content,
temperature=0.7,
validate_output=True
)
print(f"✅ 字幕分析成功:")
print(f" 分析结果: {result[:100]}...")
return True
except Exception as e:
print(f"❌ 字幕分析失败: {str(e)}")
return False
def test_config_validation():
"""测试配置验证功能"""
print("\n⚙️ 测试配置验证功能...")
try:
# 验证所有配置
validation_results = LLMConfigValidator.validate_all_configs()
summary = validation_results["summary"]
print(f"✅ 配置验证完成:")
print(f" 视觉模型提供商: {summary['valid_vision_providers']}/{summary['total_vision_providers']} 有效")
print(f" 文本模型提供商: {summary['valid_text_providers']}/{summary['total_text_providers']} 有效")
if summary["errors"]:
print(f" 发现 {len(summary['errors'])} 个错误")
for error in summary["errors"][:3]: # 只显示前3个错误
print(f" - {error}")
return summary['valid_text_providers'] > 0
except Exception as e:
print(f"❌ 配置验证失败: {str(e)}")
return False
def test_provider_info():
"""测试提供商信息获取"""
print("\n📋 测试提供商信息获取...")
try:
provider_info = UnifiedLLMService.get_provider_info()
vision_providers = list(provider_info["vision_providers"].keys())
text_providers = list(provider_info["text_providers"].keys())
print(f"✅ 提供商信息获取成功:")
print(f" 视觉模型提供商: {', '.join(vision_providers)}")
print(f" 文本模型提供商: {', '.join(text_providers)}")
return True
except Exception as e:
print(f"❌ 提供商信息获取失败: {str(e)}")
return False
async def run_all_tests():
"""运行所有测试"""
print("🚀 开始LLM服务测试...")
print("="*60)
# 测试结果统计
test_results = []
# 1. 测试配置验证
test_results.append(("配置验证", test_config_validation()))
# 2. 测试提供商信息
test_results.append(("提供商信息", test_provider_info()))
# 3. 测试文本生成
test_results.append(("文本生成", await test_text_generation()))
# 4. 测试JSON生成
test_results.append(("JSON生成", await test_json_generation()))
# 5. 测试字幕分析
test_results.append(("字幕分析", await test_subtitle_analysis()))
# 6. 测试解说文案生成
test_results.append(("解说文案生成", await test_narration_script_generation()))
# 输出测试结果
print("\n" + "="*60)
print("📊 测试结果汇总:")
print("="*60)
passed = 0
total = len(test_results)
for test_name, result in test_results:
status = "✅ 通过" if result else "❌ 失败"
print(f" {test_name:<15} {status}")
if result:
passed += 1
print(f"\n总计: {passed}/{total} 个测试通过")
if passed == total:
print("🎉 所有测试通过LLM服务工作正常。")
elif passed > 0:
print("⚠️ 部分测试通过,请检查失败的测试项。")
else:
print("💥 所有测试失败,请检查配置和网络连接。")
print("="*60)
if __name__ == "__main__":
# 设置日志级别
logger.remove()
logger.add(sys.stderr, level="INFO")
# 运行测试
asyncio.run(run_all_tests())

View File

@ -0,0 +1,274 @@
"""
统一的大模型服务接口
提供简化的API接口方便现有代码迁移到新的架构
"""
from typing import List, Dict, Any, Optional, Union
from pathlib import Path
import PIL.Image
from loguru import logger
from .manager import LLMServiceManager
from .validators import OutputValidator
from .exceptions import LLMServiceError
# 确保提供商已注册
def _ensure_providers_registered():
"""确保所有提供商都已注册"""
try:
# 检查是否有已注册的提供商
if not LLMServiceManager.list_text_providers() or not LLMServiceManager.list_vision_providers():
# 如果没有注册的提供商强制导入providers模块
from . import providers
logger.debug("强制注册LLM服务提供商")
except Exception as e:
logger.error(f"确保LLM服务提供商注册时发生错误: {str(e)}")
# 在模块加载时确保提供商已注册
_ensure_providers_registered()
class UnifiedLLMService:
"""统一的大模型服务接口"""
@staticmethod
async def analyze_images(images: List[Union[str, Path, PIL.Image.Image]],
prompt: str,
provider: Optional[str] = None,
batch_size: int = 10,
**kwargs) -> List[str]:
"""
分析图片内容
Args:
images: 图片路径列表或PIL图片对象列表
prompt: 分析提示词
provider: 视觉模型提供商名称如果不指定则使用配置中的默认值
batch_size: 批处理大小
**kwargs: 其他参数
Returns:
分析结果列表
Raises:
LLMServiceError: 服务调用失败时抛出
"""
try:
# 获取视觉模型提供商
vision_provider = LLMServiceManager.get_vision_provider(provider)
# 执行图片分析
results = await vision_provider.analyze_images(
images=images,
prompt=prompt,
batch_size=batch_size,
**kwargs
)
logger.info(f"图片分析完成,共处理 {len(images)} 张图片,生成 {len(results)} 个结果")
return results
except Exception as e:
logger.error(f"图片分析失败: {str(e)}")
raise LLMServiceError(f"图片分析失败: {str(e)}")
@staticmethod
async def generate_text(prompt: str,
system_prompt: Optional[str] = None,
provider: Optional[str] = None,
temperature: float = 1.0,
max_tokens: Optional[int] = None,
response_format: Optional[str] = None,
**kwargs) -> str:
"""
生成文本内容
Args:
prompt: 用户提示词
system_prompt: 系统提示词
provider: 文本模型提供商名称如果不指定则使用配置中的默认值
temperature: 生成温度
max_tokens: 最大token数
response_format: 响应格式 ('json' None)
**kwargs: 其他参数
Returns:
生成的文本内容
Raises:
LLMServiceError: 服务调用失败时抛出
"""
try:
# 获取文本模型提供商
text_provider = LLMServiceManager.get_text_provider(provider)
# 执行文本生成
result = await text_provider.generate_text(
prompt=prompt,
system_prompt=system_prompt,
temperature=temperature,
max_tokens=max_tokens,
response_format=response_format,
**kwargs
)
logger.info(f"文本生成完成,生成内容长度: {len(result)} 字符")
return result
except Exception as e:
logger.error(f"文本生成失败: {str(e)}")
raise LLMServiceError(f"文本生成失败: {str(e)}")
@staticmethod
async def generate_narration_script(prompt: str,
provider: Optional[str] = None,
temperature: float = 1.0,
validate_output: bool = True,
**kwargs) -> List[Dict[str, Any]]:
"""
生成解说文案
Args:
prompt: 提示词
provider: 文本模型提供商名称
temperature: 生成温度
validate_output: 是否验证输出格式
**kwargs: 其他参数
Returns:
解说文案列表
Raises:
LLMServiceError: 服务调用失败时抛出
"""
try:
# 生成文本
result = await UnifiedLLMService.generate_text(
prompt=prompt,
provider=provider,
temperature=temperature,
response_format="json",
**kwargs
)
# 验证输出格式
if validate_output:
narration_items = OutputValidator.validate_narration_script(result)
logger.info(f"解说文案生成并验证完成,共 {len(narration_items)} 个片段")
return narration_items
else:
# 简单的JSON解析
import json
parsed_result = json.loads(result)
if "items" in parsed_result:
return parsed_result["items"]
else:
return parsed_result
except Exception as e:
logger.error(f"解说文案生成失败: {str(e)}")
raise LLMServiceError(f"解说文案生成失败: {str(e)}")
@staticmethod
async def analyze_subtitle(subtitle_content: str,
provider: Optional[str] = None,
temperature: float = 1.0,
validate_output: bool = True,
**kwargs) -> str:
"""
分析字幕内容
Args:
subtitle_content: 字幕内容
provider: 文本模型提供商名称
temperature: 生成温度
validate_output: 是否验证输出格式
**kwargs: 其他参数
Returns:
分析结果
Raises:
LLMServiceError: 服务调用失败时抛出
"""
try:
# 构建分析提示词
system_prompt = "你是一位专业的剧本分析师和剧情概括助手。请仔细分析字幕内容,提取关键剧情信息。"
# 生成分析结果
result = await UnifiedLLMService.generate_text(
prompt=subtitle_content,
system_prompt=system_prompt,
provider=provider,
temperature=temperature,
**kwargs
)
# 验证输出格式
if validate_output:
validated_result = OutputValidator.validate_subtitle_analysis(result)
logger.info("字幕分析完成并验证通过")
return validated_result
else:
return result
except Exception as e:
logger.error(f"字幕分析失败: {str(e)}")
raise LLMServiceError(f"字幕分析失败: {str(e)}")
@staticmethod
def get_provider_info() -> Dict[str, Any]:
"""
获取所有提供商信息
Returns:
提供商信息字典
"""
return LLMServiceManager.get_provider_info()
@staticmethod
def list_vision_providers() -> List[str]:
"""
列出所有视觉模型提供商
Returns:
提供商名称列表
"""
return LLMServiceManager.list_vision_providers()
@staticmethod
def list_text_providers() -> List[str]:
"""
列出所有文本模型提供商
Returns:
提供商名称列表
"""
return LLMServiceManager.list_text_providers()
@staticmethod
def clear_cache():
"""清空提供商实例缓存"""
LLMServiceManager.clear_cache()
logger.info("已清空大模型服务缓存")
# 为了向后兼容,提供一些便捷函数
async def analyze_images_unified(images: List[Union[str, Path, PIL.Image.Image]],
prompt: str,
provider: Optional[str] = None,
batch_size: int = 10) -> List[str]:
"""便捷的图片分析函数"""
return await UnifiedLLMService.analyze_images(images, prompt, provider, batch_size)
async def generate_text_unified(prompt: str,
system_prompt: Optional[str] = None,
provider: Optional[str] = None,
temperature: float = 1.0,
response_format: Optional[str] = None) -> str:
"""便捷的文本生成函数"""
return await UnifiedLLMService.generate_text(
prompt, system_prompt, provider, temperature, response_format=response_format
)

View File

@ -0,0 +1,200 @@
"""
输出格式验证器
提供严格的输出格式验证机制确保大模型输出符合预期格式
"""
import json
import re
from typing import Any, Dict, List, Optional, Union
from loguru import logger
from .exceptions import ValidationError
class OutputValidator:
"""输出格式验证器"""
@staticmethod
def validate_json_output(output: str, schema: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
验证JSON输出格式
Args:
output: 待验证的输出字符串
schema: JSON Schema (可选)
Returns:
解析后的JSON对象
Raises:
ValidationError: 验证失败时抛出
"""
try:
# 清理输出字符串移除可能的markdown代码块标记
cleaned_output = OutputValidator._clean_json_output(output)
# 解析JSON
parsed_json = json.loads(cleaned_output)
# 如果提供了schema进行schema验证
if schema:
OutputValidator._validate_json_schema(parsed_json, schema)
return parsed_json
except json.JSONDecodeError as e:
logger.error(f"JSON解析失败: {str(e)}")
logger.error(f"原始输出: {output}")
raise ValidationError(f"JSON格式无效: {str(e)}", "json_parse", output)
except Exception as e:
logger.error(f"JSON验证失败: {str(e)}")
raise ValidationError(f"JSON验证失败: {str(e)}", "json_validation", output)
@staticmethod
def _clean_json_output(output: str) -> str:
"""清理JSON输出移除markdown标记等"""
# 移除可能的markdown代码块标记
output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE)
output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE)
output = re.sub(r'^```.*$', '', output, flags=re.MULTILINE)
# 移除开头和结尾的```标记
output = re.sub(r'^```', '', output)
output = re.sub(r'```$', '', output)
# 移除前后空白字符
output = output.strip()
return output
@staticmethod
def _validate_json_schema(data: Dict[str, Any], schema: Dict[str, Any]):
"""验证JSON Schema (简化版本)"""
# 这里可以集成jsonschema库进行更严格的验证
# 目前实现基础的类型检查
if "type" in schema:
expected_type = schema["type"]
if expected_type == "object" and not isinstance(data, dict):
raise ValidationError(f"期望对象类型,实际为 {type(data)}", "schema_type")
elif expected_type == "array" and not isinstance(data, list):
raise ValidationError(f"期望数组类型,实际为 {type(data)}", "schema_type")
if "required" in schema and isinstance(data, dict):
for required_field in schema["required"]:
if required_field not in data:
raise ValidationError(f"缺少必需字段: {required_field}", "schema_required")
@staticmethod
def validate_narration_script(output: str) -> List[Dict[str, Any]]:
"""
验证解说文案输出格式
Args:
output: 待验证的解说文案输出
Returns:
解析后的解说文案列表
Raises:
ValidationError: 验证失败时抛出
"""
try:
# 定义解说文案的JSON Schema
narration_schema = {
"type": "object",
"required": ["items"],
"properties": {
"items": {
"type": "array",
"items": {
"type": "object",
"required": ["_id", "timestamp", "picture", "narration"],
"properties": {
"_id": {"type": "number"},
"timestamp": {"type": "string"},
"picture": {"type": "string"},
"narration": {"type": "string"},
"OST": {"type": "number"}
}
}
}
}
}
# 验证JSON格式
parsed_data = OutputValidator.validate_json_output(output, narration_schema)
# 提取items数组
items = parsed_data.get("items", [])
# 验证每个item的具体内容
for i, item in enumerate(items):
OutputValidator._validate_narration_item(item, i)
logger.info(f"解说文案验证成功,共 {len(items)} 个片段")
return items
except ValidationError:
raise
except Exception as e:
logger.error(f"解说文案验证失败: {str(e)}")
raise ValidationError(f"解说文案验证失败: {str(e)}", "narration_validation", output)
@staticmethod
def _validate_narration_item(item: Dict[str, Any], index: int):
"""验证单个解说文案项目"""
# 验证时间戳格式
timestamp = item.get("timestamp", "")
if not re.match(r'\d{2}:\d{2}:\d{2},\d{3}-\d{2}:\d{2}:\d{2},\d{3}', timestamp):
raise ValidationError(f"{index+1}项时间戳格式无效: {timestamp}", "timestamp_format")
# 验证内容不为空
if not item.get("picture", "").strip():
raise ValidationError(f"{index+1}项画面描述不能为空", "empty_picture")
if not item.get("narration", "").strip():
raise ValidationError(f"{index+1}项解说文案不能为空", "empty_narration")
# 验证ID为正整数
item_id = item.get("_id")
if not isinstance(item_id, (int, float)) or item_id <= 0:
raise ValidationError(f"{index+1}项ID必须为正整数: {item_id}", "invalid_id")
@staticmethod
def validate_subtitle_analysis(output: str) -> str:
"""
验证字幕分析输出格式
Args:
output: 待验证的字幕分析输出
Returns:
验证后的分析内容
Raises:
ValidationError: 验证失败时抛出
"""
try:
# 基础验证:内容不能为空
if not output or not output.strip():
raise ValidationError("字幕分析结果不能为空", "empty_analysis")
# 验证内容长度合理
if len(output.strip()) < 50:
raise ValidationError("字幕分析结果过短,可能不完整", "analysis_too_short")
# 验证是否包含基本的分析要素(可根据需要调整)
analysis_keywords = ["剧情", "情节", "角色", "故事", "内容"]
if not any(keyword in output for keyword in analysis_keywords):
logger.warning("字幕分析结果可能缺少关键分析要素")
logger.info("字幕分析验证成功")
return output.strip()
except ValidationError:
raise
except Exception as e:
logger.error(f"字幕分析验证失败: {str(e)}")
raise ValidationError(f"字幕分析验证失败: {str(e)}", "analysis_validation", output)

View File

@ -4,7 +4,7 @@
'''
@Project: NarratoAI
@File : merger_video
@Author : 小林同学
@Author : Viccy同学
@Date : 2025/5/6 下午7:38
'''

View File

@ -0,0 +1,68 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project: NarratoAI
@File : __init__.py
@Author : viccy同学
@Date : 2025/1/7
@Description: 统一提示词管理模块
"""
from .manager import PromptManager
from .base import BasePrompt, VisionPrompt, TextPrompt, ParameterizedPrompt
from .registry import PromptRegistry
from .template import TemplateRenderer
from .validators import PromptOutputValidator
from .exceptions import (
PromptError,
PromptNotFoundError,
PromptValidationError,
TemplateRenderError
)
# 版本信息
__version__ = "1.0.0"
__author__ = "viccy同学"
# 导出的公共接口
__all__ = [
# 核心管理器
"PromptManager",
# 基础类
"BasePrompt",
"VisionPrompt",
"TextPrompt",
"ParameterizedPrompt",
# 工具类
"PromptRegistry",
"TemplateRenderer",
"PromptOutputValidator",
# 异常类
"PromptError",
"PromptNotFoundError",
"PromptValidationError",
"TemplateRenderError",
# 版本信息
"__version__",
"__author__"
]
# 模块初始化
def initialize_prompts():
"""初始化提示词模块,注册所有提示词"""
from . import documentary
from . import short_drama_editing
from . import short_drama_narration
# 注册各模块的提示词
documentary.register_prompts()
short_drama_editing.register_prompts()
short_drama_narration.register_prompts()
# 自动初始化
initialize_prompts()

View File

@ -0,0 +1,182 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project: NarratoAI
@File : base.py
@Author : viccy同学
@Date : 2025/1/7
@Description: 提示词基础类定义
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
from enum import Enum
from dataclasses import dataclass, field
from datetime import datetime
class ModelType(Enum):
"""模型类型枚举"""
TEXT = "text" # 文本模型
VISION = "vision" # 视觉模型
MULTIMODAL = "multimodal" # 多模态模型
class OutputFormat(Enum):
"""输出格式枚举"""
TEXT = "text" # 纯文本
JSON = "json" # JSON格式
MARKDOWN = "markdown" # Markdown格式
STRUCTURED = "structured" # 结构化数据
@dataclass
class PromptMetadata:
"""提示词元数据"""
name: str # 提示词名称
category: str # 分类
version: str # 版本
description: str # 描述
model_type: ModelType # 适用的模型类型
output_format: OutputFormat # 输出格式
author: str = "viccy同学" # 作者
created_at: datetime = field(default_factory=datetime.now) # 创建时间
updated_at: datetime = field(default_factory=datetime.now) # 更新时间
tags: List[str] = field(default_factory=list) # 标签
parameters: List[str] = field(default_factory=list) # 支持的参数列表
class BasePrompt(ABC):
"""提示词基础类"""
def __init__(self, metadata: PromptMetadata):
self.metadata = metadata
self._template = None
self._system_prompt = None
self._examples = []
@property
def name(self) -> str:
"""获取提示词名称"""
return self.metadata.name
@property
def category(self) -> str:
"""获取提示词分类"""
return self.metadata.category
@property
def version(self) -> str:
"""获取提示词版本"""
return self.metadata.version
@property
def model_type(self) -> ModelType:
"""获取适用的模型类型"""
return self.metadata.model_type
@property
def output_format(self) -> OutputFormat:
"""获取输出格式"""
return self.metadata.output_format
@abstractmethod
def get_template(self) -> str:
"""获取提示词模板"""
pass
def get_system_prompt(self) -> Optional[str]:
"""获取系统提示词"""
return self._system_prompt
def get_examples(self) -> List[str]:
"""获取示例"""
return self._examples.copy()
def validate_parameters(self, parameters: Dict[str, Any]) -> bool:
"""验证参数"""
required_params = set(self.metadata.parameters)
provided_params = set(parameters.keys())
missing_params = required_params - provided_params
if missing_params:
from .exceptions import TemplateRenderError
raise TemplateRenderError(
template_name=self.name,
error_message="缺少必需参数",
missing_params=list(missing_params)
)
return True
def render(self, parameters: Dict[str, Any] = None) -> str:
"""渲染提示词"""
parameters = parameters or {}
# 验证参数
if self.metadata.parameters:
self.validate_parameters(parameters)
# 渲染模板 - 使用自定义的模板渲染器
template = self.get_template()
try:
from .template import get_renderer
renderer = get_renderer()
return renderer.render(template, parameters)
except Exception as e:
from .exceptions import TemplateRenderError
raise TemplateRenderError(
template_name=self.name,
error_message=f"模板渲染错误: {str(e)}",
missing_params=[]
)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"metadata": {
"name": self.metadata.name,
"category": self.metadata.category,
"version": self.metadata.version,
"description": self.metadata.description,
"model_type": self.metadata.model_type.value,
"output_format": self.metadata.output_format.value,
"author": self.metadata.author,
"created_at": self.metadata.created_at.isoformat(),
"updated_at": self.metadata.updated_at.isoformat(),
"tags": self.metadata.tags,
"parameters": self.metadata.parameters
},
"template": self.get_template(),
"system_prompt": self.get_system_prompt(),
"examples": self.get_examples()
}
class TextPrompt(BasePrompt):
"""文本模型专用提示词"""
def __init__(self, metadata: PromptMetadata):
if metadata.model_type not in [ModelType.TEXT, ModelType.MULTIMODAL]:
raise ValueError(f"TextPrompt只支持TEXT或MULTIMODAL模型类型当前: {metadata.model_type}")
super().__init__(metadata)
class VisionPrompt(BasePrompt):
"""视觉模型专用提示词"""
def __init__(self, metadata: PromptMetadata):
if metadata.model_type not in [ModelType.VISION, ModelType.MULTIMODAL]:
raise ValueError(f"VisionPrompt只支持VISION或MULTIMODAL模型类型当前: {metadata.model_type}")
super().__init__(metadata)
class ParameterizedPrompt(BasePrompt):
"""支持参数化的提示词"""
def __init__(self, metadata: PromptMetadata, required_parameters: List[str] = None):
super().__init__(metadata)
if required_parameters:
self.metadata.parameters.extend(required_parameters)
# 去重
self.metadata.parameters = list(set(self.metadata.parameters))

View File

@ -0,0 +1,33 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project: NarratoAI
@File : __init__.py
@Author : viccy同学
@Date : 2025/1/7
@Description: 纪录片解说提示词模块
"""
from .frame_analysis import FrameAnalysisPrompt
from .narration_generation import NarrationGenerationPrompt
from ..manager import PromptManager
def register_prompts():
"""注册纪录片解说相关的提示词"""
# 注册视频帧分析提示词
frame_analysis_prompt = FrameAnalysisPrompt()
PromptManager.register_prompt(frame_analysis_prompt, is_default=True)
# 注册解说文案生成提示词
narration_prompt = NarrationGenerationPrompt()
PromptManager.register_prompt(narration_prompt, is_default=True)
__all__ = [
"FrameAnalysisPrompt",
"NarrationGenerationPrompt",
"register_prompts"
]

View File

@ -0,0 +1,67 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project: NarratoAI
@File : frame_analysis.py
@Author : viccy同学
@Date : 2025/1/7
@Description: 纪录片视频帧分析提示词
"""
from ..base import VisionPrompt, PromptMetadata, ModelType, OutputFormat
class FrameAnalysisPrompt(VisionPrompt):
"""纪录片视频帧分析提示词"""
def __init__(self):
metadata = PromptMetadata(
name="frame_analysis",
category="documentary",
version="v1.0",
description="分析纪录片视频关键帧,提取画面内容和场景描述",
model_type=ModelType.VISION,
output_format=OutputFormat.JSON,
tags=["纪录片", "视频分析", "关键帧", "画面描述"],
parameters=["video_theme", "custom_instructions"]
)
super().__init__(metadata)
self._system_prompt = "你是一名专业的视频内容分析师,擅长分析纪录片视频帧内容,提取关键信息和场景描述。"
def get_template(self) -> str:
return """请仔细分析这些视频关键帧图片,我需要你提供详细的画面分析。
视频主题${video_theme}
分析要求
1. 按时间顺序分析每一帧画面
2. 详细描述画面中的主要内容人物物体环境
3. 注意画面的构图色彩光线等视觉元素
4. 识别画面中的关键动作或变化
5. 提供准确的时间戳信息
${custom_instructions}
请按照以下JSON格式输出分析结果
{
"analysis": [
{
"timestamp": "00:00:05,390",
"picture": "详细的画面描述,包括场景、人物、物体、动作等",
"scene_type": "场景类型(如:建造、准备、完成等)",
"key_elements": ["关键元素1", "关键元素2"],
"visual_quality": "画面质量描述(构图、光线、色彩等)"
}
],
"summary": "整体视频内容概述",
"total_frames": "分析的帧数"
}
重要要求
1. 只输出JSON格式不要添加任何其他文字或代码块标记
2. 画面描述要详细准确为后续解说文案生成提供充分信息
3. 时间戳必须准确对应视频帧
4. 严禁虚构不存在的内容"""

View File

@ -0,0 +1,82 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project: NarratoAI
@File : narration_generation.py
@Author : viccy同学
@Date : 2025/1/7
@Description: 纪录片解说文案生成提示词
"""
from ..base import TextPrompt, PromptMetadata, ModelType, OutputFormat
class NarrationGenerationPrompt(TextPrompt):
"""纪录片解说文案生成提示词"""
def __init__(self):
metadata = PromptMetadata(
name="narration_generation",
category="documentary",
version="v1.0",
description="根据视频帧分析结果生成纪录片解说文案,特别适用于荒野建造类内容",
model_type=ModelType.TEXT,
output_format=OutputFormat.JSON,
tags=["纪录片", "解说文案", "荒野建造", "文案生成"],
parameters=["video_frame_description"]
)
super().__init__(metadata)
self._system_prompt = "你是一名专业的短视频解说文案撰写专家,擅长创作引人入胜的纪录片解说内容。"
def get_template(self) -> str:
return """我是一名荒野建造解说的博主,以下是一些同行的对标文案,请你深度学习并总结这些文案的风格特点跟内容特点:
<example_text_1>
解压助眠的天花板就是荒野建造沉浸丝滑的搭建过程可以说每一帧都是极致享受我保证强迫症来了都找不出一丁点毛病更别说全屋严丝合缝的拼接工艺还能轻松抵御零下二十度气温让你居住的每一天都温暖如春
在家闲不住的西姆今天也打算来一次野外建造行走没多久他就发现许多倒塌的树任由它们自生自灭不如将其利用起来想到这他就开始挥舞铲子要把地基挖掘出来虽然每次只能挖一点点但架不住他体能惊人没多长时间一个 2x3 的深坑就赫然出现这深度住他一人绰绰有余
随后他去附近收集来原木这些都是搭建墙壁的最好材料而在投入使用前自然要把表皮刮掉防止森林中的白蚁蛀虫处理好一大堆后西姆还在两端打孔使用木钉固定在一起这可不是用来做墙壁的而是做庇护所的承重柱只要木头间的缝隙足够紧密那搭建出的木屋就能足够坚固
每向上搭建一层他都会在中间塞入苔藓防寒保证不会泄露一丝热量其他几面也是用相同方法很快西姆就做好了三面墙壁每一根木头都极其工整保证强迫症来了都要点个赞再走
在继续搭建墙壁前西姆决定将壁炉制作出来毕竟森林夜晚的气温会很低保暖措施可是重中之重完成后他找来一块大树皮用来充当庇护所的大门而上面刮掉的木屑还能作为壁炉的引火物可以说再完美不过
测试了排烟没问题后他才开始搭建最后一面墙壁这一面要预留门和窗所以在搭建到一半后还需要在原木中间开出卡口让自己劈砍时能轻松许多此时只需将另外一根如法炮制两端拼接在一起后就是一扇大小适中的窗户而随着随后一层苔藓铺好最后一根原木落位这个庇护所的雏形就算完成
</example_text_1>
<example_text_2>
解压助眠的天花板就是荒野建造沉浸丝滑的搭建过程每一帧都是极致享受全屋严丝合缝的拼接工艺能轻松抵御零下二十度气温居住体验温暖如春
在家闲不住的西姆开启野外建造他发现倒塌的树决定加以利用先挖掘出 2x3 的深坑作为地基接着收集原木刮掉表皮防白蚁蛀虫打孔用木钉固定制作承重柱搭建墙壁时每一层都塞入苔藓防寒很快做好三面墙
为应对森林夜晚低温西姆制作壁炉用大树皮当大门刮下的木屑做引火物搭建最后一面墙时预留门窗通过在原木中间开口拼接做出窗户大门采用榫卯结构安装严丝合缝
搭建屋顶时先固定外围原木再平铺原木形成斜面屋顶之后用苔藓黏土密封缝隙铺上枯叶和泥土为美观在木屋覆盖苔藓移植小树点缀完工时遇大雨木屋防水良好
西姆利用墙壁凹槽镶嵌床框铺上苔藓床单枕头做成床劳作一天后他用壁炉烤牛肉享用建造一星期后他开始野外露营
后来西姆回家补给物资回来时森林大雪纷飞他劈柴储备带回食物调味料和被褥提高居住舒适度还用干草做靠垫他用壁炉烤牛排搭配红酒
第二天积雪融化西姆制作室外篝火堆防野兽用大树夹缝掰弯木棍堆积而成晚上点燃处理废料结束后用雪球灭火最后在室内二十五度的环境中裹被入睡
</example_text_2>
<video_frame_description>
${video_frame_description}
</video_frame_description>
我正在尝试做这个内容的解说纪录片视频我需要你以 <video_frame_description> </video_frame_description> 中的内容为解说目标根据我刚才提供给你的对标文案特点以及你总结的特点帮我生成一段关于荒野建造的解说文案文案需要符合平台受欢迎的解说风格请使用 json 格式进行输出使用 <output> 中的输出格式
<output>
{
"items": [
{
"_id": 1,
"timestamp": "00:00:05,390-00:00:10,430",
"picture": "画面描述",
"narration": "解说文案"
}
]
}
</output>
<restriction>
1. 只输出 json 内容不要输出其他任何说明性的文字
2. 解说文案的语言使用 简体中文
3. 严禁虚构画面所有画面只能从 <video_frame_description> 中摘取
4. 严禁虚构时间戳所有时间戳只能从 <video_frame_description> 中摘取
5. 解说文案要生动有趣符合荒野建造解说的风格特点
6. 每个片段的解说文案要与画面内容高度匹配
7. 保持解说的连贯性和故事性
</restriction>"""

View File

@ -0,0 +1,79 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project: NarratoAI
@File : exceptions.py
@Author : viccy同学
@Date : 2025/1/7
@Description: 提示词管理模块异常定义
"""
class PromptError(Exception):
"""提示词模块基础异常类"""
pass
class PromptNotFoundError(PromptError):
"""提示词未找到异常"""
def __init__(self, category: str, name: str, version: str = None):
self.category = category
self.name = name
self.version = version
if version:
message = f"提示词未找到: {category}.{name} (版本: {version})"
else:
message = f"提示词未找到: {category}.{name}"
super().__init__(message)
class PromptValidationError(PromptError):
"""提示词验证异常"""
def __init__(self, message: str, validation_errors: list = None):
self.validation_errors = validation_errors or []
super().__init__(message)
class TemplateRenderError(PromptError):
"""模板渲染异常"""
def __init__(self, template_name: str, error_message: str, missing_params: list = None):
self.template_name = template_name
self.error_message = error_message
self.missing_params = missing_params or []
message = f"模板渲染失败 '{template_name}': {error_message}"
if missing_params:
message += f" (缺少参数: {', '.join(missing_params)})"
super().__init__(message)
class PromptRegistrationError(PromptError):
"""提示词注册异常"""
def __init__(self, category: str, name: str, reason: str):
self.category = category
self.name = name
self.reason = reason
message = f"提示词注册失败 {category}.{name}: {reason}"
super().__init__(message)
class PromptVersionError(PromptError):
"""提示词版本异常"""
def __init__(self, category: str, name: str, version: str, reason: str):
self.category = category
self.name = name
self.version = version
self.reason = reason
message = f"提示词版本错误 {category}.{name} v{version}: {reason}"
super().__init__(message)

View File

@ -0,0 +1,287 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project: NarratoAI
@File : manager.py
@Author : viccy同学
@Date : 2025/1/7
@Description: 提示词管理器
"""
from typing import Dict, Any, List, Optional, Union
from loguru import logger
from .base import BasePrompt, ModelType, OutputFormat
from .registry import get_registry
from .template import get_renderer
from .validators import PromptOutputValidator
from .exceptions import (
PromptNotFoundError,
PromptValidationError,
TemplateRenderError
)
class PromptManager:
"""提示词管理器 - 统一的提示词管理接口"""
def __init__(self):
self._registry = get_registry()
self._renderer = get_renderer()
@classmethod
def get_prompt(cls,
category: str,
name: str,
version: Optional[str] = None,
parameters: Optional[Dict[str, Any]] = None) -> str:
"""
获取渲染后的提示词
Args:
category: 分类
name: 名称
version: 版本可选默认使用最新版本
parameters: 模板参数可选
Returns:
渲染后的提示词字符串
"""
instance = cls()
prompt_obj = instance._registry.get(category, name, version)
try:
rendered = prompt_obj.render(parameters)
logger.debug(f"提示词渲染成功: {category}.{name} v{prompt_obj.version}")
return rendered
except Exception as e:
logger.error(f"提示词渲染失败: {category}.{name} - {str(e)}")
raise
@classmethod
def get_prompt_object(cls,
category: str,
name: str,
version: Optional[str] = None) -> BasePrompt:
"""
获取提示词对象
Args:
category: 分类
name: 名称
version: 版本可选
Returns:
提示词对象
"""
instance = cls()
return instance._registry.get(category, name, version)
@classmethod
def register_prompt(cls, prompt: BasePrompt, is_default: bool = True) -> None:
"""
注册提示词
Args:
prompt: 提示词对象
is_default: 是否设为默认版本
"""
instance = cls()
instance._registry.register(prompt, is_default)
@classmethod
def list_categories(cls) -> List[str]:
"""列出所有分类"""
instance = cls()
return instance._registry.list_categories()
@classmethod
def list_prompts(cls, category: str) -> List[str]:
"""列出指定分类下的所有提示词"""
instance = cls()
return instance._registry.list_prompts(category)
@classmethod
def list_versions(cls, category: str, name: str) -> List[str]:
"""列出指定提示词的所有版本"""
instance = cls()
return instance._registry.list_versions(category, name)
@classmethod
def exists(cls, category: str, name: str, version: Optional[str] = None) -> bool:
"""检查提示词是否存在"""
instance = cls()
return instance._registry.exists(category, name, version)
@classmethod
def search_prompts(cls,
keyword: str = None,
category: str = None,
model_type: ModelType = None,
output_format: OutputFormat = None) -> List[Dict[str, str]]:
"""
搜索提示词
Args:
keyword: 关键词
category: 分类过滤
model_type: 模型类型过滤
output_format: 输出格式过滤
Returns:
匹配的提示词列表
"""
instance = cls()
results = instance._registry.search(keyword, category, model_type, output_format)
return [
{
"category": cat,
"name": name,
"version": ver,
"full_name": f"{cat}.{name}",
"identifier": f"{cat}.{name}@{ver}"
}
for cat, name, ver in results
]
@classmethod
def get_stats(cls) -> Dict[str, Any]:
"""获取统计信息"""
instance = cls()
registry_stats = instance._registry.get_stats()
return {
"registry": registry_stats,
"categories": cls.list_categories(),
"total_categories": registry_stats["categories"],
"total_prompts": registry_stats["prompts"],
"total_versions": registry_stats["versions"]
}
@classmethod
def validate_output(cls,
output: Union[str, Dict],
category: str,
name: str,
version: Optional[str] = None) -> Any:
"""
验证提示词输出
Args:
output: 输出内容
category: 提示词分类
name: 提示词名称
version: 提示词版本
Returns:
验证后的数据
"""
instance = cls()
prompt_obj = instance._registry.get(category, name, version)
# 根据输出格式进行验证
output_format = prompt_obj.metadata.output_format
try:
if output_format == OutputFormat.JSON:
# 特殊处理解说文案和剧情分析
if "narration" in name.lower() or "script" in name.lower():
return PromptOutputValidator.validate_narration_script(output)
elif "plot" in name.lower() or "analysis" in name.lower():
return PromptOutputValidator.validate_plot_analysis(output)
else:
return PromptOutputValidator.validate_json(output)
else:
return PromptOutputValidator.validate_by_format(output, output_format)
except Exception as e:
logger.error(f"输出验证失败 {category}.{name}: {str(e)}")
raise PromptValidationError(f"输出验证失败: {str(e)}")
@classmethod
def get_prompt_info(cls, category: str, name: str, version: Optional[str] = None) -> Dict[str, Any]:
"""
获取提示词详细信息
Args:
category: 分类
name: 名称
version: 版本
Returns:
提示词详细信息
"""
instance = cls()
prompt_obj = instance._registry.get(category, name, version)
return {
"metadata": {
"name": prompt_obj.metadata.name,
"category": prompt_obj.metadata.category,
"version": prompt_obj.metadata.version,
"description": prompt_obj.metadata.description,
"model_type": prompt_obj.metadata.model_type.value,
"output_format": prompt_obj.metadata.output_format.value,
"author": prompt_obj.metadata.author,
"created_at": prompt_obj.metadata.created_at.isoformat(),
"updated_at": prompt_obj.metadata.updated_at.isoformat(),
"tags": prompt_obj.metadata.tags,
"parameters": prompt_obj.metadata.parameters
},
"template_preview": prompt_obj.get_template()[:500] + "..." if len(prompt_obj.get_template()) > 500 else prompt_obj.get_template(),
"system_prompt": prompt_obj.get_system_prompt(),
"examples_count": len(prompt_obj.get_examples()),
"has_parameters": bool(prompt_obj.metadata.parameters)
}
@classmethod
def export_prompts(cls, category: Optional[str] = None) -> Dict[str, Any]:
"""
导出提示词配置
Args:
category: 分类过滤可选
Returns:
提示词配置数据
"""
instance = cls()
categories = [category] if category else instance._registry.list_categories()
export_data = {
"version": "1.0.0",
"exported_at": instance._get_current_time(),
"categories": {}
}
for cat in categories:
export_data["categories"][cat] = {}
prompts = instance._registry.list_prompts(cat)
for prompt_name in prompts:
versions = instance._registry.list_versions(cat, prompt_name)
export_data["categories"][cat][prompt_name] = {}
for ver in versions:
prompt_obj = instance._registry.get(cat, prompt_name, ver)
export_data["categories"][cat][prompt_name][ver] = prompt_obj.to_dict()
return export_data
def _get_current_time(self) -> str:
"""获取当前时间字符串"""
from datetime import datetime
return datetime.now().isoformat()
# 便捷函数
def get_prompt(category: str, name: str, version: str = None, **parameters) -> str:
"""获取提示词的便捷函数"""
return PromptManager.get_prompt(category, name, version, parameters)
def validate_prompt_output(output: Union[str, Dict], category: str, name: str, version: str = None) -> Any:
"""验证提示词输出的便捷函数"""
return PromptManager.validate_output(output, category, name, version)

View File

@ -0,0 +1,222 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project: NarratoAI
@File : registry.py
@Author : viccy同学
@Date : 2025/1/7
@Description: 提示词注册机制
"""
from typing import Dict, List, Optional, Tuple
from collections import defaultdict
from loguru import logger
from .base import BasePrompt, ModelType, OutputFormat
from .exceptions import (
PromptNotFoundError,
PromptRegistrationError,
PromptVersionError
)
class PromptRegistry:
"""提示词注册表"""
def __init__(self):
# 存储结构: {category: {name: {version: prompt}}}
self._prompts: Dict[str, Dict[str, Dict[str, BasePrompt]]] = defaultdict(
lambda: defaultdict(dict)
)
# 默认版本映射: {category: {name: default_version}}
self._default_versions: Dict[str, Dict[str, str]] = defaultdict(dict)
def register(self, prompt: BasePrompt, is_default: bool = True) -> None:
"""
注册提示词
Args:
prompt: 提示词实例
is_default: 是否设为默认版本
"""
category = prompt.category
name = prompt.name
version = prompt.version
# 检查是否已存在相同版本
if version in self._prompts[category][name]:
raise PromptRegistrationError(
category=category,
name=name,
reason=f"版本 {version} 已存在"
)
# 注册提示词
self._prompts[category][name][version] = prompt
# 设置默认版本
if is_default or name not in self._default_versions[category]:
self._default_versions[category][name] = version
logger.info(f"已注册提示词: {category}.{name} v{version}")
def get(self, category: str, name: str, version: Optional[str] = None) -> BasePrompt:
"""
获取提示词
Args:
category: 分类
name: 名称
version: 版本为None时使用默认版本
Returns:
提示词实例
"""
if category not in self._prompts:
raise PromptNotFoundError(category, name, version)
if name not in self._prompts[category]:
raise PromptNotFoundError(category, name, version)
# 确定版本
if version is None:
if name not in self._default_versions[category]:
raise PromptNotFoundError(category, name, version)
version = self._default_versions[category][name]
if version not in self._prompts[category][name]:
raise PromptNotFoundError(category, name, version)
return self._prompts[category][name][version]
def list_categories(self) -> List[str]:
"""列出所有分类"""
return list(self._prompts.keys())
def list_prompts(self, category: str) -> List[str]:
"""列出指定分类下的所有提示词名称"""
if category not in self._prompts:
return []
return list(self._prompts[category].keys())
def list_versions(self, category: str, name: str) -> List[str]:
"""列出指定提示词的所有版本"""
if category not in self._prompts or name not in self._prompts[category]:
return []
return list(self._prompts[category][name].keys())
def get_default_version(self, category: str, name: str) -> Optional[str]:
"""获取默认版本"""
return self._default_versions.get(category, {}).get(name)
def set_default_version(self, category: str, name: str, version: str) -> None:
"""设置默认版本"""
if (category not in self._prompts or
name not in self._prompts[category] or
version not in self._prompts[category][name]):
raise PromptVersionError(category, name, version, "版本不存在")
self._default_versions[category][name] = version
logger.info(f"已设置默认版本: {category}.{name} -> v{version}")
def exists(self, category: str, name: str, version: Optional[str] = None) -> bool:
"""检查提示词是否存在"""
try:
self.get(category, name, version)
return True
except PromptNotFoundError:
return False
def remove(self, category: str, name: str, version: Optional[str] = None) -> None:
"""移除提示词"""
if version is None:
# 移除所有版本
if category in self._prompts and name in self._prompts[category]:
del self._prompts[category][name]
if name in self._default_versions.get(category, {}):
del self._default_versions[category][name]
logger.info(f"已移除提示词所有版本: {category}.{name}")
else:
# 移除指定版本
if (category in self._prompts and
name in self._prompts[category] and
version in self._prompts[category][name]):
del self._prompts[category][name][version]
# 如果移除的是默认版本,需要重新设置默认版本
if (self._default_versions.get(category, {}).get(name) == version and
self._prompts[category][name]):
# 选择最新版本作为默认版本
new_default = max(self._prompts[category][name].keys())
self._default_versions[category][name] = new_default
logger.info(f"默认版本已更新: {category}.{name} -> v{new_default}")
logger.info(f"已移除提示词版本: {category}.{name} v{version}")
def search(self,
keyword: str = None,
category: str = None,
model_type: ModelType = None,
output_format: OutputFormat = None) -> List[Tuple[str, str, str]]:
"""
搜索提示词
Args:
keyword: 关键词在名称和描述中搜索
category: 分类过滤
model_type: 模型类型过滤
output_format: 输出格式过滤
Returns:
匹配的提示词列表 [(category, name, version), ...]
"""
results = []
categories = [category] if category else self._prompts.keys()
for cat in categories:
for name in self._prompts[cat]:
for version, prompt in self._prompts[cat][name].items():
# 关键词过滤
if keyword:
if (keyword.lower() not in name.lower() and
keyword.lower() not in prompt.metadata.description.lower()):
continue
# 模型类型过滤
if model_type and prompt.metadata.model_type != model_type:
continue
# 输出格式过滤
if output_format and prompt.metadata.output_format != output_format:
continue
results.append((cat, name, version))
return results
def get_stats(self) -> Dict[str, int]:
"""获取注册表统计信息"""
total_prompts = 0
total_versions = 0
for category in self._prompts:
for name in self._prompts[category]:
total_prompts += 1
total_versions += len(self._prompts[category][name])
return {
"categories": len(self._prompts),
"prompts": total_prompts,
"versions": total_versions
}
# 全局注册表实例
_global_registry = PromptRegistry()
def get_registry() -> PromptRegistry:
"""获取全局注册表实例"""
return _global_registry

View File

@ -0,0 +1,33 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project: NarratoAI
@File : __init__.py
@Author : viccy同学
@Date : 2025/1/7
@Description: 短剧混剪提示词模块
"""
from .subtitle_analysis import SubtitleAnalysisPrompt
from .plot_extraction import PlotExtractionPrompt
from ..manager import PromptManager
def register_prompts():
"""注册短剧混剪相关的提示词"""
# 注册字幕分析提示词
subtitle_analysis_prompt = SubtitleAnalysisPrompt()
PromptManager.register_prompt(subtitle_analysis_prompt, is_default=True)
# 注册爆点提取提示词
plot_extraction_prompt = PlotExtractionPrompt()
PromptManager.register_prompt(plot_extraction_prompt, is_default=True)
__all__ = [
"SubtitleAnalysisPrompt",
"PlotExtractionPrompt",
"register_prompts"
]

View File

@ -0,0 +1,70 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project: NarratoAI
@File : plot_extraction.py
@Author : viccy同学
@Date : 2025/1/7
@Description: 短剧爆点提取提示词
"""
from ..base import TextPrompt, PromptMetadata, ModelType, OutputFormat
class PlotExtractionPrompt(TextPrompt):
"""短剧爆点提取提示词"""
def __init__(self):
metadata = PromptMetadata(
name="plot_extraction",
category="short_drama_editing",
version="v1.0",
description="根据剧情梗概和字幕内容,精确定位关键剧情的时间段",
model_type=ModelType.TEXT,
output_format=OutputFormat.JSON,
tags=["短剧", "爆点定位", "时间戳", "剧情提取"],
parameters=["subtitle_content", "plot_summary", "plot_titles"]
)
super().__init__(metadata)
self._system_prompt = "你是一名短剧编剧,非常擅长根据字幕中分析视频中关键剧情出现的具体时间段。"
def get_template(self) -> str:
return """请仔细阅读剧情梗概和爆点内容,然后在字幕中找出每个爆点发生的具体时间段和爆点前后的详细剧情。
剧情梗概
${plot_summary}
需要定位的爆点内容
${plot_titles}
字幕内容
${subtitle_content}
分析要求
1. 为每个爆点找到对应的具体时间段
2. 时间段要准确反映该爆点的完整发展过程
3. 提供爆点前后的详细剧情描述
4. 确保时间戳格式正确且存在于字幕中
5. 选择最具戏剧张力的时间段
请返回一个JSON对象包含一个名为"plot_points"的数组数组中包含多个对象每个对象都要包含以下字段
{
"plot_points": [
{
"timestamp": "时间段格式为xx:xx:xx,xxx-xx:xx:xx,xxx",
"title": "关键剧情的主题",
"picture": "关键剧情前后的详细剧情描述,包括人物对话、动作、情感变化等"
}
]
}
重要要求
1. 请确保返回的是合法的JSON格式
2. 时间戳必须严格按照字幕中的格式
3. 剧情描述要详细具体包含关键对话和动作
4. 每个爆点的时间段要合理不能过短或过长
5. 严禁虚构不存在的时间戳或剧情内容
6. 只输出JSON内容不要添加任何说明文字"""

View File

@ -0,0 +1,68 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project: NarratoAI
@File : subtitle_analysis.py
@Author : viccy同学
@Date : 2025/1/7
@Description: 短剧字幕分析提示词
"""
from ..base import TextPrompt, PromptMetadata, ModelType, OutputFormat
class SubtitleAnalysisPrompt(TextPrompt):
"""短剧字幕分析提示词"""
def __init__(self):
metadata = PromptMetadata(
name="subtitle_analysis",
category="short_drama_editing",
version="v1.0",
description="分析短剧字幕内容,提取剧情梗概和关键情节点",
model_type=ModelType.TEXT,
output_format=OutputFormat.JSON,
tags=["短剧", "字幕分析", "剧情梗概", "情节提取"],
parameters=["subtitle_content", "custom_clips"]
)
super().__init__(metadata)
self._system_prompt = "你是一名短剧编剧和内容分析师,擅长从字幕中提取剧情要点和关键情节。"
def get_template(self) -> str:
return """请仔细分析以下短剧字幕内容,提取剧情梗概和关键情节点。
字幕内容
${subtitle_content}
分析要求
1. 提取整体剧情梗概概括主要故事线和核心冲突
2. 识别 ${custom_clips} 个最具吸引力的关键情节点爆点
3. 每个情节点要包含具体的时间段和详细描述
4. 关注剧情的转折点冲突高潮情感爆发等关键时刻
5. 确保选择的情节点具有强烈的戏剧张力和观看价值
请按照以下JSON格式输出分析结果
{
"summary": "整体剧情梗概,简要概括主要故事线、角色关系和核心冲突",
"plot_titles": [
"情节点1标题",
"情节点2标题",
"情节点3标题"
],
"analysis_details": {
"main_characters": ["主要角色1", "主要角色2"],
"story_theme": "故事主题",
"conflict_type": "冲突类型(如:爱情、复仇、家庭等)",
"emotional_peaks": ["情感高潮点1", "情感高潮点2"]
}
}
重要要求
1. 必须输出有效的JSON格式不能包含注释或其他文字
2. 剧情梗概要简洁明了突出核心看点
3. 情节点标题要吸引人体现戏剧冲突
4. 严禁虚构不存在的剧情内容
5. 分析要客观准确基于字幕实际内容"""

View File

@ -0,0 +1,33 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project: NarratoAI
@File : __init__.py
@Author : viccy同学
@Date : 2025/1/7
@Description: 短剧解说提示词模块
"""
from .plot_analysis import PlotAnalysisPrompt
from .script_generation import ScriptGenerationPrompt
from ..manager import PromptManager
def register_prompts():
"""注册短剧解说相关的提示词"""
# 注册剧情分析提示词
plot_analysis_prompt = PlotAnalysisPrompt()
PromptManager.register_prompt(plot_analysis_prompt, is_default=True)
# 注册解说脚本生成提示词
script_generation_prompt = ScriptGenerationPrompt()
PromptManager.register_prompt(script_generation_prompt, is_default=True)
__all__ = [
"PlotAnalysisPrompt",
"ScriptGenerationPrompt",
"register_prompts"
]

View File

@ -1,15 +1,37 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
"""
@Project: NarratoAI
@File : prompt
@Author : 小林同学
@Date : 2025/5/9 上午12:57
'''
# 字幕剧情分析提示词
subtitle_plot_analysis_v1 = """
# 角色
@File : plot_analysis.py
@Author : viccy同学
@Date : 2025/1/7
@Description: 短剧剧情分析提示词
"""
from ..base import TextPrompt, PromptMetadata, ModelType, OutputFormat
class PlotAnalysisPrompt(TextPrompt):
"""短剧剧情分析提示词"""
def __init__(self):
metadata = PromptMetadata(
name="plot_analysis",
category="short_drama_narration",
version="v1.0",
description="分析短剧字幕内容,提供详细的剧情分析和分段解析",
model_type=ModelType.TEXT,
output_format=OutputFormat.TEXT,
tags=["短剧", "剧情分析", "字幕解析", "分段分析"],
parameters=["subtitle_content"]
)
super().__init__(metadata)
self._system_prompt = "你是一位专业的剧本分析师和剧情概括助手。"
def get_template(self) -> str:
return """# 角色
你是一位专业的剧本分析师和剧情概括助手
# 任务
@ -62,36 +84,7 @@ subtitle_plot_analysis_v1 = """
# 限制
1. 严禁输出与分析结果无关的内容
2.
2. 时间戳必须严格按照字幕中的实际时间
# 请处理以下字幕:
"""
plot_writing = """
我是一个影视解说up主需要为我的粉丝讲解短剧%s的剧情目前正在解说剧情希望能让粉丝通过我的解说了解剧情并且产生 继续观看的兴趣请生成一篇解说脚本包含解说文案以及穿插原声的片段下面<plot>中的内容是短剧的剧情概述
<plot>
%s
</plot>
请使用 json 格式进行输出使用 <output> 中的输出格式
<output>
{
"items": [
{
"_id": 1, # 唯一递增id
"timestamp": "00:00:05,390-00:00:10,430",
"picture": "剧情描述或者备注",
"narration": "解说文案,如果片段为穿插的原片片段,可以直接使用 ‘播放原片+_id 进行占位",
"OST": "值为 0 表示当前片段为解说片段,值为 1 表示当前片段为穿插的原片"
}
}
</output>
<restriction>
1. 只输出 json 内容不要输出其他任何说明性的文字
2. 解说文案的语言使用 简体中文
3. 严禁虚构剧情所有画面只能从 <polt> 中摘取
4. 严禁虚构时间戳所有时间戳范围只能从 <polt> 中摘取
</restriction>
"""
${subtitle_content}"""

View File

@ -0,0 +1,63 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project: NarratoAI
@File : script_generation.py
@Author : viccy同学
@Date : 2025/1/7
@Description: 短剧解说脚本生成提示词
"""
from ..base import ParameterizedPrompt, PromptMetadata, ModelType, OutputFormat
class ScriptGenerationPrompt(ParameterizedPrompt):
"""短剧解说脚本生成提示词"""
def __init__(self):
metadata = PromptMetadata(
name="script_generation",
category="short_drama_narration",
version="v1.0",
description="根据剧情分析生成短剧解说脚本,包含解说文案和原声片段",
model_type=ModelType.TEXT,
output_format=OutputFormat.JSON,
tags=["短剧", "解说脚本", "文案生成", "原声片段"],
parameters=["drama_name", "plot_analysis"]
)
super().__init__(metadata, required_parameters=["drama_name", "plot_analysis"])
self._system_prompt = "你是一位专业的短视频解说脚本撰写专家。你必须严格按照JSON格式输出不能包含任何其他文字、说明或代码块标记。"
def get_template(self) -> str:
return """我是一个影视解说up主需要为我的粉丝讲解短剧《${drama_name}》的剧情,目前正在解说剧情,希望能让粉丝通过我的解说了解剧情,并且产生继续观看的兴趣,请生成一篇解说脚本,包含解说文案,以及穿插原声的片段,下面<plot>中的内容是短剧的剧情概述:
<plot>
${plot_analysis}
</plot>
请严格按照以下JSON格式输出不要添加任何其他文字说明或代码块标记
{
"items": [
{
"_id": 1,
"timestamp": "00:00:05,390-00:00:10,430",
"picture": "剧情描述或者备注",
"narration": "解说文案,如果片段为穿插的原片片段,可以直接使用 '播放原片+_id' 进行占位",
"OST": 0
}
]
}
重要要求
1. 只输出 json 内容不要输出其他任何说明性的文字
2. 解说文案必须遵循---的线性时间链
3. 解说文案需包含角色微表情动作细节场景氛围的描写每段80-150
4. 通过细节关联普遍情感如遗憾和解成长避免直白抒情
5. 所有细节严格源自<plot>可对角色行为进行合理心理推导但不虚构剧情
6. 时间戳从<plot>摘取可根据解说内容拆分原时间片段如将10秒拆分为两个5秒
7. 解说与原片穿插比例控制在7:3关键情绪点保留原片原声
8. 严禁跳脱剧情发展顺序所有描述必须符合先发生A再发生BA导致B的逻辑
9. 强化流程感让观众清晰感知剧情推进的先后顺序"""

View File

@ -0,0 +1,180 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project: NarratoAI
@File : template.py
@Author : viccy同学
@Date : 2025/1/7
@Description: 模板渲染引擎
"""
import re
from typing import Dict, Any, List, Optional
from string import Template
from loguru import logger
from .exceptions import TemplateRenderError
class TemplateRenderer:
"""模板渲染器"""
def __init__(self):
self._custom_filters = {}
def register_filter(self, name: str, func: callable) -> None:
"""注册自定义过滤器"""
self._custom_filters[name] = func
logger.debug(f"已注册模板过滤器: {name}")
def render(self, template: str, parameters: Dict[str, Any] = None) -> str:
"""
渲染模板
Args:
template: 模板字符串
parameters: 参数字典
Returns:
渲染后的字符串
"""
parameters = parameters or {}
try:
# 使用简单的字符串替换进行参数替换
rendered = template
for key, value in parameters.items():
# 替换 ${key} 格式的参数
rendered = rendered.replace(f"${{{key}}}", str(value))
# 也替换 $key 格式的参数(为了兼容性)
rendered = rendered.replace(f"${key}", str(value))
# 处理自定义过滤器
rendered = self._apply_filters(rendered, parameters)
return rendered
except Exception as e:
raise TemplateRenderError(
template_name="unknown",
error_message=f"模板渲染失败: {str(e)}"
)
def _apply_filters(self, text: str, parameters: Dict[str, Any]) -> str:
"""应用自定义过滤器"""
# 查找过滤器模式: ${variable|filter_name}
filter_pattern = r'\$\{([^}]+)\|([^}]+)\}'
def replace_filter(match):
var_name = match.group(1).strip()
filter_name = match.group(2).strip()
if filter_name not in self._custom_filters:
logger.warning(f"未知的过滤器: {filter_name}")
return match.group(0) # 返回原始文本
if var_name not in parameters:
logger.warning(f"参数不存在: {var_name}")
return match.group(0) # 返回原始文本
try:
filter_func = self._custom_filters[filter_name]
filtered_value = filter_func(parameters[var_name])
return str(filtered_value)
except Exception as e:
logger.error(f"过滤器执行失败 {filter_name}: {str(e)}")
return match.group(0) # 返回原始文本
return re.sub(filter_pattern, replace_filter, text)
def extract_variables(self, template: str) -> List[str]:
"""提取模板中的变量名"""
# 匹配 ${variable} 和 ${variable|filter} 模式
pattern = r'\$\{([^}|]+)(?:\|[^}]+)?\}'
matches = re.findall(pattern, template)
return list(set(match.strip() for match in matches))
def validate_template(self, template: str, required_params: List[str] = None) -> bool:
"""验证模板"""
try:
# 提取模板变量
template_vars = self.extract_variables(template)
# 检查必需参数
if required_params:
missing_params = set(required_params) - set(template_vars)
if missing_params:
raise TemplateRenderError(
template_name="validation",
error_message="模板缺少必需参数",
missing_params=list(missing_params)
)
# 尝试渲染测试
test_params = {var: f"test_{var}" for var in template_vars}
self.render(template, test_params)
return True
except Exception as e:
logger.error(f"模板验证失败: {str(e)}")
return False
# 内置过滤器
def _upper_filter(value: Any) -> str:
"""转换为大写"""
return str(value).upper()
def _lower_filter(value: Any) -> str:
"""转换为小写"""
return str(value).lower()
def _title_filter(value: Any) -> str:
"""转换为标题格式"""
return str(value).title()
def _strip_filter(value: Any) -> str:
"""去除首尾空白"""
return str(value).strip()
def _truncate_filter(value: Any, length: int = 100) -> str:
"""截断文本"""
text = str(value)
if len(text) <= length:
return text
return text[:length] + "..."
def _json_filter(value: Any) -> str:
"""转换为JSON字符串"""
import json
return json.dumps(value, ensure_ascii=False, indent=2)
# 全局渲染器实例
_global_renderer = TemplateRenderer()
# 注册内置过滤器
_global_renderer.register_filter("upper", _upper_filter)
_global_renderer.register_filter("lower", _lower_filter)
_global_renderer.register_filter("title", _title_filter)
_global_renderer.register_filter("strip", _strip_filter)
_global_renderer.register_filter("truncate", _truncate_filter)
_global_renderer.register_filter("json", _json_filter)
def get_renderer() -> TemplateRenderer:
"""获取全局渲染器实例"""
return _global_renderer
def render_template(template: str, parameters: Dict[str, Any] = None) -> str:
"""便捷的模板渲染函数"""
return _global_renderer.render(template, parameters)

View File

@ -0,0 +1,250 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project: NarratoAI
@File : validators.py
@Author : viccy同学
@Date : 2025/1/7
@Description: 提示词输出验证器
"""
import json
import re
from typing import Dict, Any, List, Optional, Union
from loguru import logger
from .base import OutputFormat
from .exceptions import PromptValidationError
class PromptOutputValidator:
"""提示词输出验证器"""
@staticmethod
def validate_json(output: str, schema: Dict[str, Any] = None) -> Dict[str, Any]:
"""
验证JSON输出
Args:
output: 输出字符串
schema: JSON schema可选
Returns:
解析后的JSON对象
"""
try:
# 清理输出(移除可能的代码块标记)
cleaned_output = PromptOutputValidator._clean_json_output(output)
# 解析JSON
parsed = json.loads(cleaned_output)
# Schema验证如果提供
if schema:
PromptOutputValidator._validate_json_schema(parsed, schema)
return parsed
except json.JSONDecodeError as e:
raise PromptValidationError(f"JSON格式错误: {str(e)}")
except Exception as e:
raise PromptValidationError(f"JSON验证失败: {str(e)}")
@staticmethod
def validate_narration_script(output: Union[str, Dict]) -> Dict[str, Any]:
"""
验证解说文案输出格式
Args:
output: 输出内容字符串或字典
Returns:
验证后的解说文案数据
"""
# 如果是字符串先解析为JSON
if isinstance(output, str):
data = PromptOutputValidator.validate_json(output)
else:
data = output
# 验证必需字段
if "items" not in data:
raise PromptValidationError("解说文案缺少 'items' 字段")
items = data["items"]
if not isinstance(items, list):
raise PromptValidationError("'items' 字段必须是数组")
if not items:
raise PromptValidationError("解说文案不能为空")
# 验证每个item
for i, item in enumerate(items):
PromptOutputValidator._validate_narration_item(item, i)
logger.debug(f"解说文案验证通过,包含 {len(items)} 个片段")
return data
@staticmethod
def validate_plot_analysis(output: Union[str, Dict]) -> Dict[str, Any]:
"""
验证剧情分析输出格式
Args:
output: 输出内容
Returns:
验证后的剧情分析数据
"""
if isinstance(output, str):
data = PromptOutputValidator.validate_json(output)
else:
data = output
# 验证剧情分析必需字段
required_fields = ["summary", "plot_points"]
for field in required_fields:
if field not in data:
raise PromptValidationError(f"剧情分析缺少 '{field}' 字段")
# 验证plot_points
plot_points = data["plot_points"]
if not isinstance(plot_points, list):
raise PromptValidationError("'plot_points' 字段必须是数组")
for i, point in enumerate(plot_points):
PromptOutputValidator._validate_plot_point(point, i)
logger.debug(f"剧情分析验证通过,包含 {len(plot_points)} 个情节点")
return data
@staticmethod
def _clean_json_output(output: str) -> str:
"""清理JSON输出"""
# 移除可能的代码块标记
output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE)
output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE)
# 移除前后空白
output = output.strip()
# 尝试提取JSON部分如果有其他文本
json_match = re.search(r'\{.*\}', output, re.DOTALL)
if json_match:
output = json_match.group(0)
return output
@staticmethod
def _validate_json_schema(data: Dict[str, Any], schema: Dict[str, Any]) -> None:
"""验证JSON Schema"""
# 简单的schema验证实现
for field, field_type in schema.items():
if field not in data:
raise PromptValidationError(f"缺少必需字段: {field}")
if not isinstance(data[field], field_type):
raise PromptValidationError(
f"字段 '{field}' 类型错误,期望: {field_type.__name__},实际: {type(data[field]).__name__}"
)
@staticmethod
def _validate_narration_item(item: Dict[str, Any], index: int) -> None:
"""验证解说文案项目"""
required_fields = ["_id", "timestamp", "picture", "narration"]
for field in required_fields:
if field not in item:
raise PromptValidationError(f"{index + 1} 个片段缺少 '{field}' 字段")
# 验证_id
if not isinstance(item["_id"], int) or item["_id"] <= 0:
raise PromptValidationError(f"{index + 1} 个片段的 '_id' 必须是正整数")
# 验证timestamp格式
timestamp = item["timestamp"]
if not isinstance(timestamp, str):
raise PromptValidationError(f"{index + 1} 个片段的 'timestamp' 必须是字符串")
# 验证时间戳格式 (HH:MM:SS,mmm-HH:MM:SS,mmm)
timestamp_pattern = r'^\d{2}:\d{2}:\d{2},\d{3}-\d{2}:\d{2}:\d{2},\d{3}$'
if not re.match(timestamp_pattern, timestamp):
raise PromptValidationError(
f"{index + 1} 个片段的时间戳格式错误,应为 'HH:MM:SS,mmm-HH:MM:SS,mmm'"
)
# 验证文本字段不为空
for field in ["picture", "narration"]:
if not isinstance(item[field], str) or not item[field].strip():
raise PromptValidationError(f"{index + 1} 个片段的 '{field}' 不能为空")
# 验证OST字段如果存在
if "OST" in item:
if not isinstance(item["OST"], int) or item["OST"] not in [0, 1, 2]:
raise PromptValidationError(
f"{index + 1} 个片段的 'OST' 必须是 0、1 或 2"
)
@staticmethod
def _validate_plot_point(point: Dict[str, Any], index: int) -> None:
"""验证剧情点"""
required_fields = ["timestamp", "title", "picture"]
for field in required_fields:
if field not in point:
raise PromptValidationError(f"{index + 1} 个剧情点缺少 '{field}' 字段")
# 验证字段类型和内容
for field in required_fields:
if not isinstance(point[field], str) or not point[field].strip():
raise PromptValidationError(f"{index + 1} 个剧情点的 '{field}' 不能为空")
# 验证时间戳格式
timestamp = point["timestamp"]
# 支持多种时间戳格式
patterns = [
r'^\d{2}:\d{2}:\d{2},\d{3}-\d{2}:\d{2}:\d{2},\d{3}$', # HH:MM:SS,mmm-HH:MM:SS,mmm
r'^\d{2}:\d{2}:\d{2}-\d{2}:\d{2}:\d{2}$', # HH:MM:SS-HH:MM:SS
]
if not any(re.match(pattern, timestamp) for pattern in patterns):
raise PromptValidationError(
f"{index + 1} 个剧情点的时间戳格式错误"
)
@staticmethod
def validate_by_format(output: str, format_type: OutputFormat, schema: Dict[str, Any] = None) -> Any:
"""
根据格式类型验证输出
Args:
output: 输出内容
format_type: 输出格式类型
schema: 验证schema可选
Returns:
验证后的数据
"""
if format_type == OutputFormat.JSON:
return PromptOutputValidator.validate_json(output, schema)
elif format_type == OutputFormat.TEXT:
return output.strip()
elif format_type == OutputFormat.MARKDOWN:
return output.strip()
elif format_type == OutputFormat.STRUCTURED:
# 结构化数据需要根据具体类型处理
return PromptOutputValidator.validate_json(output, schema)
else:
raise PromptValidationError(f"不支持的输出格式: {format_type}")
# 便捷函数
def validate_json_output(output: str, schema: Dict[str, Any] = None) -> Dict[str, Any]:
"""验证JSON输出的便捷函数"""
return PromptOutputValidator.validate_json(output, schema)
def validate_narration_output(output: Union[str, Dict]) -> Dict[str, Any]:
"""验证解说文案输出的便捷函数"""
return PromptOutputValidator.validate_narration_script(output)

View File

@ -140,14 +140,27 @@ class ScriptGenerator:
# 获取Gemini配置
vision_api_key = config.app.get("vision_gemini_api_key")
vision_model = config.app.get("vision_gemini_model_name")
vision_base_url = config.app.get("vision_gemini_base_url")
if not vision_api_key or not vision_model:
raise ValueError("未配置 Gemini API Key 或者模型")
analyzer = gemini_analyzer.VisionAnalyzer(
model_name=vision_model,
api_key=vision_api_key,
)
# 根据提供商类型选择合适的分析器
if vision_provider == 'gemini(openai)':
# 使用OpenAI兼容的Gemini代理
from app.utils.gemini_openai_analyzer import GeminiOpenAIAnalyzer
analyzer = GeminiOpenAIAnalyzer(
model_name=vision_model,
api_key=vision_api_key,
base_url=vision_base_url
)
else:
# 使用原生Gemini分析器
analyzer = gemini_analyzer.VisionAnalyzer(
model_name=vision_model,
api_key=vision_api_key,
base_url=vision_base_url
)
progress_callback(40, "正在分析关键帧...")
@ -213,13 +226,35 @@ class ScriptGenerator:
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
text_api_key = config.app.get(f'text_{text_provider}_api_key')
text_model = config.app.get(f'text_{text_provider}_model_name')
text_base_url = config.app.get(f'text_{text_provider}_base_url')
processor = ScriptProcessor(
model_name=text_model,
api_key=text_api_key,
prompt=custom_prompt,
video_theme=video_theme
)
# 根据提供商类型选择合适的处理器
if text_provider == 'gemini(openai)':
# 使用OpenAI兼容的Gemini代理
from app.utils.script_generator import GeminiOpenAIGenerator
generator = GeminiOpenAIGenerator(
model_name=text_model,
api_key=text_api_key,
prompt=custom_prompt,
base_url=text_base_url
)
processor = ScriptProcessor(
model_name=text_model,
api_key=text_api_key,
base_url=text_base_url,
prompt=custom_prompt,
video_theme=video_theme
)
processor.generator = generator
else:
# 使用标准处理器包括原生Gemini
processor = ScriptProcessor(
model_name=text_model,
api_key=text_api_key,
base_url=text_base_url,
prompt=custom_prompt,
video_theme=video_theme
)
return processor.process_frames(frame_content_list)

View File

@ -4,7 +4,7 @@
'''
@Project: NarratoAI
@File : update_script
@Author : 小林同学
@Author : Viccy同学
@Date : 2025/5/6 下午11:00
'''

View File

@ -5,53 +5,162 @@ from pathlib import Path
from loguru import logger
from tqdm import tqdm
import asyncio
from tenacity import retry, stop_after_attempt, RetryError, retry_if_exception_type, wait_exponential
from google.api_core import exceptions
import google.generativeai as genai
from tenacity import retry, stop_after_attempt, retry_if_exception_type, wait_exponential
import requests
import PIL.Image
import traceback
import base64
import io
from app.utils import utils
class VisionAnalyzer:
"""视觉分析器类"""
"""原生Gemini视觉分析器类"""
def __init__(self, model_name: str = "gemini-1.5-flash", api_key: str = None):
def __init__(self, model_name: str = "gemini-2.0-flash-exp", api_key: str = None, base_url: str = None):
"""初始化视觉分析器"""
if not api_key:
raise ValueError("必须提供API密钥")
self.model_name = model_name
self.api_key = api_key
self.base_url = base_url or "https://generativelanguage.googleapis.com/v1beta"
# 初始化配置
self._configure_client()
def _configure_client(self):
"""配置API客户端"""
genai.configure(api_key=self.api_key)
# 开放 Gemini 模型安全设置
from google.generativeai.types import HarmCategory, HarmBlockThreshold
safety_settings = {
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
self.model = genai.GenerativeModel(self.model_name, safety_settings=safety_settings)
"""配置原生Gemini API客户端"""
# 使用原生Gemini REST API
self.client = None
logger.info(f"配置原生Gemini API端点: {self.base_url}, 模型: {self.model_name}")
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(exceptions.ResourceExhausted)
retry=retry_if_exception_type(requests.exceptions.RequestException)
)
async def _generate_content_with_retry(self, prompt, batch):
"""使用重试机制的内部方法来调用 generate_content_async"""
"""使用重试机制调用原生Gemini API"""
try:
return await self.model.generate_content_async([prompt, *batch])
except exceptions.ResourceExhausted as e:
print(f"API配额限制: {str(e)}")
raise RetryError("API调用失败")
return await self._generate_with_gemini_api(prompt, batch)
except requests.exceptions.RequestException as e:
logger.warning(f"Gemini API请求异常: {str(e)}")
raise
except Exception as e:
logger.error(f"Gemini API生成内容时发生错误: {str(e)}")
raise
async def _generate_with_gemini_api(self, prompt, batch):
"""使用原生Gemini REST API生成内容"""
# 将PIL图片转换为base64编码
image_parts = []
for img in batch:
# 将PIL图片转换为字节流
img_buffer = io.BytesIO()
img.save(img_buffer, format='JPEG', quality=85) # 优化图片质量
img_bytes = img_buffer.getvalue()
# 转换为base64
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
image_parts.append({
"inline_data": {
"mime_type": "image/jpeg",
"data": img_base64
}
})
# 构建符合官方文档的请求数据
request_data = {
"contents": [{
"parts": [
{"text": prompt},
*image_parts
]
}],
"generationConfig": {
"temperature": 1.0,
"topK": 40,
"topP": 0.95,
"maxOutputTokens": 8192,
"candidateCount": 1,
"stopSequences": []
},
"safetySettings": [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
}
]
}
# 构建请求URL
url = f"{self.base_url}/models/{self.model_name}:generateContent?key={self.api_key}"
# 发送请求
response = await asyncio.to_thread(
requests.post,
url,
json=request_data,
headers={
"Content-Type": "application/json",
"User-Agent": "NarratoAI/1.0"
},
timeout=120 # 增加超时时间
)
# 处理HTTP错误
if response.status_code == 429:
raise requests.exceptions.RequestException(f"API配额限制: {response.text}")
elif response.status_code == 400:
raise Exception(f"请求参数错误: {response.text}")
elif response.status_code == 403:
raise Exception(f"API密钥无效或权限不足: {response.text}")
elif response.status_code != 200:
raise Exception(f"Gemini API请求失败: {response.status_code} - {response.text}")
response_data = response.json()
# 检查响应格式
if "candidates" not in response_data or not response_data["candidates"]:
raise Exception("Gemini API返回无效响应可能触发了安全过滤")
candidate = response_data["candidates"][0]
# 检查是否被安全过滤阻止
if "finishReason" in candidate and candidate["finishReason"] == "SAFETY":
raise Exception("内容被Gemini安全过滤器阻止")
if "content" not in candidate or "parts" not in candidate["content"]:
raise Exception("Gemini API返回内容格式错误")
# 提取文本内容
text_content = ""
for part in candidate["content"]["parts"]:
if "text" in part:
text_content += part["text"]
if not text_content.strip():
raise Exception("Gemini API返回空内容")
# 创建兼容的响应对象
class CompatibleResponse:
def __init__(self, text):
self.text = text
return CompatibleResponse(text_content)
async def analyze_images(self,
images: Union[List[str], List[PIL.Image.Image]],

View File

@ -0,0 +1,177 @@
"""
OpenAI兼容的Gemini视觉分析器
使用标准OpenAI格式调用Gemini代理服务
"""
import json
from typing import List, Union, Dict
import os
from pathlib import Path
from loguru import logger
from tqdm import tqdm
import asyncio
from tenacity import retry, stop_after_attempt, retry_if_exception_type, wait_exponential
import requests
import PIL.Image
import traceback
import base64
import io
from app.utils import utils
class GeminiOpenAIAnalyzer:
"""OpenAI兼容的Gemini视觉分析器类"""
def __init__(self, model_name: str = "gemini-2.0-flash-exp", api_key: str = None, base_url: str = None):
"""初始化OpenAI兼容的Gemini分析器"""
if not api_key:
raise ValueError("必须提供API密钥")
if not base_url:
raise ValueError("必须提供OpenAI兼容的代理端点URL")
self.model_name = model_name
self.api_key = api_key
self.base_url = base_url.rstrip('/')
# 初始化OpenAI客户端
self._configure_client()
def _configure_client(self):
"""配置OpenAI兼容的客户端"""
from openai import OpenAI
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url
)
logger.info(f"配置OpenAI兼容Gemini代理端点: {self.base_url}, 模型: {self.model_name}")
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((requests.exceptions.RequestException, Exception))
)
async def _generate_content_with_retry(self, prompt, batch):
"""使用重试机制调用OpenAI兼容的Gemini代理"""
try:
return await self._generate_with_openai_api(prompt, batch)
except Exception as e:
logger.warning(f"OpenAI兼容Gemini代理请求异常: {str(e)}")
raise
async def _generate_with_openai_api(self, prompt, batch):
"""使用OpenAI兼容接口生成内容"""
# 将PIL图片转换为base64编码
image_contents = []
for img in batch:
# 将PIL图片转换为字节流
img_buffer = io.BytesIO()
img.save(img_buffer, format='JPEG', quality=85)
img_bytes = img_buffer.getvalue()
# 转换为base64
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
image_contents.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{img_base64}"
}
})
# 构建OpenAI格式的消息
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
*image_contents
]
}
]
# 调用OpenAI兼容接口
response = await asyncio.to_thread(
self.client.chat.completions.create,
model=self.model_name,
messages=messages,
max_tokens=4000,
temperature=1.0
)
# 创建兼容的响应对象
class CompatibleResponse:
def __init__(self, text):
self.text = text
return CompatibleResponse(response.choices[0].message.content)
async def analyze_images(self,
images: List[Union[str, Path, PIL.Image.Image]],
prompt: str,
batch_size: int = 10) -> List[str]:
"""
分析图片并返回结果
Args:
images: 图片路径列表或PIL图片对象列表
prompt: 分析提示词
batch_size: 批处理大小
Returns:
分析结果列表
"""
logger.info(f"开始分析 {len(images)} 张图片使用OpenAI兼容Gemini代理")
# 加载图片
loaded_images = []
for img in images:
if isinstance(img, (str, Path)):
try:
pil_img = PIL.Image.open(img)
# 调整图片大小以优化性能
if pil_img.size[0] > 1024 or pil_img.size[1] > 1024:
pil_img.thumbnail((1024, 1024), PIL.Image.Resampling.LANCZOS)
loaded_images.append(pil_img)
except Exception as e:
logger.error(f"加载图片失败 {img}: {str(e)}")
continue
elif isinstance(img, PIL.Image.Image):
loaded_images.append(img)
else:
logger.warning(f"不支持的图片类型: {type(img)}")
continue
if not loaded_images:
raise ValueError("没有有效的图片可以分析")
# 分批处理
results = []
total_batches = (len(loaded_images) + batch_size - 1) // batch_size
for i in tqdm(range(0, len(loaded_images), batch_size),
desc="分析图片批次", total=total_batches):
batch = loaded_images[i:i + batch_size]
try:
response = await self._generate_content_with_retry(prompt, batch)
results.append(response.text)
# 添加延迟以避免API限流
if i + batch_size < len(loaded_images):
await asyncio.sleep(1)
except Exception as e:
logger.error(f"分析批次 {i//batch_size + 1} 失败: {str(e)}")
results.append(f"分析失败: {str(e)}")
logger.info(f"完成图片分析,共处理 {len(results)} 个批次")
return results
def analyze_images_sync(self,
images: List[Union[str, Path, PIL.Image.Image]],
prompt: str,
batch_size: int = 10) -> List[str]:
"""
同步版本的图片分析方法
"""
return asyncio.run(self.analyze_images(images, prompt, batch_size))

View File

@ -6,7 +6,7 @@ from loguru import logger
from typing import List, Dict
from datetime import datetime
from openai import OpenAI
import google.generativeai as genai
import requests
import time
@ -134,59 +134,182 @@ class OpenAIGenerator(BaseGenerator):
class GeminiGenerator(BaseGenerator):
"""Google Gemini API 生成器实现"""
def __init__(self, model_name: str, api_key: str, prompt: str):
"""原生Gemini API 生成器实现"""
def __init__(self, model_name: str, api_key: str, prompt: str, base_url: str = None):
super().__init__(model_name, api_key, prompt)
genai.configure(api_key=api_key)
self.model = genai.GenerativeModel(model_name)
# Gemini特定参数
self.base_url = base_url or "https://generativelanguage.googleapis.com/v1beta"
self.client = None
# 原生Gemini API参数
self.default_params = {
"temperature": self.default_params["temperature"],
"top_p": self.default_params["top_p"],
"candidate_count": 1,
"stop_sequences": None
"topP": self.default_params["top_p"],
"topK": 40,
"maxOutputTokens": 4000,
"candidateCount": 1,
"stopSequences": []
}
class GeminiOpenAIGenerator(BaseGenerator):
"""OpenAI兼容的Gemini代理生成器实现"""
def __init__(self, model_name: str, api_key: str, prompt: str, base_url: str = None):
super().__init__(model_name, api_key, prompt)
if not base_url:
raise ValueError("OpenAI兼容的Gemini代理必须提供base_url")
self.base_url = base_url.rstrip('/')
# 使用OpenAI兼容接口
from openai import OpenAI
self.client = OpenAI(
api_key=api_key,
base_url=base_url
)
# OpenAI兼容接口参数
self.default_params = {
"temperature": self.default_params["temperature"],
"max_tokens": 4000,
"stream": False
}
def _generate(self, messages: list, params: dict) -> any:
"""实现Gemini特定的生成逻辑"""
while True:
"""实现OpenAI兼容Gemini代理的生成逻辑"""
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
**params
)
return response
except Exception as e:
logger.error(f"OpenAI兼容Gemini代理生成错误: {str(e)}")
raise
def _process_response(self, response: any) -> str:
"""处理OpenAI兼容接口的响应"""
if not response or not response.choices:
raise ValueError("OpenAI兼容Gemini代理返回无效响应")
return response.choices[0].message.content.strip()
def _generate(self, messages: list, params: dict) -> any:
"""实现原生Gemini API的生成逻辑"""
max_retries = 3
for attempt in range(max_retries):
try:
# 转换消息格式为Gemini格式
prompt = "\n".join([m["content"] for m in messages])
response = self.model.generate_content(
prompt,
generation_config=params
# 构建请求数据
request_data = {
"contents": [{
"parts": [{"text": prompt}]
}],
"generationConfig": params,
"safetySettings": [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
}
]
}
# 构建请求URL
url = f"{self.base_url}/models/{self.model_name}:generateContent?key={self.api_key}"
# 发送请求
response = requests.post(
url,
json=request_data,
headers={
"Content-Type": "application/json",
"User-Agent": "NarratoAI/1.0"
},
timeout=120
)
# 检查响应是否包含有效内容
if (hasattr(response, 'result') and
hasattr(response.result, 'candidates') and
response.result.candidates):
candidate = response.result.candidates[0]
# 检查是否有内容字段
if not hasattr(candidate, 'content'):
logger.warning("Gemini API 返回速率限制响应等待30秒后重试...")
time.sleep(30) # 等待3秒后重试
if response.status_code == 429:
# 处理限流
wait_time = 65 if attempt == 0 else 30
logger.warning(f"原生Gemini API 触发限流,等待{wait_time}秒后重试...")
time.sleep(wait_time)
continue
if response.status_code == 400:
raise Exception(f"请求参数错误: {response.text}")
elif response.status_code == 403:
raise Exception(f"API密钥无效或权限不足: {response.text}")
elif response.status_code != 200:
raise Exception(f"原生Gemini API请求失败: {response.status_code} - {response.text}")
response_data = response.json()
# 检查响应格式
if "candidates" not in response_data or not response_data["candidates"]:
if attempt < max_retries - 1:
logger.warning("原生Gemini API 返回无效响应等待30秒后重试...")
time.sleep(30)
continue
return response
except Exception as e:
error_str = str(e)
if "429" in error_str:
logger.warning("Gemini API 触发限流等待65秒后重试...")
time.sleep(65) # 等待65秒后重试
else:
raise Exception("原生Gemini API返回无效响应可能触发了安全过滤")
candidate = response_data["candidates"][0]
# 检查是否被安全过滤阻止
if "finishReason" in candidate and candidate["finishReason"] == "SAFETY":
raise Exception("内容被Gemini安全过滤器阻止")
# 创建兼容的响应对象
class CompatibleResponse:
def __init__(self, data):
self.data = data
candidate = data["candidates"][0]
if "content" in candidate and "parts" in candidate["content"]:
self.text = ""
for part in candidate["content"]["parts"]:
if "text" in part:
self.text += part["text"]
else:
self.text = ""
return CompatibleResponse(response_data)
except requests.exceptions.RequestException as e:
if attempt < max_retries - 1:
logger.warning(f"网络请求失败等待30秒后重试: {str(e)}")
time.sleep(30)
continue
else:
logger.error(f"Gemini 生成文案错误: \n{error_str}")
logger.error(f"原生Gemini API请求失败: {str(e)}")
raise
except Exception as e:
if attempt < max_retries - 1 and "429" in str(e):
logger.warning("原生Gemini API 触发限流等待65秒后重试...")
time.sleep(65)
continue
else:
logger.error(f"原生Gemini 生成文案错误: {str(e)}")
raise
def _process_response(self, response: any) -> str:
"""处理Gemini的响应"""
"""处理原生Gemini API的响应"""
if not response or not response.text:
raise ValueError("Invalid response from Gemini API")
raise ValueError("原生Gemini API返回无效响应")
return response.text.strip()
@ -318,7 +441,7 @@ class ScriptProcessor:
# 根据模型名称选择对应的生成器
logger.info(f"文本 LLM 提供商: {model_name}")
if 'gemini' in model_name.lower():
self.generator = GeminiGenerator(model_name, self.api_key, self.prompt)
self.generator = GeminiGenerator(model_name, self.api_key, self.prompt, self.base_url)
elif 'qwen' in model_name.lower():
self.generator = QwenGenerator(model_name, self.api_key, self.prompt, self.base_url)
elif 'moonshot' in model_name.lower():

View File

@ -1,5 +1,5 @@
[app]
project_version="0.6.6"
project_version="0.6.7"
# 支持视频理解的大模型提供商
# gemini (谷歌, 需要 VPN)
# siliconflow (硅基流动)

367
docs/LLM_MIGRATION_GUIDE.md Normal file
View File

@ -0,0 +1,367 @@
# NarratoAI 大模型服务迁移指南
## 📋 概述
本指南帮助开发者将现有代码从旧的大模型调用方式迁移到新的统一LLM服务架构。新架构提供了更好的模块化、错误处理和配置管理。
## 🔄 迁移对比
### 旧的调用方式 vs 新的调用方式
#### 1. 视觉分析器创建
**旧方式:**
```python
from app.utils import gemini_analyzer, qwenvl_analyzer
if provider == 'gemini':
analyzer = gemini_analyzer.VisionAnalyzer(
model_name=model,
api_key=api_key,
base_url=base_url
)
elif provider == 'qwenvl':
analyzer = qwenvl_analyzer.QwenAnalyzer(
model_name=model,
api_key=api_key,
base_url=base_url
)
```
**新方式:**
```python
from app.services.llm.unified_service import UnifiedLLMService
# 方式1: 直接使用统一服务
results = await UnifiedLLMService.analyze_images(
images=images,
prompt=prompt,
provider=provider # 可选,使用配置中的默认值
)
# 方式2: 使用迁移适配器(向后兼容)
from app.services.llm.migration_adapter import create_vision_analyzer
analyzer = create_vision_analyzer(provider, api_key, model, base_url)
results = await analyzer.analyze_images(images, prompt)
```
#### 2. 文本生成
**旧方式:**
```python
from openai import OpenAI
client = OpenAI(api_key=api_key, base_url=base_url)
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
],
temperature=temperature,
response_format={"type": "json_object"}
)
result = response.choices[0].message.content
```
**新方式:**
```python
from app.services.llm.unified_service import UnifiedLLMService
result = await UnifiedLLMService.generate_text(
prompt=prompt,
system_prompt=system_prompt,
temperature=temperature,
response_format="json"
)
```
#### 3. 解说文案生成
**旧方式:**
```python
from app.services.generate_narration_script import generate_narration
narration = generate_narration(
markdown_content,
api_key,
base_url=base_url,
model=model
)
# 手动解析JSON和验证格式
import json
narration_dict = json.loads(narration)['items']
```
**新方式:**
```python
from app.services.llm.unified_service import UnifiedLLMService
# 自动验证输出格式
narration_items = await UnifiedLLMService.generate_narration_script(
prompt=prompt,
validate_output=True # 自动验证JSON格式和字段
)
```
## 📝 具体迁移步骤
### 步骤1: 更新配置文件
**旧配置格式:**
```toml
[app]
llm_provider = "openai"
openai_api_key = "sk-xxx"
openai_model_name = "gpt-4"
vision_llm_provider = "gemini"
gemini_api_key = "xxx"
gemini_model_name = "gemini-1.5-pro"
```
**新配置格式:**
```toml
[app]
# 视觉模型配置
vision_llm_provider = "gemini"
vision_gemini_api_key = "xxx"
vision_gemini_model_name = "gemini-2.0-flash-lite"
vision_gemini_base_url = "https://generativelanguage.googleapis.com/v1beta"
# 文本模型配置
text_llm_provider = "openai"
text_openai_api_key = "sk-xxx"
text_openai_model_name = "gpt-4o-mini"
text_openai_base_url = "https://api.openai.com/v1"
```
### 步骤2: 更新导入语句
**旧导入:**
```python
from app.utils import gemini_analyzer, qwenvl_analyzer
from app.services.generate_narration_script import generate_narration
from app.services.SDE.short_drama_explanation import analyze_subtitle
```
**新导入:**
```python
from app.services.llm.unified_service import UnifiedLLMService
from app.services.llm.migration_adapter import (
create_vision_analyzer,
SubtitleAnalyzerAdapter
)
```
### 步骤3: 更新函数调用
#### 图片分析迁移
**旧代码:**
```python
def analyze_images_old(provider, api_key, model, base_url, images, prompt):
if provider == 'gemini':
analyzer = gemini_analyzer.VisionAnalyzer(
model_name=model,
api_key=api_key,
base_url=base_url
)
else:
analyzer = qwenvl_analyzer.QwenAnalyzer(
model_name=model,
api_key=api_key,
base_url=base_url
)
# 同步调用
results = []
for batch in batches:
result = analyzer.analyze_batch(batch, prompt)
results.append(result)
return results
```
**新代码:**
```python
async def analyze_images_new(images, prompt, provider=None):
# 异步调用,自动批处理
results = await UnifiedLLMService.analyze_images(
images=images,
prompt=prompt,
provider=provider,
batch_size=10
)
return results
```
#### 字幕分析迁移
**旧代码:**
```python
from app.services.SDE.short_drama_explanation import analyze_subtitle
result = analyze_subtitle(
subtitle_file_path=subtitle_path,
api_key=api_key,
model=model,
base_url=base_url,
provider=provider
)
```
**新代码:**
```python
# 方式1: 使用统一服务
with open(subtitle_path, 'r', encoding='utf-8') as f:
subtitle_content = f.read()
result = await UnifiedLLMService.analyze_subtitle(
subtitle_content=subtitle_content,
provider=provider,
validate_output=True
)
# 方式2: 使用适配器
from app.services.llm.migration_adapter import SubtitleAnalyzerAdapter
analyzer = SubtitleAnalyzerAdapter(api_key, model, base_url, provider)
result = analyzer.analyze_subtitle(subtitle_content)
```
## 🔧 常见迁移问题
### 1. 同步 vs 异步调用
**问题:** 新架构使用异步调用,旧代码是同步的。
**解决方案:**
```python
# 在同步函数中调用异步函数
import asyncio
def sync_function():
result = asyncio.run(UnifiedLLMService.generate_text(prompt))
return result
# 或者将整个函数改为异步
async def async_function():
result = await UnifiedLLMService.generate_text(prompt)
return result
```
### 2. 配置获取方式变化
**问题:** 配置键名发生变化。
**解决方案:**
```python
# 旧方式
api_key = config.app.get('openai_api_key')
model = config.app.get('openai_model_name')
# 新方式
provider = config.app.get('text_llm_provider', 'openai')
api_key = config.app.get(f'text_{provider}_api_key')
model = config.app.get(f'text_{provider}_model_name')
```
### 3. 错误处理更新
**旧方式:**
```python
try:
result = some_llm_call()
except Exception as e:
print(f"Error: {e}")
```
**新方式:**
```python
from app.services.llm.exceptions import LLMServiceError, ValidationError
try:
result = await UnifiedLLMService.generate_text(prompt)
except ValidationError as e:
print(f"输出验证失败: {e.message}")
except LLMServiceError as e:
print(f"LLM服务错误: {e.message}")
except Exception as e:
print(f"未知错误: {e}")
```
## ✅ 迁移检查清单
### 配置迁移
- [ ] 更新配置文件格式
- [ ] 验证所有API密钥配置正确
- [ ] 运行配置验证器检查
### 代码迁移
- [ ] 更新导入语句
- [ ] 将同步调用改为异步调用
- [ ] 更新错误处理机制
- [ ] 使用新的统一接口
### 测试验证
- [ ] 运行LLM服务测试脚本
- [ ] 测试所有功能模块
- [ ] 验证输出格式正确
- [ ] 检查性能和稳定性
### 清理工作
- [ ] 移除未使用的旧代码
- [ ] 更新文档和注释
- [ ] 清理过时的依赖
## 🚀 迁移最佳实践
### 1. 渐进式迁移
- 先迁移一个模块,测试通过后再迁移其他模块
- 保留旧代码作为备用方案
- 使用迁移适配器确保向后兼容
### 2. 充分测试
- 在每个迁移步骤后运行测试
- 比较新旧实现的输出结果
- 测试边界情况和错误处理
### 3. 监控和日志
- 启用详细日志记录
- 监控API调用成功率
- 跟踪性能指标
### 4. 文档更新
- 更新代码注释
- 更新API文档
- 记录迁移过程中的问题和解决方案
## 📞 获取帮助
如果在迁移过程中遇到问题:
1. **查看测试脚本输出**
```bash
python app/services/llm/test_llm_service.py
```
2. **验证配置**
```python
from app.services.llm.config_validator import LLMConfigValidator
results = LLMConfigValidator.validate_all_configs()
LLMConfigValidator.print_validation_report(results)
```
3. **查看详细日志**
```python
from loguru import logger
logger.add("migration.log", level="DEBUG")
```
4. **参考示例代码**
- 查看 `app/services/llm/test_llm_service.py` 中的使用示例
- 参考已迁移的文件如 `webui/tools/base.py`
---
*最后更新: 2025-01-07*

294
docs/LLM_SERVICE_GUIDE.md Normal file
View File

@ -0,0 +1,294 @@
# NarratoAI 大模型服务使用指南
## 📖 概述
NarratoAI 项目已完成大模型服务的全面重构,提供了统一、模块化、可扩展的大模型集成架构。新架构支持多种大模型供应商,具有严格的输出格式验证和完善的错误处理机制。
## 🏗️ 架构概览
### 核心组件
```
app/services/llm/
├── __init__.py # 模块入口
├── base.py # 抽象基类
├── manager.py # 服务管理器
├── unified_service.py # 统一服务接口
├── validators.py # 输出格式验证器
├── exceptions.py # 异常类定义
├── migration_adapter.py # 迁移适配器
├── config_validator.py # 配置验证器
├── test_llm_service.py # 测试脚本
└── providers/ # 提供商实现
├── __init__.py
├── gemini_provider.py
├── gemini_openai_provider.py
├── openai_provider.py
├── qwen_provider.py
├── deepseek_provider.py
└── siliconflow_provider.py
```
### 支持的供应商
#### 视觉模型供应商
- **Gemini** (原生API + OpenAI兼容)
- **QwenVL** (通义千问视觉)
- **Siliconflow** (硅基流动)
#### 文本生成模型供应商
- **OpenAI** (标准OpenAI API)
- **Gemini** (原生API + OpenAI兼容)
- **DeepSeek** (深度求索)
- **Qwen** (通义千问)
- **Siliconflow** (硅基流动)
## ⚙️ 配置说明
### 配置文件格式
`config.toml` 中配置大模型服务:
```toml
[app]
# 视觉模型提供商配置
vision_llm_provider = "gemini"
# Gemini 视觉模型
vision_gemini_api_key = "your_gemini_api_key"
vision_gemini_model_name = "gemini-2.0-flash-lite"
vision_gemini_base_url = "https://generativelanguage.googleapis.com/v1beta"
# QwenVL 视觉模型
vision_qwenvl_api_key = "your_qwen_api_key"
vision_qwenvl_model_name = "qwen2.5-vl-32b-instruct"
vision_qwenvl_base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
# 文本模型提供商配置
text_llm_provider = "openai"
# OpenAI 文本模型
text_openai_api_key = "your_openai_api_key"
text_openai_model_name = "gpt-4o-mini"
text_openai_base_url = "https://api.openai.com/v1"
# DeepSeek 文本模型
text_deepseek_api_key = "your_deepseek_api_key"
text_deepseek_model_name = "deepseek-chat"
text_deepseek_base_url = "https://api.deepseek.com"
```
### 配置验证
使用配置验证器检查配置是否正确:
```python
from app.services.llm.config_validator import LLMConfigValidator
# 验证所有配置
results = LLMConfigValidator.validate_all_configs()
# 打印验证报告
LLMConfigValidator.print_validation_report(results)
# 获取配置建议
suggestions = LLMConfigValidator.get_config_suggestions()
```
## 🚀 使用方法
### 1. 统一服务接口(推荐)
```python
from app.services.llm.unified_service import UnifiedLLMService
# 图片分析
results = await UnifiedLLMService.analyze_images(
images=["path/to/image1.jpg", "path/to/image2.jpg"],
prompt="请描述这些图片的内容",
provider="gemini", # 可选,不指定则使用配置中的默认值
batch_size=10
)
# 文本生成
text = await UnifiedLLMService.generate_text(
prompt="请介绍人工智能的发展历史",
system_prompt="你是一个专业的AI专家",
provider="openai", # 可选
temperature=0.7,
response_format="json" # 可选支持JSON格式输出
)
# 解说文案生成(带验证)
narration_items = await UnifiedLLMService.generate_narration_script(
prompt="根据视频内容生成解说文案...",
validate_output=True # 自动验证输出格式
)
# 字幕分析
analysis = await UnifiedLLMService.analyze_subtitle(
subtitle_content="字幕内容...",
validate_output=True
)
```
### 2. 直接使用服务管理器
```python
from app.services.llm.manager import LLMServiceManager
# 获取视觉模型提供商
vision_provider = LLMServiceManager.get_vision_provider("gemini")
results = await vision_provider.analyze_images(images, prompt)
# 获取文本模型提供商
text_provider = LLMServiceManager.get_text_provider("openai")
text = await text_provider.generate_text(prompt)
```
### 3. 迁移适配器(向后兼容)
```python
from app.services.llm.migration_adapter import create_vision_analyzer
# 兼容旧的接口
analyzer = create_vision_analyzer("gemini", api_key, model, base_url)
results = await analyzer.analyze_images(images, prompt)
```
## 🔍 输出格式验证
### 解说文案验证
```python
from app.services.llm.validators import OutputValidator
# 验证解说文案格式
try:
narration_items = OutputValidator.validate_narration_script(output)
print(f"验证成功,共 {len(narration_items)} 个片段")
except ValidationError as e:
print(f"验证失败: {e.message}")
```
### JSON输出验证
```python
# 验证JSON格式
try:
data = OutputValidator.validate_json_output(output)
print("JSON格式验证成功")
except ValidationError as e:
print(f"JSON验证失败: {e.message}")
```
## 🧪 测试和调试
### 运行测试脚本
```bash
# 运行完整的LLM服务测试
python app/services/llm/test_llm_service.py
```
测试脚本会验证:
- 配置有效性
- 提供商信息获取
- 文本生成功能
- JSON格式生成
- 字幕分析功能
- 解说文案生成功能
### 调试技巧
1. **启用详细日志**
```python
from loguru import logger
logger.add("llm_service.log", level="DEBUG")
```
2. **清空提供商缓存**
```python
UnifiedLLMService.clear_cache()
```
3. **检查提供商信息**
```python
info = UnifiedLLMService.get_provider_info()
print(info)
```
## ⚠️ 注意事项
### 1. API密钥安全
- 不要在代码中硬编码API密钥
- 使用环境变量或配置文件管理密钥
- 定期轮换API密钥
### 2. 错误处理
- 所有LLM服务调用都应该包装在try-catch中
- 使用适当的异常类型进行错误处理
- 实现重试机制处理临时性错误
### 3. 性能优化
- 合理设置批处理大小
- 使用缓存避免重复调用
- 监控API调用频率和成本
### 4. 模型选择
- 根据任务类型选择合适的模型
- 考虑成本和性能的平衡
- 定期更新到最新的模型版本
## 🔧 扩展新供应商
### 1. 创建提供商类
```python
# app/services/llm/providers/new_provider.py
from ..base import TextModelProvider
class NewTextProvider(TextModelProvider):
@property
def provider_name(self) -> str:
return "new_provider"
@property
def supported_models(self) -> List[str]:
return ["model-1", "model-2"]
async def generate_text(self, prompt: str, **kwargs) -> str:
# 实现具体的API调用逻辑
pass
```
### 2. 注册提供商
```python
# app/services/llm/providers/__init__.py
from .new_provider import NewTextProvider
LLMServiceManager.register_text_provider('new_provider', NewTextProvider)
```
### 3. 添加配置支持
```toml
# config.toml
text_new_provider_api_key = "your_api_key"
text_new_provider_model_name = "model-1"
text_new_provider_base_url = "https://api.newprovider.com/v1"
```
## 📞 技术支持
如果在使用过程中遇到问题:
1. 首先运行测试脚本检查配置
2. 查看日志文件了解详细错误信息
3. 检查API密钥和网络连接
4. 参考本文档的故障排除部分
---
*最后更新: 2025-01-07*

View File

@ -0,0 +1,267 @@
# 提示词管理系统文档
## 概述
本项目实现了统一的提示词管理系统,用于集中管理三个核心功能的提示词:
- **纪录片解说** - 视频帧分析和解说文案生成
- **短剧混剪** - 字幕分析和爆点提取
- **短剧解说** - 剧情分析和解说脚本生成
## 系统架构
```
app/services/prompts/
├── __init__.py # 模块初始化
├── base.py # 基础提示词类
├── manager.py # 提示词管理器
├── registry.py # 提示词注册机制
├── template.py # 模板渲染引擎
├── validators.py # 输出验证器
├── exceptions.py # 异常定义
├── documentary/ # 纪录片解说提示词
│ ├── __init__.py
│ ├── frame_analysis.py # 视频帧分析
│ └── narration_generation.py # 解说文案生成
├── short_drama_editing/ # 短剧混剪提示词
│ ├── __init__.py
│ ├── subtitle_analysis.py # 字幕分析
│ └── plot_extraction.py # 爆点提取
└── short_drama_narration/ # 短剧解说提示词
├── __init__.py
├── plot_analysis.py # 剧情分析
└── script_generation.py # 解说脚本生成
```
## 核心特性
### 1. 统一管理
- 所有提示词集中在 `app/services/prompts/` 模块中
- 按功能模块分类组织
- 支持版本控制和回滚
### 2. 模型类型适配
- **TextPrompt**: 文本模型专用
- **VisionPrompt**: 视觉模型专用
- **ParameterizedPrompt**: 支持参数化
### 3. 参数化支持
- 动态参数替换
- 参数验证
- 模板渲染
### 4. 输出验证
- 严格的JSON格式验证
- 特定业务场景验证(解说文案、剧情分析等)
- 自定义验证规则
## 使用方法
### 基本用法
```python
from app.services.prompts import PromptManager
# 获取纪录片解说的视频帧分析提示词
prompt = PromptManager.get_prompt(
category="documentary",
name="frame_analysis",
parameters={
"video_theme": "荒野建造",
"custom_instructions": "请特别关注建造过程的细节"
}
)
# 获取短剧解说的剧情分析提示词
prompt = PromptManager.get_prompt(
category="short_drama_narration",
name="plot_analysis",
parameters={"subtitle_content": "字幕内容..."}
)
```
### 高级功能
```python
# 搜索提示词
results = PromptManager.search_prompts(
keyword="分析",
model_type=ModelType.TEXT
)
# 获取提示词详细信息
info = PromptManager.get_prompt_info(
category="documentary",
name="narration_generation"
)
# 验证输出
validated_data = PromptManager.validate_output(
output=llm_response,
category="documentary",
name="narration_generation"
)
```
## 已注册的提示词
### 纪录片解说 (documentary)
- `frame_analysis` - 视频帧分析提示词
- `narration_generation` - 解说文案生成提示词
### 短剧混剪 (short_drama_editing)
- `subtitle_analysis` - 字幕分析提示词
- `plot_extraction` - 爆点提取提示词
### 短剧解说 (short_drama_narration)
- `plot_analysis` - 剧情分析提示词
- `script_generation` - 解说脚本生成提示词
## 迁移指南
### 旧代码迁移
**之前的用法:**
```python
from app.services.SDE.prompt import subtitle_plot_analysis_v1
prompt = subtitle_plot_analysis_v1
```
**新的用法:**
```python
from app.services.prompts import PromptManager
prompt = PromptManager.get_prompt(
category="short_drama_narration",
name="plot_analysis",
parameters={"subtitle_content": content}
)
```
### 已更新的文件
- `app/services/SDE/short_drama_explanation.py`
- `app/services/SDP/utils/step1_subtitle_analyzer_openai.py`
- `app/services/generate_narration_script.py`
## 扩展指南
### 添加新提示词
1. 在相应分类目录下创建新的提示词类:
```python
from ..base import TextPrompt, PromptMetadata, ModelType, OutputFormat
class NewPrompt(TextPrompt):
def __init__(self):
metadata = PromptMetadata(
name="new_prompt",
category="your_category",
version="v1.0",
description="提示词描述",
model_type=ModelType.TEXT,
output_format=OutputFormat.JSON,
parameters=["param1", "param2"]
)
super().__init__(metadata)
def get_template(self) -> str:
return "您的提示词模板内容..."
```
2. 在 `__init__.py` 中注册:
```python
def register_prompts():
new_prompt = NewPrompt()
PromptManager.register_prompt(new_prompt, is_default=True)
```
### 添加新分类
1. 创建新的分类目录
2. 实现提示词类
3. 在主模块的 `__init__.py` 中导入并注册
## 测试
运行测试脚本验证系统功能:
```bash
python test_prompt_system.py
```
## 注意事项
1. **模板参数**: 使用 `${parameter_name}` 格式
2. **JSON格式**: 模板中的JSON示例使用标准格式 `{``}`,不要使用双大括号
3. **参数验证**: 必需参数会自动验证
4. **版本管理**: 支持多版本共存,默认使用最新版本
5. **输出验证**: 建议对LLM输出进行验证以确保格式正确
6. **JSON解析**: 系统提供强大的JSON解析兼容性自动处理各种格式问题
## JSON解析优化
系统提供了强大的JSON解析兼容性能够处理LLM生成的各种格式问题
### 支持的格式修复
1. **双大括号修复**: 自动将 `{{``}}` 转换为标准的 `{``}`
2. **代码块提取**: 自动从 ````json` 代码块中提取JSON内容
3. **额外文本处理**: 自动提取大括号包围的JSON内容忽略前后的额外文本
4. **尾随逗号修复**: 自动移除对象和数组末尾的多余逗号
5. **注释移除**: 自动移除 `//``#` 注释
6. **引号修复**: 自动修复单引号和缺失的属性名引号
### 解析策略
系统采用多重解析策略,按优先级依次尝试:
```python
strategies = [
("直接解析", lambda s: json.loads(s)),
("修复双大括号", _fix_double_braces),
("提取代码块", _extract_code_block),
("提取大括号内容", _extract_braces_content),
("修复常见格式问题", _fix_common_json_issues),
("修复引号问题", _fix_quote_issues),
("修复尾随逗号", _fix_trailing_commas),
("强制修复", _force_fix_json),
]
```
### 使用示例
```python
from webui.tools.generate_short_summary import parse_and_fix_json
# 处理双大括号JSON
json_str = '{{ "items": [{{ "_id": 1, "name": "test" }}] }}'
result = parse_and_fix_json(json_str) # 自动修复并解析
# 处理有额外文本的JSON
json_str = '这是一些文本\n{"items": []}\n更多文本'
result = parse_and_fix_json(json_str) # 自动提取JSON部分
```
## 性能优化
- 提示词模板会被缓存
- 支持批量操作
- 异步渲染支持(未来版本)
- JSON解析采用多策略优化确保高成功率
## 故障排除
### 常见问题
1. **模板渲染错误**: 检查参数名称和格式
2. **提示词未找到**: 确认分类、名称和版本正确
3. **输出验证失败**: 检查LLM输出格式是否符合要求
### 日志调试
系统使用 loguru 记录详细日志,可通过日志排查问题:
```python
from loguru import logger
logger.debug("调试信息")
```

View File

@ -1 +1 @@
0.6.6
0.6.7

View File

@ -7,6 +7,45 @@ from app.utils import utils
from loguru import logger
def validate_api_key(api_key: str, provider: str) -> tuple[bool, str]:
"""验证API密钥格式"""
if not api_key or not api_key.strip():
return False, f"{provider} API密钥不能为空"
# 基本长度检查
if len(api_key.strip()) < 10:
return False, f"{provider} API密钥长度过短请检查是否正确"
return True, ""
def validate_base_url(base_url: str, provider: str) -> tuple[bool, str]:
"""验证Base URL格式"""
if not base_url or not base_url.strip():
return True, "" # base_url可以为空
base_url = base_url.strip()
if not (base_url.startswith('http://') or base_url.startswith('https://')):
return False, f"{provider} Base URL必须以http://或https://开头"
return True, ""
def validate_model_name(model_name: str, provider: str) -> tuple[bool, str]:
"""验证模型名称"""
if not model_name or not model_name.strip():
return False, f"{provider} 模型名称不能为空"
return True, ""
def show_config_validation_errors(errors: list):
"""显示配置验证错误"""
if errors:
for error in errors:
st.error(error)
def render_basic_settings(tr):
"""渲染基础设置面板"""
with st.expander(tr("Basic Settings"), expanded=False):
@ -87,29 +126,96 @@ def render_proxy_settings(tr):
def test_vision_model_connection(api_key, base_url, model_name, provider, tr):
"""测试视觉模型连接
Args:
api_key: API密钥
base_url: 基础URL
model_name: 模型名称
provider: 提供商名称
Returns:
bool: 连接是否成功
str: 测试结果消息
"""
import requests
if provider.lower() == 'gemini':
import google.generativeai as genai
# 原生Gemini API测试
try:
genai.configure(api_key=api_key)
model = genai.GenerativeModel(model_name)
model.generate_content("直接回复我文本'当前网络可用'")
return True, tr("gemini model is available")
# 构建请求数据
request_data = {
"contents": [{
"parts": [{"text": "直接回复我文本'当前网络可用'"}]
}],
"generationConfig": {
"temperature": 1.0,
"topK": 40,
"topP": 0.95,
"maxOutputTokens": 100,
},
"safetySettings": [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
}
]
}
# 构建请求URL
api_base_url = base_url or "https://generativelanguage.googleapis.com/v1beta"
url = f"{api_base_url}/models/{model_name}:generateContent?key={api_key}"
# 发送请求
response = requests.post(
url,
json=request_data,
headers={"Content-Type": "application/json"},
timeout=30
)
if response.status_code == 200:
return True, tr("原生Gemini模型连接成功")
else:
return False, f"{tr('原生Gemini模型连接失败')}: HTTP {response.status_code}"
except Exception as e:
return False, f"{tr('gemini model is not available')}: {str(e)}"
return False, f"{tr('原生Gemini模型连接失败')}: {str(e)}"
elif provider.lower() == 'gemini(openai)':
# OpenAI兼容的Gemini代理测试
try:
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
test_url = f"{base_url.rstrip('/')}/chat/completions"
test_data = {
"model": model_name,
"messages": [
{"role": "user", "content": "直接回复我文本'当前网络可用'"}
],
"stream": False
}
response = requests.post(test_url, headers=headers, json=test_data, timeout=10)
if response.status_code == 200:
return True, tr("OpenAI兼容Gemini代理连接成功")
else:
return False, f"{tr('OpenAI兼容Gemini代理连接失败')}: HTTP {response.status_code}"
except Exception as e:
return False, f"{tr('OpenAI兼容Gemini代理连接失败')}: {str(e)}"
elif provider.lower() == 'narratoapi':
import requests
try:
# 构建测试请求
headers = {
@ -172,7 +278,7 @@ def render_vision_llm_settings(tr):
st.subheader(tr("Vision Model Settings"))
# 视频分析模型提供商选择
vision_providers = ['Siliconflow', 'Gemini', 'QwenVL', 'OpenAI']
vision_providers = ['Siliconflow', 'Gemini', 'Gemini(OpenAI)', 'QwenVL', 'OpenAI']
saved_vision_provider = config.app.get("vision_llm_provider", "Gemini").lower()
saved_provider_index = 0
@ -191,9 +297,15 @@ def render_vision_llm_settings(tr):
st.session_state['vision_llm_providers'] = vision_provider
# 获取已保存的视觉模型配置
vision_api_key = config.app.get(f"vision_{vision_provider}_api_key", "")
vision_base_url = config.app.get(f"vision_{vision_provider}_base_url", "")
vision_model_name = config.app.get(f"vision_{vision_provider}_model_name", "")
# 处理特殊的提供商名称映射
if vision_provider == 'gemini(openai)':
vision_config_key = 'vision_gemini_openai'
else:
vision_config_key = f'vision_{vision_provider}'
vision_api_key = config.app.get(f"{vision_config_key}_api_key", "")
vision_base_url = config.app.get(f"{vision_config_key}_base_url", "")
vision_model_name = config.app.get(f"{vision_config_key}_model_name", "")
# 渲染视觉模型配置输入框
st_vision_api_key = st.text_input(tr("Vision API Key"), value=vision_api_key, type="password")
@ -201,15 +313,25 @@ def render_vision_llm_settings(tr):
# 根据不同提供商设置默认值和帮助信息
if vision_provider == 'gemini':
st_vision_base_url = st.text_input(
tr("Vision Base URL"),
value=vision_base_url,
disabled=True,
help=tr("Gemini API does not require a base URL")
tr("Vision Base URL"),
value=vision_base_url or "https://generativelanguage.googleapis.com/v1beta",
help=tr("原生Gemini API端点默认: https://generativelanguage.googleapis.com/v1beta")
)
st_vision_model_name = st.text_input(
tr("Vision Model Name"),
value=vision_model_name or "gemini-2.0-flash-lite",
help=tr("Default: gemini-2.0-flash-lite")
tr("Vision Model Name"),
value=vision_model_name or "gemini-2.0-flash-exp",
help=tr("原生Gemini模型默认: gemini-2.0-flash-exp")
)
elif vision_provider == 'gemini(openai)':
st_vision_base_url = st.text_input(
tr("Vision Base URL"),
value=vision_base_url or "https://generativelanguage.googleapis.com/v1beta/openai",
help=tr("OpenAI兼容的Gemini代理端点如: https://your-proxy.com/v1")
)
st_vision_model_name = st.text_input(
tr("Vision Model Name"),
value=vision_model_name or "gemini-2.0-flash-exp",
help=tr("OpenAI格式的Gemini模型名称默认: gemini-2.0-flash-exp")
)
elif vision_provider == 'qwenvl':
st_vision_base_url = st.text_input(
@ -228,30 +350,81 @@ def render_vision_llm_settings(tr):
# 在配置输入框后添加测试按钮
if st.button(tr("Test Connection"), key="test_vision_connection"):
with st.spinner(tr("Testing connection...")):
success, message = test_vision_model_connection(
api_key=st_vision_api_key,
base_url=st_vision_base_url,
model_name=st_vision_model_name,
provider=vision_provider,
tr=tr
)
if success:
st.success(tr(message))
else:
st.error(tr(message))
# 先验证配置
test_errors = []
if not st_vision_api_key:
test_errors.append("请先输入API密钥")
if not st_vision_model_name:
test_errors.append("请先输入模型名称")
# 保存视觉模型配置
if test_errors:
for error in test_errors:
st.error(error)
else:
with st.spinner(tr("Testing connection...")):
try:
success, message = test_vision_model_connection(
api_key=st_vision_api_key,
base_url=st_vision_base_url,
model_name=st_vision_model_name,
provider=vision_provider,
tr=tr
)
if success:
st.success(message)
else:
st.error(message)
except Exception as e:
st.error(f"测试连接时发生错误: {str(e)}")
logger.error(f"视频分析模型连接测试失败: {str(e)}")
# 验证和保存视觉模型配置
validation_errors = []
config_changed = False
# 验证API密钥
if st_vision_api_key:
config.app[f"vision_{vision_provider}_api_key"] = st_vision_api_key
st.session_state[f"vision_{vision_provider}_api_key"] = st_vision_api_key
is_valid, error_msg = validate_api_key(st_vision_api_key, f"视频分析({vision_provider})")
if is_valid:
config.app[f"{vision_config_key}_api_key"] = st_vision_api_key
st.session_state[f"{vision_config_key}_api_key"] = st_vision_api_key
config_changed = True
else:
validation_errors.append(error_msg)
# 验证Base URL
if st_vision_base_url:
config.app[f"vision_{vision_provider}_base_url"] = st_vision_base_url
st.session_state[f"vision_{vision_provider}_base_url"] = st_vision_base_url
is_valid, error_msg = validate_base_url(st_vision_base_url, f"视频分析({vision_provider})")
if is_valid:
config.app[f"{vision_config_key}_base_url"] = st_vision_base_url
st.session_state[f"{vision_config_key}_base_url"] = st_vision_base_url
config_changed = True
else:
validation_errors.append(error_msg)
# 验证模型名称
if st_vision_model_name:
config.app[f"vision_{vision_provider}_model_name"] = st_vision_model_name
st.session_state[f"vision_{vision_provider}_model_name"] = st_vision_model_name
is_valid, error_msg = validate_model_name(st_vision_model_name, f"视频分析({vision_provider})")
if is_valid:
config.app[f"{vision_config_key}_model_name"] = st_vision_model_name
st.session_state[f"{vision_config_key}_model_name"] = st_vision_model_name
config_changed = True
else:
validation_errors.append(error_msg)
# 显示验证错误
show_config_validation_errors(validation_errors)
# 如果配置有变化且没有验证错误,保存到文件
if config_changed and not validation_errors:
try:
config.save_config()
if st_vision_api_key or st_vision_base_url or st_vision_model_name:
st.success(f"视频分析模型({vision_provider})配置已保存")
except Exception as e:
st.error(f"保存配置失败: {str(e)}")
logger.error(f"保存视频分析配置失败: {str(e)}")
def test_text_model_connection(api_key, base_url, model_name, provider, tr):
@ -278,14 +451,74 @@ def test_text_model_connection(api_key, base_url, model_name, provider, tr):
# 特殊处理Gemini
if provider.lower() == 'gemini':
import google.generativeai as genai
# 原生Gemini API测试
try:
genai.configure(api_key=api_key)
model = genai.GenerativeModel(model_name)
model.generate_content("直接回复我文本'当前网络可用'")
return True, tr("Gemini model is available")
# 构建请求数据
request_data = {
"contents": [{
"parts": [{"text": "直接回复我文本'当前网络可用'"}]
}],
"generationConfig": {
"temperature": 1.0,
"topK": 40,
"topP": 0.95,
"maxOutputTokens": 100,
},
"safetySettings": [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
}
]
}
# 构建请求URL
api_base_url = base_url or "https://generativelanguage.googleapis.com/v1beta"
url = f"{api_base_url}/models/{model_name}:generateContent?key={api_key}"
# 发送请求
response = requests.post(
url,
json=request_data,
headers={"Content-Type": "application/json"},
timeout=30
)
if response.status_code == 200:
return True, tr("原生Gemini模型连接成功")
else:
return False, f"{tr('原生Gemini模型连接失败')}: HTTP {response.status_code}"
except Exception as e:
return False, f"{tr('Gemini model is not available')}: {str(e)}"
return False, f"{tr('原生Gemini模型连接失败')}: {str(e)}"
elif provider.lower() == 'gemini(openai)':
# OpenAI兼容的Gemini代理测试
test_url = f"{base_url.rstrip('/')}/chat/completions"
test_data = {
"model": model_name,
"messages": [
{"role": "user", "content": "直接回复我文本'当前网络可用'"}
],
"stream": False
}
response = requests.post(test_url, headers=headers, json=test_data, timeout=10)
if response.status_code == 200:
return True, tr("OpenAI兼容Gemini代理连接成功")
else:
return False, f"{tr('OpenAI兼容Gemini代理连接失败')}: HTTP {response.status_code}"
else:
test_url = f"{base_url.rstrip('/')}/chat/completions"
@ -322,7 +555,7 @@ def render_text_llm_settings(tr):
st.subheader(tr("Text Generation Model Settings"))
# 文案生成模型提供商选择
text_providers = ['OpenAI', 'Siliconflow', 'DeepSeek', 'Gemini', 'Qwen', 'Moonshot']
text_providers = ['OpenAI', 'Siliconflow', 'DeepSeek', 'Gemini', 'Gemini(OpenAI)', 'Qwen', 'Moonshot']
saved_text_provider = config.app.get("text_llm_provider", "OpenAI").lower()
saved_provider_index = 0
@ -346,32 +579,108 @@ def render_text_llm_settings(tr):
# 渲染文本模型配置输入框
st_text_api_key = st.text_input(tr("Text API Key"), value=text_api_key, type="password")
st_text_base_url = st.text_input(tr("Text Base URL"), value=text_base_url)
st_text_model_name = st.text_input(tr("Text Model Name"), value=text_model_name)
# 根据不同提供商设置默认值和帮助信息
if text_provider == 'gemini':
st_text_base_url = st.text_input(
tr("Text Base URL"),
value=text_base_url or "https://generativelanguage.googleapis.com/v1beta",
help=tr("原生Gemini API端点默认: https://generativelanguage.googleapis.com/v1beta")
)
st_text_model_name = st.text_input(
tr("Text Model Name"),
value=text_model_name or "gemini-2.0-flash-exp",
help=tr("原生Gemini模型默认: gemini-2.0-flash-exp")
)
elif text_provider == 'gemini(openai)':
st_text_base_url = st.text_input(
tr("Text Base URL"),
value=text_base_url or "https://generativelanguage.googleapis.com/v1beta/openai",
help=tr("OpenAI兼容的Gemini代理端点如: https://your-proxy.com/v1")
)
st_text_model_name = st.text_input(
tr("Text Model Name"),
value=text_model_name or "gemini-2.0-flash-exp",
help=tr("OpenAI格式的Gemini模型名称默认: gemini-2.0-flash-exp")
)
else:
st_text_base_url = st.text_input(tr("Text Base URL"), value=text_base_url)
st_text_model_name = st.text_input(tr("Text Model Name"), value=text_model_name)
# 添加测试按钮
if st.button(tr("Test Connection"), key="test_text_connection"):
with st.spinner(tr("Testing connection...")):
success, message = test_text_model_connection(
api_key=st_text_api_key,
base_url=st_text_base_url,
model_name=st_text_model_name,
provider=text_provider,
tr=tr
)
if success:
st.success(message)
else:
st.error(message)
# 先验证配置
test_errors = []
if not st_text_api_key:
test_errors.append("请先输入API密钥")
if not st_text_model_name:
test_errors.append("请先输入模型名称")
# 保存文本模型配置
if test_errors:
for error in test_errors:
st.error(error)
else:
with st.spinner(tr("Testing connection...")):
try:
success, message = test_text_model_connection(
api_key=st_text_api_key,
base_url=st_text_base_url,
model_name=st_text_model_name,
provider=text_provider,
tr=tr
)
if success:
st.success(message)
else:
st.error(message)
except Exception as e:
st.error(f"测试连接时发生错误: {str(e)}")
logger.error(f"文案生成模型连接测试失败: {str(e)}")
# 验证和保存文本模型配置
text_validation_errors = []
text_config_changed = False
# 验证API密钥
if st_text_api_key:
config.app[f"text_{text_provider}_api_key"] = st_text_api_key
is_valid, error_msg = validate_api_key(st_text_api_key, f"文案生成({text_provider})")
if is_valid:
config.app[f"text_{text_provider}_api_key"] = st_text_api_key
text_config_changed = True
else:
text_validation_errors.append(error_msg)
# 验证Base URL
if st_text_base_url:
config.app[f"text_{text_provider}_base_url"] = st_text_base_url
is_valid, error_msg = validate_base_url(st_text_base_url, f"文案生成({text_provider})")
if is_valid:
config.app[f"text_{text_provider}_base_url"] = st_text_base_url
text_config_changed = True
else:
text_validation_errors.append(error_msg)
# 验证模型名称
if st_text_model_name:
config.app[f"text_{text_provider}_model_name"] = st_text_model_name
is_valid, error_msg = validate_model_name(st_text_model_name, f"文案生成({text_provider})")
if is_valid:
config.app[f"text_{text_provider}_model_name"] = st_text_model_name
text_config_changed = True
else:
text_validation_errors.append(error_msg)
# 显示验证错误
show_config_validation_errors(text_validation_errors)
# 如果配置有变化且没有验证错误,保存到文件
if text_config_changed and not text_validation_errors:
try:
config.save_config()
if st_text_api_key or st_text_base_url or st_text_model_name:
st.success(f"文案生成模型({text_provider})配置已保存")
except Exception as e:
st.error(f"保存配置失败: {str(e)}")
logger.error(f"保存文案生成配置失败: {str(e)}")
# # Cloudflare 特殊配置
# if text_provider == 'cloudflare':

View File

@ -6,31 +6,45 @@ from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from app.config import config
# 导入新的LLM服务模块 - 确保提供商被注册
import app.services.llm # 这会触发提供商注册
from app.services.llm.migration_adapter import create_vision_analyzer as create_vision_analyzer_new
# 保留旧的导入以确保向后兼容
from app.utils import gemini_analyzer, qwenvl_analyzer
def create_vision_analyzer(provider, api_key, model, base_url):
"""
创建视觉分析器实例
创建视觉分析器实例 - 已重构为使用新的LLM服务架构
Args:
provider: 提供商名称 ('gemini' 'qwenvl')
provider: 提供商名称 ('gemini', 'gemini(openai)', 'qwenvl', 'siliconflow')
api_key: API密钥
model: 模型名称
base_url: API基础URL
Returns:
VisionAnalyzer QwenAnalyzer 实例
视觉分析器实例
"""
if provider == 'gemini':
return gemini_analyzer.VisionAnalyzer(model_name=model, api_key=api_key)
else:
# 只传入必要的参数
return qwenvl_analyzer.QwenAnalyzer(
model_name=model,
api_key=api_key,
base_url=base_url
)
try:
# 优先使用新的LLM服务架构
return create_vision_analyzer_new(provider, api_key, model, base_url)
except Exception as e:
logger.warning(f"使用新LLM服务失败回退到旧实现: {str(e)}")
# 回退到旧的实现以确保兼容性
if provider == 'gemini':
return gemini_analyzer.VisionAnalyzer(model_name=model, api_key=api_key, base_url=base_url)
elif provider == 'gemini(openai)':
from app.utils.gemini_openai_analyzer import GeminiOpenAIAnalyzer
return GeminiOpenAIAnalyzer(model_name=model, api_key=api_key, base_url=base_url)
else:
# 只传入必要的参数
return qwenvl_analyzer.QwenAnalyzer(
model_name=model,
api_key=api_key,
base_url=base_url
)
def get_batch_timestamps(batch_files, prev_batch_files=None):

View File

@ -368,7 +368,16 @@ def generate_script_docu(params):
base_url=text_base_url,
model=text_model
)
narration_dict = json.loads(narration)['items']
# 使用增强的JSON解析器
from webui.tools.generate_short_summary import parse_and_fix_json
narration_data = parse_and_fix_json(narration)
if not narration_data or 'items' not in narration_data:
logger.error(f"解说文案JSON解析失败原始内容: {narration[:200]}...")
raise Exception("解说文案格式错误无法解析JSON或缺少items字段")
narration_dict = narration_data['items']
# 为 narration_dict 中每个 item 新增一个 OST: 2 的字段, 代表保留原声和配音
narration_dict = [{**item, "OST": 2} for item in narration_dict]
logger.debug(f"解说文案创作完成:\n{"\n".join([item['narration'] for item in narration_dict])}")

View File

@ -69,6 +69,7 @@ def generate_script_short(tr, params, custom_clips=5):
model_name=text_model,
base_url=text_base_url,
custom_clips=custom_clips,
provider=text_provider
)
if script is None:

View File

@ -16,6 +16,122 @@ from loguru import logger
from app.config import config
from app.services.SDE.short_drama_explanation import analyze_subtitle, generate_narration_script
# 导入新的LLM服务模块 - 确保提供商被注册
import app.services.llm # 这会触发提供商注册
from app.services.llm.migration_adapter import SubtitleAnalyzerAdapter
import re
def parse_and_fix_json(json_string):
"""
解析并修复JSON字符串
Args:
json_string: 待解析的JSON字符串
Returns:
dict: 解析后的字典如果解析失败返回None
"""
if not json_string or not json_string.strip():
logger.error("JSON字符串为空")
return None
# 清理字符串
json_string = json_string.strip()
# 尝试直接解析
try:
return json.loads(json_string)
except json.JSONDecodeError as e:
logger.warning(f"直接JSON解析失败: {e}")
# 尝试修复双大括号问题LLM生成的常见问题
try:
# 将双大括号替换为单大括号
fixed_braces = json_string.replace('{{', '{').replace('}}', '}')
logger.info("修复双大括号格式")
return json.loads(fixed_braces)
except json.JSONDecodeError:
pass
# 尝试提取JSON部分
try:
# 查找JSON代码块
json_match = re.search(r'```json\s*(.*?)\s*```', json_string, re.DOTALL)
if json_match:
json_content = json_match.group(1).strip()
logger.info("从代码块中提取JSON内容")
return json.loads(json_content)
except json.JSONDecodeError:
pass
# 尝试查找大括号包围的内容
try:
# 查找第一个 { 到最后一个 } 的内容
start_idx = json_string.find('{')
end_idx = json_string.rfind('}')
if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
json_content = json_string[start_idx:end_idx+1]
logger.info("提取大括号包围的JSON内容")
return json.loads(json_content)
except json.JSONDecodeError:
pass
# 尝试综合修复JSON格式问题
try:
fixed_json = json_string
# 1. 修复双大括号问题
fixed_json = fixed_json.replace('{{', '{').replace('}}', '}')
# 2. 提取JSON内容如果有其他文本包围
start_idx = fixed_json.find('{')
end_idx = fixed_json.rfind('}')
if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
fixed_json = fixed_json[start_idx:end_idx+1]
# 3. 移除注释
fixed_json = re.sub(r'#.*', '', fixed_json)
fixed_json = re.sub(r'//.*', '', fixed_json)
# 4. 移除多余的逗号
fixed_json = re.sub(r',\s*}', '}', fixed_json)
fixed_json = re.sub(r',\s*]', ']', fixed_json)
# 5. 修复单引号
fixed_json = re.sub(r"'([^']*)':", r'"\1":', fixed_json)
# 6. 修复没有引号的属性名
fixed_json = re.sub(r'(\w+)(\s*):', r'"\1"\2:', fixed_json)
# 7. 修复重复的引号
fixed_json = re.sub(r'""([^"]*?)""', r'"\1"', fixed_json)
logger.info("尝试综合修复JSON格式问题后解析")
return json.loads(fixed_json)
except json.JSONDecodeError as e:
logger.debug(f"综合修复失败: {e}")
pass
# 如果所有方法都失败,尝试创建一个基本的结构
logger.error(f"所有JSON解析方法都失败原始内容: {json_string[:200]}...")
# 尝试从文本中提取关键信息创建基本结构
try:
# 这是一个简单的回退方案
return {
"items": [
{
"_id": 1,
"timestamp": "00:00:00,000-00:00:10,000",
"picture": "解析失败,使用默认内容",
"narration": json_string[:100] + "..." if len(json_string) > 100 else json_string,
"OST": 0
}
]
}
except Exception:
return None
def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperature):
@ -49,20 +165,36 @@ def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperatu
return
"""
2. 分析字幕总结剧情
2. 分析字幕总结剧情 - 使用新的LLM服务架构
"""
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
text_api_key = config.app.get(f'text_{text_provider}_api_key')
text_model = config.app.get(f'text_{text_provider}_model_name')
text_base_url = config.app.get(f'text_{text_provider}_base_url')
analysis_result = analyze_subtitle(
subtitle_file_path=subtitle_path,
api_key=text_api_key,
model=text_model,
base_url=text_base_url,
save_result=True,
temperature=temperature
)
try:
# 优先使用新的LLM服务架构
logger.info("使用新的LLM服务架构进行字幕分析")
analyzer = SubtitleAnalyzerAdapter(text_api_key, text_model, text_base_url, text_provider)
# 读取字幕文件
with open(subtitle_path, 'r', encoding='utf-8') as f:
subtitle_content = f.read()
analysis_result = analyzer.analyze_subtitle(subtitle_content)
except Exception as e:
logger.warning(f"使用新LLM服务失败回退到旧实现: {str(e)}")
# 回退到旧的实现
analysis_result = analyze_subtitle(
subtitle_file_path=subtitle_path,
api_key=text_api_key,
model=text_model,
base_url=text_base_url,
save_result=True,
temperature=temperature,
provider=text_provider
)
"""
3. 根据剧情生成解说文案
"""
@ -70,16 +202,28 @@ def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperatu
logger.info("字幕分析成功!")
update_progress(60, "正在生成文案...")
# 根据剧情生成解说文案
narration_result = generate_narration_script(
short_name=video_theme,
plot_analysis=analysis_result["analysis"],
api_key=text_api_key,
model=text_model,
base_url=text_base_url,
save_result=True,
temperature=temperature
)
# 根据剧情生成解说文案 - 使用新的LLM服务架构
try:
# 优先使用新的LLM服务架构
logger.info("使用新的LLM服务架构生成解说文案")
narration_result = analyzer.generate_narration_script(
short_name=video_theme,
plot_analysis=analysis_result["analysis"],
temperature=temperature
)
except Exception as e:
logger.warning(f"使用新LLM服务失败回退到旧实现: {str(e)}")
# 回退到旧的实现
narration_result = generate_narration_script(
short_name=video_theme,
plot_analysis=analysis_result["analysis"],
api_key=text_api_key,
model=text_model,
base_url=text_base_url,
save_result=True,
temperature=temperature,
provider=text_provider
)
if narration_result["status"] == "success":
logger.info("\n解说文案生成成功!")
@ -100,7 +244,20 @@ def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperatu
# 结果转换为JSON字符串
narration_script = narration_result["narration_script"]
narration_dict = json.loads(narration_script)
# 增强JSON解析包含错误处理和修复
narration_dict = parse_and_fix_json(narration_script)
if narration_dict is None:
st.error("生成的解说文案格式错误无法解析为JSON")
logger.error(f"JSON解析失败原始内容: {narration_script}")
st.stop()
# 验证JSON结构
if 'items' not in narration_dict:
st.error("生成的解说文案缺少必要的'items'字段")
logger.error(f"JSON结构错误缺少items字段: {narration_dict}")
st.stop()
script = json.dumps(narration_dict['items'], ensure_ascii=False, indent=2)
if script is None: