mirror of
https://github.com/linyqh/NarratoAI.git
synced 2025-12-11 10:32:49 +00:00
feat: 更新作者信息并增强API配置验证功能
在基础设置中新增API密钥、基础URL和模型名称的验证功能,确保用户输入的配置有效性,提升系统的稳定性和用户体验。
This commit is contained in:
parent
04ffda297f
commit
dd59d5295d
@ -4,7 +4,7 @@
|
|||||||
'''
|
'''
|
||||||
@Project: NarratoAI
|
@Project: NarratoAI
|
||||||
@File : audio_config
|
@File : audio_config
|
||||||
@Author : 小林同学
|
@Author : Viccy同学
|
||||||
@Date : 2025/1/7
|
@Date : 2025/1/7
|
||||||
@Description: 音频配置管理
|
@Description: 音频配置管理
|
||||||
'''
|
'''
|
||||||
|
|||||||
@ -74,24 +74,27 @@ plot_writing = """
|
|||||||
%s
|
%s
|
||||||
</plot>
|
</plot>
|
||||||
|
|
||||||
请使用 json 格式进行输出;使用 <output> 中的输出格式:
|
请严格按照以下JSON格式输出,不要添加任何其他文字、说明或代码块标记:
|
||||||
<output>
|
|
||||||
{
|
{
|
||||||
"items": [
|
"items": [
|
||||||
{
|
{
|
||||||
"_id": 1, # 唯一递增id
|
"_id": 1,
|
||||||
"timestamp": "00:00:05,390-00:00:10,430",
|
"timestamp": "00:00:05,390-00:00:10,430",
|
||||||
"picture": "剧情描述或者备注",
|
"picture": "剧情描述或者备注",
|
||||||
"narration": "解说文案,如果片段为穿插的原片片段,可以直接使用 ‘播放原片+_id‘ 进行占位",
|
"narration": "解说文案,如果片段为穿插的原片片段,可以直接使用 ‘播放原片+_id‘ 进行占位",
|
||||||
"OST": "值为 0 表示当前片段为解说片段,值为 1 表示当前片段为穿插的原片"
|
"OST": 0
|
||||||
}
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
</output>
|
|
||||||
|
|
||||||
<restriction>
|
重要要求:
|
||||||
1. 只输出 json 内容,不要输出其他任何说明性的文字
|
1. 必须输出有效的JSON格式,不能包含注释
|
||||||
2. 解说文案的语言使用 简体中文
|
2. OST字段必须是数字:0表示解说片段,1表示原片片段
|
||||||
3. 严禁虚构剧情,所有画面只能从 <polt> 中摘取
|
3. _id必须是递增的数字
|
||||||
4. 严禁虚构时间戳,所有时间戳范围只能从 <polt> 中摘取
|
4. 只输出JSON内容,不要输出任何说明文字
|
||||||
</restriction>
|
5. 不要使用代码块标记(如```json)
|
||||||
|
6. 解说文案使用简体中文
|
||||||
|
7. 严禁虚构剧情,所有内容只能从<plot>中摘取
|
||||||
|
8. 严禁虚构时间戳,所有时间戳只能从<plot>中摘取
|
||||||
"""
|
"""
|
||||||
@ -22,34 +22,44 @@ class SubtitleAnalyzer:
|
|||||||
"""字幕剧情分析器,负责分析字幕内容并提取关键剧情段落"""
|
"""字幕剧情分析器,负责分析字幕内容并提取关键剧情段落"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
base_url: Optional[str] = None,
|
base_url: Optional[str] = None,
|
||||||
custom_prompt: Optional[str] = None,
|
custom_prompt: Optional[str] = None,
|
||||||
temperature: Optional[float] = 1.0,
|
temperature: Optional[float] = 1.0,
|
||||||
|
provider: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化字幕分析器
|
初始化字幕分析器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
api_key: API密钥,如果不提供则从配置中读取
|
api_key: API密钥,如果不提供则从配置中读取
|
||||||
model: 模型名称,如果不提供则从配置中读取
|
model: 模型名称,如果不提供则从配置中读取
|
||||||
base_url: API基础URL,如果不提供则从配置中读取或使用默认值
|
base_url: API基础URL,如果不提供则从配置中读取或使用默认值
|
||||||
custom_prompt: 自定义提示词,如果不提供则使用默认值
|
custom_prompt: 自定义提示词,如果不提供则使用默认值
|
||||||
temperature: 模型温度
|
temperature: 模型温度
|
||||||
|
provider: 提供商类型,用于确定API调用格式
|
||||||
"""
|
"""
|
||||||
# 使用传入的参数或从配置中获取
|
# 使用传入的参数或从配置中获取
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.model = model
|
self.model = model
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
|
self.provider = provider or self._detect_provider()
|
||||||
|
|
||||||
# 设置提示词模板
|
# 设置提示词模板
|
||||||
self.prompt_template = custom_prompt or subtitle_plot_analysis_v1
|
self.prompt_template = custom_prompt or subtitle_plot_analysis_v1
|
||||||
|
|
||||||
|
# 根据提供商类型确定是否为原生Gemini
|
||||||
|
self.is_native_gemini = self.provider.lower() == 'gemini'
|
||||||
|
|
||||||
# 初始化HTTP请求所需的头信息
|
# 初始化HTTP请求所需的头信息
|
||||||
self._init_headers()
|
self._init_headers()
|
||||||
|
|
||||||
|
def _detect_provider(self):
|
||||||
|
"""根据配置自动检测提供商类型"""
|
||||||
|
return config.app.get('text_llm_provider', 'gemini').lower()
|
||||||
|
|
||||||
def _init_headers(self):
|
def _init_headers(self):
|
||||||
"""初始化HTTP请求头"""
|
"""初始化HTTP请求头"""
|
||||||
@ -67,18 +77,152 @@ class SubtitleAnalyzer:
|
|||||||
def analyze_subtitle(self, subtitle_content: str) -> Dict[str, Any]:
|
def analyze_subtitle(self, subtitle_content: str) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
分析字幕内容
|
分析字幕内容
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
subtitle_content: 字幕内容文本
|
subtitle_content: 字幕内容文本
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Any]: 包含分析结果的字典
|
Dict[str, Any]: 包含分析结果的字典
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 构建完整提示词
|
# 构建完整提示词
|
||||||
prompt = f"{self.prompt_template}\n\n{subtitle_content}"
|
prompt = f"{self.prompt_template}\n\n{subtitle_content}"
|
||||||
|
|
||||||
# 构建请求体数据
|
if self.is_native_gemini:
|
||||||
|
# 使用原生Gemini API格式
|
||||||
|
return self._call_native_gemini_api(prompt)
|
||||||
|
else:
|
||||||
|
# 使用OpenAI兼容格式
|
||||||
|
return self._call_openai_compatible_api(prompt)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"字幕分析过程中发生错误: {str(e)}")
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": str(e),
|
||||||
|
"temperature": self.temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
def _call_native_gemini_api(self, prompt: str) -> Dict[str, Any]:
|
||||||
|
"""调用原生Gemini API"""
|
||||||
|
try:
|
||||||
|
# 构建原生Gemini API请求数据
|
||||||
|
payload = {
|
||||||
|
"systemInstruction": {
|
||||||
|
"parts": [{"text": "你是一位专业的剧本分析师和剧情概括助手。请严格按照要求的格式输出分析结果。"}]
|
||||||
|
},
|
||||||
|
"contents": [{
|
||||||
|
"parts": [{"text": prompt}]
|
||||||
|
}],
|
||||||
|
"generationConfig": {
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"topK": 40,
|
||||||
|
"topP": 0.95,
|
||||||
|
"maxOutputTokens": 4000,
|
||||||
|
"candidateCount": 1
|
||||||
|
},
|
||||||
|
"safetySettings": [
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HARASSMENT",
|
||||||
|
"threshold": "BLOCK_NONE"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||||
|
"threshold": "BLOCK_NONE"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||||
|
"threshold": "BLOCK_NONE"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||||
|
"threshold": "BLOCK_NONE"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构建请求URL
|
||||||
|
url = f"{self.base_url}/models/{self.model}:generateContent?key={self.api_key}"
|
||||||
|
|
||||||
|
# 发送请求
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
json=payload,
|
||||||
|
headers={"Content-Type": "application/json", "User-Agent": "NarratoAI/1.0"},
|
||||||
|
timeout=120
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
response_data = response.json()
|
||||||
|
|
||||||
|
# 检查响应格式
|
||||||
|
if "candidates" not in response_data or not response_data["candidates"]:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "原生Gemini API返回无效响应,可能触发了安全过滤",
|
||||||
|
"temperature": self.temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
candidate = response_data["candidates"][0]
|
||||||
|
|
||||||
|
# 检查是否被安全过滤阻止
|
||||||
|
if "finishReason" in candidate and candidate["finishReason"] == "SAFETY":
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "内容被Gemini安全过滤器阻止",
|
||||||
|
"temperature": self.temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
if "content" not in candidate or "parts" not in candidate["content"]:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "原生Gemini API返回内容格式错误",
|
||||||
|
"temperature": self.temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
# 提取文本内容
|
||||||
|
analysis_result = ""
|
||||||
|
for part in candidate["content"]["parts"]:
|
||||||
|
if "text" in part:
|
||||||
|
analysis_result += part["text"]
|
||||||
|
|
||||||
|
if not analysis_result.strip():
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "原生Gemini API返回空内容",
|
||||||
|
"temperature": self.temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(f"原生Gemini字幕分析完成")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"analysis": analysis_result,
|
||||||
|
"tokens_used": response_data.get("usage", {}).get("total_tokens", 0),
|
||||||
|
"model": self.model,
|
||||||
|
"temperature": self.temperature
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
error_msg = f"原生Gemini API请求失败,状态码: {response.status_code}, 响应: {response.text}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": error_msg,
|
||||||
|
"temperature": self.temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"原生Gemini API调用失败: {str(e)}")
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"原生Gemini API调用失败: {str(e)}",
|
||||||
|
"temperature": self.temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
def _call_openai_compatible_api(self, prompt: str) -> Dict[str, Any]:
|
||||||
|
"""调用OpenAI兼容的API"""
|
||||||
|
try:
|
||||||
|
# 构建OpenAI格式的请求数据
|
||||||
payload = {
|
payload = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": [
|
"messages": [
|
||||||
@ -87,22 +231,22 @@ class SubtitleAnalyzer:
|
|||||||
],
|
],
|
||||||
"temperature": self.temperature
|
"temperature": self.temperature
|
||||||
}
|
}
|
||||||
|
|
||||||
# 构建请求地址
|
# 构建请求地址
|
||||||
url = f"{self.base_url}/chat/completions"
|
url = f"{self.base_url}/chat/completions"
|
||||||
|
|
||||||
# 发送HTTP请求
|
# 发送HTTP请求
|
||||||
response = requests.post(url, headers=self.headers, json=payload)
|
response = requests.post(url, headers=self.headers, json=payload, timeout=120)
|
||||||
|
|
||||||
# 解析响应
|
# 解析响应
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
|
|
||||||
# 提取响应内容
|
# 提取响应内容
|
||||||
if "choices" in response_data and len(response_data["choices"]) > 0:
|
if "choices" in response_data and len(response_data["choices"]) > 0:
|
||||||
analysis_result = response_data["choices"][0]["message"]["content"]
|
analysis_result = response_data["choices"][0]["message"]["content"]
|
||||||
logger.debug(f"字幕分析完成,消耗的tokens: {response_data.get('usage', {}).get('total_tokens', 0)}")
|
logger.debug(f"OpenAI兼容API字幕分析完成,消耗的tokens: {response_data.get('usage', {}).get('total_tokens', 0)}")
|
||||||
|
|
||||||
# 返回结果
|
# 返回结果
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
@ -112,26 +256,26 @@ class SubtitleAnalyzer:
|
|||||||
"temperature": self.temperature
|
"temperature": self.temperature
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
logger.error("字幕分析失败: 未获取到有效响应")
|
logger.error("OpenAI兼容API字幕分析失败: 未获取到有效响应")
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "未获取到有效响应",
|
"message": "未获取到有效响应",
|
||||||
"temperature": self.temperature
|
"temperature": self.temperature
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
error_msg = f"请求失败,状态码: {response.status_code}, 响应: {response.text}"
|
error_msg = f"OpenAI兼容API请求失败,状态码: {response.status_code}, 响应: {response.text}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": error_msg,
|
"message": error_msg,
|
||||||
"temperature": self.temperature
|
"temperature": self.temperature
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"字幕分析过程中发生错误: {str(e)}")
|
logger.error(f"OpenAI兼容API调用失败: {str(e)}")
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": str(e),
|
"message": f"OpenAI兼容API调用失败: {str(e)}",
|
||||||
"temperature": self.temperature
|
"temperature": self.temperature
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -206,12 +350,12 @@ class SubtitleAnalyzer:
|
|||||||
def generate_narration_script(self, short_name:str, plot_analysis: str, temperature: float = 0.7) -> Dict[str, Any]:
|
def generate_narration_script(self, short_name:str, plot_analysis: str, temperature: float = 0.7) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
根据剧情分析生成解说文案
|
根据剧情分析生成解说文案
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
short_name: 短剧名称
|
short_name: 短剧名称
|
||||||
plot_analysis: 剧情分析内容
|
plot_analysis: 剧情分析内容
|
||||||
temperature: 生成温度,控制创造性,默认0.7
|
temperature: 生成温度,控制创造性,默认0.7
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Any]: 包含生成结果的字典
|
Dict[str, Any]: 包含生成结果的字典
|
||||||
"""
|
"""
|
||||||
@ -219,7 +363,145 @@ class SubtitleAnalyzer:
|
|||||||
# 构建完整提示词
|
# 构建完整提示词
|
||||||
prompt = plot_writing % (short_name, plot_analysis)
|
prompt = plot_writing % (short_name, plot_analysis)
|
||||||
|
|
||||||
# 构建请求体数据
|
if self.is_native_gemini:
|
||||||
|
# 使用原生Gemini API格式
|
||||||
|
return self._generate_narration_with_native_gemini(prompt, temperature)
|
||||||
|
else:
|
||||||
|
# 使用OpenAI兼容格式
|
||||||
|
return self._generate_narration_with_openai_compatible(prompt, temperature)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"解说文案生成过程中发生错误: {str(e)}")
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": str(e),
|
||||||
|
"temperature": self.temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
def _generate_narration_with_native_gemini(self, prompt: str, temperature: float) -> Dict[str, Any]:
|
||||||
|
"""使用原生Gemini API生成解说文案"""
|
||||||
|
try:
|
||||||
|
# 构建原生Gemini API请求数据
|
||||||
|
# 为了确保JSON输出,在提示词中添加更强的约束
|
||||||
|
enhanced_prompt = f"{prompt}\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"systemInstruction": {
|
||||||
|
"parts": [{"text": "你是一位专业的短视频解说脚本撰写专家。你必须严格按照JSON格式输出,不能包含任何其他文字、说明或代码块标记。"}]
|
||||||
|
},
|
||||||
|
"contents": [{
|
||||||
|
"parts": [{"text": enhanced_prompt}]
|
||||||
|
}],
|
||||||
|
"generationConfig": {
|
||||||
|
"temperature": temperature,
|
||||||
|
"topK": 40,
|
||||||
|
"topP": 0.95,
|
||||||
|
"maxOutputTokens": 4000,
|
||||||
|
"candidateCount": 1,
|
||||||
|
"stopSequences": ["```", "注意", "说明"]
|
||||||
|
},
|
||||||
|
"safetySettings": [
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HARASSMENT",
|
||||||
|
"threshold": "BLOCK_NONE"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||||
|
"threshold": "BLOCK_NONE"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||||
|
"threshold": "BLOCK_NONE"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||||
|
"threshold": "BLOCK_NONE"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构建请求URL
|
||||||
|
url = f"{self.base_url}/models/{self.model}:generateContent?key={self.api_key}"
|
||||||
|
|
||||||
|
# 发送请求
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
json=payload,
|
||||||
|
headers={"Content-Type": "application/json", "User-Agent": "NarratoAI/1.0"},
|
||||||
|
timeout=120
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
response_data = response.json()
|
||||||
|
|
||||||
|
# 检查响应格式
|
||||||
|
if "candidates" not in response_data or not response_data["candidates"]:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "原生Gemini API返回无效响应,可能触发了安全过滤",
|
||||||
|
"temperature": temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
candidate = response_data["candidates"][0]
|
||||||
|
|
||||||
|
# 检查是否被安全过滤阻止
|
||||||
|
if "finishReason" in candidate and candidate["finishReason"] == "SAFETY":
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "内容被Gemini安全过滤器阻止",
|
||||||
|
"temperature": temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
if "content" not in candidate or "parts" not in candidate["content"]:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "原生Gemini API返回内容格式错误",
|
||||||
|
"temperature": temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
# 提取文本内容
|
||||||
|
narration_script = ""
|
||||||
|
for part in candidate["content"]["parts"]:
|
||||||
|
if "text" in part:
|
||||||
|
narration_script += part["text"]
|
||||||
|
|
||||||
|
if not narration_script.strip():
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "原生Gemini API返回空内容",
|
||||||
|
"temperature": temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(f"原生Gemini解说文案生成完成")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"narration_script": narration_script,
|
||||||
|
"tokens_used": response_data.get("usage", {}).get("total_tokens", 0),
|
||||||
|
"model": self.model,
|
||||||
|
"temperature": temperature
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
error_msg = f"原生Gemini API请求失败,状态码: {response.status_code}, 响应: {response.text}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": error_msg,
|
||||||
|
"temperature": temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"原生Gemini API解说文案生成失败: {str(e)}")
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"原生Gemini API解说文案生成失败: {str(e)}",
|
||||||
|
"temperature": temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
def _generate_narration_with_openai_compatible(self, prompt: str, temperature: float) -> Dict[str, Any]:
|
||||||
|
"""使用OpenAI兼容API生成解说文案"""
|
||||||
|
try:
|
||||||
|
# 构建OpenAI格式的请求数据
|
||||||
payload = {
|
payload = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": [
|
"messages": [
|
||||||
@ -228,56 +510,56 @@ class SubtitleAnalyzer:
|
|||||||
],
|
],
|
||||||
"temperature": temperature
|
"temperature": temperature
|
||||||
}
|
}
|
||||||
|
|
||||||
# 对特定模型添加响应格式设置
|
# 对特定模型添加响应格式设置
|
||||||
if self.model not in ["deepseek-reasoner"]:
|
if self.model not in ["deepseek-reasoner"]:
|
||||||
payload["response_format"] = {"type": "json_object"}
|
payload["response_format"] = {"type": "json_object"}
|
||||||
|
|
||||||
# 构建请求地址
|
# 构建请求地址
|
||||||
url = f"{self.base_url}/chat/completions"
|
url = f"{self.base_url}/chat/completions"
|
||||||
|
|
||||||
# 发送HTTP请求
|
# 发送HTTP请求
|
||||||
response = requests.post(url, headers=self.headers, json=payload)
|
response = requests.post(url, headers=self.headers, json=payload, timeout=120)
|
||||||
|
|
||||||
# 解析响应
|
# 解析响应
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
|
|
||||||
# 提取响应内容
|
# 提取响应内容
|
||||||
if "choices" in response_data and len(response_data["choices"]) > 0:
|
if "choices" in response_data and len(response_data["choices"]) > 0:
|
||||||
narration_script = response_data["choices"][0]["message"]["content"]
|
narration_script = response_data["choices"][0]["message"]["content"]
|
||||||
logger.debug(f"解说文案生成完成,消耗的tokens: {response_data.get('usage', {}).get('total_tokens', 0)}")
|
logger.debug(f"OpenAI兼容API解说文案生成完成,消耗的tokens: {response_data.get('usage', {}).get('total_tokens', 0)}")
|
||||||
|
|
||||||
# 返回结果
|
# 返回结果
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"narration_script": narration_script,
|
"narration_script": narration_script,
|
||||||
"tokens_used": response_data.get("usage", {}).get("total_tokens", 0),
|
"tokens_used": response_data.get("usage", {}).get("total_tokens", 0),
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"temperature": self.temperature
|
"temperature": temperature
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
logger.error("解说文案生成失败: 未获取到有效响应")
|
logger.error("OpenAI兼容API解说文案生成失败: 未获取到有效响应")
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "未获取到有效响应",
|
"message": "未获取到有效响应",
|
||||||
"temperature": self.temperature
|
"temperature": temperature
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
error_msg = f"请求失败,状态码: {response.status_code}, 响应: {response.text}"
|
error_msg = f"OpenAI兼容API请求失败,状态码: {response.status_code}, 响应: {response.text}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": error_msg,
|
"message": error_msg,
|
||||||
"temperature": self.temperature
|
"temperature": temperature
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"解说文案生成过程中发生错误: {str(e)}")
|
logger.error(f"OpenAI兼容API解说文案生成失败: {str(e)}")
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": str(e),
|
"message": f"OpenAI兼容API解说文案生成失败: {str(e)}",
|
||||||
"temperature": self.temperature
|
"temperature": temperature
|
||||||
}
|
}
|
||||||
|
|
||||||
def save_narration_script(self, narration_result: Dict[str, Any], output_path: Optional[str] = None) -> str:
|
def save_narration_script(self, narration_result: Dict[str, Any], output_path: Optional[str] = None) -> str:
|
||||||
@ -324,11 +606,12 @@ def analyze_subtitle(
|
|||||||
custom_prompt: Optional[str] = None,
|
custom_prompt: Optional[str] = None,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
save_result: bool = False,
|
save_result: bool = False,
|
||||||
output_path: Optional[str] = None
|
output_path: Optional[str] = None,
|
||||||
|
provider: Optional[str] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
分析字幕内容的便捷函数
|
分析字幕内容的便捷函数
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
subtitle_content: 字幕内容文本
|
subtitle_content: 字幕内容文本
|
||||||
subtitle_file_path: 字幕文件路径
|
subtitle_file_path: 字幕文件路径
|
||||||
@ -339,7 +622,8 @@ def analyze_subtitle(
|
|||||||
temperature: 模型温度
|
temperature: 模型温度
|
||||||
save_result: 是否保存结果到文件
|
save_result: 是否保存结果到文件
|
||||||
output_path: 输出文件路径
|
output_path: 输出文件路径
|
||||||
|
provider: 提供商类型
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Any]: 包含分析结果的字典
|
Dict[str, Any]: 包含分析结果的字典
|
||||||
"""
|
"""
|
||||||
@ -349,7 +633,8 @@ def analyze_subtitle(
|
|||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
model=model,
|
model=model,
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
custom_prompt=custom_prompt
|
custom_prompt=custom_prompt,
|
||||||
|
provider=provider
|
||||||
)
|
)
|
||||||
logger.debug(f"使用模型: {analyzer.model} 开始分析, 温度: {analyzer.temperature}")
|
logger.debug(f"使用模型: {analyzer.model} 开始分析, 温度: {analyzer.temperature}")
|
||||||
# 分析字幕
|
# 分析字幕
|
||||||
@ -379,11 +664,12 @@ def generate_narration_script(
|
|||||||
base_url: Optional[str] = None,
|
base_url: Optional[str] = None,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
save_result: bool = False,
|
save_result: bool = False,
|
||||||
output_path: Optional[str] = None
|
output_path: Optional[str] = None,
|
||||||
|
provider: Optional[str] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
根据剧情分析生成解说文案的便捷函数
|
根据剧情分析生成解说文案的便捷函数
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
short_name: 短剧名称
|
short_name: 短剧名称
|
||||||
plot_analysis: 剧情分析内容,直接提供
|
plot_analysis: 剧情分析内容,直接提供
|
||||||
@ -393,7 +679,8 @@ def generate_narration_script(
|
|||||||
temperature: 生成温度,控制创造性
|
temperature: 生成温度,控制创造性
|
||||||
save_result: 是否保存结果到文件
|
save_result: 是否保存结果到文件
|
||||||
output_path: 输出文件路径
|
output_path: 输出文件路径
|
||||||
|
provider: 提供商类型
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Any]: 包含生成结果的字典
|
Dict[str, Any]: 包含生成结果的字典
|
||||||
"""
|
"""
|
||||||
@ -402,7 +689,8 @@ def generate_narration_script(
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
model=model,
|
model=model,
|
||||||
base_url=base_url
|
base_url=base_url,
|
||||||
|
provider=provider
|
||||||
)
|
)
|
||||||
|
|
||||||
# 生成解说文案
|
# 生成解说文案
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
'''
|
'''
|
||||||
@Project: NarratoAI
|
@Project: NarratoAI
|
||||||
@File : audio_normalizer
|
@File : audio_normalizer
|
||||||
@Author : 小林同学
|
@Author : Viccy同学
|
||||||
@Date : 2025/1/7
|
@Date : 2025/1/7
|
||||||
@Description: 音频响度分析和标准化工具
|
@Description: 音频响度分析和标准化工具
|
||||||
'''
|
'''
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
'''
|
'''
|
||||||
@Project: NarratoAI
|
@Project: NarratoAI
|
||||||
@File : clip_video
|
@File : clip_video
|
||||||
@Author : 小林同学
|
@Author : Viccy同学
|
||||||
@Date : 2025/5/6 下午6:14
|
@Date : 2025/5/6 下午6:14
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
'''
|
'''
|
||||||
@Project: NarratoAI
|
@Project: NarratoAI
|
||||||
@File : 生成介绍文案
|
@File : 生成介绍文案
|
||||||
@Author : 小林同学
|
@Author : Viccy同学
|
||||||
@Date : 2025/5/8 上午11:33
|
@Date : 2025/5/8 上午11:33
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
'''
|
'''
|
||||||
@Project: NarratoAI
|
@Project: NarratoAI
|
||||||
@File : generate_video
|
@File : generate_video
|
||||||
@Author : 小林同学
|
@Author : Viccy同学
|
||||||
@Date : 2025/5/7 上午11:55
|
@Date : 2025/5/7 上午11:55
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
'''
|
'''
|
||||||
@Project: NarratoAI
|
@Project: NarratoAI
|
||||||
@File : merger_video
|
@File : merger_video
|
||||||
@Author : 小林同学
|
@Author : Viccy同学
|
||||||
@Date : 2025/5/6 下午7:38
|
@Date : 2025/5/6 下午7:38
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
|||||||
@ -140,14 +140,27 @@ class ScriptGenerator:
|
|||||||
# 获取Gemini配置
|
# 获取Gemini配置
|
||||||
vision_api_key = config.app.get("vision_gemini_api_key")
|
vision_api_key = config.app.get("vision_gemini_api_key")
|
||||||
vision_model = config.app.get("vision_gemini_model_name")
|
vision_model = config.app.get("vision_gemini_model_name")
|
||||||
|
vision_base_url = config.app.get("vision_gemini_base_url")
|
||||||
|
|
||||||
if not vision_api_key or not vision_model:
|
if not vision_api_key or not vision_model:
|
||||||
raise ValueError("未配置 Gemini API Key 或者模型")
|
raise ValueError("未配置 Gemini API Key 或者模型")
|
||||||
|
|
||||||
analyzer = gemini_analyzer.VisionAnalyzer(
|
# 根据提供商类型选择合适的分析器
|
||||||
model_name=vision_model,
|
if vision_provider == 'gemini(openai)':
|
||||||
api_key=vision_api_key,
|
# 使用OpenAI兼容的Gemini代理
|
||||||
)
|
from app.utils.gemini_openai_analyzer import GeminiOpenAIAnalyzer
|
||||||
|
analyzer = GeminiOpenAIAnalyzer(
|
||||||
|
model_name=vision_model,
|
||||||
|
api_key=vision_api_key,
|
||||||
|
base_url=vision_base_url
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 使用原生Gemini分析器
|
||||||
|
analyzer = gemini_analyzer.VisionAnalyzer(
|
||||||
|
model_name=vision_model,
|
||||||
|
api_key=vision_api_key,
|
||||||
|
base_url=vision_base_url
|
||||||
|
)
|
||||||
|
|
||||||
progress_callback(40, "正在分析关键帧...")
|
progress_callback(40, "正在分析关键帧...")
|
||||||
|
|
||||||
@ -213,13 +226,35 @@ class ScriptGenerator:
|
|||||||
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
|
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
|
||||||
text_api_key = config.app.get(f'text_{text_provider}_api_key')
|
text_api_key = config.app.get(f'text_{text_provider}_api_key')
|
||||||
text_model = config.app.get(f'text_{text_provider}_model_name')
|
text_model = config.app.get(f'text_{text_provider}_model_name')
|
||||||
|
text_base_url = config.app.get(f'text_{text_provider}_base_url')
|
||||||
|
|
||||||
processor = ScriptProcessor(
|
# 根据提供商类型选择合适的处理器
|
||||||
model_name=text_model,
|
if text_provider == 'gemini(openai)':
|
||||||
api_key=text_api_key,
|
# 使用OpenAI兼容的Gemini代理
|
||||||
prompt=custom_prompt,
|
from app.utils.script_generator import GeminiOpenAIGenerator
|
||||||
video_theme=video_theme
|
generator = GeminiOpenAIGenerator(
|
||||||
)
|
model_name=text_model,
|
||||||
|
api_key=text_api_key,
|
||||||
|
prompt=custom_prompt,
|
||||||
|
base_url=text_base_url
|
||||||
|
)
|
||||||
|
processor = ScriptProcessor(
|
||||||
|
model_name=text_model,
|
||||||
|
api_key=text_api_key,
|
||||||
|
base_url=text_base_url,
|
||||||
|
prompt=custom_prompt,
|
||||||
|
video_theme=video_theme
|
||||||
|
)
|
||||||
|
processor.generator = generator
|
||||||
|
else:
|
||||||
|
# 使用标准处理器(包括原生Gemini)
|
||||||
|
processor = ScriptProcessor(
|
||||||
|
model_name=text_model,
|
||||||
|
api_key=text_api_key,
|
||||||
|
base_url=text_base_url,
|
||||||
|
prompt=custom_prompt,
|
||||||
|
video_theme=video_theme
|
||||||
|
)
|
||||||
|
|
||||||
return processor.process_frames(frame_content_list)
|
return processor.process_frames(frame_content_list)
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
'''
|
'''
|
||||||
@Project: NarratoAI
|
@Project: NarratoAI
|
||||||
@File : update_script
|
@File : update_script
|
||||||
@Author : 小林同学
|
@Author : Viccy同学
|
||||||
@Date : 2025/5/6 下午11:00
|
@Date : 2025/5/6 下午11:00
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
|||||||
@ -5,53 +5,162 @@ from pathlib import Path
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import asyncio
|
import asyncio
|
||||||
from tenacity import retry, stop_after_attempt, RetryError, retry_if_exception_type, wait_exponential
|
from tenacity import retry, stop_after_attempt, retry_if_exception_type, wait_exponential
|
||||||
from google.api_core import exceptions
|
import requests
|
||||||
import google.generativeai as genai
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import traceback
|
import traceback
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
from app.utils import utils
|
from app.utils import utils
|
||||||
|
|
||||||
|
|
||||||
class VisionAnalyzer:
|
class VisionAnalyzer:
|
||||||
"""视觉分析器类"""
|
"""原生Gemini视觉分析器类"""
|
||||||
|
|
||||||
def __init__(self, model_name: str = "gemini-1.5-flash", api_key: str = None):
|
def __init__(self, model_name: str = "gemini-2.0-flash-exp", api_key: str = None, base_url: str = None):
|
||||||
"""初始化视觉分析器"""
|
"""初始化视觉分析器"""
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise ValueError("必须提供API密钥")
|
raise ValueError("必须提供API密钥")
|
||||||
|
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
self.base_url = base_url or "https://generativelanguage.googleapis.com/v1beta"
|
||||||
|
|
||||||
# 初始化配置
|
# 初始化配置
|
||||||
self._configure_client()
|
self._configure_client()
|
||||||
|
|
||||||
def _configure_client(self):
|
def _configure_client(self):
|
||||||
"""配置API客户端"""
|
"""配置原生Gemini API客户端"""
|
||||||
genai.configure(api_key=self.api_key)
|
# 使用原生Gemini REST API
|
||||||
# 开放 Gemini 模型安全设置
|
self.client = None
|
||||||
from google.generativeai.types import HarmCategory, HarmBlockThreshold
|
logger.info(f"配置原生Gemini API,端点: {self.base_url}, 模型: {self.model_name}")
|
||||||
safety_settings = {
|
|
||||||
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
|
||||||
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
|
||||||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
|
||||||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
|
||||||
}
|
|
||||||
self.model = genai.GenerativeModel(self.model_name, safety_settings=safety_settings)
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
retry=retry_if_exception_type(exceptions.ResourceExhausted)
|
retry=retry_if_exception_type(requests.exceptions.RequestException)
|
||||||
)
|
)
|
||||||
async def _generate_content_with_retry(self, prompt, batch):
|
async def _generate_content_with_retry(self, prompt, batch):
|
||||||
"""使用重试机制的内部方法来调用 generate_content_async"""
|
"""使用重试机制调用原生Gemini API"""
|
||||||
try:
|
try:
|
||||||
return await self.model.generate_content_async([prompt, *batch])
|
return await self._generate_with_gemini_api(prompt, batch)
|
||||||
except exceptions.ResourceExhausted as e:
|
except requests.exceptions.RequestException as e:
|
||||||
print(f"API配额限制: {str(e)}")
|
logger.warning(f"Gemini API请求异常: {str(e)}")
|
||||||
raise RetryError("API调用失败")
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Gemini API生成内容时发生错误: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _generate_with_gemini_api(self, prompt, batch):
|
||||||
|
"""使用原生Gemini REST API生成内容"""
|
||||||
|
# 将PIL图片转换为base64编码
|
||||||
|
image_parts = []
|
||||||
|
for img in batch:
|
||||||
|
# 将PIL图片转换为字节流
|
||||||
|
img_buffer = io.BytesIO()
|
||||||
|
img.save(img_buffer, format='JPEG', quality=85) # 优化图片质量
|
||||||
|
img_bytes = img_buffer.getvalue()
|
||||||
|
|
||||||
|
# 转换为base64
|
||||||
|
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
|
||||||
|
image_parts.append({
|
||||||
|
"inline_data": {
|
||||||
|
"mime_type": "image/jpeg",
|
||||||
|
"data": img_base64
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
# 构建符合官方文档的请求数据
|
||||||
|
request_data = {
|
||||||
|
"contents": [{
|
||||||
|
"parts": [
|
||||||
|
{"text": prompt},
|
||||||
|
*image_parts
|
||||||
|
]
|
||||||
|
}],
|
||||||
|
"generationConfig": {
|
||||||
|
"temperature": 1.0,
|
||||||
|
"topK": 40,
|
||||||
|
"topP": 0.95,
|
||||||
|
"maxOutputTokens": 8192,
|
||||||
|
"candidateCount": 1,
|
||||||
|
"stopSequences": []
|
||||||
|
},
|
||||||
|
"safetySettings": [
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HARASSMENT",
|
||||||
|
"threshold": "BLOCK_NONE"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||||
|
"threshold": "BLOCK_NONE"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||||
|
"threshold": "BLOCK_NONE"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||||
|
"threshold": "BLOCK_NONE"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构建请求URL
|
||||||
|
url = f"{self.base_url}/models/{self.model_name}:generateContent?key={self.api_key}"
|
||||||
|
|
||||||
|
# 发送请求
|
||||||
|
response = await asyncio.to_thread(
|
||||||
|
requests.post,
|
||||||
|
url,
|
||||||
|
json=request_data,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"User-Agent": "NarratoAI/1.0"
|
||||||
|
},
|
||||||
|
timeout=120 # 增加超时时间
|
||||||
|
)
|
||||||
|
|
||||||
|
# 处理HTTP错误
|
||||||
|
if response.status_code == 429:
|
||||||
|
raise requests.exceptions.RequestException(f"API配额限制: {response.text}")
|
||||||
|
elif response.status_code == 400:
|
||||||
|
raise Exception(f"请求参数错误: {response.text}")
|
||||||
|
elif response.status_code == 403:
|
||||||
|
raise Exception(f"API密钥无效或权限不足: {response.text}")
|
||||||
|
elif response.status_code != 200:
|
||||||
|
raise Exception(f"Gemini API请求失败: {response.status_code} - {response.text}")
|
||||||
|
|
||||||
|
response_data = response.json()
|
||||||
|
|
||||||
|
# 检查响应格式
|
||||||
|
if "candidates" not in response_data or not response_data["candidates"]:
|
||||||
|
raise Exception("Gemini API返回无效响应,可能触发了安全过滤")
|
||||||
|
|
||||||
|
candidate = response_data["candidates"][0]
|
||||||
|
|
||||||
|
# 检查是否被安全过滤阻止
|
||||||
|
if "finishReason" in candidate and candidate["finishReason"] == "SAFETY":
|
||||||
|
raise Exception("内容被Gemini安全过滤器阻止")
|
||||||
|
|
||||||
|
if "content" not in candidate or "parts" not in candidate["content"]:
|
||||||
|
raise Exception("Gemini API返回内容格式错误")
|
||||||
|
|
||||||
|
# 提取文本内容
|
||||||
|
text_content = ""
|
||||||
|
for part in candidate["content"]["parts"]:
|
||||||
|
if "text" in part:
|
||||||
|
text_content += part["text"]
|
||||||
|
|
||||||
|
if not text_content.strip():
|
||||||
|
raise Exception("Gemini API返回空内容")
|
||||||
|
|
||||||
|
# 创建兼容的响应对象
|
||||||
|
class CompatibleResponse:
|
||||||
|
def __init__(self, text):
|
||||||
|
self.text = text
|
||||||
|
|
||||||
|
return CompatibleResponse(text_content)
|
||||||
|
|
||||||
async def analyze_images(self,
|
async def analyze_images(self,
|
||||||
images: Union[List[str], List[PIL.Image.Image]],
|
images: Union[List[str], List[PIL.Image.Image]],
|
||||||
|
|||||||
177
app/utils/gemini_openai_analyzer.py
Normal file
177
app/utils/gemini_openai_analyzer.py
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
"""
|
||||||
|
OpenAI兼容的Gemini视觉分析器
|
||||||
|
使用标准OpenAI格式调用Gemini代理服务
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import List, Union, Dict
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from loguru import logger
|
||||||
|
from tqdm import tqdm
|
||||||
|
import asyncio
|
||||||
|
from tenacity import retry, stop_after_attempt, retry_if_exception_type, wait_exponential
|
||||||
|
import requests
|
||||||
|
import PIL.Image
|
||||||
|
import traceback
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
from app.utils import utils
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiOpenAIAnalyzer:
|
||||||
|
"""OpenAI兼容的Gemini视觉分析器类"""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str = "gemini-2.0-flash-exp", api_key: str = None, base_url: str = None):
|
||||||
|
"""初始化OpenAI兼容的Gemini分析器"""
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError("必须提供API密钥")
|
||||||
|
|
||||||
|
if not base_url:
|
||||||
|
raise ValueError("必须提供OpenAI兼容的代理端点URL")
|
||||||
|
|
||||||
|
self.model_name = model_name
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = base_url.rstrip('/')
|
||||||
|
|
||||||
|
# 初始化OpenAI客户端
|
||||||
|
self._configure_client()
|
||||||
|
|
||||||
|
def _configure_client(self):
|
||||||
|
"""配置OpenAI兼容的客户端"""
|
||||||
|
from openai import OpenAI
|
||||||
|
self.client = OpenAI(
|
||||||
|
api_key=self.api_key,
|
||||||
|
base_url=self.base_url
|
||||||
|
)
|
||||||
|
logger.info(f"配置OpenAI兼容Gemini代理,端点: {self.base_url}, 模型: {self.model_name}")
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
|
retry=retry_if_exception_type((requests.exceptions.RequestException, Exception))
|
||||||
|
)
|
||||||
|
async def _generate_content_with_retry(self, prompt, batch):
|
||||||
|
"""使用重试机制调用OpenAI兼容的Gemini代理"""
|
||||||
|
try:
|
||||||
|
return await self._generate_with_openai_api(prompt, batch)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"OpenAI兼容Gemini代理请求异常: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _generate_with_openai_api(self, prompt, batch):
|
||||||
|
"""使用OpenAI兼容接口生成内容"""
|
||||||
|
# 将PIL图片转换为base64编码
|
||||||
|
image_contents = []
|
||||||
|
for img in batch:
|
||||||
|
# 将PIL图片转换为字节流
|
||||||
|
img_buffer = io.BytesIO()
|
||||||
|
img.save(img_buffer, format='JPEG', quality=85)
|
||||||
|
img_bytes = img_buffer.getvalue()
|
||||||
|
|
||||||
|
# 转换为base64
|
||||||
|
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
|
||||||
|
image_contents.append({
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{img_base64}"
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
# 构建OpenAI格式的消息
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": prompt},
|
||||||
|
*image_contents
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# 调用OpenAI兼容接口
|
||||||
|
response = await asyncio.to_thread(
|
||||||
|
self.client.chat.completions.create,
|
||||||
|
model=self.model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=4000,
|
||||||
|
temperature=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建兼容的响应对象
|
||||||
|
class CompatibleResponse:
|
||||||
|
def __init__(self, text):
|
||||||
|
self.text = text
|
||||||
|
|
||||||
|
return CompatibleResponse(response.choices[0].message.content)
|
||||||
|
|
||||||
|
async def analyze_images(self,
|
||||||
|
images: List[Union[str, Path, PIL.Image.Image]],
|
||||||
|
prompt: str,
|
||||||
|
batch_size: int = 10) -> List[str]:
|
||||||
|
"""
|
||||||
|
分析图片并返回结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images: 图片路径列表或PIL图片对象列表
|
||||||
|
prompt: 分析提示词
|
||||||
|
batch_size: 批处理大小
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
分析结果列表
|
||||||
|
"""
|
||||||
|
logger.info(f"开始分析 {len(images)} 张图片,使用OpenAI兼容Gemini代理")
|
||||||
|
|
||||||
|
# 加载图片
|
||||||
|
loaded_images = []
|
||||||
|
for img in images:
|
||||||
|
if isinstance(img, (str, Path)):
|
||||||
|
try:
|
||||||
|
pil_img = PIL.Image.open(img)
|
||||||
|
# 调整图片大小以优化性能
|
||||||
|
if pil_img.size[0] > 1024 or pil_img.size[1] > 1024:
|
||||||
|
pil_img.thumbnail((1024, 1024), PIL.Image.Resampling.LANCZOS)
|
||||||
|
loaded_images.append(pil_img)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载图片失败 {img}: {str(e)}")
|
||||||
|
continue
|
||||||
|
elif isinstance(img, PIL.Image.Image):
|
||||||
|
loaded_images.append(img)
|
||||||
|
else:
|
||||||
|
logger.warning(f"不支持的图片类型: {type(img)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not loaded_images:
|
||||||
|
raise ValueError("没有有效的图片可以分析")
|
||||||
|
|
||||||
|
# 分批处理
|
||||||
|
results = []
|
||||||
|
total_batches = (len(loaded_images) + batch_size - 1) // batch_size
|
||||||
|
|
||||||
|
for i in tqdm(range(0, len(loaded_images), batch_size),
|
||||||
|
desc="分析图片批次", total=total_batches):
|
||||||
|
batch = loaded_images[i:i + batch_size]
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self._generate_content_with_retry(prompt, batch)
|
||||||
|
results.append(response.text)
|
||||||
|
|
||||||
|
# 添加延迟以避免API限流
|
||||||
|
if i + batch_size < len(loaded_images):
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"分析批次 {i//batch_size + 1} 失败: {str(e)}")
|
||||||
|
results.append(f"分析失败: {str(e)}")
|
||||||
|
|
||||||
|
logger.info(f"完成图片分析,共处理 {len(results)} 个批次")
|
||||||
|
return results
|
||||||
|
|
||||||
|
def analyze_images_sync(self,
|
||||||
|
images: List[Union[str, Path, PIL.Image.Image]],
|
||||||
|
prompt: str,
|
||||||
|
batch_size: int = 10) -> List[str]:
|
||||||
|
"""
|
||||||
|
同步版本的图片分析方法
|
||||||
|
"""
|
||||||
|
return asyncio.run(self.analyze_images(images, prompt, batch_size))
|
||||||
@ -6,7 +6,7 @@ from loguru import logger
|
|||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
import google.generativeai as genai
|
import requests
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
@ -134,59 +134,182 @@ class OpenAIGenerator(BaseGenerator):
|
|||||||
|
|
||||||
|
|
||||||
class GeminiGenerator(BaseGenerator):
|
class GeminiGenerator(BaseGenerator):
|
||||||
"""Google Gemini API 生成器实现"""
|
"""原生Gemini API 生成器实现"""
|
||||||
def __init__(self, model_name: str, api_key: str, prompt: str):
|
def __init__(self, model_name: str, api_key: str, prompt: str, base_url: str = None):
|
||||||
super().__init__(model_name, api_key, prompt)
|
super().__init__(model_name, api_key, prompt)
|
||||||
genai.configure(api_key=api_key)
|
|
||||||
self.model = genai.GenerativeModel(model_name)
|
self.base_url = base_url or "https://generativelanguage.googleapis.com/v1beta"
|
||||||
|
self.client = None
|
||||||
# Gemini特定参数
|
|
||||||
|
# 原生Gemini API参数
|
||||||
self.default_params = {
|
self.default_params = {
|
||||||
"temperature": self.default_params["temperature"],
|
"temperature": self.default_params["temperature"],
|
||||||
"top_p": self.default_params["top_p"],
|
"topP": self.default_params["top_p"],
|
||||||
"candidate_count": 1,
|
"topK": 40,
|
||||||
"stop_sequences": None
|
"maxOutputTokens": 4000,
|
||||||
|
"candidateCount": 1,
|
||||||
|
"stopSequences": []
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiOpenAIGenerator(BaseGenerator):
|
||||||
|
"""OpenAI兼容的Gemini代理生成器实现"""
|
||||||
|
def __init__(self, model_name: str, api_key: str, prompt: str, base_url: str = None):
|
||||||
|
super().__init__(model_name, api_key, prompt)
|
||||||
|
|
||||||
|
if not base_url:
|
||||||
|
raise ValueError("OpenAI兼容的Gemini代理必须提供base_url")
|
||||||
|
|
||||||
|
self.base_url = base_url.rstrip('/')
|
||||||
|
|
||||||
|
# 使用OpenAI兼容接口
|
||||||
|
from openai import OpenAI
|
||||||
|
self.client = OpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url
|
||||||
|
)
|
||||||
|
|
||||||
|
# OpenAI兼容接口参数
|
||||||
|
self.default_params = {
|
||||||
|
"temperature": self.default_params["temperature"],
|
||||||
|
"max_tokens": 4000,
|
||||||
|
"stream": False
|
||||||
}
|
}
|
||||||
|
|
||||||
def _generate(self, messages: list, params: dict) -> any:
|
def _generate(self, messages: list, params: dict) -> any:
|
||||||
"""实现Gemini特定的生成逻辑"""
|
"""实现OpenAI兼容Gemini代理的生成逻辑"""
|
||||||
while True:
|
try:
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=messages,
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OpenAI兼容Gemini代理生成错误: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _process_response(self, response: any) -> str:
|
||||||
|
"""处理OpenAI兼容接口的响应"""
|
||||||
|
if not response or not response.choices:
|
||||||
|
raise ValueError("OpenAI兼容Gemini代理返回无效响应")
|
||||||
|
return response.choices[0].message.content.strip()
|
||||||
|
|
||||||
|
def _generate(self, messages: list, params: dict) -> any:
|
||||||
|
"""实现原生Gemini API的生成逻辑"""
|
||||||
|
max_retries = 3
|
||||||
|
for attempt in range(max_retries):
|
||||||
try:
|
try:
|
||||||
# 转换消息格式为Gemini格式
|
# 转换消息格式为Gemini格式
|
||||||
prompt = "\n".join([m["content"] for m in messages])
|
prompt = "\n".join([m["content"] for m in messages])
|
||||||
response = self.model.generate_content(
|
|
||||||
prompt,
|
# 构建请求数据
|
||||||
generation_config=params
|
request_data = {
|
||||||
|
"contents": [{
|
||||||
|
"parts": [{"text": prompt}]
|
||||||
|
}],
|
||||||
|
"generationConfig": params,
|
||||||
|
"safetySettings": [
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HARASSMENT",
|
||||||
|
"threshold": "BLOCK_NONE"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||||
|
"threshold": "BLOCK_NONE"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||||
|
"threshold": "BLOCK_NONE"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||||
|
"threshold": "BLOCK_NONE"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构建请求URL
|
||||||
|
url = f"{self.base_url}/models/{self.model_name}:generateContent?key={self.api_key}"
|
||||||
|
|
||||||
|
# 发送请求
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
json=request_data,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"User-Agent": "NarratoAI/1.0"
|
||||||
|
},
|
||||||
|
timeout=120
|
||||||
)
|
)
|
||||||
|
|
||||||
# 检查响应是否包含有效内容
|
if response.status_code == 429:
|
||||||
if (hasattr(response, 'result') and
|
# 处理限流
|
||||||
hasattr(response.result, 'candidates') and
|
wait_time = 65 if attempt == 0 else 30
|
||||||
response.result.candidates):
|
logger.warning(f"原生Gemini API 触发限流,等待{wait_time}秒后重试...")
|
||||||
|
time.sleep(wait_time)
|
||||||
candidate = response.result.candidates[0]
|
continue
|
||||||
|
|
||||||
# 检查是否有内容字段
|
if response.status_code == 400:
|
||||||
if not hasattr(candidate, 'content'):
|
raise Exception(f"请求参数错误: {response.text}")
|
||||||
logger.warning("Gemini API 返回速率限制响应,等待30秒后重试...")
|
elif response.status_code == 403:
|
||||||
time.sleep(30) # 等待3秒后重试
|
raise Exception(f"API密钥无效或权限不足: {response.text}")
|
||||||
|
elif response.status_code != 200:
|
||||||
|
raise Exception(f"原生Gemini API请求失败: {response.status_code} - {response.text}")
|
||||||
|
|
||||||
|
response_data = response.json()
|
||||||
|
|
||||||
|
# 检查响应格式
|
||||||
|
if "candidates" not in response_data or not response_data["candidates"]:
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
logger.warning("原生Gemini API 返回无效响应,等待30秒后重试...")
|
||||||
|
time.sleep(30)
|
||||||
continue
|
continue
|
||||||
return response
|
else:
|
||||||
|
raise Exception("原生Gemini API返回无效响应,可能触发了安全过滤")
|
||||||
except Exception as e:
|
|
||||||
error_str = str(e)
|
candidate = response_data["candidates"][0]
|
||||||
if "429" in error_str:
|
|
||||||
logger.warning("Gemini API 触发限流,等待65秒后重试...")
|
# 检查是否被安全过滤阻止
|
||||||
time.sleep(65) # 等待65秒后重试
|
if "finishReason" in candidate and candidate["finishReason"] == "SAFETY":
|
||||||
|
raise Exception("内容被Gemini安全过滤器阻止")
|
||||||
|
|
||||||
|
# 创建兼容的响应对象
|
||||||
|
class CompatibleResponse:
|
||||||
|
def __init__(self, data):
|
||||||
|
self.data = data
|
||||||
|
candidate = data["candidates"][0]
|
||||||
|
if "content" in candidate and "parts" in candidate["content"]:
|
||||||
|
self.text = ""
|
||||||
|
for part in candidate["content"]["parts"]:
|
||||||
|
if "text" in part:
|
||||||
|
self.text += part["text"]
|
||||||
|
else:
|
||||||
|
self.text = ""
|
||||||
|
|
||||||
|
return CompatibleResponse(response_data)
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
logger.warning(f"网络请求失败,等待30秒后重试: {str(e)}")
|
||||||
|
time.sleep(30)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
logger.error(f"Gemini 生成文案错误: \n{error_str}")
|
logger.error(f"原生Gemini API请求失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
if attempt < max_retries - 1 and "429" in str(e):
|
||||||
|
logger.warning("原生Gemini API 触发限流,等待65秒后重试...")
|
||||||
|
time.sleep(65)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
logger.error(f"原生Gemini 生成文案错误: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _process_response(self, response: any) -> str:
|
def _process_response(self, response: any) -> str:
|
||||||
"""处理Gemini的响应"""
|
"""处理原生Gemini API的响应"""
|
||||||
if not response or not response.text:
|
if not response or not response.text:
|
||||||
raise ValueError("Invalid response from Gemini API")
|
raise ValueError("原生Gemini API返回无效响应")
|
||||||
return response.text.strip()
|
return response.text.strip()
|
||||||
|
|
||||||
|
|
||||||
@ -318,7 +441,7 @@ class ScriptProcessor:
|
|||||||
# 根据模型名称选择对应的生成器
|
# 根据模型名称选择对应的生成器
|
||||||
logger.info(f"文本 LLM 提供商: {model_name}")
|
logger.info(f"文本 LLM 提供商: {model_name}")
|
||||||
if 'gemini' in model_name.lower():
|
if 'gemini' in model_name.lower():
|
||||||
self.generator = GeminiGenerator(model_name, self.api_key, self.prompt)
|
self.generator = GeminiGenerator(model_name, self.api_key, self.prompt, self.base_url)
|
||||||
elif 'qwen' in model_name.lower():
|
elif 'qwen' in model_name.lower():
|
||||||
self.generator = QwenGenerator(model_name, self.api_key, self.prompt, self.base_url)
|
self.generator = QwenGenerator(model_name, self.api_key, self.prompt, self.base_url)
|
||||||
elif 'moonshot' in model_name.lower():
|
elif 'moonshot' in model_name.lower():
|
||||||
|
|||||||
@ -7,6 +7,45 @@ from app.utils import utils
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
def validate_api_key(api_key: str, provider: str) -> tuple[bool, str]:
|
||||||
|
"""验证API密钥格式"""
|
||||||
|
if not api_key or not api_key.strip():
|
||||||
|
return False, f"{provider} API密钥不能为空"
|
||||||
|
|
||||||
|
# 基本长度检查
|
||||||
|
if len(api_key.strip()) < 10:
|
||||||
|
return False, f"{provider} API密钥长度过短,请检查是否正确"
|
||||||
|
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
|
||||||
|
def validate_base_url(base_url: str, provider: str) -> tuple[bool, str]:
|
||||||
|
"""验证Base URL格式"""
|
||||||
|
if not base_url or not base_url.strip():
|
||||||
|
return True, "" # base_url可以为空
|
||||||
|
|
||||||
|
base_url = base_url.strip()
|
||||||
|
if not (base_url.startswith('http://') or base_url.startswith('https://')):
|
||||||
|
return False, f"{provider} Base URL必须以http://或https://开头"
|
||||||
|
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
|
||||||
|
def validate_model_name(model_name: str, provider: str) -> tuple[bool, str]:
|
||||||
|
"""验证模型名称"""
|
||||||
|
if not model_name or not model_name.strip():
|
||||||
|
return False, f"{provider} 模型名称不能为空"
|
||||||
|
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
|
||||||
|
def show_config_validation_errors(errors: list):
|
||||||
|
"""显示配置验证错误"""
|
||||||
|
if errors:
|
||||||
|
for error in errors:
|
||||||
|
st.error(error)
|
||||||
|
|
||||||
|
|
||||||
def render_basic_settings(tr):
|
def render_basic_settings(tr):
|
||||||
"""渲染基础设置面板"""
|
"""渲染基础设置面板"""
|
||||||
with st.expander(tr("Basic Settings"), expanded=False):
|
with st.expander(tr("Basic Settings"), expanded=False):
|
||||||
@ -87,29 +126,96 @@ def render_proxy_settings(tr):
|
|||||||
|
|
||||||
def test_vision_model_connection(api_key, base_url, model_name, provider, tr):
|
def test_vision_model_connection(api_key, base_url, model_name, provider, tr):
|
||||||
"""测试视觉模型连接
|
"""测试视觉模型连接
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
api_key: API密钥
|
api_key: API密钥
|
||||||
base_url: 基础URL
|
base_url: 基础URL
|
||||||
model_name: 模型名称
|
model_name: 模型名称
|
||||||
provider: 提供商名称
|
provider: 提供商名称
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: 连接是否成功
|
bool: 连接是否成功
|
||||||
str: 测试结果消息
|
str: 测试结果消息
|
||||||
"""
|
"""
|
||||||
|
import requests
|
||||||
if provider.lower() == 'gemini':
|
if provider.lower() == 'gemini':
|
||||||
import google.generativeai as genai
|
# 原生Gemini API测试
|
||||||
|
|
||||||
try:
|
try:
|
||||||
genai.configure(api_key=api_key)
|
# 构建请求数据
|
||||||
model = genai.GenerativeModel(model_name)
|
request_data = {
|
||||||
model.generate_content("直接回复我文本'当前网络可用'")
|
"contents": [{
|
||||||
return True, tr("gemini model is available")
|
"parts": [{"text": "直接回复我文本'当前网络可用'"}]
|
||||||
|
}],
|
||||||
|
"generationConfig": {
|
||||||
|
"temperature": 1.0,
|
||||||
|
"topK": 40,
|
||||||
|
"topP": 0.95,
|
||||||
|
"maxOutputTokens": 100,
|
||||||
|
},
|
||||||
|
"safetySettings": [
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HARASSMENT",
|
||||||
|
"threshold": "BLOCK_NONE"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||||
|
"threshold": "BLOCK_NONE"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||||
|
"threshold": "BLOCK_NONE"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||||
|
"threshold": "BLOCK_NONE"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构建请求URL
|
||||||
|
api_base_url = base_url or "https://generativelanguage.googleapis.com/v1beta"
|
||||||
|
url = f"{api_base_url}/models/{model_name}:generateContent?key={api_key}"
|
||||||
|
|
||||||
|
# 发送请求
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
json=request_data,
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
return True, tr("原生Gemini模型连接成功")
|
||||||
|
else:
|
||||||
|
return False, f"{tr('原生Gemini模型连接失败')}: HTTP {response.status_code}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return False, f"{tr('gemini model is not available')}: {str(e)}"
|
return False, f"{tr('原生Gemini模型连接失败')}: {str(e)}"
|
||||||
|
|
||||||
|
elif provider.lower() == 'gemini(openai)':
|
||||||
|
# OpenAI兼容的Gemini代理测试
|
||||||
|
try:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
test_url = f"{base_url.rstrip('/')}/chat/completions"
|
||||||
|
test_data = {
|
||||||
|
"model": model_name,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "直接回复我文本'当前网络可用'"}
|
||||||
|
],
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(test_url, headers=headers, json=test_data, timeout=10)
|
||||||
|
if response.status_code == 200:
|
||||||
|
return True, tr("OpenAI兼容Gemini代理连接成功")
|
||||||
|
else:
|
||||||
|
return False, f"{tr('OpenAI兼容Gemini代理连接失败')}: HTTP {response.status_code}"
|
||||||
|
except Exception as e:
|
||||||
|
return False, f"{tr('OpenAI兼容Gemini代理连接失败')}: {str(e)}"
|
||||||
elif provider.lower() == 'narratoapi':
|
elif provider.lower() == 'narratoapi':
|
||||||
import requests
|
|
||||||
try:
|
try:
|
||||||
# 构建测试请求
|
# 构建测试请求
|
||||||
headers = {
|
headers = {
|
||||||
@ -172,7 +278,7 @@ def render_vision_llm_settings(tr):
|
|||||||
st.subheader(tr("Vision Model Settings"))
|
st.subheader(tr("Vision Model Settings"))
|
||||||
|
|
||||||
# 视频分析模型提供商选择
|
# 视频分析模型提供商选择
|
||||||
vision_providers = ['Siliconflow', 'Gemini', 'QwenVL', 'OpenAI']
|
vision_providers = ['Siliconflow', 'Gemini', 'Gemini(OpenAI)', 'QwenVL', 'OpenAI']
|
||||||
saved_vision_provider = config.app.get("vision_llm_provider", "Gemini").lower()
|
saved_vision_provider = config.app.get("vision_llm_provider", "Gemini").lower()
|
||||||
saved_provider_index = 0
|
saved_provider_index = 0
|
||||||
|
|
||||||
@ -191,9 +297,15 @@ def render_vision_llm_settings(tr):
|
|||||||
st.session_state['vision_llm_providers'] = vision_provider
|
st.session_state['vision_llm_providers'] = vision_provider
|
||||||
|
|
||||||
# 获取已保存的视觉模型配置
|
# 获取已保存的视觉模型配置
|
||||||
vision_api_key = config.app.get(f"vision_{vision_provider}_api_key", "")
|
# 处理特殊的提供商名称映射
|
||||||
vision_base_url = config.app.get(f"vision_{vision_provider}_base_url", "")
|
if vision_provider == 'gemini(openai)':
|
||||||
vision_model_name = config.app.get(f"vision_{vision_provider}_model_name", "")
|
vision_config_key = 'vision_gemini_openai'
|
||||||
|
else:
|
||||||
|
vision_config_key = f'vision_{vision_provider}'
|
||||||
|
|
||||||
|
vision_api_key = config.app.get(f"{vision_config_key}_api_key", "")
|
||||||
|
vision_base_url = config.app.get(f"{vision_config_key}_base_url", "")
|
||||||
|
vision_model_name = config.app.get(f"{vision_config_key}_model_name", "")
|
||||||
|
|
||||||
# 渲染视觉模型配置输入框
|
# 渲染视觉模型配置输入框
|
||||||
st_vision_api_key = st.text_input(tr("Vision API Key"), value=vision_api_key, type="password")
|
st_vision_api_key = st.text_input(tr("Vision API Key"), value=vision_api_key, type="password")
|
||||||
@ -201,15 +313,25 @@ def render_vision_llm_settings(tr):
|
|||||||
# 根据不同提供商设置默认值和帮助信息
|
# 根据不同提供商设置默认值和帮助信息
|
||||||
if vision_provider == 'gemini':
|
if vision_provider == 'gemini':
|
||||||
st_vision_base_url = st.text_input(
|
st_vision_base_url = st.text_input(
|
||||||
tr("Vision Base URL"),
|
tr("Vision Base URL"),
|
||||||
value=vision_base_url,
|
value=vision_base_url or "https://generativelanguage.googleapis.com/v1beta",
|
||||||
disabled=True,
|
help=tr("原生Gemini API端点,默认: https://generativelanguage.googleapis.com/v1beta")
|
||||||
help=tr("Gemini API does not require a base URL")
|
|
||||||
)
|
)
|
||||||
st_vision_model_name = st.text_input(
|
st_vision_model_name = st.text_input(
|
||||||
tr("Vision Model Name"),
|
tr("Vision Model Name"),
|
||||||
value=vision_model_name or "gemini-2.0-flash-lite",
|
value=vision_model_name or "gemini-2.0-flash-exp",
|
||||||
help=tr("Default: gemini-2.0-flash-lite")
|
help=tr("原生Gemini模型,默认: gemini-2.0-flash-exp")
|
||||||
|
)
|
||||||
|
elif vision_provider == 'gemini(openai)':
|
||||||
|
st_vision_base_url = st.text_input(
|
||||||
|
tr("Vision Base URL"),
|
||||||
|
value=vision_base_url or "https://generativelanguage.googleapis.com/v1beta/openai",
|
||||||
|
help=tr("OpenAI兼容的Gemini代理端点,如: https://your-proxy.com/v1")
|
||||||
|
)
|
||||||
|
st_vision_model_name = st.text_input(
|
||||||
|
tr("Vision Model Name"),
|
||||||
|
value=vision_model_name or "gemini-2.0-flash-exp",
|
||||||
|
help=tr("OpenAI格式的Gemini模型名称,默认: gemini-2.0-flash-exp")
|
||||||
)
|
)
|
||||||
elif vision_provider == 'qwenvl':
|
elif vision_provider == 'qwenvl':
|
||||||
st_vision_base_url = st.text_input(
|
st_vision_base_url = st.text_input(
|
||||||
@ -228,30 +350,81 @@ def render_vision_llm_settings(tr):
|
|||||||
|
|
||||||
# 在配置输入框后添加测试按钮
|
# 在配置输入框后添加测试按钮
|
||||||
if st.button(tr("Test Connection"), key="test_vision_connection"):
|
if st.button(tr("Test Connection"), key="test_vision_connection"):
|
||||||
with st.spinner(tr("Testing connection...")):
|
# 先验证配置
|
||||||
success, message = test_vision_model_connection(
|
test_errors = []
|
||||||
api_key=st_vision_api_key,
|
if not st_vision_api_key:
|
||||||
base_url=st_vision_base_url,
|
test_errors.append("请先输入API密钥")
|
||||||
model_name=st_vision_model_name,
|
if not st_vision_model_name:
|
||||||
provider=vision_provider,
|
test_errors.append("请先输入模型名称")
|
||||||
tr=tr
|
|
||||||
)
|
|
||||||
|
|
||||||
if success:
|
|
||||||
st.success(tr(message))
|
|
||||||
else:
|
|
||||||
st.error(tr(message))
|
|
||||||
|
|
||||||
# 保存视觉模型配置
|
if test_errors:
|
||||||
|
for error in test_errors:
|
||||||
|
st.error(error)
|
||||||
|
else:
|
||||||
|
with st.spinner(tr("Testing connection...")):
|
||||||
|
try:
|
||||||
|
success, message = test_vision_model_connection(
|
||||||
|
api_key=st_vision_api_key,
|
||||||
|
base_url=st_vision_base_url,
|
||||||
|
model_name=st_vision_model_name,
|
||||||
|
provider=vision_provider,
|
||||||
|
tr=tr
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
st.success(message)
|
||||||
|
else:
|
||||||
|
st.error(message)
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"测试连接时发生错误: {str(e)}")
|
||||||
|
logger.error(f"视频分析模型连接测试失败: {str(e)}")
|
||||||
|
|
||||||
|
# 验证和保存视觉模型配置
|
||||||
|
validation_errors = []
|
||||||
|
config_changed = False
|
||||||
|
|
||||||
|
# 验证API密钥
|
||||||
if st_vision_api_key:
|
if st_vision_api_key:
|
||||||
config.app[f"vision_{vision_provider}_api_key"] = st_vision_api_key
|
is_valid, error_msg = validate_api_key(st_vision_api_key, f"视频分析({vision_provider})")
|
||||||
st.session_state[f"vision_{vision_provider}_api_key"] = st_vision_api_key
|
if is_valid:
|
||||||
|
config.app[f"{vision_config_key}_api_key"] = st_vision_api_key
|
||||||
|
st.session_state[f"{vision_config_key}_api_key"] = st_vision_api_key
|
||||||
|
config_changed = True
|
||||||
|
else:
|
||||||
|
validation_errors.append(error_msg)
|
||||||
|
|
||||||
|
# 验证Base URL
|
||||||
if st_vision_base_url:
|
if st_vision_base_url:
|
||||||
config.app[f"vision_{vision_provider}_base_url"] = st_vision_base_url
|
is_valid, error_msg = validate_base_url(st_vision_base_url, f"视频分析({vision_provider})")
|
||||||
st.session_state[f"vision_{vision_provider}_base_url"] = st_vision_base_url
|
if is_valid:
|
||||||
|
config.app[f"{vision_config_key}_base_url"] = st_vision_base_url
|
||||||
|
st.session_state[f"{vision_config_key}_base_url"] = st_vision_base_url
|
||||||
|
config_changed = True
|
||||||
|
else:
|
||||||
|
validation_errors.append(error_msg)
|
||||||
|
|
||||||
|
# 验证模型名称
|
||||||
if st_vision_model_name:
|
if st_vision_model_name:
|
||||||
config.app[f"vision_{vision_provider}_model_name"] = st_vision_model_name
|
is_valid, error_msg = validate_model_name(st_vision_model_name, f"视频分析({vision_provider})")
|
||||||
st.session_state[f"vision_{vision_provider}_model_name"] = st_vision_model_name
|
if is_valid:
|
||||||
|
config.app[f"{vision_config_key}_model_name"] = st_vision_model_name
|
||||||
|
st.session_state[f"{vision_config_key}_model_name"] = st_vision_model_name
|
||||||
|
config_changed = True
|
||||||
|
else:
|
||||||
|
validation_errors.append(error_msg)
|
||||||
|
|
||||||
|
# 显示验证错误
|
||||||
|
show_config_validation_errors(validation_errors)
|
||||||
|
|
||||||
|
# 如果配置有变化且没有验证错误,保存到文件
|
||||||
|
if config_changed and not validation_errors:
|
||||||
|
try:
|
||||||
|
config.save_config()
|
||||||
|
if st_vision_api_key or st_vision_base_url or st_vision_model_name:
|
||||||
|
st.success(f"视频分析模型({vision_provider})配置已保存")
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"保存配置失败: {str(e)}")
|
||||||
|
logger.error(f"保存视频分析配置失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def test_text_model_connection(api_key, base_url, model_name, provider, tr):
|
def test_text_model_connection(api_key, base_url, model_name, provider, tr):
|
||||||
@ -278,14 +451,74 @@ def test_text_model_connection(api_key, base_url, model_name, provider, tr):
|
|||||||
|
|
||||||
# 特殊处理Gemini
|
# 特殊处理Gemini
|
||||||
if provider.lower() == 'gemini':
|
if provider.lower() == 'gemini':
|
||||||
import google.generativeai as genai
|
# 原生Gemini API测试
|
||||||
try:
|
try:
|
||||||
genai.configure(api_key=api_key)
|
# 构建请求数据
|
||||||
model = genai.GenerativeModel(model_name)
|
request_data = {
|
||||||
model.generate_content("直接回复我文本'当前网络可用'")
|
"contents": [{
|
||||||
return True, tr("Gemini model is available")
|
"parts": [{"text": "直接回复我文本'当前网络可用'"}]
|
||||||
|
}],
|
||||||
|
"generationConfig": {
|
||||||
|
"temperature": 1.0,
|
||||||
|
"topK": 40,
|
||||||
|
"topP": 0.95,
|
||||||
|
"maxOutputTokens": 100,
|
||||||
|
},
|
||||||
|
"safetySettings": [
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HARASSMENT",
|
||||||
|
"threshold": "BLOCK_NONE"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||||
|
"threshold": "BLOCK_NONE"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||||
|
"threshold": "BLOCK_NONE"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||||
|
"threshold": "BLOCK_NONE"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构建请求URL
|
||||||
|
api_base_url = base_url or "https://generativelanguage.googleapis.com/v1beta"
|
||||||
|
url = f"{api_base_url}/models/{model_name}:generateContent?key={api_key}"
|
||||||
|
|
||||||
|
# 发送请求
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
json=request_data,
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
return True, tr("原生Gemini模型连接成功")
|
||||||
|
else:
|
||||||
|
return False, f"{tr('原生Gemini模型连接失败')}: HTTP {response.status_code}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return False, f"{tr('Gemini model is not available')}: {str(e)}"
|
return False, f"{tr('原生Gemini模型连接失败')}: {str(e)}"
|
||||||
|
|
||||||
|
elif provider.lower() == 'gemini(openai)':
|
||||||
|
# OpenAI兼容的Gemini代理测试
|
||||||
|
test_url = f"{base_url.rstrip('/')}/chat/completions"
|
||||||
|
test_data = {
|
||||||
|
"model": model_name,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "直接回复我文本'当前网络可用'"}
|
||||||
|
],
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(test_url, headers=headers, json=test_data, timeout=10)
|
||||||
|
if response.status_code == 200:
|
||||||
|
return True, tr("OpenAI兼容Gemini代理连接成功")
|
||||||
|
else:
|
||||||
|
return False, f"{tr('OpenAI兼容Gemini代理连接失败')}: HTTP {response.status_code}"
|
||||||
else:
|
else:
|
||||||
test_url = f"{base_url.rstrip('/')}/chat/completions"
|
test_url = f"{base_url.rstrip('/')}/chat/completions"
|
||||||
|
|
||||||
@ -322,7 +555,7 @@ def render_text_llm_settings(tr):
|
|||||||
st.subheader(tr("Text Generation Model Settings"))
|
st.subheader(tr("Text Generation Model Settings"))
|
||||||
|
|
||||||
# 文案生成模型提供商选择
|
# 文案生成模型提供商选择
|
||||||
text_providers = ['OpenAI', 'Siliconflow', 'DeepSeek', 'Gemini', 'Qwen', 'Moonshot']
|
text_providers = ['OpenAI', 'Siliconflow', 'DeepSeek', 'Gemini', 'Gemini(OpenAI)', 'Qwen', 'Moonshot']
|
||||||
saved_text_provider = config.app.get("text_llm_provider", "OpenAI").lower()
|
saved_text_provider = config.app.get("text_llm_provider", "OpenAI").lower()
|
||||||
saved_provider_index = 0
|
saved_provider_index = 0
|
||||||
|
|
||||||
@ -346,32 +579,108 @@ def render_text_llm_settings(tr):
|
|||||||
|
|
||||||
# 渲染文本模型配置输入框
|
# 渲染文本模型配置输入框
|
||||||
st_text_api_key = st.text_input(tr("Text API Key"), value=text_api_key, type="password")
|
st_text_api_key = st.text_input(tr("Text API Key"), value=text_api_key, type="password")
|
||||||
st_text_base_url = st.text_input(tr("Text Base URL"), value=text_base_url)
|
|
||||||
st_text_model_name = st.text_input(tr("Text Model Name"), value=text_model_name)
|
# 根据不同提供商设置默认值和帮助信息
|
||||||
|
if text_provider == 'gemini':
|
||||||
|
st_text_base_url = st.text_input(
|
||||||
|
tr("Text Base URL"),
|
||||||
|
value=text_base_url or "https://generativelanguage.googleapis.com/v1beta",
|
||||||
|
help=tr("原生Gemini API端点,默认: https://generativelanguage.googleapis.com/v1beta")
|
||||||
|
)
|
||||||
|
st_text_model_name = st.text_input(
|
||||||
|
tr("Text Model Name"),
|
||||||
|
value=text_model_name or "gemini-2.0-flash-exp",
|
||||||
|
help=tr("原生Gemini模型,默认: gemini-2.0-flash-exp")
|
||||||
|
)
|
||||||
|
elif text_provider == 'gemini(openai)':
|
||||||
|
st_text_base_url = st.text_input(
|
||||||
|
tr("Text Base URL"),
|
||||||
|
value=text_base_url or "https://generativelanguage.googleapis.com/v1beta/openai",
|
||||||
|
help=tr("OpenAI兼容的Gemini代理端点,如: https://your-proxy.com/v1")
|
||||||
|
)
|
||||||
|
st_text_model_name = st.text_input(
|
||||||
|
tr("Text Model Name"),
|
||||||
|
value=text_model_name or "gemini-2.0-flash-exp",
|
||||||
|
help=tr("OpenAI格式的Gemini模型名称,默认: gemini-2.0-flash-exp")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
st_text_base_url = st.text_input(tr("Text Base URL"), value=text_base_url)
|
||||||
|
st_text_model_name = st.text_input(tr("Text Model Name"), value=text_model_name)
|
||||||
|
|
||||||
# 添加测试按钮
|
# 添加测试按钮
|
||||||
if st.button(tr("Test Connection"), key="test_text_connection"):
|
if st.button(tr("Test Connection"), key="test_text_connection"):
|
||||||
with st.spinner(tr("Testing connection...")):
|
# 先验证配置
|
||||||
success, message = test_text_model_connection(
|
test_errors = []
|
||||||
api_key=st_text_api_key,
|
if not st_text_api_key:
|
||||||
base_url=st_text_base_url,
|
test_errors.append("请先输入API密钥")
|
||||||
model_name=st_text_model_name,
|
if not st_text_model_name:
|
||||||
provider=text_provider,
|
test_errors.append("请先输入模型名称")
|
||||||
tr=tr
|
|
||||||
)
|
|
||||||
|
|
||||||
if success:
|
|
||||||
st.success(message)
|
|
||||||
else:
|
|
||||||
st.error(message)
|
|
||||||
|
|
||||||
# 保存文本模型配置
|
if test_errors:
|
||||||
|
for error in test_errors:
|
||||||
|
st.error(error)
|
||||||
|
else:
|
||||||
|
with st.spinner(tr("Testing connection...")):
|
||||||
|
try:
|
||||||
|
success, message = test_text_model_connection(
|
||||||
|
api_key=st_text_api_key,
|
||||||
|
base_url=st_text_base_url,
|
||||||
|
model_name=st_text_model_name,
|
||||||
|
provider=text_provider,
|
||||||
|
tr=tr
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
st.success(message)
|
||||||
|
else:
|
||||||
|
st.error(message)
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"测试连接时发生错误: {str(e)}")
|
||||||
|
logger.error(f"文案生成模型连接测试失败: {str(e)}")
|
||||||
|
|
||||||
|
# 验证和保存文本模型配置
|
||||||
|
text_validation_errors = []
|
||||||
|
text_config_changed = False
|
||||||
|
|
||||||
|
# 验证API密钥
|
||||||
if st_text_api_key:
|
if st_text_api_key:
|
||||||
config.app[f"text_{text_provider}_api_key"] = st_text_api_key
|
is_valid, error_msg = validate_api_key(st_text_api_key, f"文案生成({text_provider})")
|
||||||
|
if is_valid:
|
||||||
|
config.app[f"text_{text_provider}_api_key"] = st_text_api_key
|
||||||
|
text_config_changed = True
|
||||||
|
else:
|
||||||
|
text_validation_errors.append(error_msg)
|
||||||
|
|
||||||
|
# 验证Base URL
|
||||||
if st_text_base_url:
|
if st_text_base_url:
|
||||||
config.app[f"text_{text_provider}_base_url"] = st_text_base_url
|
is_valid, error_msg = validate_base_url(st_text_base_url, f"文案生成({text_provider})")
|
||||||
|
if is_valid:
|
||||||
|
config.app[f"text_{text_provider}_base_url"] = st_text_base_url
|
||||||
|
text_config_changed = True
|
||||||
|
else:
|
||||||
|
text_validation_errors.append(error_msg)
|
||||||
|
|
||||||
|
# 验证模型名称
|
||||||
if st_text_model_name:
|
if st_text_model_name:
|
||||||
config.app[f"text_{text_provider}_model_name"] = st_text_model_name
|
is_valid, error_msg = validate_model_name(st_text_model_name, f"文案生成({text_provider})")
|
||||||
|
if is_valid:
|
||||||
|
config.app[f"text_{text_provider}_model_name"] = st_text_model_name
|
||||||
|
text_config_changed = True
|
||||||
|
else:
|
||||||
|
text_validation_errors.append(error_msg)
|
||||||
|
|
||||||
|
# 显示验证错误
|
||||||
|
show_config_validation_errors(text_validation_errors)
|
||||||
|
|
||||||
|
# 如果配置有变化且没有验证错误,保存到文件
|
||||||
|
if text_config_changed and not text_validation_errors:
|
||||||
|
try:
|
||||||
|
config.save_config()
|
||||||
|
if st_text_api_key or st_text_base_url or st_text_model_name:
|
||||||
|
st.success(f"文案生成模型({text_provider})配置已保存")
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"保存配置失败: {str(e)}")
|
||||||
|
logger.error(f"保存文案生成配置失败: {str(e)}")
|
||||||
|
|
||||||
# # Cloudflare 特殊配置
|
# # Cloudflare 特殊配置
|
||||||
# if text_provider == 'cloudflare':
|
# if text_provider == 'cloudflare':
|
||||||
|
|||||||
@ -23,11 +23,14 @@ def create_vision_analyzer(provider, api_key, model, base_url):
|
|||||||
VisionAnalyzer 或 QwenAnalyzer 实例
|
VisionAnalyzer 或 QwenAnalyzer 实例
|
||||||
"""
|
"""
|
||||||
if provider == 'gemini':
|
if provider == 'gemini':
|
||||||
return gemini_analyzer.VisionAnalyzer(model_name=model, api_key=api_key)
|
return gemini_analyzer.VisionAnalyzer(model_name=model, api_key=api_key, base_url=base_url)
|
||||||
|
elif provider == 'gemini(openai)':
|
||||||
|
from app.utils.gemini_openai_analyzer import GeminiOpenAIAnalyzer
|
||||||
|
return GeminiOpenAIAnalyzer(model_name=model, api_key=api_key, base_url=base_url)
|
||||||
else:
|
else:
|
||||||
# 只传入必要的参数
|
# 只传入必要的参数
|
||||||
return qwenvl_analyzer.QwenAnalyzer(
|
return qwenvl_analyzer.QwenAnalyzer(
|
||||||
model_name=model,
|
model_name=model,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=base_url
|
base_url=base_url
|
||||||
)
|
)
|
||||||
|
|||||||
@ -16,6 +16,89 @@ from loguru import logger
|
|||||||
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.services.SDE.short_drama_explanation import analyze_subtitle, generate_narration_script
|
from app.services.SDE.short_drama_explanation import analyze_subtitle, generate_narration_script
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def parse_and_fix_json(json_string):
|
||||||
|
"""
|
||||||
|
解析并修复JSON字符串
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_string: 待解析的JSON字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 解析后的字典,如果解析失败返回None
|
||||||
|
"""
|
||||||
|
if not json_string or not json_string.strip():
|
||||||
|
logger.error("JSON字符串为空")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 清理字符串
|
||||||
|
json_string = json_string.strip()
|
||||||
|
|
||||||
|
# 尝试直接解析
|
||||||
|
try:
|
||||||
|
return json.loads(json_string)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning(f"直接JSON解析失败: {e}")
|
||||||
|
|
||||||
|
# 尝试提取JSON部分
|
||||||
|
try:
|
||||||
|
# 查找JSON代码块
|
||||||
|
json_match = re.search(r'```json\s*(.*?)\s*```', json_string, re.DOTALL)
|
||||||
|
if json_match:
|
||||||
|
json_content = json_match.group(1).strip()
|
||||||
|
logger.info("从代码块中提取JSON内容")
|
||||||
|
return json.loads(json_content)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 尝试查找大括号包围的内容
|
||||||
|
try:
|
||||||
|
# 查找第一个 { 到最后一个 } 的内容
|
||||||
|
start_idx = json_string.find('{')
|
||||||
|
end_idx = json_string.rfind('}')
|
||||||
|
if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
|
||||||
|
json_content = json_string[start_idx:end_idx+1]
|
||||||
|
logger.info("提取大括号包围的JSON内容")
|
||||||
|
return json.loads(json_content)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 尝试修复常见的JSON格式问题
|
||||||
|
try:
|
||||||
|
# 移除注释
|
||||||
|
json_string = re.sub(r'#.*', '', json_string)
|
||||||
|
# 移除多余的逗号
|
||||||
|
json_string = re.sub(r',\s*}', '}', json_string)
|
||||||
|
json_string = re.sub(r',\s*]', ']', json_string)
|
||||||
|
# 修复单引号
|
||||||
|
json_string = re.sub(r"'([^']*)':", r'"\1":', json_string)
|
||||||
|
|
||||||
|
logger.info("尝试修复JSON格式问题后解析")
|
||||||
|
return json.loads(json_string)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 如果所有方法都失败,尝试创建一个基本的结构
|
||||||
|
logger.error(f"所有JSON解析方法都失败,原始内容: {json_string[:200]}...")
|
||||||
|
|
||||||
|
# 尝试从文本中提取关键信息创建基本结构
|
||||||
|
try:
|
||||||
|
# 这是一个简单的回退方案
|
||||||
|
return {
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"_id": 1,
|
||||||
|
"timestamp": "00:00:00,000-00:00:10,000",
|
||||||
|
"picture": "解析失败,使用默认内容",
|
||||||
|
"narration": json_string[:100] + "..." if len(json_string) > 100 else json_string,
|
||||||
|
"OST": 0
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperature):
|
def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperature):
|
||||||
@ -61,7 +144,8 @@ def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperatu
|
|||||||
model=text_model,
|
model=text_model,
|
||||||
base_url=text_base_url,
|
base_url=text_base_url,
|
||||||
save_result=True,
|
save_result=True,
|
||||||
temperature=temperature
|
temperature=temperature,
|
||||||
|
provider=text_provider
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
3. 根据剧情生成解说文案
|
3. 根据剧情生成解说文案
|
||||||
@ -78,7 +162,8 @@ def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperatu
|
|||||||
model=text_model,
|
model=text_model,
|
||||||
base_url=text_base_url,
|
base_url=text_base_url,
|
||||||
save_result=True,
|
save_result=True,
|
||||||
temperature=temperature
|
temperature=temperature,
|
||||||
|
provider=text_provider
|
||||||
)
|
)
|
||||||
|
|
||||||
if narration_result["status"] == "success":
|
if narration_result["status"] == "success":
|
||||||
@ -100,7 +185,20 @@ def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperatu
|
|||||||
|
|
||||||
# 结果转换为JSON字符串
|
# 结果转换为JSON字符串
|
||||||
narration_script = narration_result["narration_script"]
|
narration_script = narration_result["narration_script"]
|
||||||
narration_dict = json.loads(narration_script)
|
|
||||||
|
# 增强JSON解析,包含错误处理和修复
|
||||||
|
narration_dict = parse_and_fix_json(narration_script)
|
||||||
|
if narration_dict is None:
|
||||||
|
st.error("生成的解说文案格式错误,无法解析为JSON")
|
||||||
|
logger.error(f"JSON解析失败,原始内容: {narration_script}")
|
||||||
|
st.stop()
|
||||||
|
|
||||||
|
# 验证JSON结构
|
||||||
|
if 'items' not in narration_dict:
|
||||||
|
st.error("生成的解说文案缺少必要的'items'字段")
|
||||||
|
logger.error(f"JSON结构错误,缺少items字段: {narration_dict}")
|
||||||
|
st.stop()
|
||||||
|
|
||||||
script = json.dumps(narration_dict['items'], ensure_ascii=False, indent=2)
|
script = json.dumps(narration_dict['items'], ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
if script is None:
|
if script is None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user