feat: 更新作者信息并增强API配置验证功能

在基础设置中新增API密钥、基础URL和模型名称的验证功能,确保用户输入的配置有效性,提升系统的稳定性和用户体验。
This commit is contained in:
linyq 2025-07-07 15:40:34 +08:00
parent 04ffda297f
commit dd59d5295d
16 changed files with 1354 additions and 209 deletions

View File

@ -4,7 +4,7 @@
'''
@Project: NarratoAI
@File : audio_config
@Author : 小林同学
@Author : Viccy同学
@Date : 2025/1/7
@Description: 音频配置管理
'''

View File

@ -74,24 +74,27 @@ plot_writing = """
%s
</plot>
使用 json 格式进行输出使用 <output> 中的输出格式
<output>
严格按照以下JSON格式输出不要添加任何其他文字说明或代码块标记
{
"items": [
{
"_id": 1, # 唯一递增id
"_id": 1,
"timestamp": "00:00:05,390-00:00:10,430",
"picture": "剧情描述或者备注",
"narration": "解说文案,如果片段为穿插的原片片段,可以直接使用 ‘播放原片+_id 进行占位",
"OST": "值为 0 表示当前片段为解说片段,值为 1 表示当前片段为穿插的原片"
"OST": 0
}
]
}
</output>
<restriction>
1. 只输出 json 内容不要输出其他任何说明性的文字
2. 解说文案的语言使用 简体中文
3. 严禁虚构剧情所有画面只能从 <polt> 中摘取
4. 严禁虚构时间戳所有时间戳范围只能从 <polt> 中摘取
</restriction>
重要要求
1. 必须输出有效的JSON格式不能包含注释
2. OST字段必须是数字0表示解说片段1表示原片片段
3. _id必须是递增的数字
4. 只输出JSON内容不要输出任何说明文字
5. 不要使用代码块标记```json
6. 解说文案使用简体中文
7. 严禁虚构剧情所有内容只能从<plot>中摘取
8. 严禁虚构时间戳所有时间戳只能从<plot>中摘取
"""

View File

