mirror of
https://github.com/linyqh/NarratoAI.git
synced 2025-12-11 10:32:49 +00:00
feat: 更新作者信息并增强API配置验证功能
在基础设置中新增API密钥、基础URL和模型名称的验证功能,确保用户输入的配置有效性,提升系统的稳定性和用户体验。
This commit is contained in:
parent
04ffda297f
commit
dd59d5295d
@ -4,7 +4,7 @@
|
||||
'''
|
||||
@Project: NarratoAI
|
||||
@File : audio_config
|
||||
@Author : 小林同学
|
||||
@Author : Viccy同学
|
||||
@Date : 2025/1/7
|
||||
@Description: 音频配置管理
|
||||
'''
|
||||
|
||||
@ -74,24 +74,27 @@ plot_writing = """
|
||||
%s
|
||||
</plot>
|
||||
|
||||
请使用 json 格式进行输出;使用 <output> 中的输出格式:
|
||||
<output>
|
||||
请严格按照以下JSON格式输出,不要添加任何其他文字、说明或代码块标记:
|
||||
|
||||
{
|
||||
"items": [
|
||||
{
|
||||
"_id": 1, # 唯一递增id
|
||||
"_id": 1,
|
||||
"timestamp": "00:00:05,390-00:00:10,430",
|
||||
"picture": "剧情描述或者备注",
|
||||
"narration": "解说文案,如果片段为穿插的原片片段,可以直接使用 ‘播放原片+_id‘ 进行占位",
|
||||
"OST": "值为 0 表示当前片段为解说片段,值为 1 表示当前片段为穿插的原片"
|
||||
"OST": 0
|
||||
}
|
||||
]
|
||||
}
|
||||
</output>
|
||||
|
||||
<restriction>
|
||||
1. 只输出 json 内容,不要输出其他任何说明性的文字
|
||||
2. 解说文案的语言使用 简体中文
|
||||
3. 严禁虚构剧情,所有画面只能从 <polt> 中摘取
|
||||
4. 严禁虚构时间戳,所有时间戳范围只能从 <polt> 中摘取
|
||||
</restriction>
|
||||
重要要求:
|
||||
1. 必须输出有效的JSON格式,不能包含注释
|
||||
2. OST字段必须是数字:0表示解说片段,1表示原片片段
|
||||
3. _id必须是递增的数字
|
||||
4. 只输出JSON内容,不要输出任何说明文字
|
||||
5. 不要使用代码块标记(如```json)
|
||||
6. 解说文案使用简体中文
|
||||
7. 严禁虚构剧情,所有内容只能从<plot>中摘取
|
||||
8. 严禁虚构时间戳,所有时间戳只能从<plot>中摘取
|
||||
"""
|
||||
@ -28,6 +28,7 @@ class SubtitleAnalyzer:
|
||||
base_url: Optional[str] = None,
|
||||
custom_prompt: Optional[str] = None,
|
||||
temperature: Optional[float] = 1.0,
|
||||
provider: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
初始化字幕分析器
|
||||
@ -38,19 +39,28 @@ class SubtitleAnalyzer:
|
||||
base_url: API基础URL,如果不提供则从配置中读取或使用默认值
|
||||
custom_prompt: 自定义提示词,如果不提供则使用默认值
|
||||
temperature: 模型温度
|
||||
provider: 提供商类型,用于确定API调用格式
|
||||
"""
|
||||
# 使用传入的参数或从配置中获取
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
self.temperature = temperature
|
||||
self.provider = provider or self._detect_provider()
|
||||
|
||||
# 设置提示词模板
|
||||
self.prompt_template = custom_prompt or subtitle_plot_analysis_v1
|
||||
|
||||
# 根据提供商类型确定是否为原生Gemini
|
||||
self.is_native_gemini = self.provider.lower() == 'gemini'
|
||||
|
||||
# 初始化HTTP请求所需的头信息
|
||||
self._init_headers()
|
||||
|
||||
def _detect_provider(self):
|
||||
"""根据配置自动检测提供商类型"""
|
||||
return config.app.get('text_llm_provider', 'gemini').lower()
|
||||
|
||||
def _init_headers(self):
|
||||
"""初始化HTTP请求头"""
|
||||
try:
|
||||
@ -78,7 +88,141 @@ class SubtitleAnalyzer:
|
||||
# 构建完整提示词
|
||||
prompt = f"{self.prompt_template}\n\n{subtitle_content}"
|
||||
|
||||
# 构建请求体数据
|
||||
if self.is_native_gemini:
|
||||
# 使用原生Gemini API格式
|
||||
return self._call_native_gemini_api(prompt)
|
||||
else:
|
||||
# 使用OpenAI兼容格式
|
||||
return self._call_openai_compatible_api(prompt)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"字幕分析过程中发生错误: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e),
|
||||
"temperature": self.temperature
|
||||
}
|
||||
|
||||
def _call_native_gemini_api(self, prompt: str) -> Dict[str, Any]:
|
||||
"""调用原生Gemini API"""
|
||||
try:
|
||||
# 构建原生Gemini API请求数据
|
||||
payload = {
|
||||
"systemInstruction": {
|
||||
"parts": [{"text": "你是一位专业的剧本分析师和剧情概括助手。请严格按照要求的格式输出分析结果。"}]
|
||||
},
|
||||
"contents": [{
|
||||
"parts": [{"text": prompt}]
|
||||
}],
|
||||
"generationConfig": {
|
||||
"temperature": self.temperature,
|
||||
"topK": 40,
|
||||
"topP": 0.95,
|
||||
"maxOutputTokens": 4000,
|
||||
"candidateCount": 1
|
||||
},
|
||||
"safetySettings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# 构建请求URL
|
||||
url = f"{self.base_url}/models/{self.model}:generateContent?key={self.api_key}"
|
||||
|
||||
# 发送请求
|
||||
response = requests.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json", "User-Agent": "NarratoAI/1.0"},
|
||||
timeout=120
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
|
||||
# 检查响应格式
|
||||
if "candidates" not in response_data or not response_data["candidates"]:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "原生Gemini API返回无效响应,可能触发了安全过滤",
|
||||
"temperature": self.temperature
|
||||
}
|
||||
|
||||
candidate = response_data["candidates"][0]
|
||||
|
||||
# 检查是否被安全过滤阻止
|
||||
if "finishReason" in candidate and candidate["finishReason"] == "SAFETY":
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "内容被Gemini安全过滤器阻止",
|
||||
"temperature": self.temperature
|
||||
}
|
||||
|
||||
if "content" not in candidate or "parts" not in candidate["content"]:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "原生Gemini API返回内容格式错误",
|
||||
"temperature": self.temperature
|
||||
}
|
||||
|
||||
# 提取文本内容
|
||||
analysis_result = ""
|
||||
for part in candidate["content"]["parts"]:
|
||||
if "text" in part:
|
||||
analysis_result += part["text"]
|
||||
|
||||
if not analysis_result.strip():
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "原生Gemini API返回空内容",
|
||||
"temperature": self.temperature
|
||||
}
|
||||
|
||||
logger.debug(f"原生Gemini字幕分析完成")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"analysis": analysis_result,
|
||||
"tokens_used": response_data.get("usage", {}).get("total_tokens", 0),
|
||||
"model": self.model,
|
||||
"temperature": self.temperature
|
||||
}
|
||||
else:
|
||||
error_msg = f"原生Gemini API请求失败,状态码: {response.status_code}, 响应: {response.text}"
|
||||
logger.error(error_msg)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": error_msg,
|
||||
"temperature": self.temperature
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"原生Gemini API调用失败: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"原生Gemini API调用失败: {str(e)}",
|
||||
"temperature": self.temperature
|
||||
}
|
||||
|
||||
def _call_openai_compatible_api(self, prompt: str) -> Dict[str, Any]:
|
||||
"""调用OpenAI兼容的API"""
|
||||
try:
|
||||
# 构建OpenAI格式的请求数据
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
@ -92,7 +236,7 @@ class SubtitleAnalyzer:
|
||||
url = f"{self.base_url}/chat/completions"
|
||||
|
||||
# 发送HTTP请求
|
||||
response = requests.post(url, headers=self.headers, json=payload)
|
||||
response = requests.post(url, headers=self.headers, json=payload, timeout=120)
|
||||
|
||||
# 解析响应
|
||||
if response.status_code == 200:
|
||||
@ -101,7 +245,7 @@ class SubtitleAnalyzer:
|
||||
# 提取响应内容
|
||||
if "choices" in response_data and len(response_data["choices"]) > 0:
|
||||
analysis_result = response_data["choices"][0]["message"]["content"]
|
||||
logger.debug(f"字幕分析完成,消耗的tokens: {response_data.get('usage', {}).get('total_tokens', 0)}")
|
||||
logger.debug(f"OpenAI兼容API字幕分析完成,消耗的tokens: {response_data.get('usage', {}).get('total_tokens', 0)}")
|
||||
|
||||
# 返回结果
|
||||
return {
|
||||
@ -112,14 +256,14 @@ class SubtitleAnalyzer:
|
||||
"temperature": self.temperature
|
||||
}
|
||||
else:
|
||||
logger.error("字幕分析失败: 未获取到有效响应")
|
||||
logger.error("OpenAI兼容API字幕分析失败: 未获取到有效响应")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "未获取到有效响应",
|
||||
"temperature": self.temperature
|
||||
}
|
||||
else:
|
||||
error_msg = f"请求失败,状态码: {response.status_code}, 响应: {response.text}"
|
||||
error_msg = f"OpenAI兼容API请求失败,状态码: {response.status_code}, 响应: {response.text}"
|
||||
logger.error(error_msg)
|
||||
return {
|
||||
"status": "error",
|
||||
@ -128,10 +272,10 @@ class SubtitleAnalyzer:
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"字幕分析过程中发生错误: {str(e)}")
|
||||
logger.error(f"OpenAI兼容API调用失败: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e),
|
||||
"message": f"OpenAI兼容API调用失败: {str(e)}",
|
||||
"temperature": self.temperature
|
||||
}
|
||||
|
||||
@ -219,7 +363,145 @@ class SubtitleAnalyzer:
|
||||
# 构建完整提示词
|
||||
prompt = plot_writing % (short_name, plot_analysis)
|
||||
|
||||
# 构建请求体数据
|
||||
if self.is_native_gemini:
|
||||
# 使用原生Gemini API格式
|
||||
return self._generate_narration_with_native_gemini(prompt, temperature)
|
||||
else:
|
||||
# 使用OpenAI兼容格式
|
||||
return self._generate_narration_with_openai_compatible(prompt, temperature)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解说文案生成过程中发生错误: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e),
|
||||
"temperature": self.temperature
|
||||
}
|
||||
|
||||
def _generate_narration_with_native_gemini(self, prompt: str, temperature: float) -> Dict[str, Any]:
|
||||
"""使用原生Gemini API生成解说文案"""
|
||||
try:
|
||||
# 构建原生Gemini API请求数据
|
||||
# 为了确保JSON输出,在提示词中添加更强的约束
|
||||
enhanced_prompt = f"{prompt}\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。"
|
||||
|
||||
payload = {
|
||||
"systemInstruction": {
|
||||
"parts": [{"text": "你是一位专业的短视频解说脚本撰写专家。你必须严格按照JSON格式输出,不能包含任何其他文字、说明或代码块标记。"}]
|
||||
},
|
||||
"contents": [{
|
||||
"parts": [{"text": enhanced_prompt}]
|
||||
}],
|
||||
"generationConfig": {
|
||||
"temperature": temperature,
|
||||
"topK": 40,
|
||||
"topP": 0.95,
|
||||
"maxOutputTokens": 4000,
|
||||
"candidateCount": 1,
|
||||
"stopSequences": ["```", "注意", "说明"]
|
||||
},
|
||||
"safetySettings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# 构建请求URL
|
||||
url = f"{self.base_url}/models/{self.model}:generateContent?key={self.api_key}"
|
||||
|
||||
# 发送请求
|
||||
response = requests.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json", "User-Agent": "NarratoAI/1.0"},
|
||||
timeout=120
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
|
||||
# 检查响应格式
|
||||
if "candidates" not in response_data or not response_data["candidates"]:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "原生Gemini API返回无效响应,可能触发了安全过滤",
|
||||
"temperature": temperature
|
||||
}
|
||||
|
||||
candidate = response_data["candidates"][0]
|
||||
|
||||
# 检查是否被安全过滤阻止
|
||||
if "finishReason" in candidate and candidate["finishReason"] == "SAFETY":
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "内容被Gemini安全过滤器阻止",
|
||||
"temperature": temperature
|
||||
}
|
||||
|
||||
if "content" not in candidate or "parts" not in candidate["content"]:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "原生Gemini API返回内容格式错误",
|
||||
"temperature": temperature
|
||||
}
|
||||
|
||||
# 提取文本内容
|
||||
narration_script = ""
|
||||
for part in candidate["content"]["parts"]:
|
||||
if "text" in part:
|
||||
narration_script += part["text"]
|
||||
|
||||
if not narration_script.strip():
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "原生Gemini API返回空内容",
|
||||
"temperature": temperature
|
||||
}
|
||||
|
||||
logger.debug(f"原生Gemini解说文案生成完成")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"narration_script": narration_script,
|
||||
"tokens_used": response_data.get("usage", {}).get("total_tokens", 0),
|
||||
"model": self.model,
|
||||
"temperature": temperature
|
||||
}
|
||||
else:
|
||||
error_msg = f"原生Gemini API请求失败,状态码: {response.status_code}, 响应: {response.text}"
|
||||
logger.error(error_msg)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": error_msg,
|
||||
"temperature": temperature
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"原生Gemini API解说文案生成失败: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"原生Gemini API解说文案生成失败: {str(e)}",
|
||||
"temperature": temperature
|
||||
}
|
||||
|
||||
def _generate_narration_with_openai_compatible(self, prompt: str, temperature: float) -> Dict[str, Any]:
|
||||
"""使用OpenAI兼容API生成解说文案"""
|
||||
try:
|
||||
# 构建OpenAI格式的请求数据
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
@ -237,7 +519,7 @@ class SubtitleAnalyzer:
|
||||
url = f"{self.base_url}/chat/completions"
|
||||
|
||||
# 发送HTTP请求
|
||||
response = requests.post(url, headers=self.headers, json=payload)
|
||||
response = requests.post(url, headers=self.headers, json=payload, timeout=120)
|
||||
|
||||
# 解析响应
|
||||
if response.status_code == 200:
|
||||
@ -246,7 +528,7 @@ class SubtitleAnalyzer:
|
||||
# 提取响应内容
|
||||
if "choices" in response_data and len(response_data["choices"]) > 0:
|
||||
narration_script = response_data["choices"][0]["message"]["content"]
|
||||
logger.debug(f"解说文案生成完成,消耗的tokens: {response_data.get('usage', {}).get('total_tokens', 0)}")
|
||||
logger.debug(f"OpenAI兼容API解说文案生成完成,消耗的tokens: {response_data.get('usage', {}).get('total_tokens', 0)}")
|
||||
|
||||
# 返回结果
|
||||
return {
|
||||
@ -254,30 +536,30 @@ class SubtitleAnalyzer:
|
||||
"narration_script": narration_script,
|
||||
"tokens_used": response_data.get("usage", {}).get("total_tokens", 0),
|
||||
"model": self.model,
|
||||
"temperature": self.temperature
|
||||
"temperature": temperature
|
||||
}
|
||||
else:
|
||||
logger.error("解说文案生成失败: 未获取到有效响应")
|
||||
logger.error("OpenAI兼容API解说文案生成失败: 未获取到有效响应")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "未获取到有效响应",
|
||||
"temperature": self.temperature
|
||||
"temperature": temperature
|
||||
}
|
||||
else:
|
||||
error_msg = f"请求失败,状态码: {response.status_code}, 响应: {response.text}"
|
||||
error_msg = f"OpenAI兼容API请求失败,状态码: {response.status_code}, 响应: {response.text}"
|
||||
logger.error(error_msg)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": error_msg,
|
||||
"temperature": self.temperature
|
||||
"temperature": temperature
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解说文案生成过程中发生错误: {str(e)}")
|
||||
logger.error(f"OpenAI兼容API解说文案生成失败: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e),
|
||||
"temperature": self.temperature
|
||||
"message": f"OpenAI兼容API解说文案生成失败: {str(e)}",
|
||||
"temperature": temperature
|
||||
}
|
||||
|
||||
def save_narration_script(self, narration_result: Dict[str, Any], output_path: Optional[str] = None) -> str:
|
||||
@ -324,7 +606,8 @@ def analyze_subtitle(
|
||||
custom_prompt: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
save_result: bool = False,
|
||||
output_path: Optional[str] = None
|
||||
output_path: Optional[str] = None,
|
||||
provider: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
分析字幕内容的便捷函数
|
||||
@ -339,6 +622,7 @@ def analyze_subtitle(
|
||||
temperature: 模型温度
|
||||
save_result: 是否保存结果到文件
|
||||
output_path: 输出文件路径
|
||||
provider: 提供商类型
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 包含分析结果的字典
|
||||
@ -349,7 +633,8 @@ def analyze_subtitle(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url,
|
||||
custom_prompt=custom_prompt
|
||||
custom_prompt=custom_prompt,
|
||||
provider=provider
|
||||
)
|
||||
logger.debug(f"使用模型: {analyzer.model} 开始分析, 温度: {analyzer.temperature}")
|
||||
# 分析字幕
|
||||
@ -379,7 +664,8 @@ def generate_narration_script(
|
||||
base_url: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
save_result: bool = False,
|
||||
output_path: Optional[str] = None
|
||||
output_path: Optional[str] = None,
|
||||
provider: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
根据剧情分析生成解说文案的便捷函数
|
||||
@ -393,6 +679,7 @@ def generate_narration_script(
|
||||
temperature: 生成温度,控制创造性
|
||||
save_result: 是否保存结果到文件
|
||||
output_path: 输出文件路径
|
||||
provider: 提供商类型
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 包含生成结果的字典
|
||||
@ -402,7 +689,8 @@ def generate_narration_script(
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url
|
||||
base_url=base_url,
|
||||
provider=provider
|
||||
)
|
||||
|
||||
# 生成解说文案
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
'''
|
||||
@Project: NarratoAI
|
||||
@File : audio_normalizer
|
||||
@Author : 小林同学
|
||||
@Author : Viccy同学
|
||||
@Date : 2025/1/7
|
||||
@Description: 音频响度分析和标准化工具
|
||||
'''
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
'''
|
||||
@Project: NarratoAI
|
||||
@File : clip_video
|
||||
@Author : 小林同学
|
||||
@Author : Viccy同学
|
||||
@Date : 2025/5/6 下午6:14
|
||||
'''
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
'''
|
||||
@Project: NarratoAI
|
||||
@File : 生成介绍文案
|
||||
@Author : 小林同学
|
||||
@Author : Viccy同学
|
||||
@Date : 2025/5/8 上午11:33
|
||||
'''
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
'''
|
||||
@Project: NarratoAI
|
||||
@File : generate_video
|
||||
@Author : 小林同学
|
||||
@Author : Viccy同学
|
||||
@Date : 2025/5/7 上午11:55
|
||||
'''
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
'''
|
||||
@Project: NarratoAI
|
||||
@File : merger_video
|
||||
@Author : 小林同学
|
||||
@Author : Viccy同学
|
||||
@Date : 2025/5/6 下午7:38
|
||||
'''
|
||||
|
||||
|
||||
@ -140,14 +140,27 @@ class ScriptGenerator:
|
||||
# 获取Gemini配置
|
||||
vision_api_key = config.app.get("vision_gemini_api_key")
|
||||
vision_model = config.app.get("vision_gemini_model_name")
|
||||
vision_base_url = config.app.get("vision_gemini_base_url")
|
||||
|
||||
if not vision_api_key or not vision_model:
|
||||
raise ValueError("未配置 Gemini API Key 或者模型")
|
||||
|
||||
analyzer = gemini_analyzer.VisionAnalyzer(
|
||||
model_name=vision_model,
|
||||
api_key=vision_api_key,
|
||||
)
|
||||
# 根据提供商类型选择合适的分析器
|
||||
if vision_provider == 'gemini(openai)':
|
||||
# 使用OpenAI兼容的Gemini代理
|
||||
from app.utils.gemini_openai_analyzer import GeminiOpenAIAnalyzer
|
||||
analyzer = GeminiOpenAIAnalyzer(
|
||||
model_name=vision_model,
|
||||
api_key=vision_api_key,
|
||||
base_url=vision_base_url
|
||||
)
|
||||
else:
|
||||
# 使用原生Gemini分析器
|
||||
analyzer = gemini_analyzer.VisionAnalyzer(
|
||||
model_name=vision_model,
|
||||
api_key=vision_api_key,
|
||||
base_url=vision_base_url
|
||||
)
|
||||
|
||||
progress_callback(40, "正在分析关键帧...")
|
||||
|
||||
@ -213,13 +226,35 @@ class ScriptGenerator:
|
||||
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
|
||||
text_api_key = config.app.get(f'text_{text_provider}_api_key')
|
||||
text_model = config.app.get(f'text_{text_provider}_model_name')
|
||||
text_base_url = config.app.get(f'text_{text_provider}_base_url')
|
||||
|
||||
processor = ScriptProcessor(
|
||||
model_name=text_model,
|
||||
api_key=text_api_key,
|
||||
prompt=custom_prompt,
|
||||
video_theme=video_theme
|
||||
)
|
||||
# 根据提供商类型选择合适的处理器
|
||||
if text_provider == 'gemini(openai)':
|
||||
# 使用OpenAI兼容的Gemini代理
|
||||
from app.utils.script_generator import GeminiOpenAIGenerator
|
||||
generator = GeminiOpenAIGenerator(
|
||||
model_name=text_model,
|
||||
api_key=text_api_key,
|
||||
prompt=custom_prompt,
|
||||
base_url=text_base_url
|
||||
)
|
||||
processor = ScriptProcessor(
|
||||
model_name=text_model,
|
||||
api_key=text_api_key,
|
||||
base_url=text_base_url,
|
||||
prompt=custom_prompt,
|
||||
video_theme=video_theme
|
||||
)
|
||||
processor.generator = generator
|
||||
else:
|
||||
# 使用标准处理器(包括原生Gemini)
|
||||
processor = ScriptProcessor(
|
||||
model_name=text_model,
|
||||
api_key=text_api_key,
|
||||
base_url=text_base_url,
|
||||
prompt=custom_prompt,
|
||||
video_theme=video_theme
|
||||
)
|
||||
|
||||
return processor.process_frames(frame_content_list)
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
'''
|
||||
@Project: NarratoAI
|
||||
@File : update_script
|
||||
@Author : 小林同学
|
||||
@Author : Viccy同学
|
||||
@Date : 2025/5/6 下午11:00
|
||||
'''
|
||||
|
||||
|
||||
@ -5,53 +5,162 @@ from pathlib import Path
|
||||
from loguru import logger
|
||||
from tqdm import tqdm
|
||||
import asyncio
|
||||
from tenacity import retry, stop_after_attempt, RetryError, retry_if_exception_type, wait_exponential
|
||||
from google.api_core import exceptions
|
||||
import google.generativeai as genai
|
||||
from tenacity import retry, stop_after_attempt, retry_if_exception_type, wait_exponential
|
||||
import requests
|
||||
import PIL.Image
|
||||
import traceback
|
||||
import base64
|
||||
import io
|
||||
from app.utils import utils
|
||||
|
||||
|
||||
class VisionAnalyzer:
|
||||
"""视觉分析器类"""
|
||||
"""原生Gemini视觉分析器类"""
|
||||
|
||||
def __init__(self, model_name: str = "gemini-1.5-flash", api_key: str = None):
|
||||
def __init__(self, model_name: str = "gemini-2.0-flash-exp", api_key: str = None, base_url: str = None):
|
||||
"""初始化视觉分析器"""
|
||||
if not api_key:
|
||||
raise ValueError("必须提供API密钥")
|
||||
|
||||
self.model_name = model_name
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url or "https://generativelanguage.googleapis.com/v1beta"
|
||||
|
||||
# 初始化配置
|
||||
self._configure_client()
|
||||
|
||||
def _configure_client(self):
|
||||
"""配置API客户端"""
|
||||
genai.configure(api_key=self.api_key)
|
||||
# 开放 Gemini 模型安全设置
|
||||
from google.generativeai.types import HarmCategory, HarmBlockThreshold
|
||||
safety_settings = {
|
||||
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
||||
}
|
||||
self.model = genai.GenerativeModel(self.model_name, safety_settings=safety_settings)
|
||||
"""配置原生Gemini API客户端"""
|
||||
# 使用原生Gemini REST API
|
||||
self.client = None
|
||||
logger.info(f"配置原生Gemini API,端点: {self.base_url}, 模型: {self.model_name}")
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(exceptions.ResourceExhausted)
|
||||
retry=retry_if_exception_type(requests.exceptions.RequestException)
|
||||
)
|
||||
async def _generate_content_with_retry(self, prompt, batch):
|
||||
"""使用重试机制的内部方法来调用 generate_content_async"""
|
||||
"""使用重试机制调用原生Gemini API"""
|
||||
try:
|
||||
return await self.model.generate_content_async([prompt, *batch])
|
||||
except exceptions.ResourceExhausted as e:
|
||||
print(f"API配额限制: {str(e)}")
|
||||
raise RetryError("API调用失败")
|
||||
return await self._generate_with_gemini_api(prompt, batch)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.warning(f"Gemini API请求异常: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Gemini API生成内容时发生错误: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _generate_with_gemini_api(self, prompt, batch):
|
||||
"""使用原生Gemini REST API生成内容"""
|
||||
# 将PIL图片转换为base64编码
|
||||
image_parts = []
|
||||
for img in batch:
|
||||
# 将PIL图片转换为字节流
|
||||
img_buffer = io.BytesIO()
|
||||
img.save(img_buffer, format='JPEG', quality=85) # 优化图片质量
|
||||
img_bytes = img_buffer.getvalue()
|
||||
|
||||
# 转换为base64
|
||||
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
|
||||
image_parts.append({
|
||||
"inline_data": {
|
||||
"mime_type": "image/jpeg",
|
||||
"data": img_base64
|
||||
}
|
||||
})
|
||||
|
||||
# 构建符合官方文档的请求数据
|
||||
request_data = {
|
||||
"contents": [{
|
||||
"parts": [
|
||||
{"text": prompt},
|
||||
*image_parts
|
||||
]
|
||||
}],
|
||||
"generationConfig": {
|
||||
"temperature": 1.0,
|
||||
"topK": 40,
|
||||
"topP": 0.95,
|
||||
"maxOutputTokens": 8192,
|
||||
"candidateCount": 1,
|
||||
"stopSequences": []
|
||||
},
|
||||
"safetySettings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# 构建请求URL
|
||||
url = f"{self.base_url}/models/{self.model_name}:generateContent?key={self.api_key}"
|
||||
|
||||
# 发送请求
|
||||
response = await asyncio.to_thread(
|
||||
requests.post,
|
||||
url,
|
||||
json=request_data,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "NarratoAI/1.0"
|
||||
},
|
||||
timeout=120 # 增加超时时间
|
||||
)
|
||||
|
||||
# 处理HTTP错误
|
||||
if response.status_code == 429:
|
||||
raise requests.exceptions.RequestException(f"API配额限制: {response.text}")
|
||||
elif response.status_code == 400:
|
||||
raise Exception(f"请求参数错误: {response.text}")
|
||||
elif response.status_code == 403:
|
||||
raise Exception(f"API密钥无效或权限不足: {response.text}")
|
||||
elif response.status_code != 200:
|
||||
raise Exception(f"Gemini API请求失败: {response.status_code} - {response.text}")
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
# 检查响应格式
|
||||
if "candidates" not in response_data or not response_data["candidates"]:
|
||||
raise Exception("Gemini API返回无效响应,可能触发了安全过滤")
|
||||
|
||||
candidate = response_data["candidates"][0]
|
||||
|
||||
# 检查是否被安全过滤阻止
|
||||
if "finishReason" in candidate and candidate["finishReason"] == "SAFETY":
|
||||
raise Exception("内容被Gemini安全过滤器阻止")
|
||||
|
||||
if "content" not in candidate or "parts" not in candidate["content"]:
|
||||
raise Exception("Gemini API返回内容格式错误")
|
||||
|
||||
# 提取文本内容
|
||||
text_content = ""
|
||||
for part in candidate["content"]["parts"]:
|
||||
if "text" in part:
|
||||
text_content += part["text"]
|
||||
|
||||
if not text_content.strip():
|
||||
raise Exception("Gemini API返回空内容")
|
||||
|
||||
# 创建兼容的响应对象
|
||||
class CompatibleResponse:
|
||||
def __init__(self, text):
|
||||
self.text = text
|
||||
|
||||
return CompatibleResponse(text_content)
|
||||
|
||||
async def analyze_images(self,
|
||||
images: Union[List[str], List[PIL.Image.Image]],
|
||||
|
||||
177
app/utils/gemini_openai_analyzer.py
Normal file
177
app/utils/gemini_openai_analyzer.py
Normal file
@ -0,0 +1,177 @@
|
||||
"""
|
||||
OpenAI兼容的Gemini视觉分析器
|
||||
使用标准OpenAI格式调用Gemini代理服务
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import List, Union, Dict
|
||||
import os
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from tqdm import tqdm
|
||||
import asyncio
|
||||
from tenacity import retry, stop_after_attempt, retry_if_exception_type, wait_exponential
|
||||
import requests
|
||||
import PIL.Image
|
||||
import traceback
|
||||
import base64
|
||||
import io
|
||||
from app.utils import utils
|
||||
|
||||
|
||||
class GeminiOpenAIAnalyzer:
|
||||
"""OpenAI兼容的Gemini视觉分析器类"""
|
||||
|
||||
def __init__(self, model_name: str = "gemini-2.0-flash-exp", api_key: str = None, base_url: str = None):
|
||||
"""初始化OpenAI兼容的Gemini分析器"""
|
||||
if not api_key:
|
||||
raise ValueError("必须提供API密钥")
|
||||
|
||||
if not base_url:
|
||||
raise ValueError("必须提供OpenAI兼容的代理端点URL")
|
||||
|
||||
self.model_name = model_name
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url.rstrip('/')
|
||||
|
||||
# 初始化OpenAI客户端
|
||||
self._configure_client()
|
||||
|
||||
def _configure_client(self):
|
||||
"""配置OpenAI兼容的客户端"""
|
||||
from openai import OpenAI
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
logger.info(f"配置OpenAI兼容Gemini代理,端点: {self.base_url}, 模型: {self.model_name}")
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((requests.exceptions.RequestException, Exception))
|
||||
)
|
||||
async def _generate_content_with_retry(self, prompt, batch):
|
||||
"""使用重试机制调用OpenAI兼容的Gemini代理"""
|
||||
try:
|
||||
return await self._generate_with_openai_api(prompt, batch)
|
||||
except Exception as e:
|
||||
logger.warning(f"OpenAI兼容Gemini代理请求异常: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _generate_with_openai_api(self, prompt, batch):
|
||||
"""使用OpenAI兼容接口生成内容"""
|
||||
# 将PIL图片转换为base64编码
|
||||
image_contents = []
|
||||
for img in batch:
|
||||
# 将PIL图片转换为字节流
|
||||
img_buffer = io.BytesIO()
|
||||
img.save(img_buffer, format='JPEG', quality=85)
|
||||
img_bytes = img_buffer.getvalue()
|
||||
|
||||
# 转换为base64
|
||||
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
|
||||
image_contents.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{img_base64}"
|
||||
}
|
||||
})
|
||||
|
||||
# 构建OpenAI格式的消息
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
*image_contents
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# 调用OpenAI兼容接口
|
||||
response = await asyncio.to_thread(
|
||||
self.client.chat.completions.create,
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
max_tokens=4000,
|
||||
temperature=1.0
|
||||
)
|
||||
|
||||
# 创建兼容的响应对象
|
||||
class CompatibleResponse:
|
||||
def __init__(self, text):
|
||||
self.text = text
|
||||
|
||||
return CompatibleResponse(response.choices[0].message.content)
|
||||
|
||||
async def analyze_images(self,
|
||||
images: List[Union[str, Path, PIL.Image.Image]],
|
||||
prompt: str,
|
||||
batch_size: int = 10) -> List[str]:
|
||||
"""
|
||||
分析图片并返回结果
|
||||
|
||||
Args:
|
||||
images: 图片路径列表或PIL图片对象列表
|
||||
prompt: 分析提示词
|
||||
batch_size: 批处理大小
|
||||
|
||||
Returns:
|
||||
分析结果列表
|
||||
"""
|
||||
logger.info(f"开始分析 {len(images)} 张图片,使用OpenAI兼容Gemini代理")
|
||||
|
||||
# 加载图片
|
||||
loaded_images = []
|
||||
for img in images:
|
||||
if isinstance(img, (str, Path)):
|
||||
try:
|
||||
pil_img = PIL.Image.open(img)
|
||||
# 调整图片大小以优化性能
|
||||
if pil_img.size[0] > 1024 or pil_img.size[1] > 1024:
|
||||
pil_img.thumbnail((1024, 1024), PIL.Image.Resampling.LANCZOS)
|
||||
loaded_images.append(pil_img)
|
||||
except Exception as e:
|
||||
logger.error(f"加载图片失败 {img}: {str(e)}")
|
||||
continue
|
||||
elif isinstance(img, PIL.Image.Image):
|
||||
loaded_images.append(img)
|
||||
else:
|
||||
logger.warning(f"不支持的图片类型: {type(img)}")
|
||||
continue
|
||||
|
||||
if not loaded_images:
|
||||
raise ValueError("没有有效的图片可以分析")
|
||||
|
||||
# 分批处理
|
||||
results = []
|
||||
total_batches = (len(loaded_images) + batch_size - 1) // batch_size
|
||||
|
||||
for i in tqdm(range(0, len(loaded_images), batch_size),
|
||||
desc="分析图片批次", total=total_batches):
|
||||
batch = loaded_images[i:i + batch_size]
|
||||
|
||||
try:
|
||||
response = await self._generate_content_with_retry(prompt, batch)
|
||||
results.append(response.text)
|
||||
|
||||
# 添加延迟以避免API限流
|
||||
if i + batch_size < len(loaded_images):
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"分析批次 {i//batch_size + 1} 失败: {str(e)}")
|
||||
results.append(f"分析失败: {str(e)}")
|
||||
|
||||
logger.info(f"完成图片分析,共处理 {len(results)} 个批次")
|
||||
return results
|
||||
|
||||
def analyze_images_sync(self,
|
||||
images: List[Union[str, Path, PIL.Image.Image]],
|
||||
prompt: str,
|
||||
batch_size: int = 10) -> List[str]:
|
||||
"""
|
||||
同步版本的图片分析方法
|
||||
"""
|
||||
return asyncio.run(self.analyze_images(images, prompt, batch_size))
|
||||
@ -6,7 +6,7 @@ from loguru import logger
|
||||
from typing import List, Dict
|
||||
from datetime import datetime
|
||||
from openai import OpenAI
|
||||
import google.generativeai as genai
|
||||
import requests
|
||||
import time
|
||||
|
||||
|
||||
@ -134,59 +134,182 @@ class OpenAIGenerator(BaseGenerator):
|
||||
|
||||
|
||||
class GeminiGenerator(BaseGenerator):
|
||||
"""Google Gemini API 生成器实现"""
|
||||
def __init__(self, model_name: str, api_key: str, prompt: str):
|
||||
"""原生Gemini API 生成器实现"""
|
||||
def __init__(self, model_name: str, api_key: str, prompt: str, base_url: str = None):
|
||||
super().__init__(model_name, api_key, prompt)
|
||||
genai.configure(api_key=api_key)
|
||||
self.model = genai.GenerativeModel(model_name)
|
||||
|
||||
# Gemini特定参数
|
||||
self.base_url = base_url or "https://generativelanguage.googleapis.com/v1beta"
|
||||
self.client = None
|
||||
|
||||
# 原生Gemini API参数
|
||||
self.default_params = {
|
||||
"temperature": self.default_params["temperature"],
|
||||
"top_p": self.default_params["top_p"],
|
||||
"candidate_count": 1,
|
||||
"stop_sequences": None
|
||||
"topP": self.default_params["top_p"],
|
||||
"topK": 40,
|
||||
"maxOutputTokens": 4000,
|
||||
"candidateCount": 1,
|
||||
"stopSequences": []
|
||||
}
|
||||
|
||||
|
||||
class GeminiOpenAIGenerator(BaseGenerator):
|
||||
"""OpenAI兼容的Gemini代理生成器实现"""
|
||||
def __init__(self, model_name: str, api_key: str, prompt: str, base_url: str = None):
|
||||
super().__init__(model_name, api_key, prompt)
|
||||
|
||||
if not base_url:
|
||||
raise ValueError("OpenAI兼容的Gemini代理必须提供base_url")
|
||||
|
||||
self.base_url = base_url.rstrip('/')
|
||||
|
||||
# 使用OpenAI兼容接口
|
||||
from openai import OpenAI
|
||||
self.client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url
|
||||
)
|
||||
|
||||
# OpenAI兼容接口参数
|
||||
self.default_params = {
|
||||
"temperature": self.default_params["temperature"],
|
||||
"max_tokens": 4000,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
def _generate(self, messages: list, params: dict) -> any:
|
||||
"""实现Gemini特定的生成逻辑"""
|
||||
while True:
|
||||
"""实现OpenAI兼容Gemini代理的生成逻辑"""
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
**params
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI兼容Gemini代理生成错误: {str(e)}")
|
||||
raise
|
||||
|
||||
def _process_response(self, response: any) -> str:
|
||||
"""处理OpenAI兼容接口的响应"""
|
||||
if not response or not response.choices:
|
||||
raise ValueError("OpenAI兼容Gemini代理返回无效响应")
|
||||
return response.choices[0].message.content.strip()
|
||||
|
||||
def _generate(self, messages: list, params: dict) -> any:
|
||||
"""实现原生Gemini API的生成逻辑"""
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# 转换消息格式为Gemini格式
|
||||
prompt = "\n".join([m["content"] for m in messages])
|
||||
response = self.model.generate_content(
|
||||
prompt,
|
||||
generation_config=params
|
||||
|
||||
# 构建请求数据
|
||||
request_data = {
|
||||
"contents": [{
|
||||
"parts": [{"text": prompt}]
|
||||
}],
|
||||
"generationConfig": params,
|
||||
"safetySettings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# 构建请求URL
|
||||
url = f"{self.base_url}/models/{self.model_name}:generateContent?key={self.api_key}"
|
||||
|
||||
# 发送请求
|
||||
response = requests.post(
|
||||
url,
|
||||
json=request_data,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "NarratoAI/1.0"
|
||||
},
|
||||
timeout=120
|
||||
)
|
||||
|
||||
# 检查响应是否包含有效内容
|
||||
if (hasattr(response, 'result') and
|
||||
hasattr(response.result, 'candidates') and
|
||||
response.result.candidates):
|
||||
if response.status_code == 429:
|
||||
# 处理限流
|
||||
wait_time = 65 if attempt == 0 else 30
|
||||
logger.warning(f"原生Gemini API 触发限流,等待{wait_time}秒后重试...")
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
|
||||
candidate = response.result.candidates[0]
|
||||
if response.status_code == 400:
|
||||
raise Exception(f"请求参数错误: {response.text}")
|
||||
elif response.status_code == 403:
|
||||
raise Exception(f"API密钥无效或权限不足: {response.text}")
|
||||
elif response.status_code != 200:
|
||||
raise Exception(f"原生Gemini API请求失败: {response.status_code} - {response.text}")
|
||||
|
||||
# 检查是否有内容字段
|
||||
if not hasattr(candidate, 'content'):
|
||||
logger.warning("Gemini API 返回速率限制响应,等待30秒后重试...")
|
||||
time.sleep(30) # 等待3秒后重试
|
||||
response_data = response.json()
|
||||
|
||||
# 检查响应格式
|
||||
if "candidates" not in response_data or not response_data["candidates"]:
|
||||
if attempt < max_retries - 1:
|
||||
logger.warning("原生Gemini API 返回无效响应,等待30秒后重试...")
|
||||
time.sleep(30)
|
||||
continue
|
||||
return response
|
||||
else:
|
||||
raise Exception("原生Gemini API返回无效响应,可能触发了安全过滤")
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e)
|
||||
if "429" in error_str:
|
||||
logger.warning("Gemini API 触发限流,等待65秒后重试...")
|
||||
time.sleep(65) # 等待65秒后重试
|
||||
candidate = response_data["candidates"][0]
|
||||
|
||||
# 检查是否被安全过滤阻止
|
||||
if "finishReason" in candidate and candidate["finishReason"] == "SAFETY":
|
||||
raise Exception("内容被Gemini安全过滤器阻止")
|
||||
|
||||
# 创建兼容的响应对象
|
||||
class CompatibleResponse:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
candidate = data["candidates"][0]
|
||||
if "content" in candidate and "parts" in candidate["content"]:
|
||||
self.text = ""
|
||||
for part in candidate["content"]["parts"]:
|
||||
if "text" in part:
|
||||
self.text += part["text"]
|
||||
else:
|
||||
self.text = ""
|
||||
|
||||
return CompatibleResponse(response_data)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
if attempt < max_retries - 1:
|
||||
logger.warning(f"网络请求失败,等待30秒后重试: {str(e)}")
|
||||
time.sleep(30)
|
||||
continue
|
||||
else:
|
||||
logger.error(f"Gemini 生成文案错误: \n{error_str}")
|
||||
logger.error(f"原生Gemini API请求失败: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
if attempt < max_retries - 1 and "429" in str(e):
|
||||
logger.warning("原生Gemini API 触发限流,等待65秒后重试...")
|
||||
time.sleep(65)
|
||||
continue
|
||||
else:
|
||||
logger.error(f"原生Gemini 生成文案错误: {str(e)}")
|
||||
raise
|
||||
|
||||
def _process_response(self, response: any) -> str:
|
||||
"""处理Gemini的响应"""
|
||||
"""处理原生Gemini API的响应"""
|
||||
if not response or not response.text:
|
||||
raise ValueError("Invalid response from Gemini API")
|
||||
raise ValueError("原生Gemini API返回无效响应")
|
||||
return response.text.strip()
|
||||
|
||||
|
||||
@ -318,7 +441,7 @@ class ScriptProcessor:
|
||||
# 根据模型名称选择对应的生成器
|
||||
logger.info(f"文本 LLM 提供商: {model_name}")
|
||||
if 'gemini' in model_name.lower():
|
||||
self.generator = GeminiGenerator(model_name, self.api_key, self.prompt)
|
||||
self.generator = GeminiGenerator(model_name, self.api_key, self.prompt, self.base_url)
|
||||
elif 'qwen' in model_name.lower():
|
||||
self.generator = QwenGenerator(model_name, self.api_key, self.prompt, self.base_url)
|
||||
elif 'moonshot' in model_name.lower():
|
||||
|
||||
@ -7,6 +7,45 @@ from app.utils import utils
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def validate_api_key(api_key: str, provider: str) -> tuple[bool, str]:
|
||||
"""验证API密钥格式"""
|
||||
if not api_key or not api_key.strip():
|
||||
return False, f"{provider} API密钥不能为空"
|
||||
|
||||
# 基本长度检查
|
||||
if len(api_key.strip()) < 10:
|
||||
return False, f"{provider} API密钥长度过短,请检查是否正确"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def validate_base_url(base_url: str, provider: str) -> tuple[bool, str]:
|
||||
"""验证Base URL格式"""
|
||||
if not base_url or not base_url.strip():
|
||||
return True, "" # base_url可以为空
|
||||
|
||||
base_url = base_url.strip()
|
||||
if not (base_url.startswith('http://') or base_url.startswith('https://')):
|
||||
return False, f"{provider} Base URL必须以http://或https://开头"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def validate_model_name(model_name: str, provider: str) -> tuple[bool, str]:
|
||||
"""验证模型名称"""
|
||||
if not model_name or not model_name.strip():
|
||||
return False, f"{provider} 模型名称不能为空"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def show_config_validation_errors(errors: list):
|
||||
"""显示配置验证错误"""
|
||||
if errors:
|
||||
for error in errors:
|
||||
st.error(error)
|
||||
|
||||
|
||||
def render_basic_settings(tr):
|
||||
"""渲染基础设置面板"""
|
||||
with st.expander(tr("Basic Settings"), expanded=False):
|
||||
@ -98,18 +137,85 @@ def test_vision_model_connection(api_key, base_url, model_name, provider, tr):
|
||||
bool: 连接是否成功
|
||||
str: 测试结果消息
|
||||
"""
|
||||
import requests
|
||||
if provider.lower() == 'gemini':
|
||||
import google.generativeai as genai
|
||||
|
||||
# 原生Gemini API测试
|
||||
try:
|
||||
genai.configure(api_key=api_key)
|
||||
model = genai.GenerativeModel(model_name)
|
||||
model.generate_content("直接回复我文本'当前网络可用'")
|
||||
return True, tr("gemini model is available")
|
||||
# 构建请求数据
|
||||
request_data = {
|
||||
"contents": [{
|
||||
"parts": [{"text": "直接回复我文本'当前网络可用'"}]
|
||||
}],
|
||||
"generationConfig": {
|
||||
"temperature": 1.0,
|
||||
"topK": 40,
|
||||
"topP": 0.95,
|
||||
"maxOutputTokens": 100,
|
||||
},
|
||||
"safetySettings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# 构建请求URL
|
||||
api_base_url = base_url or "https://generativelanguage.googleapis.com/v1beta"
|
||||
url = f"{api_base_url}/models/{model_name}:generateContent?key={api_key}"
|
||||
|
||||
# 发送请求
|
||||
response = requests.post(
|
||||
url,
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return True, tr("原生Gemini模型连接成功")
|
||||
else:
|
||||
return False, f"{tr('原生Gemini模型连接失败')}: HTTP {response.status_code}"
|
||||
except Exception as e:
|
||||
return False, f"{tr('gemini model is not available')}: {str(e)}"
|
||||
return False, f"{tr('原生Gemini模型连接失败')}: {str(e)}"
|
||||
|
||||
elif provider.lower() == 'gemini(openai)':
|
||||
# OpenAI兼容的Gemini代理测试
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
test_url = f"{base_url.rstrip('/')}/chat/completions"
|
||||
test_data = {
|
||||
"model": model_name,
|
||||
"messages": [
|
||||
{"role": "user", "content": "直接回复我文本'当前网络可用'"}
|
||||
],
|
||||
"stream": False
|
||||
}
|
||||
|
||||
response = requests.post(test_url, headers=headers, json=test_data, timeout=10)
|
||||
if response.status_code == 200:
|
||||
return True, tr("OpenAI兼容Gemini代理连接成功")
|
||||
else:
|
||||
return False, f"{tr('OpenAI兼容Gemini代理连接失败')}: HTTP {response.status_code}"
|
||||
except Exception as e:
|
||||
return False, f"{tr('OpenAI兼容Gemini代理连接失败')}: {str(e)}"
|
||||
elif provider.lower() == 'narratoapi':
|
||||
import requests
|
||||
try:
|
||||
# 构建测试请求
|
||||
headers = {
|
||||
@ -172,7 +278,7 @@ def render_vision_llm_settings(tr):
|
||||
st.subheader(tr("Vision Model Settings"))
|
||||
|
||||
# 视频分析模型提供商选择
|
||||
vision_providers = ['Siliconflow', 'Gemini', 'QwenVL', 'OpenAI']
|
||||
vision_providers = ['Siliconflow', 'Gemini', 'Gemini(OpenAI)', 'QwenVL', 'OpenAI']
|
||||
saved_vision_provider = config.app.get("vision_llm_provider", "Gemini").lower()
|
||||
saved_provider_index = 0
|
||||
|
||||
@ -191,9 +297,15 @@ def render_vision_llm_settings(tr):
|
||||
st.session_state['vision_llm_providers'] = vision_provider
|
||||
|
||||
# 获取已保存的视觉模型配置
|
||||
vision_api_key = config.app.get(f"vision_{vision_provider}_api_key", "")
|
||||
vision_base_url = config.app.get(f"vision_{vision_provider}_base_url", "")
|
||||
vision_model_name = config.app.get(f"vision_{vision_provider}_model_name", "")
|
||||
# 处理特殊的提供商名称映射
|
||||
if vision_provider == 'gemini(openai)':
|
||||
vision_config_key = 'vision_gemini_openai'
|
||||
else:
|
||||
vision_config_key = f'vision_{vision_provider}'
|
||||
|
||||
vision_api_key = config.app.get(f"{vision_config_key}_api_key", "")
|
||||
vision_base_url = config.app.get(f"{vision_config_key}_base_url", "")
|
||||
vision_model_name = config.app.get(f"{vision_config_key}_model_name", "")
|
||||
|
||||
# 渲染视觉模型配置输入框
|
||||
st_vision_api_key = st.text_input(tr("Vision API Key"), value=vision_api_key, type="password")
|
||||
@ -202,14 +314,24 @@ def render_vision_llm_settings(tr):
|
||||
if vision_provider == 'gemini':
|
||||
st_vision_base_url = st.text_input(
|
||||
tr("Vision Base URL"),
|
||||
value=vision_base_url,
|
||||
disabled=True,
|
||||
help=tr("Gemini API does not require a base URL")
|
||||
value=vision_base_url or "https://generativelanguage.googleapis.com/v1beta",
|
||||
help=tr("原生Gemini API端点,默认: https://generativelanguage.googleapis.com/v1beta")
|
||||
)
|
||||
st_vision_model_name = st.text_input(
|
||||
tr("Vision Model Name"),
|
||||
value=vision_model_name or "gemini-2.0-flash-lite",
|
||||
help=tr("Default: gemini-2.0-flash-lite")
|
||||
value=vision_model_name or "gemini-2.0-flash-exp",
|
||||
help=tr("原生Gemini模型,默认: gemini-2.0-flash-exp")
|
||||
)
|
||||
elif vision_provider == 'gemini(openai)':
|
||||
st_vision_base_url = st.text_input(
|
||||
tr("Vision Base URL"),
|
||||
value=vision_base_url or "https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
help=tr("OpenAI兼容的Gemini代理端点,如: https://your-proxy.com/v1")
|
||||
)
|
||||
st_vision_model_name = st.text_input(
|
||||
tr("Vision Model Name"),
|
||||
value=vision_model_name or "gemini-2.0-flash-exp",
|
||||
help=tr("OpenAI格式的Gemini模型名称,默认: gemini-2.0-flash-exp")
|
||||
)
|
||||
elif vision_provider == 'qwenvl':
|
||||
st_vision_base_url = st.text_input(
|
||||
@ -228,30 +350,81 @@ def render_vision_llm_settings(tr):
|
||||
|
||||
# 在配置输入框后添加测试按钮
|
||||
if st.button(tr("Test Connection"), key="test_vision_connection"):
|
||||
with st.spinner(tr("Testing connection...")):
|
||||
success, message = test_vision_model_connection(
|
||||
api_key=st_vision_api_key,
|
||||
base_url=st_vision_base_url,
|
||||
model_name=st_vision_model_name,
|
||||
provider=vision_provider,
|
||||
tr=tr
|
||||
)
|
||||
# 先验证配置
|
||||
test_errors = []
|
||||
if not st_vision_api_key:
|
||||
test_errors.append("请先输入API密钥")
|
||||
if not st_vision_model_name:
|
||||
test_errors.append("请先输入模型名称")
|
||||
|
||||
if success:
|
||||
st.success(tr(message))
|
||||
else:
|
||||
st.error(tr(message))
|
||||
if test_errors:
|
||||
for error in test_errors:
|
||||
st.error(error)
|
||||
else:
|
||||
with st.spinner(tr("Testing connection...")):
|
||||
try:
|
||||
success, message = test_vision_model_connection(
|
||||
api_key=st_vision_api_key,
|
||||
base_url=st_vision_base_url,
|
||||
model_name=st_vision_model_name,
|
||||
provider=vision_provider,
|
||||
tr=tr
|
||||
)
|
||||
|
||||
# 保存视觉模型配置
|
||||
if success:
|
||||
st.success(message)
|
||||
else:
|
||||
st.error(message)
|
||||
except Exception as e:
|
||||
st.error(f"测试连接时发生错误: {str(e)}")
|
||||
logger.error(f"视频分析模型连接测试失败: {str(e)}")
|
||||
|
||||
# 验证和保存视觉模型配置
|
||||
validation_errors = []
|
||||
config_changed = False
|
||||
|
||||
# 验证API密钥
|
||||
if st_vision_api_key:
|
||||
config.app[f"vision_{vision_provider}_api_key"] = st_vision_api_key
|
||||
st.session_state[f"vision_{vision_provider}_api_key"] = st_vision_api_key
|
||||
is_valid, error_msg = validate_api_key(st_vision_api_key, f"视频分析({vision_provider})")
|
||||
if is_valid:
|
||||
config.app[f"{vision_config_key}_api_key"] = st_vision_api_key
|
||||
st.session_state[f"{vision_config_key}_api_key"] = st_vision_api_key
|
||||
config_changed = True
|
||||
else:
|
||||
validation_errors.append(error_msg)
|
||||
|
||||
# 验证Base URL
|
||||
if st_vision_base_url:
|
||||
config.app[f"vision_{vision_provider}_base_url"] = st_vision_base_url
|
||||
st.session_state[f"vision_{vision_provider}_base_url"] = st_vision_base_url
|
||||
is_valid, error_msg = validate_base_url(st_vision_base_url, f"视频分析({vision_provider})")
|
||||
if is_valid:
|
||||
config.app[f"{vision_config_key}_base_url"] = st_vision_base_url
|
||||
st.session_state[f"{vision_config_key}_base_url"] = st_vision_base_url
|
||||
config_changed = True
|
||||
else:
|
||||
validation_errors.append(error_msg)
|
||||
|
||||
# 验证模型名称
|
||||
if st_vision_model_name:
|
||||
config.app[f"vision_{vision_provider}_model_name"] = st_vision_model_name
|
||||
st.session_state[f"vision_{vision_provider}_model_name"] = st_vision_model_name
|
||||
is_valid, error_msg = validate_model_name(st_vision_model_name, f"视频分析({vision_provider})")
|
||||
if is_valid:
|
||||
config.app[f"{vision_config_key}_model_name"] = st_vision_model_name
|
||||
st.session_state[f"{vision_config_key}_model_name"] = st_vision_model_name
|
||||
config_changed = True
|
||||
else:
|
||||
validation_errors.append(error_msg)
|
||||
|
||||
# 显示验证错误
|
||||
show_config_validation_errors(validation_errors)
|
||||
|
||||
# 如果配置有变化且没有验证错误,保存到文件
|
||||
if config_changed and not validation_errors:
|
||||
try:
|
||||
config.save_config()
|
||||
if st_vision_api_key or st_vision_base_url or st_vision_model_name:
|
||||
st.success(f"视频分析模型({vision_provider})配置已保存")
|
||||
except Exception as e:
|
||||
st.error(f"保存配置失败: {str(e)}")
|
||||
logger.error(f"保存视频分析配置失败: {str(e)}")
|
||||
|
||||
|
||||
def test_text_model_connection(api_key, base_url, model_name, provider, tr):
|
||||
@ -278,14 +451,74 @@ def test_text_model_connection(api_key, base_url, model_name, provider, tr):
|
||||
|
||||
# 特殊处理Gemini
|
||||
if provider.lower() == 'gemini':
|
||||
import google.generativeai as genai
|
||||
# 原生Gemini API测试
|
||||
try:
|
||||
genai.configure(api_key=api_key)
|
||||
model = genai.GenerativeModel(model_name)
|
||||
model.generate_content("直接回复我文本'当前网络可用'")
|
||||
return True, tr("Gemini model is available")
|
||||
# 构建请求数据
|
||||
request_data = {
|
||||
"contents": [{
|
||||
"parts": [{"text": "直接回复我文本'当前网络可用'"}]
|
||||
}],
|
||||
"generationConfig": {
|
||||
"temperature": 1.0,
|
||||
"topK": 40,
|
||||
"topP": 0.95,
|
||||
"maxOutputTokens": 100,
|
||||
},
|
||||
"safetySettings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# 构建请求URL
|
||||
api_base_url = base_url or "https://generativelanguage.googleapis.com/v1beta"
|
||||
url = f"{api_base_url}/models/{model_name}:generateContent?key={api_key}"
|
||||
|
||||
# 发送请求
|
||||
response = requests.post(
|
||||
url,
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return True, tr("原生Gemini模型连接成功")
|
||||
else:
|
||||
return False, f"{tr('原生Gemini模型连接失败')}: HTTP {response.status_code}"
|
||||
except Exception as e:
|
||||
return False, f"{tr('Gemini model is not available')}: {str(e)}"
|
||||
return False, f"{tr('原生Gemini模型连接失败')}: {str(e)}"
|
||||
|
||||
elif provider.lower() == 'gemini(openai)':
|
||||
# OpenAI兼容的Gemini代理测试
|
||||
test_url = f"{base_url.rstrip('/')}/chat/completions"
|
||||
test_data = {
|
||||
"model": model_name,
|
||||
"messages": [
|
||||
{"role": "user", "content": "直接回复我文本'当前网络可用'"}
|
||||
],
|
||||
"stream": False
|
||||
}
|
||||
|
||||
response = requests.post(test_url, headers=headers, json=test_data, timeout=10)
|
||||
if response.status_code == 200:
|
||||
return True, tr("OpenAI兼容Gemini代理连接成功")
|
||||
else:
|
||||
return False, f"{tr('OpenAI兼容Gemini代理连接失败')}: HTTP {response.status_code}"
|
||||
else:
|
||||
test_url = f"{base_url.rstrip('/')}/chat/completions"
|
||||
|
||||
@ -322,7 +555,7 @@ def render_text_llm_settings(tr):
|
||||
st.subheader(tr("Text Generation Model Settings"))
|
||||
|
||||
# 文案生成模型提供商选择
|
||||
text_providers = ['OpenAI', 'Siliconflow', 'DeepSeek', 'Gemini', 'Qwen', 'Moonshot']
|
||||
text_providers = ['OpenAI', 'Siliconflow', 'DeepSeek', 'Gemini', 'Gemini(OpenAI)', 'Qwen', 'Moonshot']
|
||||
saved_text_provider = config.app.get("text_llm_provider", "OpenAI").lower()
|
||||
saved_provider_index = 0
|
||||
|
||||
@ -346,32 +579,108 @@ def render_text_llm_settings(tr):
|
||||
|
||||
# 渲染文本模型配置输入框
|
||||
st_text_api_key = st.text_input(tr("Text API Key"), value=text_api_key, type="password")
|
||||
st_text_base_url = st.text_input(tr("Text Base URL"), value=text_base_url)
|
||||
st_text_model_name = st.text_input(tr("Text Model Name"), value=text_model_name)
|
||||
|
||||
# 根据不同提供商设置默认值和帮助信息
|
||||
if text_provider == 'gemini':
|
||||
st_text_base_url = st.text_input(
|
||||
tr("Text Base URL"),
|
||||
value=text_base_url or "https://generativelanguage.googleapis.com/v1beta",
|
||||
help=tr("原生Gemini API端点,默认: https://generativelanguage.googleapis.com/v1beta")
|
||||
)
|
||||
st_text_model_name = st.text_input(
|
||||
tr("Text Model Name"),
|
||||
value=text_model_name or "gemini-2.0-flash-exp",
|
||||
help=tr("原生Gemini模型,默认: gemini-2.0-flash-exp")
|
||||
)
|
||||
elif text_provider == 'gemini(openai)':
|
||||
st_text_base_url = st.text_input(
|
||||
tr("Text Base URL"),
|
||||
value=text_base_url or "https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
help=tr("OpenAI兼容的Gemini代理端点,如: https://your-proxy.com/v1")
|
||||
)
|
||||
st_text_model_name = st.text_input(
|
||||
tr("Text Model Name"),
|
||||
value=text_model_name or "gemini-2.0-flash-exp",
|
||||
help=tr("OpenAI格式的Gemini模型名称,默认: gemini-2.0-flash-exp")
|
||||
)
|
||||
else:
|
||||
st_text_base_url = st.text_input(tr("Text Base URL"), value=text_base_url)
|
||||
st_text_model_name = st.text_input(tr("Text Model Name"), value=text_model_name)
|
||||
|
||||
# 添加测试按钮
|
||||
if st.button(tr("Test Connection"), key="test_text_connection"):
|
||||
with st.spinner(tr("Testing connection...")):
|
||||
success, message = test_text_model_connection(
|
||||
api_key=st_text_api_key,
|
||||
base_url=st_text_base_url,
|
||||
model_name=st_text_model_name,
|
||||
provider=text_provider,
|
||||
tr=tr
|
||||
)
|
||||
# 先验证配置
|
||||
test_errors = []
|
||||
if not st_text_api_key:
|
||||
test_errors.append("请先输入API密钥")
|
||||
if not st_text_model_name:
|
||||
test_errors.append("请先输入模型名称")
|
||||
|
||||
if success:
|
||||
st.success(message)
|
||||
else:
|
||||
st.error(message)
|
||||
if test_errors:
|
||||
for error in test_errors:
|
||||
st.error(error)
|
||||
else:
|
||||
with st.spinner(tr("Testing connection...")):
|
||||
try:
|
||||
success, message = test_text_model_connection(
|
||||
api_key=st_text_api_key,
|
||||
base_url=st_text_base_url,
|
||||
model_name=st_text_model_name,
|
||||
provider=text_provider,
|
||||
tr=tr
|
||||
)
|
||||
|
||||
# 保存文本模型配置
|
||||
if success:
|
||||
st.success(message)
|
||||
else:
|
||||
st.error(message)
|
||||
except Exception as e:
|
||||
st.error(f"测试连接时发生错误: {str(e)}")
|
||||
logger.error(f"文案生成模型连接测试失败: {str(e)}")
|
||||
|
||||
# 验证和保存文本模型配置
|
||||
text_validation_errors = []
|
||||
text_config_changed = False
|
||||
|
||||
# 验证API密钥
|
||||
if st_text_api_key:
|
||||
config.app[f"text_{text_provider}_api_key"] = st_text_api_key
|
||||
is_valid, error_msg = validate_api_key(st_text_api_key, f"文案生成({text_provider})")
|
||||
if is_valid:
|
||||
config.app[f"text_{text_provider}_api_key"] = st_text_api_key
|
||||
text_config_changed = True
|
||||
else:
|
||||
text_validation_errors.append(error_msg)
|
||||
|
||||
# 验证Base URL
|
||||
if st_text_base_url:
|
||||
config.app[f"text_{text_provider}_base_url"] = st_text_base_url
|
||||
is_valid, error_msg = validate_base_url(st_text_base_url, f"文案生成({text_provider})")
|
||||
if is_valid:
|
||||
config.app[f"text_{text_provider}_base_url"] = st_text_base_url
|
||||
text_config_changed = True
|
||||
else:
|
||||
text_validation_errors.append(error_msg)
|
||||
|
||||
# 验证模型名称
|
||||
if st_text_model_name:
|
||||
config.app[f"text_{text_provider}_model_name"] = st_text_model_name
|
||||
is_valid, error_msg = validate_model_name(st_text_model_name, f"文案生成({text_provider})")
|
||||
if is_valid:
|
||||
config.app[f"text_{text_provider}_model_name"] = st_text_model_name
|
||||
text_config_changed = True
|
||||
else:
|
||||
text_validation_errors.append(error_msg)
|
||||
|
||||
# 显示验证错误
|
||||
show_config_validation_errors(text_validation_errors)
|
||||
|
||||
# 如果配置有变化且没有验证错误,保存到文件
|
||||
if text_config_changed and not text_validation_errors:
|
||||
try:
|
||||
config.save_config()
|
||||
if st_text_api_key or st_text_base_url or st_text_model_name:
|
||||
st.success(f"文案生成模型({text_provider})配置已保存")
|
||||
except Exception as e:
|
||||
st.error(f"保存配置失败: {str(e)}")
|
||||
logger.error(f"保存文案生成配置失败: {str(e)}")
|
||||
|
||||
# # Cloudflare 特殊配置
|
||||
# if text_provider == 'cloudflare':
|
||||
|
||||
@ -23,7 +23,10 @@ def create_vision_analyzer(provider, api_key, model, base_url):
|
||||
VisionAnalyzer 或 QwenAnalyzer 实例
|
||||
"""
|
||||
if provider == 'gemini':
|
||||
return gemini_analyzer.VisionAnalyzer(model_name=model, api_key=api_key)
|
||||
return gemini_analyzer.VisionAnalyzer(model_name=model, api_key=api_key, base_url=base_url)
|
||||
elif provider == 'gemini(openai)':
|
||||
from app.utils.gemini_openai_analyzer import GeminiOpenAIAnalyzer
|
||||
return GeminiOpenAIAnalyzer(model_name=model, api_key=api_key, base_url=base_url)
|
||||
else:
|
||||
# 只传入必要的参数
|
||||
return qwenvl_analyzer.QwenAnalyzer(
|
||||
|
||||
@ -16,6 +16,89 @@ from loguru import logger
|
||||
|
||||
from app.config import config
|
||||
from app.services.SDE.short_drama_explanation import analyze_subtitle, generate_narration_script
|
||||
import re
|
||||
|
||||
|
||||
def parse_and_fix_json(json_string):
|
||||
"""
|
||||
解析并修复JSON字符串
|
||||
|
||||
Args:
|
||||
json_string: 待解析的JSON字符串
|
||||
|
||||
Returns:
|
||||
dict: 解析后的字典,如果解析失败返回None
|
||||
"""
|
||||
if not json_string or not json_string.strip():
|
||||
logger.error("JSON字符串为空")
|
||||
return None
|
||||
|
||||
# 清理字符串
|
||||
json_string = json_string.strip()
|
||||
|
||||
# 尝试直接解析
|
||||
try:
|
||||
return json.loads(json_string)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"直接JSON解析失败: {e}")
|
||||
|
||||
# 尝试提取JSON部分
|
||||
try:
|
||||
# 查找JSON代码块
|
||||
json_match = re.search(r'```json\s*(.*?)\s*```', json_string, re.DOTALL)
|
||||
if json_match:
|
||||
json_content = json_match.group(1).strip()
|
||||
logger.info("从代码块中提取JSON内容")
|
||||
return json.loads(json_content)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 尝试查找大括号包围的内容
|
||||
try:
|
||||
# 查找第一个 { 到最后一个 } 的内容
|
||||
start_idx = json_string.find('{')
|
||||
end_idx = json_string.rfind('}')
|
||||
if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
|
||||
json_content = json_string[start_idx:end_idx+1]
|
||||
logger.info("提取大括号包围的JSON内容")
|
||||
return json.loads(json_content)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 尝试修复常见的JSON格式问题
|
||||
try:
|
||||
# 移除注释
|
||||
json_string = re.sub(r'#.*', '', json_string)
|
||||
# 移除多余的逗号
|
||||
json_string = re.sub(r',\s*}', '}', json_string)
|
||||
json_string = re.sub(r',\s*]', ']', json_string)
|
||||
# 修复单引号
|
||||
json_string = re.sub(r"'([^']*)':", r'"\1":', json_string)
|
||||
|
||||
logger.info("尝试修复JSON格式问题后解析")
|
||||
return json.loads(json_string)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 如果所有方法都失败,尝试创建一个基本的结构
|
||||
logger.error(f"所有JSON解析方法都失败,原始内容: {json_string[:200]}...")
|
||||
|
||||
# 尝试从文本中提取关键信息创建基本结构
|
||||
try:
|
||||
# 这是一个简单的回退方案
|
||||
return {
|
||||
"items": [
|
||||
{
|
||||
"_id": 1,
|
||||
"timestamp": "00:00:00,000-00:00:10,000",
|
||||
"picture": "解析失败,使用默认内容",
|
||||
"narration": json_string[:100] + "..." if len(json_string) > 100 else json_string,
|
||||
"OST": 0
|
||||
}
|
||||
]
|
||||
}
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperature):
|
||||
@ -61,7 +144,8 @@ def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperatu
|
||||
model=text_model,
|
||||
base_url=text_base_url,
|
||||
save_result=True,
|
||||
temperature=temperature
|
||||
temperature=temperature,
|
||||
provider=text_provider
|
||||
)
|
||||
"""
|
||||
3. 根据剧情生成解说文案
|
||||
@ -78,7 +162,8 @@ def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperatu
|
||||
model=text_model,
|
||||
base_url=text_base_url,
|
||||
save_result=True,
|
||||
temperature=temperature
|
||||
temperature=temperature,
|
||||
provider=text_provider
|
||||
)
|
||||
|
||||
if narration_result["status"] == "success":
|
||||
@ -100,7 +185,20 @@ def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperatu
|
||||
|
||||
# 结果转换为JSON字符串
|
||||
narration_script = narration_result["narration_script"]
|
||||
narration_dict = json.loads(narration_script)
|
||||
|
||||
# 增强JSON解析,包含错误处理和修复
|
||||
narration_dict = parse_and_fix_json(narration_script)
|
||||
if narration_dict is None:
|
||||
st.error("生成的解说文案格式错误,无法解析为JSON")
|
||||
logger.error(f"JSON解析失败,原始内容: {narration_script}")
|
||||
st.stop()
|
||||
|
||||
# 验证JSON结构
|
||||
if 'items' not in narration_dict:
|
||||
st.error("生成的解说文案缺少必要的'items'字段")
|
||||
logger.error(f"JSON结构错误,缺少items字段: {narration_dict}")
|
||||
st.stop()
|
||||
|
||||
script = json.dumps(narration_dict['items'], ensure_ascii=False, indent=2)
|
||||
|
||||
if script is None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user