@ -22,34 +22,44 @@ class SubtitleAnalyzer:
"""字幕剧情分析器,负责分析字幕内容并提取关键剧情段落"""
def __init__(
self,
self,
api_key: Optional[str] = None,
model: Optional[str] = None,
base_url: Optional[str] = None,
custom_prompt: Optional[str] = None,
temperature: Optional[float] = 1.0,
provider: Optional[str] = None,
):
"""
初始化字幕分析器
Args:
api_key: API密钥如果不提供则从配置中读取
model: 模型名称如果不提供则从配置中读取
base_url: API基础URL如果不提供则从配置中读取或使用默认值
custom_prompt: 自定义提示词如果不提供则使用默认值
temperature: 模型温度
provider: 提供商类型用于确定API调用格式
"""
# 使用传入的参数或从配置中获取
self.api_key = api_key
self.model = model
self.base_url = base_url
self.temperature = temperature
self.provider = provider or self._detect_provider()
# 设置提示词模板
self.prompt_template = custom_prompt or subtitle_plot_analysis_v1
# 根据提供商类型确定是否为原生Gemini
self.is_native_gemini = self.provider.lower() == 'gemini'
# 初始化HTTP请求所需的头信息
self._init_headers()
def _detect_provider(self):
"""根据配置自动检测提供商类型"""
return config.app.get('text_llm_provider', 'gemini').lower()
def _init_headers(self):
"""初始化HTTP请求头"""
@ -67,18 +77,152 @@ class SubtitleAnalyzer:
def analyze_subtitle(self, subtitle_content: str) -> Dict[str, Any]:
"""
分析字幕内容
Args:
subtitle_content: 字幕内容文本
Returns:
Dict[str, Any]: 包含分析结果的字典
"""
try:
# 构建完整提示词
prompt = f"{self.prompt_template}\n\n{subtitle_content}"
# 构建请求体数据
if self.is_native_gemini:
# 使用原生Gemini API格式
return self._call_native_gemini_api(prompt)
else:
# 使用OpenAI兼容格式
return self._call_openai_compatible_api(prompt)
except Exception as e:
logger.error(f"字幕分析过程中发生错误: {str(e)}")
return {
"status": "error",
"message": str(e),
"temperature": self.temperature
}
def _call_native_gemini_api(self, prompt: str) -> Dict[str, Any]:
"""调用原生Gemini API"""
try:
# 构建原生Gemini API请求数据
payload = {
"systemInstruction": {
"parts": [{"text": "你是一位专业的剧本分析师和剧情概括助手。请严格按照要求的格式输出分析结果。"}]
},
"contents": [{
"parts": [{"text": prompt}]
}],
"generationConfig": {
"temperature": self.temperature,
"topK": 40,
"topP": 0.95,
"maxOutputTokens": 4000,
"candidateCount": 1
},
"safetySettings": [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
}
]
}
# 构建请求URL
url = f"{self.base_url}/models/{self.model}:generateContent?key={self.api_key}"
# 发送请求
response = requests.post(
url,
json=payload,
headers={"Content-Type": "application/json", "User-Agent": "NarratoAI/1.0"},
timeout=120
)
if response.status_code == 200:
response_data = response.json()
# 检查响应格式
if "candidates" not in response_data or not response_data["candidates"]:
return {
"status": "error",
"message": "原生Gemini API返回无效响应可能触发了安全过滤",
"temperature": self.temperature
}
candidate = response_data["candidates"][0]
# 检查是否被安全过滤阻止
if "finishReason" in candidate and candidate["finishReason"] == "SAFETY":
return {
"status": "error",
"message": "内容被Gemini安全过滤器阻止",
"temperature": self.temperature
}
if "content" not in candidate or "parts" not in candidate["content"]:
return {
"status": "error",
"message": "原生Gemini API返回内容格式错误",
"temperature": self.temperature
}
# 提取文本内容
analysis_result = ""
for part in candidate["content"]["parts"]:
if "text" in part:
analysis_result += part["text"]
if not analysis_result.strip():
return {
"status": "error",
"message": "原生Gemini API返回空内容",
"temperature": self.temperature
}
logger.debug(f"原生Gemini字幕分析完成")
return {
"status": "success",
"analysis": analysis_result,
"tokens_used": response_data.get("usage", {}).get("total_tokens", 0),
"model": self.model,
"temperature": self.temperature
}
else:
error_msg = f"原生Gemini API请求失败状态码: {response.status_code}, 响应: {response.text}"
logger.error(error_msg)
return {
"status": "error",
"message": error_msg,
"temperature": self.temperature
}
except Exception as e:
logger.error(f"原生Gemini API调用失败: {str(e)}")
return {
"status": "error",
"message": f"原生Gemini API调用失败: {str(e)}",
"temperature": self.temperature
}
def _call_openai_compatible_api(self, prompt: str) -> Dict[str, Any]:
"""调用OpenAI兼容的API"""
try:
# 构建OpenAI格式的请求数据
payload = {
"model": self.model,
"messages": [
@ -87,22 +231,22 @@ class SubtitleAnalyzer:
],
"temperature": self.temperature
}
# 构建请求地址
url = f"{self.base_url}/chat/completions"
# 发送HTTP请求
response = requests.post(url, headers=self.headers, json=payload)
response = requests.post(url, headers=self.headers, json=payload, timeout=120)
# 解析响应
if response.status_code == 200:
response_data = response.json()
# 提取响应内容
if "choices" in response_data and len(response_data["choices"]) > 0:
analysis_result = response_data["choices"][0]["message"]["content"]
logger.debug(f"字幕分析完成消耗的tokens: {response_data.get('usage', {}).get('total_tokens', 0)}")
logger.debug(f"OpenAI兼容API字幕分析完成消耗的tokens: {response_data.get('usage', {}).get('total_tokens', 0)}")
# 返回结果
return {
"status": "success",
@ -112,26 +256,26 @@ class SubtitleAnalyzer:
"temperature": self.temperature
}
else:
logger.error("字幕分析失败: 未获取到有效响应")
logger.error("OpenAI兼容API字幕分析失败: 未获取到有效响应")
return {
"status": "error",
"message": "未获取到有效响应",
"temperature": self.temperature
}
else:
error_msg = f"请求失败,状态码: {response.status_code}, 响应: {response.text}"
error_msg = f"OpenAI兼容API请求失败,状态码: {response.status_code}, 响应: {response.text}"
logger.error(error_msg)
return {
"status": "error",
"message": error_msg,
"temperature": self.temperature
}
except Exception as e:
logger.error(f"字幕分析过程中发生错误: {str(e)}")
logger.error(f"OpenAI兼容API调用失败: {str(e)}")
return {
"status": "error",
"message": str(e),
"message": f"OpenAI兼容API调用失败: {str(e)}",
"temperature": self.temperature
}
@ -206,12 +350,12 @@ class SubtitleAnalyzer:
def generate_narration_script(self, short_name:str, plot_analysis: str, temperature: float = 0.7) -> Dict[str, Any]:
"""
根据剧情分析生成解说文案
Args:
short_name: 短剧名称
plot_analysis: 剧情分析内容
temperature: 生成温度控制创造性默认0.7
Returns:
Dict[str, Any]: 包含生成结果的字典
"""
@ -219,7 +363,145 @@ class SubtitleAnalyzer:
# 构建完整提示词
prompt = plot_writing % (short_name, plot_analysis)
# 构建请求体数据
if self.is_native_gemini:
# 使用原生Gemini API格式
return self._generate_narration_with_native_gemini(prompt, temperature)
else:
# 使用OpenAI兼容格式
return self._generate_narration_with_openai_compatible(prompt, temperature)
except Exception as e:
logger.error(f"解说文案生成过程中发生错误: {str(e)}")
return {
"status": "error",
"message": str(e),
"temperature": self.temperature
}
def _generate_narration_with_native_gemini(self, prompt: str, temperature: float) -> Dict[str, Any]:
"""使用原生Gemini API生成解说文案"""
try:
# 构建原生Gemini API请求数据
# 为了确保JSON输出在提示词中添加更强的约束
enhanced_prompt = f"{prompt}\n\n请确保输出严格的JSON格式不要包含任何其他文字或标记。"
payload = {
"systemInstruction": {
"parts": [{"text": "你是一位专业的短视频解说脚本撰写专家。你必须严格按照JSON格式输出不能包含任何其他文字、说明或代码块标记。"}]
},
"contents": [{
"parts": [{"text": enhanced_prompt}]
}],
"generationConfig": {
"temperature": temperature,
"topK": 40,
"topP": 0.95,
"maxOutputTokens": 4000,
"candidateCount": 1,
"stopSequences": ["```", "注意", "说明"]
},
"safetySettings": [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
}
]
}
# 构建请求URL
url = f"{self.base_url}/models/{self.model}:generateContent?key={self.api_key}"
# 发送请求
response = requests.post(
url,
json=payload,
headers={"Content-Type": "application/json", "User-Agent": "NarratoAI/1.0"},
timeout=120
)
if response.status_code == 200:
response_data = response.json()
# 检查响应格式
if "candidates" not in response_data or not response_data["candidates"]:
return {
"status": "error",
"message": "原生Gemini API返回无效响应可能触发了安全过滤",
"temperature": temperature
}
candidate = response_data["candidates"][0]
# 检查是否被安全过滤阻止
if "finishReason" in candidate and candidate["finishReason"] == "SAFETY":
return {
"status": "error",
"message": "内容被Gemini安全过滤器阻止",
"temperature": temperature
}
if "content" not in candidate or "parts" not in candidate["content"]:
return {
"status": "error",
"message": "原生Gemini API返回内容格式错误",
"temperature": temperature
}
# 提取文本内容
narration_script = ""
for part in candidate["content"]["parts"]:
if "text" in part:
narration_script += part["text"]
if not narration_script.strip():
return {
"status": "error",
"message": "原生Gemini API返回空内容",
"temperature": temperature
}
logger.debug(f"原生Gemini解说文案生成完成")
return {
"status": "success",
"narration_script": narration_script,
"tokens_used": response_data.get("usage", {}).get("total_tokens", 0),
"model": self.model,
"temperature": temperature
}
else:
error_msg = f"原生Gemini API请求失败状态码: {response.status_code}, 响应: {response.text}"
logger.error(error_msg)
return {
"status": "error",
"message": error_msg,
"temperature": temperature
}
except Exception as e:
logger.error(f"原生Gemini API解说文案生成失败: {str(e)}")
return {
"status": "error",
"message": f"原生Gemini API解说文案生成失败: {str(e)}",
"temperature": temperature
}
def _generate_narration_with_openai_compatible(self, prompt: str, temperature: float) -> Dict[str, Any]:
"""使用OpenAI兼容API生成解说文案"""
try:
# 构建OpenAI格式的请求数据
payload = {
"model": self.model,
"messages": [
@ -228,56 +510,56 @@ class SubtitleAnalyzer:
],
"temperature": temperature
}
# 对特定模型添加响应格式设置
if self.model not in ["deepseek-reasoner"]:
payload["response_format"] = {"type": "json_object"}
# 构建请求地址
url = f"{self.base_url}/chat/completions"
# 发送HTTP请求
response = requests.post(url, headers=self.headers, json=payload)
response = requests.post(url, headers=self.headers, json=payload, timeout=120)
# 解析响应
if response.status_code == 200:
response_data = response.json()
# 提取响应内容
if "choices" in response_data and len(response_data["choices"]) > 0:
narration_script = response_data["choices"][0]["message"]["content"]
logger.debug(f"解说文案生成完成消耗的tokens: {response_data.get('usage', {}).get('total_tokens', 0)}")
logger.debug(f"OpenAI兼容API解说文案生成完成消耗的tokens: {response_data.get('usage', {}).get('total_tokens', 0)}")
# 返回结果
return {
"status": "success",
"narration_script": narration_script,
"tokens_used": response_data.get("usage", {}).get("total_tokens", 0),
"model": self.model,
"temperature": self.temperature
"temperature": temperature
}
else:
logger.error("解说文案生成失败: 未获取到有效响应")
logger.error("OpenAI兼容API解说文案生成失败: 未获取到有效响应")
return {
"status": "error",
"message": "未获取到有效响应",
"temperature": self.temperature
"temperature": temperature
}
else:
error_msg = f"请求失败,状态码: {response.status_code}, 响应: {response.text}"
error_msg = f"OpenAI兼容API请求失败,状态码: {response.status_code}, 响应: {response.text}"
logger.error(error_msg)
return {
"status": "error",
"message": error_msg,
"temperature": self.temperature
"temperature": temperature
}
except Exception as e:
logger.error(f"解说文案生成过程中发生错误: {str(e)}")
logger.error(f"OpenAI兼容API解说文案生成失败: {str(e)}")
return {
"status": "error",
"message": str(e),
"temperature": self.temperature
"message": f"OpenAI兼容API解说文案生成失败: {str(e)}",
"temperature": temperature
}
def save_narration_script(self, narration_result: Dict[str, Any], output_path: Optional[str] = None) -> str:
@ -324,11 +606,12 @@ def analyze_subtitle(
custom_prompt: Optional[str] = None,
temperature: float = 1.0,
save_result: bool = False,
output_path: Optional[str] = None
output_path: Optional[str] = None,
provider: Optional[str] = None
) -> Dict[str, Any]:
"""
分析字幕内容的便捷函数
Args:
subtitle_content: 字幕内容文本
subtitle_file_path: 字幕文件路径
@ -339,7 +622,8 @@ def analyze_subtitle(
temperature: 模型温度
save_result: 是否保存结果到文件
output_path: 输出文件路径
provider: 提供商类型
Returns:
Dict[str, Any]: 包含分析结果的字典
"""
@ -349,7 +633,8 @@ def analyze_subtitle(
api_key=api_key,
model=model,
base_url=base_url,
custom_prompt=custom_prompt
custom_prompt=custom_prompt,
provider=provider
)
logger.debug(f"使用模型: {analyzer.model} 开始分析, 温度: {analyzer.temperature}")
# 分析字幕
@ -379,11 +664,12 @@ def generate_narration_script(
base_url: Optional[str] = None,
temperature: float = 1.0,
save_result: bool = False,
output_path: Optional[str] = None
output_path: Optional[str] = None,
provider: Optional[str] = None
) -> Dict[str, Any]:
"""
根据剧情分析生成解说文案的便捷函数
Args:
short_name: 短剧名称
plot_analysis: 剧情分析内容直接提供
@ -393,7 +679,8 @@ def generate_narration_script(
temperature: 生成温度控制创造性
save_result: 是否保存结果到文件
output_path: 输出文件路径
provider: 提供商类型
Returns:
Dict[str, Any]: 包含生成结果的字典
"""
@ -402,7 +689,8 @@ def generate_narration_script(
temperature=temperature,
api_key=api_key,
model=model,
base_url=base_url
base_url=base_url,
provider=provider
)
# 生成解说文案

View File

@ -4,7 +4,7 @@
'''
@Project: NarratoAI
@File : audio_normalizer
@Author : 小林同学
@Author : Viccy同学
@Date : 2025/1/7
@Description: 音频响度分析和标准化工具
'''

View File

@ -4,7 +4,7 @@
'''
@Project: NarratoAI
@File : clip_video
@Author : 小林同学
@Author : Viccy同学
@Date : 2025/5/6 下午6:14
'''

View File

@ -4,7 +4,7 @@
'''
@Project: NarratoAI
@File : 生成介绍文案
@Author : 小林同学
@Author : Viccy同学
@Date : 2025/5/8 上午11:33
'''

View File

@ -4,7 +4,7 @@
'''
@Project: NarratoAI
@File : generate_video
@Author : 小林同学
@Author : Viccy同学
@Date : 2025/5/7 上午11:55
'''

View File

@ -4,7 +4,7 @@
'''
@Project: NarratoAI
@File : merger_video
@Author : 小林同学
@Author : Viccy同学
@Date : 2025/5/6 下午7:38
'''

View File

@ -140,14 +140,27 @@ class ScriptGenerator:
# 获取Gemini配置
vision_api_key = config.app.get("vision_gemini_api_key")
vision_model = config.app.get("vision_gemini_model_name")
vision_base_url = config.app.get("vision_gemini_base_url")
if not vision_api_key or not vision_model:
raise ValueError("未配置 Gemini API Key 或者模型")
analyzer = gemini_analyzer.VisionAnalyzer(
model_name=vision_model,
api_key=vision_api_key,
)
# 根据提供商类型选择合适的分析器
if vision_provider == 'gemini(openai)':
# 使用OpenAI兼容的Gemini代理
from app.utils.gemini_openai_analyzer import GeminiOpenAIAnalyzer
analyzer = GeminiOpenAIAnalyzer(
model_name=vision_model,
api_key=vision_api_key,
base_url=vision_base_url
)
else:
# 使用原生Gemini分析器
analyzer = gemini_analyzer.VisionAnalyzer(
model_name=vision_model,
api_key=vision_api_key,
base_url=vision_base_url
)
progress_callback(40, "正在分析关键帧...")
@ -213,13 +226,35 @@ class ScriptGenerator:
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
text_api_key = config.app.get(f'text_{text_provider}_api_key')
text_model = config.app.get(f'text_{text_provider}_model_name')
text_base_url = config.app.get(f'text_{text_provider}_base_url')
processor = ScriptProcessor(
model_name=text_model,
api_key=text_api_key,
prompt=custom_prompt,
video_theme=video_theme
)
# 根据提供商类型选择合适的处理器
if text_provider == 'gemini(openai)':
# 使用OpenAI兼容的Gemini代理
from app.utils.script_generator import GeminiOpenAIGenerator
generator = GeminiOpenAIGenerator(
model_name=text_model,
api_key=text_api_key,
prompt=custom_prompt,
base_url=text_base_url
)
processor = ScriptProcessor(
model_name=text_model,
api_key=text_api_key,
base_url=text_base_url,
prompt=custom_prompt,
video_theme=video_theme
)
processor.generator = generator
else:
# 使用标准处理器包括原生Gemini
processor = ScriptProcessor(
model_name=text_model,
api_key=text_api_key,
base_url=text_base_url,
prompt=custom_prompt,
video_theme=video_theme
)
return processor.process_frames(frame_content_list)

View File

@ -4,7 +4,7 @@
'''
@Project: NarratoAI
@File : update_script
@Author : 小林同学
@Author : Viccy同学
@Date : 2025/5/6 下午11:00
'''

View File

@ -5,53 +5,162 @@ from pathlib import Path
from loguru import logger
from tqdm import tqdm
import asyncio
from tenacity import retry, stop_after_attempt, RetryError, retry_if_exception_type, wait_exponential
from google.api_core import exceptions
import google.generativeai as genai
from tenacity import retry, stop_after_attempt, retry_if_exception_type, wait_exponential
import requests
import PIL.Image
import traceback
import base64
import io
from app.utils import utils
class VisionAnalyzer:
"""视觉分析器类"""
"""原生Gemini视觉分析器类"""
def __init__(self, model_name: str = "gemini-1.5-flash", api_key: str = None):
def __init__(self, model_name: str = "gemini-2.0-flash-exp", api_key: str = None, base_url: str = None):
"""初始化视觉分析器"""
if not api_key:
raise ValueError("必须提供API密钥")
self.model_name = model_name
self.api_key = api_key
self.base_url = base_url or "https://generativelanguage.googleapis.com/v1beta"
# 初始化配置
self._configure_client()
def _configure_client(self):
"""配置API客户端"""
genai.configure(api_key=self.api_key)
# 开放 Gemini 模型安全设置
from google.generativeai.types import HarmCategory, HarmBlockThreshold
safety_settings = {
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
self.model = genai.GenerativeModel(self.model_name, safety_settings=safety_settings)
"""配置原生Gemini API客户端"""
# 使用原生Gemini REST API
self.client = None
logger.info(f"配置原生Gemini API端点: {self.base_url}, 模型: {self.model_name}")
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(exceptions.ResourceExhausted)
retry=retry_if_exception_type(requests.exceptions.RequestException)
)
async def _generate_content_with_retry(self, prompt, batch):
"""使用重试机制的内部方法来调用 generate_content_async"""
"""使用重试机制调用原生Gemini API"""
try:
return await self.model.generate_content_async([prompt, *batch])
except exceptions.ResourceExhausted as e:
print(f"API配额限制: {str(e)}")
raise RetryError("API调用失败")
return await self._generate_with_gemini_api(prompt, batch)
except requests.exceptions.RequestException as e:
logger.warning(f"Gemini API请求异常: {str(e)}")
raise
except Exception as e:
logger.error(f"Gemini API生成内容时发生错误: {str(e)}")
raise
async def _generate_with_gemini_api(self, prompt, batch):
"""使用原生Gemini REST API生成内容"""
# 将PIL图片转换为base64编码
image_parts = []
for img in batch:
# 将PIL图片转换为字节流
img_buffer = io.BytesIO()
img.save(img_buffer, format='JPEG', quality=85) # 优化图片质量
img_bytes = img_buffer.getvalue()
# 转换为base64
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
image_parts.append({
"inline_data": {
"mime_type": "image/jpeg",
"data": img_base64
}
})
# 构建符合官方文档的请求数据
request_data = {
"contents": [{
"parts": [
{"text": prompt},
*image_parts
]
}],
"generationConfig": {
"temperature": 1.0,
"topK": 40,
"topP": 0.95,
"maxOutputTokens": 8192,
"candidateCount": 1,
"stopSequences": []
},
"safetySettings": [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
}
]
}
# 构建请求URL
url = f"{self.base_url}/models/{self.model_name}:generateContent?key={self.api_key}"
# 发送请求
response = await asyncio.to_thread(
requests.post,
url,
json=request_data,
headers={
"Content-Type": "application/json",
"User-Agent": "NarratoAI/1.0"
},
timeout=120 # 增加超时时间
)
# 处理HTTP错误
if response.status_code == 429:
raise requests.exceptions.RequestException(f"API配额限制: {response.text}")
elif response.status_code == 400:
raise Exception(f"请求参数错误: {response.text}")
elif response.status_code == 403:
raise Exception(f"API密钥无效或权限不足: {response.text}")
elif response.status_code != 200:
raise Exception(f"Gemini API请求失败: {response.status_code} - {response.text}")
response_data = response.json()
# 检查响应格式
if "candidates" not in response_data or not response_data["candidates"]:
raise Exception("Gemini API返回无效响应可能触发了安全过滤")
candidate = response_data["candidates"][0]
# 检查是否被安全过滤阻止
if "finishReason" in candidate and candidate["finishReason"] == "SAFETY":
raise Exception("内容被Gemini安全过滤器阻止")
if "content" not in candidate or "parts" not in candidate["content"]:
raise Exception("Gemini API返回内容格式错误")
# 提取文本内容
text_content = ""
for part in candidate["content"]["parts"]:
if "text" in part:
text_content += part["text"]
if not text_content.strip():
raise Exception("Gemini API返回空内容")
# 创建兼容的响应对象
class CompatibleResponse:
def __init__(self, text):
self.text = text
return CompatibleResponse(text_content)
async def analyze_images(self,
images: Union[List[str], List[PIL.Image.Image]],

View File

@ -0,0 +1,177 @@
"""
OpenAI兼容的Gemini视觉分析器
使用标准OpenAI格式调用Gemini代理服务
"""
import json
from typing import List, Union, Dict
import os
from pathlib import Path
from loguru import logger
from tqdm import tqdm
import asyncio
from tenacity import retry, stop_after_attempt, retry_if_exception_type, wait_exponential
import requests
import PIL.Image
import traceback
import base64
import io
from app.utils import utils
class GeminiOpenAIAnalyzer:
"""OpenAI兼容的Gemini视觉分析器类"""
def __init__(self, model_name: str = "gemini-2.0-flash-exp", api_key: str = None, base_url: str = None):
"""初始化OpenAI兼容的Gemini分析器"""
if not api_key:
raise ValueError("必须提供API密钥")
if not base_url:
raise ValueError("必须提供OpenAI兼容的代理端点URL")
self.model_name = model_name
self.api_key = api_key
self.base_url = base_url.rstrip('/')
# 初始化OpenAI客户端
self._configure_client()
def _configure_client(self):
"""配置OpenAI兼容的客户端"""
from openai import OpenAI
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url
)
logger.info(f"配置OpenAI兼容Gemini代理端点: {self.base_url}, 模型: {self.model_name}")
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((requests.exceptions.RequestException, Exception))
)
async def _generate_content_with_retry(self, prompt, batch):
"""使用重试机制调用OpenAI兼容的Gemini代理"""
try:
return await self._generate_with_openai_api(prompt, batch)
except Exception as e:
logger.warning(f"OpenAI兼容Gemini代理请求异常: {str(e)}")
raise
async def _generate_with_openai_api(self, prompt, batch):
"""使用OpenAI兼容接口生成内容"""
# 将PIL图片转换为base64编码
image_contents = []
for img in batch:
# 将PIL图片转换为字节流
img_buffer = io.BytesIO()
img.save(img_buffer, format='JPEG', quality=85)
img_bytes = img_buffer.getvalue()
# 转换为base64
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
image_contents.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{img_base64}"
}
})
# 构建OpenAI格式的消息
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
*image_contents
]
}
]
# 调用OpenAI兼容接口
response = await asyncio.to_thread(
self.client.chat.completions.create,
model=self.model_name,
messages=messages,
max_tokens=4000,
temperature=1.0
)
# 创建兼容的响应对象
class CompatibleResponse:
def __init__(self, text):
self.text = text
return CompatibleResponse(response.choices[0].message.content)
async def analyze_images(self,
images: List[Union[str, Path, PIL.Image.Image]],
prompt: str,
batch_size: int = 10) -> List[str]:
"""
分析图片并返回结果
Args:
images: 图片路径列表或PIL图片对象列表
prompt: 分析提示词
batch_size: 批处理大小
Returns:
分析结果列表
"""
logger.info(f"开始分析 {len(images)} 张图片使用OpenAI兼容Gemini代理")
# 加载图片
loaded_images = []
for img in images:
if isinstance(img, (str, Path)):
try:
pil_img = PIL.Image.open(img)
# 调整图片大小以优化性能
if pil_img.size[0] > 1024 or pil_img.size[1] > 1024:
pil_img.thumbnail((1024, 1024), PIL.Image.Resampling.LANCZOS)
loaded_images.append(pil_img)
except Exception as e:
logger.error(f"加载图片失败 {img}: {str(e)}")
continue
elif isinstance(img, PIL.Image.Image):
loaded_images.append(img)
else:
logger.warning(f"不支持的图片类型: {type(img)}")
continue
if not loaded_images:
raise ValueError("没有有效的图片可以分析")
# 分批处理
results = []
total_batches = (len(loaded_images) + batch_size - 1) // batch_size
for i in tqdm(range(0, len(loaded_images), batch_size),
desc="分析图片批次", total=total_batches):
batch = loaded_images[i:i + batch_size]
try:
response = await self._generate_content_with_retry(prompt, batch)
results.append(response.text)
# 添加延迟以避免API限流
if i + batch_size < len(loaded_images):
await asyncio.sleep(1)
except Exception as e:
logger.error(f"分析批次 {i//batch_size + 1} 失败: {str(e)}")
results.append(f"分析失败: {str(e)}")
logger.info(f"完成图片分析,共处理 {len(results)} 个批次")
return results
def analyze_images_sync(self,
images: List[Union[str, Path, PIL.Image.Image]],
prompt: str,
batch_size: int = 10) -> List[str]:
"""
同步版本的图片分析方法
"""
return asyncio.run(self.analyze_images(images, prompt, batch_size))

View File

@ -6,7 +6,7 @@ from loguru import logger
from typing import List, Dict
from datetime import datetime
from openai import OpenAI
import google.generativeai as genai
import requests
import time
@ -134,59 +134,182 @@ class OpenAIGenerator(BaseGenerator):
class GeminiGenerator(BaseGenerator):
"""Google Gemini API 生成器实现"""
def __init__(self, model_name: str, api_key: str, prompt: str):
"""原生Gemini API 生成器实现"""
def __init__(self, model_name: str, api_key: str, prompt: str, base_url: str = None):
super().__init__(model_name, api_key, prompt)
genai.configure(api_key=api_key)
self.model = genai.GenerativeModel(model_name)
# Gemini特定参数
self.base_url = base_url or "https://generativelanguage.googleapis.com/v1beta"
self.client = None
# 原生Gemini API参数
self.default_params = {
"temperature": self.default_params["temperature"],
"top_p": self.default_params["top_p"],
"candidate_count": 1,
"stop_sequences": None
"topP": self.default_params["top_p"],
"topK": 40,
"maxOutputTokens": 4000,
"candidateCount": 1,
"stopSequences": []
}
class GeminiOpenAIGenerator(BaseGenerator):
"""OpenAI兼容的Gemini代理生成器实现"""
def __init__(self, model_name: str, api_key: str, prompt: str, base_url: str = None):
super().__init__(model_name, api_key, prompt)
if not base_url:
raise ValueError("OpenAI兼容的Gemini代理必须提供base_url")
self.base_url = base_url.rstrip('/')
# 使用OpenAI兼容接口
from openai import OpenAI
self.client = OpenAI(
api_key=api_key,
base_url=base_url
)
# OpenAI兼容接口参数
self.default_params = {
"temperature": self.default_params["temperature"],
"max_tokens": 4000,
"stream": False
}
def _generate(self, messages: list, params: dict) -> any:
"""实现Gemini特定的生成逻辑"""
while True:
"""实现OpenAI兼容Gemini代理的生成逻辑"""
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
**params
)
return response
except Exception as e:
logger.error(f"OpenAI兼容Gemini代理生成错误: {str(e)}")
raise
def _process_response(self, response: any) -> str:
"""处理OpenAI兼容接口的响应"""
if not response or not response.choices:
raise ValueError("OpenAI兼容Gemini代理返回无效响应")
return response.choices[0].message.content.strip()
def _generate(self, messages: list, params: dict) -> any:
"""实现原生Gemini API的生成逻辑"""
max_retries = 3
for attempt in range(max_retries):
try:
# 转换消息格式为Gemini格式
prompt = "\n".join([m["content"] for m in messages])
response = self.model.generate_content(
prompt,
generation_config=params
# 构建请求数据
request_data = {
"contents": [{
"parts": [{"text": prompt}]
}],
"generationConfig": params,
"safetySettings": [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
}
]
}
# 构建请求URL
url = f"{self.base_url}/models/{self.model_name}:generateContent?key={self.api_key}"
# 发送请求
response = requests.post(
url,
json=request_data,
headers={
"Content-Type": "application/json",
"User-Agent": "NarratoAI/1.0"
},
timeout=120
)
# 检查响应是否包含有效内容
if (hasattr(response, 'result') and
hasattr(response.result, 'candidates') and
response.result.candidates):
candidate = response.result.candidates[0]
# 检查是否有内容字段
if not hasattr(candidate, 'content'):
logger.warning("Gemini API 返回速率限制响应等待30秒后重试...")
time.sleep(30) # 等待3秒后重试
if response.status_code == 429:
# 处理限流
wait_time = 65 if attempt == 0 else 30
logger.warning(f"原生Gemini API 触发限流,等待{wait_time}秒后重试...")
time.sleep(wait_time)
continue
if response.status_code == 400:
raise Exception(f"请求参数错误: {response.text}")
elif response.status_code == 403:
raise Exception(f"API密钥无效或权限不足: {response.text}")
elif response.status_code != 200:
raise Exception(f"原生Gemini API请求失败: {response.status_code} - {response.text}")
response_data = response.json()
# 检查响应格式
if "candidates" not in response_data or not response_data["candidates"]:
if attempt < max_retries - 1:
logger.warning("原生Gemini API 返回无效响应等待30秒后重试...")
time.sleep(30)
continue
return response
except Exception as e:
error_str = str(e)
if "429" in error_str:
logger.warning("Gemini API 触发限流等待65秒后重试...")
time.sleep(65) # 等待65秒后重试
else:
raise Exception("原生Gemini API返回无效响应可能触发了安全过滤")
candidate = response_data["candidates"][0]
# 检查是否被安全过滤阻止
if "finishReason" in candidate and candidate["finishReason"] == "SAFETY":
raise Exception("内容被Gemini安全过滤器阻止")
# 创建兼容的响应对象
class CompatibleResponse:
def __init__(self, data):
self.data = data
candidate = data["candidates"][0]
if "content" in candidate and "parts" in candidate["content"]:
self.text = ""
for part in candidate["content"]["parts"]:
if "text" in part:
self.text += part["text"]
else:
self.text = ""
return CompatibleResponse(response_data)
except requests.exceptions.RequestException as e:
if attempt < max_retries - 1:
logger.warning(f"网络请求失败等待30秒后重试: {str(e)}")
time.sleep(30)
continue
else:
logger.error(f"Gemini 生成文案错误: \n{error_str}")
logger.error(f"原生Gemini API请求失败: {str(e)}")
raise
except Exception as e:
if attempt < max_retries - 1 and "429" in str(e):
logger.warning("原生Gemini API 触发限流等待65秒后重试...")
time.sleep(65)
continue
else:
logger.error(f"原生Gemini 生成文案错误: {str(e)}")
raise
def _process_response(self, response: any) -> str:
"""处理Gemini的响应"""
"""处理原生Gemini API的响应"""
if not response or not response.text:
raise ValueError("Invalid response from Gemini API")
raise ValueError("原生Gemini API返回无效响应")
return response.text.strip()
@ -318,7 +441,7 @@ class ScriptProcessor:
# 根据模型名称选择对应的生成器
logger.info(f"文本 LLM 提供商: {model_name}")
if 'gemini' in model_name.lower():
self.generator = GeminiGenerator(model_name, self.api_key, self.prompt)
self.generator = GeminiGenerator(model_name, self.api_key, self.prompt, self.base_url)
elif 'qwen' in model_name.lower():
self.generator = QwenGenerator(model_name, self.api_key, self.prompt, self.base_url)
elif 'moonshot' in model_name.lower():

View File

@ -7,6 +7,45 @@ from app.utils import utils
from loguru import logger
def validate_api_key(api_key: str, provider: str) -> tuple[bool, str]:
"""验证API密钥格式"""
if not api_key or not api_key.strip():
return False, f"{provider} API密钥不能为空"
# 基本长度检查
if len(api_key.strip()) < 10:
return False, f"{provider} API密钥长度过短请检查是否正确"
return True, ""
def validate_base_url(base_url: str, provider: str) -> tuple[bool, str]:
"""验证Base URL格式"""
if not base_url or not base_url.strip():
return True, "" # base_url可以为空
base_url = base_url.strip()
if not (base_url.startswith('http://') or base_url.startswith('https://')):
return False, f"{provider} Base URL必须以http://或https://开头"
return True, ""
def validate_model_name(model_name: str, provider: str) -> tuple[bool, str]:
"""验证模型名称"""
if not model_name or not model_name.strip():
return False, f"{provider} 模型名称不能为空"
return True, ""
def show_config_validation_errors(errors: list):
"""显示配置验证错误"""
if errors:
for error in errors:
st.error(error)
def render_basic_settings(tr):
"""渲染基础设置面板"""
with st.expander(tr("Basic Settings"), expanded=False):
@ -87,29 +126,96 @@ def render_proxy_settings(tr):
def test_vision_model_connection(api_key, base_url, model_name, provider, tr):
"""测试视觉模型连接
Args:
api_key: API密钥
base_url: 基础URL
model_name: 模型名称
provider: 提供商名称
Returns:
bool: 连接是否成功
str: 测试结果消息
"""
import requests
if provider.lower() == 'gemini':
import google.generativeai as genai
# 原生Gemini API测试
try:
genai.configure(api_key=api_key)
model = genai.GenerativeModel(model_name)
model.generate_content("直接回复我文本'当前网络可用'")
return True, tr("gemini model is available")
# 构建请求数据
request_data = {
"contents": [{
"parts": [{"text": "直接回复我文本'当前网络可用'"}]
}],
"generationConfig": {
"temperature": 1.0,
"topK": 40,
"topP": 0.95,
"maxOutputTokens": 100,
},
"safetySettings": [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
}
]
}
# 构建请求URL
api_base_url = base_url or "https://generativelanguage.googleapis.com/v1beta"
url = f"{api_base_url}/models/{model_name}:generateContent?key={api_key}"
# 发送请求
response = requests.post(
url,
json=request_data,
headers={"Content-Type": "application/json"},
timeout=30
)
if response.status_code == 200:
return True, tr("原生Gemini模型连接成功")
else:
return False, f"{tr('原生Gemini模型连接失败')}: HTTP {response.status_code}"
except Exception as e:
return False, f"{tr('gemini model is not available')}: {str(e)}"
return False, f"{tr('原生Gemini模型连接失败')}: {str(e)}"
elif provider.lower() == 'gemini(openai)':
# OpenAI兼容的Gemini代理测试
try:
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
test_url = f"{base_url.rstrip('/')}/chat/completions"
test_data = {
"model": model_name,
"messages": [
{"role": "user", "content": "直接回复我文本'当前网络可用'"}
],
"stream": False
}
response = requests.post(test_url, headers=headers, json=test_data, timeout=10)
if response.status_code == 200:
return True, tr("OpenAI兼容Gemini代理连接成功")
else:
return False, f"{tr('OpenAI兼容Gemini代理连接失败')}: HTTP {response.status_code}"
except Exception as e:
return False, f"{tr('OpenAI兼容Gemini代理连接失败')}: {str(e)}"
elif provider.lower() == 'narratoapi':
import requests
try:
# 构建测试请求
headers = {
@ -172,7 +278,7 @@ def render_vision_llm_settings(tr):
st.subheader(tr("Vision Model Settings"))
# 视频分析模型提供商选择
vision_providers = ['Siliconflow', 'Gemini', 'QwenVL', 'OpenAI']
vision_providers = ['Siliconflow', 'Gemini', 'Gemini(OpenAI)', 'QwenVL', 'OpenAI']
saved_vision_provider = config.app.get("vision_llm_provider", "Gemini").lower()
saved_provider_index = 0
@ -191,9 +297,15 @@ def render_vision_llm_settings(tr):
st.session_state['vision_llm_providers'] = vision_provider
# 获取已保存的视觉模型配置
vision_api_key = config.app.get(f"vision_{vision_provider}_api_key", "")
vision_base_url = config.app.get(f"vision_{vision_provider}_base_url", "")
vision_model_name = config.app.get(f"vision_{vision_provider}_model_name", "")
# 处理特殊的提供商名称映射
if vision_provider == 'gemini(openai)':
vision_config_key = 'vision_gemini_openai'
else:
vision_config_key = f'vision_{vision_provider}'
vision_api_key = config.app.get(f"{vision_config_key}_api_key", "")
vision_base_url = config.app.get(f"{vision_config_key}_base_url", "")
vision_model_name = config.app.get(f"{vision_config_key}_model_name", "")
# 渲染视觉模型配置输入框
st_vision_api_key = st.text_input(tr("Vision API Key"), value=vision_api_key, type="password")
@ -201,15 +313,25 @@ def render_vision_llm_settings(tr):
# 根据不同提供商设置默认值和帮助信息
if vision_provider == 'gemini':
st_vision_base_url = st.text_input(
tr("Vision Base URL"),
value=vision_base_url,
disabled=True,
help=tr("Gemini API does not require a base URL")
tr("Vision Base URL"),
value=vision_base_url or "https://generativelanguage.googleapis.com/v1beta",
help=tr("原生Gemini API端点默认: https://generativelanguage.googleapis.com/v1beta")
)
st_vision_model_name = st.text_input(
tr("Vision Model Name"),
value=vision_model_name or "gemini-2.0-flash-lite",
help=tr("Default: gemini-2.0-flash-lite")
tr("Vision Model Name"),
value=vision_model_name or "gemini-2.0-flash-exp",
help=tr("原生Gemini模型默认: gemini-2.0-flash-exp")
)
elif vision_provider == 'gemini(openai)':
st_vision_base_url = st.text_input(
tr("Vision Base URL"),
value=vision_base_url or "https://generativelanguage.googleapis.com/v1beta/openai",
help=tr("OpenAI兼容的Gemini代理端点如: https://your-proxy.com/v1")
)
st_vision_model_name = st.text_input(
tr("Vision Model Name"),
value=vision_model_name or "gemini-2.0-flash-exp",
help=tr("OpenAI格式的Gemini模型名称默认: gemini-2.0-flash-exp")
)
elif vision_provider == 'qwenvl':
st_vision_base_url = st.text_input(
@ -228,30 +350,81 @@ def render_vision_llm_settings(tr):
# 在配置输入框后添加测试按钮
if st.button(tr("Test Connection"), key="test_vision_connection"):
with st.spinner(tr("Testing connection...")):
success, message = test_vision_model_connection(
api_key=st_vision_api_key,
base_url=st_vision_base_url,
model_name=st_vision_model_name,
provider=vision_provider,
tr=tr
)
if success:
st.success(tr(message))
else:
st.error(tr(message))
# 先验证配置
test_errors = []
if not st_vision_api_key:
test_errors.append("请先输入API密钥")
if not st_vision_model_name:
test_errors.append("请先输入模型名称")
# 保存视觉模型配置
if test_errors:
for error in test_errors:
st.error(error)
else:
with st.spinner(tr("Testing connection...")):
try:
success, message = test_vision_model_connection(
api_key=st_vision_api_key,
base_url=st_vision_base_url,
model_name=st_vision_model_name,
provider=vision_provider,
tr=tr
)
if success:
st.success(message)
else:
st.error(message)
except Exception as e:
st.error(f"测试连接时发生错误: {str(e)}")
logger.error(f"视频分析模型连接测试失败: {str(e)}")
# 验证和保存视觉模型配置
validation_errors = []
config_changed = False
# 验证API密钥
if st_vision_api_key:
config.app[f"vision_{vision_provider}_api_key"] = st_vision_api_key
st.session_state[f"vision_{vision_provider}_api_key"] = st_vision_api_key
is_valid, error_msg = validate_api_key(st_vision_api_key, f"视频分析({vision_provider})")
if is_valid:
config.app[f"{vision_config_key}_api_key"] = st_vision_api_key
st.session_state[f"{vision_config_key}_api_key"] = st_vision_api_key
config_changed = True
else:
validation_errors.append(error_msg)
# 验证Base URL
if st_vision_base_url:
config.app[f"vision_{vision_provider}_base_url"] = st_vision_base_url
st.session_state[f"vision_{vision_provider}_base_url"] = st_vision_base_url
is_valid, error_msg = validate_base_url(st_vision_base_url, f"视频分析({vision_provider})")
if is_valid:
config.app[f"{vision_config_key}_base_url"] = st_vision_base_url
st.session_state[f"{vision_config_key}_base_url"] = st_vision_base_url
config_changed = True
else:
validation_errors.append(error_msg)
# 验证模型名称
if st_vision_model_name:
config.app[f"vision_{vision_provider}_model_name"] = st_vision_model_name
st.session_state[f"vision_{vision_provider}_model_name"] = st_vision_model_name
is_valid, error_msg = validate_model_name(st_vision_model_name, f"视频分析({vision_provider})")
if is_valid:
config.app[f"{vision_config_key}_model_name"] = st_vision_model_name
st.session_state[f"{vision_config_key}_model_name"] = st_vision_model_name
config_changed = True
else:
validation_errors.append(error_msg)
# 显示验证错误
show_config_validation_errors(validation_errors)
# 如果配置有变化且没有验证错误,保存到文件
if config_changed and not validation_errors:
try:
config.save_config()
if st_vision_api_key or st_vision_base_url or st_vision_model_name:
st.success(f"视频分析模型({vision_provider})配置已保存")
except Exception as e:
st.error(f"保存配置失败: {str(e)}")
logger.error(f"保存视频分析配置失败: {str(e)}")
def test_text_model_connection(api_key, base_url, model_name, provider, tr):
@ -278,14 +451,74 @@ def test_text_model_connection(api_key, base_url, model_name, provider, tr):
# 特殊处理Gemini
if provider.lower() == 'gemini':
import google.generativeai as genai
# 原生Gemini API测试
try:
genai.configure(api_key=api_key)
model = genai.GenerativeModel(model_name)
model.generate_content("直接回复我文本'当前网络可用'")
return True, tr("Gemini model is available")
# 构建请求数据
request_data = {
"contents": [{
"parts": [{"text": "直接回复我文本'当前网络可用'"}]
}],
"generationConfig": {
"temperature": 1.0,
"topK": 40,
"topP": 0.95,
"maxOutputTokens": 100,
},
"safetySettings": [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
}
]
}
# 构建请求URL
api_base_url = base_url or "https://generativelanguage.googleapis.com/v1beta"
url = f"{api_base_url}/models/{model_name}:generateContent?key={api_key}"
# 发送请求
response = requests.post(
url,
json=request_data,
headers={"Content-Type": "application/json"},
timeout=30
)
if response.status_code == 200:
return True, tr("原生Gemini模型连接成功")
else:
return False, f"{tr('原生Gemini模型连接失败')}: HTTP {response.status_code}"
except Exception as e:
return False, f"{tr('Gemini model is not available')}: {str(e)}"
return False, f"{tr('原生Gemini模型连接失败')}: {str(e)}"
elif provider.lower() == 'gemini(openai)':
# OpenAI兼容的Gemini代理测试
test_url = f"{base_url.rstrip('/')}/chat/completions"
test_data = {
"model": model_name,
"messages": [
{"role": "user", "content": "直接回复我文本'当前网络可用'"}
],
"stream": False
}
response = requests.post(test_url, headers=headers, json=test_data, timeout=10)
if response.status_code == 200:
return True, tr("OpenAI兼容Gemini代理连接成功")
else:
return False, f"{tr('OpenAI兼容Gemini代理连接失败')}: HTTP {response.status_code}"
else:
test_url = f"{base_url.rstrip('/')}/chat/completions"
@ -322,7 +555,7 @@ def render_text_llm_settings(tr):
st.subheader(tr("Text Generation Model Settings"))
# 文案生成模型提供商选择
text_providers = ['OpenAI', 'Siliconflow', 'DeepSeek', 'Gemini', 'Qwen', 'Moonshot']
text_providers = ['OpenAI', 'Siliconflow', 'DeepSeek', 'Gemini', 'Gemini(OpenAI)', 'Qwen', 'Moonshot']
saved_text_provider = config.app.get("text_llm_provider", "OpenAI").lower()
saved_provider_index = 0
@ -346,32 +579,108 @@ def render_text_llm_settings(tr):
# 渲染文本模型配置输入框
st_text_api_key = st.text_input(tr("Text API Key"), value=text_api_key, type="password")
st_text_base_url = st.text_input(tr("Text Base URL"), value=text_base_url)
st_text_model_name = st.text_input(tr("Text Model Name"), value=text_model_name)
# 根据不同提供商设置默认值和帮助信息
if text_provider == 'gemini':
st_text_base_url = st.text_input(
tr("Text Base URL"),
value=text_base_url or "https://generativelanguage.googleapis.com/v1beta",
help=tr("原生Gemini API端点默认: https://generativelanguage.googleapis.com/v1beta")
)
st_text_model_name = st.text_input(
tr("Text Model Name"),
value=text_model_name or "gemini-2.0-flash-exp",
help=tr("原生Gemini模型默认: gemini-2.0-flash-exp")
)
elif text_provider == 'gemini(openai)':
st_text_base_url = st.text_input(
tr("Text Base URL"),
value=text_base_url or "https://generativelanguage.googleapis.com/v1beta/openai",
help=tr("OpenAI兼容的Gemini代理端点如: https://your-proxy.com/v1")
)
st_text_model_name = st.text_input(
tr("Text Model Name"),
value=text_model_name or "gemini-2.0-flash-exp",
help=tr("OpenAI格式的Gemini模型名称默认: gemini-2.0-flash-exp")
)
else:
st_text_base_url = st.text_input(tr("Text Base URL"), value=text_base_url)
st_text_model_name = st.text_input(tr("Text Model Name"), value=text_model_name)
# 添加测试按钮
if st.button(tr("Test Connection"), key="test_text_connection"):
with st.spinner(tr("Testing connection...")):
success, message = test_text_model_connection(
api_key=st_text_api_key,
base_url=st_text_base_url,
model_name=st_text_model_name,
provider=text_provider,
tr=tr
)
if success:
st.success(message)
else:
st.error(message)
# 先验证配置
test_errors = []
if not st_text_api_key:
test_errors.append("请先输入API密钥")
if not st_text_model_name:
test_errors.append("请先输入模型名称")
# 保存文本模型配置
if test_errors:
for error in test_errors:
st.error(error)
else:
with st.spinner(tr("Testing connection...")):
try:
success, message = test_text_model_connection(
api_key=st_text_api_key,
base_url=st_text_base_url,
model_name=st_text_model_name,
provider=text_provider,
tr=tr
)
if success:
st.success(message)
else:
st.error(message)
except Exception as e:
st.error(f"测试连接时发生错误: {str(e)}")
logger.error(f"文案生成模型连接测试失败: {str(e)}")
# 验证和保存文本模型配置
text_validation_errors = []
text_config_changed = False
# 验证API密钥
if st_text_api_key:
config.app[f"text_{text_provider}_api_key"] = st_text_api_key
is_valid, error_msg = validate_api_key(st_text_api_key, f"文案生成({text_provider})")
if is_valid:
config.app[f"text_{text_provider}_api_key"] = st_text_api_key
text_config_changed = True
else:
text_validation_errors.append(error_msg)
# 验证Base URL
if st_text_base_url:
config.app[f"text_{text_provider}_base_url"] = st_text_base_url
is_valid, error_msg = validate_base_url(st_text_base_url, f"文案生成({text_provider})")
if is_valid:
config.app[f"text_{text_provider}_base_url"] = st_text_base_url
text_config_changed = True
else:
text_validation_errors.append(error_msg)
# 验证模型名称
if st_text_model_name:
config.app[f"text_{text_provider}_model_name"] = st_text_model_name
is_valid, error_msg = validate_model_name(st_text_model_name, f"文案生成({text_provider})")
if is_valid:
config.app[f"text_{text_provider}_model_name"] = st_text_model_name
text_config_changed = True
else:
text_validation_errors.append(error_msg)
# 显示验证错误
show_config_validation_errors(text_validation_errors)
# 如果配置有变化且没有验证错误,保存到文件
if text_config_changed and not text_validation_errors:
try:
config.save_config()
if st_text_api_key or st_text_base_url or st_text_model_name:
st.success(f"文案生成模型({text_provider})配置已保存")
except Exception as e:
st.error(f"保存配置失败: {str(e)}")
logger.error(f"保存文案生成配置失败: {str(e)}")
# # Cloudflare 特殊配置
# if text_provider == 'cloudflare':

View File

@ -23,11 +23,14 @@ def create_vision_analyzer(provider, api_key, model, base_url):
VisionAnalyzer QwenAnalyzer 实例
"""
if provider == 'gemini':
return gemini_analyzer.VisionAnalyzer(model_name=model, api_key=api_key)
return gemini_analyzer.VisionAnalyzer(model_name=model, api_key=api_key, base_url=base_url)
elif provider == 'gemini(openai)':
from app.utils.gemini_openai_analyzer import GeminiOpenAIAnalyzer
return GeminiOpenAIAnalyzer(model_name=model, api_key=api_key, base_url=base_url)
else:
# 只传入必要的参数
return qwenvl_analyzer.QwenAnalyzer(
model_name=model,
model_name=model,
api_key=api_key,
base_url=base_url
)

View File

@ -16,6 +16,89 @@ from loguru import logger
from app.config import config
from app.services.SDE.short_drama_explanation import analyze_subtitle, generate_narration_script
import re
def parse_and_fix_json(json_string):
"""
解析并修复JSON字符串
Args:
json_string: 待解析的JSON字符串
Returns:
dict: 解析后的字典如果解析失败返回None
"""
if not json_string or not json_string.strip():
logger.error("JSON字符串为空")
return None
# 清理字符串
json_string = json_string.strip()
# 尝试直接解析
try:
return json.loads(json_string)
except json.JSONDecodeError as e:
logger.warning(f"直接JSON解析失败: {e}")
# 尝试提取JSON部分
try:
# 查找JSON代码块
json_match = re.search(r'```json\s*(.*?)\s*```', json_string, re.DOTALL)
if json_match:
json_content = json_match.group(1).strip()
logger.info("从代码块中提取JSON内容")
return json.loads(json_content)
except json.JSONDecodeError:
pass
# 尝试查找大括号包围的内容
try:
# 查找第一个 { 到最后一个 } 的内容
start_idx = json_string.find('{')
end_idx = json_string.rfind('}')
if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
json_content = json_string[start_idx:end_idx+1]
logger.info("提取大括号包围的JSON内容")
return json.loads(json_content)
except json.JSONDecodeError:
pass
# 尝试修复常见的JSON格式问题
try:
# 移除注释
json_string = re.sub(r'#.*', '', json_string)
# 移除多余的逗号
json_string = re.sub(r',\s*}', '}', json_string)
json_string = re.sub(r',\s*]', ']', json_string)
# 修复单引号
json_string = re.sub(r"'([^']*)':", r'"\1":', json_string)
logger.info("尝试修复JSON格式问题后解析")
return json.loads(json_string)
except json.JSONDecodeError:
pass
# 如果所有方法都失败,尝试创建一个基本的结构
logger.error(f"所有JSON解析方法都失败原始内容: {json_string[:200]}...")
# 尝试从文本中提取关键信息创建基本结构
try:
# 这是一个简单的回退方案
return {
"items": [
{
"_id": 1,
"timestamp": "00:00:00,000-00:00:10,000",
"picture": "解析失败,使用默认内容",
"narration": json_string[:100] + "..." if len(json_string) > 100 else json_string,
"OST": 0
}
]
}
except Exception:
return None
def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperature):
@ -61,7 +144,8 @@ def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperatu
model=text_model,
base_url=text_base_url,
save_result=True,
temperature=temperature
temperature=temperature,
provider=text_provider
)
"""
3. 根据剧情生成解说文案
@ -78,7 +162,8 @@ def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperatu
model=text_model,
base_url=text_base_url,
save_result=True,
temperature=temperature
temperature=temperature,
provider=text_provider
)
if narration_result["status"] == "success":
@ -100,7 +185,20 @@ def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperatu
# 结果转换为JSON字符串
narration_script = narration_result["narration_script"]
narration_dict = json.loads(narration_script)
# 增强JSON解析包含错误处理和修复
narration_dict = parse_and_fix_json(narration_script)
if narration_dict is None:
st.error("生成的解说文案格式错误无法解析为JSON")
logger.error(f"JSON解析失败原始内容: {narration_script}")
st.stop()
# 验证JSON结构
if 'items' not in narration_dict:
st.error("生成的解说文案缺少必要的'items'字段")
logger.error(f"JSON结构错误缺少items字段: {narration_dict}")
st.stop()
script = json.dumps(narration_dict['items'], ensure_ascii=False, indent=2)
if script is None: