From dd59d5295d13a7ce289d2d49c3414e3552eb5b86 Mon Sep 17 00:00:00 2001 From: linyq Date: Mon, 7 Jul 2025 15:40:34 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=9B=B4=E6=96=B0=E4=BD=9C=E8=80=85?= =?UTF-8?q?=E4=BF=A1=E6=81=AF=E5=B9=B6=E5=A2=9E=E5=BC=BAAPI=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E9=AA=8C=E8=AF=81=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在基础设置中新增API密钥、基础URL和模型名称的验证功能,确保用户输入的配置有效性,提升系统的稳定性和用户体验。 --- app/config/audio_config.py | 2 +- app/services/SDE/prompt.py | 25 +- app/services/SDE/short_drama_explanation.py | 384 ++++++++++++++--- app/services/audio_normalizer.py | 2 +- app/services/clip_video.py | 2 +- app/services/generate_narration_script.py | 2 +- app/services/generate_video.py | 2 +- app/services/merger_video.py | 2 +- app/services/script_service.py | 57 ++- app/services/update_script.py | 2 +- app/utils/gemini_analyzer.py | 153 ++++++- app/utils/gemini_openai_analyzer.py | 177 ++++++++ app/utils/script_generator.py | 199 +++++++-- webui/components/basic_settings.py | 443 +++++++++++++++++--- webui/tools/base.py | 7 +- webui/tools/generate_short_summary.py | 104 ++++- 16 files changed, 1354 insertions(+), 209 deletions(-) create mode 100644 app/utils/gemini_openai_analyzer.py diff --git a/app/config/audio_config.py b/app/config/audio_config.py index da4cf47..1e2a18a 100644 --- a/app/config/audio_config.py +++ b/app/config/audio_config.py @@ -4,7 +4,7 @@ ''' @Project: NarratoAI @File : audio_config -@Author : 小林同学 +@Author : Viccy同学 @Date : 2025/1/7 @Description: 音频配置管理 ''' diff --git a/app/services/SDE/prompt.py b/app/services/SDE/prompt.py index 78385cc..2ee9ea8 100644 --- a/app/services/SDE/prompt.py +++ b/app/services/SDE/prompt.py @@ -74,24 +74,27 @@ plot_writing = """ %s -请使用 json 格式进行输出;使用 中的输出格式: - +请严格按照以下JSON格式输出,不要添加任何其他文字、说明或代码块标记: + { "items": [ { - "_id": 1, # 唯一递增id + "_id": 1, "timestamp": "00:00:05,390-00:00:10,430", "picture": "剧情描述或者备注", "narration": "解说文案,如果片段为穿插的原片片段,可以直接使用 ‘播放原片+_id‘ 进行占位", - "OST": "值为 0 表示当前片段为解说片段,值为 1 表示当前片段为穿插的原片" + "OST": 0 } + ] } - - -1. 只输出 json 内容,不要输出其他任何说明性的文字 -2. 解说文案的语言使用 简体中文 -3. 严禁虚构剧情,所有画面只能从 中摘取 -4. 严禁虚构时间戳,所有时间戳范围只能从 中摘取 - +重要要求: +1. 必须输出有效的JSON格式,不能包含注释 +2. OST字段必须是数字:0表示解说片段,1表示原片片段 +3. _id必须是递增的数字 +4. 只输出JSON内容,不要输出任何说明文字 +5. 不要使用代码块标记(如```json) +6. 解说文案使用简体中文 +7. 严禁虚构剧情,所有内容只能从中摘取 +8. 严禁虚构时间戳,所有时间戳只能从中摘取 """ \ No newline at end of file diff --git a/app/services/SDE/short_drama_explanation.py b/app/services/SDE/short_drama_explanation.py index 56a460d..c563171 100644 --- a/app/services/SDE/short_drama_explanation.py +++ b/app/services/SDE/short_drama_explanation.py @@ -22,34 +22,44 @@ class SubtitleAnalyzer: """字幕剧情分析器,负责分析字幕内容并提取关键剧情段落""" def __init__( - self, + self, api_key: Optional[str] = None, model: Optional[str] = None, base_url: Optional[str] = None, custom_prompt: Optional[str] = None, temperature: Optional[float] = 1.0, + provider: Optional[str] = None, ): """ 初始化字幕分析器 - + Args: api_key: API密钥,如果不提供则从配置中读取 model: 模型名称,如果不提供则从配置中读取 base_url: API基础URL,如果不提供则从配置中读取或使用默认值 custom_prompt: 自定义提示词,如果不提供则使用默认值 temperature: 模型温度 + provider: 提供商类型,用于确定API调用格式 """ # 使用传入的参数或从配置中获取 self.api_key = api_key self.model = model self.base_url = base_url self.temperature = temperature - + self.provider = provider or self._detect_provider() + # 设置提示词模板 self.prompt_template = custom_prompt or subtitle_plot_analysis_v1 - + + # 根据提供商类型确定是否为原生Gemini + self.is_native_gemini = self.provider.lower() == 'gemini' + # 初始化HTTP请求所需的头信息 self._init_headers() + + def _detect_provider(self): + """根据配置自动检测提供商类型""" + return config.app.get('text_llm_provider', 'gemini').lower() def _init_headers(self): """初始化HTTP请求头""" @@ -67,18 +77,152 @@ class SubtitleAnalyzer: def analyze_subtitle(self, subtitle_content: str) -> Dict[str, Any]: """ 分析字幕内容 - + Args: subtitle_content: 字幕内容文本 - + Returns: Dict[str, Any]: 包含分析结果的字典 """ try: # 构建完整提示词 prompt = f"{self.prompt_template}\n\n{subtitle_content}" - - # 构建请求体数据 + + if self.is_native_gemini: + # 使用原生Gemini API格式 + return self._call_native_gemini_api(prompt) + else: + # 使用OpenAI兼容格式 + return self._call_openai_compatible_api(prompt) + + except Exception as e: + logger.error(f"字幕分析过程中发生错误: {str(e)}") + return { + "status": "error", + "message": str(e), + "temperature": self.temperature + } + + def _call_native_gemini_api(self, prompt: str) -> Dict[str, Any]: + """调用原生Gemini API""" + try: + # 构建原生Gemini API请求数据 + payload = { + "systemInstruction": { + "parts": [{"text": "你是一位专业的剧本分析师和剧情概括助手。请严格按照要求的格式输出分析结果。"}] + }, + "contents": [{ + "parts": [{"text": prompt}] + }], + "generationConfig": { + "temperature": self.temperature, + "topK": 40, + "topP": 0.95, + "maxOutputTokens": 4000, + "candidateCount": 1 + }, + "safetySettings": [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_NONE" + } + ] + } + + # 构建请求URL + url = f"{self.base_url}/models/{self.model}:generateContent?key={self.api_key}" + + # 发送请求 + response = requests.post( + url, + json=payload, + headers={"Content-Type": "application/json", "User-Agent": "NarratoAI/1.0"}, + timeout=120 + ) + + if response.status_code == 200: + response_data = response.json() + + # 检查响应格式 + if "candidates" not in response_data or not response_data["candidates"]: + return { + "status": "error", + "message": "原生Gemini API返回无效响应,可能触发了安全过滤", + "temperature": self.temperature + } + + candidate = response_data["candidates"][0] + + # 检查是否被安全过滤阻止 + if "finishReason" in candidate and candidate["finishReason"] == "SAFETY": + return { + "status": "error", + "message": "内容被Gemini安全过滤器阻止", + "temperature": self.temperature + } + + if "content" not in candidate or "parts" not in candidate["content"]: + return { + "status": "error", + "message": "原生Gemini API返回内容格式错误", + "temperature": self.temperature + } + + # 提取文本内容 + analysis_result = "" + for part in candidate["content"]["parts"]: + if "text" in part: + analysis_result += part["text"] + + if not analysis_result.strip(): + return { + "status": "error", + "message": "原生Gemini API返回空内容", + "temperature": self.temperature + } + + logger.debug(f"原生Gemini字幕分析完成") + + return { + "status": "success", + "analysis": analysis_result, + "tokens_used": response_data.get("usage", {}).get("total_tokens", 0), + "model": self.model, + "temperature": self.temperature + } + else: + error_msg = f"原生Gemini API请求失败,状态码: {response.status_code}, 响应: {response.text}" + logger.error(error_msg) + return { + "status": "error", + "message": error_msg, + "temperature": self.temperature + } + + except Exception as e: + logger.error(f"原生Gemini API调用失败: {str(e)}") + return { + "status": "error", + "message": f"原生Gemini API调用失败: {str(e)}", + "temperature": self.temperature + } + + def _call_openai_compatible_api(self, prompt: str) -> Dict[str, Any]: + """调用OpenAI兼容的API""" + try: + # 构建OpenAI格式的请求数据 payload = { "model": self.model, "messages": [ @@ -87,22 +231,22 @@ class SubtitleAnalyzer: ], "temperature": self.temperature } - + # 构建请求地址 url = f"{self.base_url}/chat/completions" - + # 发送HTTP请求 - response = requests.post(url, headers=self.headers, json=payload) - + response = requests.post(url, headers=self.headers, json=payload, timeout=120) + # 解析响应 if response.status_code == 200: response_data = response.json() - + # 提取响应内容 if "choices" in response_data and len(response_data["choices"]) > 0: analysis_result = response_data["choices"][0]["message"]["content"] - logger.debug(f"字幕分析完成,消耗的tokens: {response_data.get('usage', {}).get('total_tokens', 0)}") - + logger.debug(f"OpenAI兼容API字幕分析完成,消耗的tokens: {response_data.get('usage', {}).get('total_tokens', 0)}") + # 返回结果 return { "status": "success", @@ -112,26 +256,26 @@ class SubtitleAnalyzer: "temperature": self.temperature } else: - logger.error("字幕分析失败: 未获取到有效响应") + logger.error("OpenAI兼容API字幕分析失败: 未获取到有效响应") return { "status": "error", "message": "未获取到有效响应", "temperature": self.temperature } else: - error_msg = f"请求失败,状态码: {response.status_code}, 响应: {response.text}" + error_msg = f"OpenAI兼容API请求失败,状态码: {response.status_code}, 响应: {response.text}" logger.error(error_msg) return { "status": "error", "message": error_msg, "temperature": self.temperature } - + except Exception as e: - logger.error(f"字幕分析过程中发生错误: {str(e)}") + logger.error(f"OpenAI兼容API调用失败: {str(e)}") return { "status": "error", - "message": str(e), + "message": f"OpenAI兼容API调用失败: {str(e)}", "temperature": self.temperature } @@ -206,12 +350,12 @@ class SubtitleAnalyzer: def generate_narration_script(self, short_name:str, plot_analysis: str, temperature: float = 0.7) -> Dict[str, Any]: """ 根据剧情分析生成解说文案 - + Args: short_name: 短剧名称 plot_analysis: 剧情分析内容 temperature: 生成温度,控制创造性,默认0.7 - + Returns: Dict[str, Any]: 包含生成结果的字典 """ @@ -219,7 +363,145 @@ class SubtitleAnalyzer: # 构建完整提示词 prompt = plot_writing % (short_name, plot_analysis) - # 构建请求体数据 + if self.is_native_gemini: + # 使用原生Gemini API格式 + return self._generate_narration_with_native_gemini(prompt, temperature) + else: + # 使用OpenAI兼容格式 + return self._generate_narration_with_openai_compatible(prompt, temperature) + + except Exception as e: + logger.error(f"解说文案生成过程中发生错误: {str(e)}") + return { + "status": "error", + "message": str(e), + "temperature": self.temperature + } + + def _generate_narration_with_native_gemini(self, prompt: str, temperature: float) -> Dict[str, Any]: + """使用原生Gemini API生成解说文案""" + try: + # 构建原生Gemini API请求数据 + # 为了确保JSON输出,在提示词中添加更强的约束 + enhanced_prompt = f"{prompt}\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。" + + payload = { + "systemInstruction": { + "parts": [{"text": "你是一位专业的短视频解说脚本撰写专家。你必须严格按照JSON格式输出,不能包含任何其他文字、说明或代码块标记。"}] + }, + "contents": [{ + "parts": [{"text": enhanced_prompt}] + }], + "generationConfig": { + "temperature": temperature, + "topK": 40, + "topP": 0.95, + "maxOutputTokens": 4000, + "candidateCount": 1, + "stopSequences": ["```", "注意", "说明"] + }, + "safetySettings": [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_NONE" + } + ] + } + + # 构建请求URL + url = f"{self.base_url}/models/{self.model}:generateContent?key={self.api_key}" + + # 发送请求 + response = requests.post( + url, + json=payload, + headers={"Content-Type": "application/json", "User-Agent": "NarratoAI/1.0"}, + timeout=120 + ) + + if response.status_code == 200: + response_data = response.json() + + # 检查响应格式 + if "candidates" not in response_data or not response_data["candidates"]: + return { + "status": "error", + "message": "原生Gemini API返回无效响应,可能触发了安全过滤", + "temperature": temperature + } + + candidate = response_data["candidates"][0] + + # 检查是否被安全过滤阻止 + if "finishReason" in candidate and candidate["finishReason"] == "SAFETY": + return { + "status": "error", + "message": "内容被Gemini安全过滤器阻止", + "temperature": temperature + } + + if "content" not in candidate or "parts" not in candidate["content"]: + return { + "status": "error", + "message": "原生Gemini API返回内容格式错误", + "temperature": temperature + } + + # 提取文本内容 + narration_script = "" + for part in candidate["content"]["parts"]: + if "text" in part: + narration_script += part["text"] + + if not narration_script.strip(): + return { + "status": "error", + "message": "原生Gemini API返回空内容", + "temperature": temperature + } + + logger.debug(f"原生Gemini解说文案生成完成") + + return { + "status": "success", + "narration_script": narration_script, + "tokens_used": response_data.get("usage", {}).get("total_tokens", 0), + "model": self.model, + "temperature": temperature + } + else: + error_msg = f"原生Gemini API请求失败,状态码: {response.status_code}, 响应: {response.text}" + logger.error(error_msg) + return { + "status": "error", + "message": error_msg, + "temperature": temperature + } + + except Exception as e: + logger.error(f"原生Gemini API解说文案生成失败: {str(e)}") + return { + "status": "error", + "message": f"原生Gemini API解说文案生成失败: {str(e)}", + "temperature": temperature + } + + def _generate_narration_with_openai_compatible(self, prompt: str, temperature: float) -> Dict[str, Any]: + """使用OpenAI兼容API生成解说文案""" + try: + # 构建OpenAI格式的请求数据 payload = { "model": self.model, "messages": [ @@ -228,56 +510,56 @@ class SubtitleAnalyzer: ], "temperature": temperature } - + # 对特定模型添加响应格式设置 if self.model not in ["deepseek-reasoner"]: payload["response_format"] = {"type": "json_object"} - + # 构建请求地址 url = f"{self.base_url}/chat/completions" - + # 发送HTTP请求 - response = requests.post(url, headers=self.headers, json=payload) - + response = requests.post(url, headers=self.headers, json=payload, timeout=120) + # 解析响应 if response.status_code == 200: response_data = response.json() - + # 提取响应内容 if "choices" in response_data and len(response_data["choices"]) > 0: narration_script = response_data["choices"][0]["message"]["content"] - logger.debug(f"解说文案生成完成,消耗的tokens: {response_data.get('usage', {}).get('total_tokens', 0)}") - + logger.debug(f"OpenAI兼容API解说文案生成完成,消耗的tokens: {response_data.get('usage', {}).get('total_tokens', 0)}") + # 返回结果 return { "status": "success", "narration_script": narration_script, "tokens_used": response_data.get("usage", {}).get("total_tokens", 0), "model": self.model, - "temperature": self.temperature + "temperature": temperature } else: - logger.error("解说文案生成失败: 未获取到有效响应") + logger.error("OpenAI兼容API解说文案生成失败: 未获取到有效响应") return { "status": "error", "message": "未获取到有效响应", - "temperature": self.temperature + "temperature": temperature } else: - error_msg = f"请求失败,状态码: {response.status_code}, 响应: {response.text}" + error_msg = f"OpenAI兼容API请求失败,状态码: {response.status_code}, 响应: {response.text}" logger.error(error_msg) return { "status": "error", "message": error_msg, - "temperature": self.temperature + "temperature": temperature } - + except Exception as e: - logger.error(f"解说文案生成过程中发生错误: {str(e)}") + logger.error(f"OpenAI兼容API解说文案生成失败: {str(e)}") return { "status": "error", - "message": str(e), - "temperature": self.temperature + "message": f"OpenAI兼容API解说文案生成失败: {str(e)}", + "temperature": temperature } def save_narration_script(self, narration_result: Dict[str, Any], output_path: Optional[str] = None) -> str: @@ -324,11 +606,12 @@ def analyze_subtitle( custom_prompt: Optional[str] = None, temperature: float = 1.0, save_result: bool = False, - output_path: Optional[str] = None + output_path: Optional[str] = None, + provider: Optional[str] = None ) -> Dict[str, Any]: """ 分析字幕内容的便捷函数 - + Args: subtitle_content: 字幕内容文本 subtitle_file_path: 字幕文件路径 @@ -339,7 +622,8 @@ def analyze_subtitle( temperature: 模型温度 save_result: 是否保存结果到文件 output_path: 输出文件路径 - + provider: 提供商类型 + Returns: Dict[str, Any]: 包含分析结果的字典 """ @@ -349,7 +633,8 @@ def analyze_subtitle( api_key=api_key, model=model, base_url=base_url, - custom_prompt=custom_prompt + custom_prompt=custom_prompt, + provider=provider ) logger.debug(f"使用模型: {analyzer.model} 开始分析, 温度: {analyzer.temperature}") # 分析字幕 @@ -379,11 +664,12 @@ def generate_narration_script( base_url: Optional[str] = None, temperature: float = 1.0, save_result: bool = False, - output_path: Optional[str] = None + output_path: Optional[str] = None, + provider: Optional[str] = None ) -> Dict[str, Any]: """ 根据剧情分析生成解说文案的便捷函数 - + Args: short_name: 短剧名称 plot_analysis: 剧情分析内容,直接提供 @@ -393,7 +679,8 @@ def generate_narration_script( temperature: 生成温度,控制创造性 save_result: 是否保存结果到文件 output_path: 输出文件路径 - + provider: 提供商类型 + Returns: Dict[str, Any]: 包含生成结果的字典 """ @@ -402,7 +689,8 @@ def generate_narration_script( temperature=temperature, api_key=api_key, model=model, - base_url=base_url + base_url=base_url, + provider=provider ) # 生成解说文案 diff --git a/app/services/audio_normalizer.py b/app/services/audio_normalizer.py index 25ba4ee..b0796b8 100644 --- a/app/services/audio_normalizer.py +++ b/app/services/audio_normalizer.py @@ -4,7 +4,7 @@ ''' @Project: NarratoAI @File : audio_normalizer -@Author : 小林同学 +@Author : Viccy同学 @Date : 2025/1/7 @Description: 音频响度分析和标准化工具 ''' diff --git a/app/services/clip_video.py b/app/services/clip_video.py index 81794e4..65b97ea 100644 --- a/app/services/clip_video.py +++ b/app/services/clip_video.py @@ -4,7 +4,7 @@ ''' @Project: NarratoAI @File : clip_video -@Author : 小林同学 +@Author : Viccy同学 @Date : 2025/5/6 下午6:14 ''' diff --git a/app/services/generate_narration_script.py b/app/services/generate_narration_script.py index f6640db..91b6281 100644 --- a/app/services/generate_narration_script.py +++ b/app/services/generate_narration_script.py @@ -4,7 +4,7 @@ ''' @Project: NarratoAI @File : 生成介绍文案 -@Author : 小林同学 +@Author : Viccy同学 @Date : 2025/5/8 上午11:33 ''' diff --git a/app/services/generate_video.py b/app/services/generate_video.py index 9b03b52..8395aef 100644 --- a/app/services/generate_video.py +++ b/app/services/generate_video.py @@ -4,7 +4,7 @@ ''' @Project: NarratoAI @File : generate_video -@Author : 小林同学 +@Author : Viccy同学 @Date : 2025/5/7 上午11:55 ''' diff --git a/app/services/merger_video.py b/app/services/merger_video.py index 75c686b..82ed20e 100644 --- a/app/services/merger_video.py +++ b/app/services/merger_video.py @@ -4,7 +4,7 @@ ''' @Project: NarratoAI @File : merger_video -@Author : 小林同学 +@Author : Viccy同学 @Date : 2025/5/6 下午7:38 ''' diff --git a/app/services/script_service.py b/app/services/script_service.py index 461978b..e9ff042 100644 --- a/app/services/script_service.py +++ b/app/services/script_service.py @@ -140,14 +140,27 @@ class ScriptGenerator: # 获取Gemini配置 vision_api_key = config.app.get("vision_gemini_api_key") vision_model = config.app.get("vision_gemini_model_name") - + vision_base_url = config.app.get("vision_gemini_base_url") + if not vision_api_key or not vision_model: raise ValueError("未配置 Gemini API Key 或者模型") - analyzer = gemini_analyzer.VisionAnalyzer( - model_name=vision_model, - api_key=vision_api_key, - ) + # 根据提供商类型选择合适的分析器 + if vision_provider == 'gemini(openai)': + # 使用OpenAI兼容的Gemini代理 + from app.utils.gemini_openai_analyzer import GeminiOpenAIAnalyzer + analyzer = GeminiOpenAIAnalyzer( + model_name=vision_model, + api_key=vision_api_key, + base_url=vision_base_url + ) + else: + # 使用原生Gemini分析器 + analyzer = gemini_analyzer.VisionAnalyzer( + model_name=vision_model, + api_key=vision_api_key, + base_url=vision_base_url + ) progress_callback(40, "正在分析关键帧...") @@ -213,13 +226,35 @@ class ScriptGenerator: text_provider = config.app.get('text_llm_provider', 'gemini').lower() text_api_key = config.app.get(f'text_{text_provider}_api_key') text_model = config.app.get(f'text_{text_provider}_model_name') + text_base_url = config.app.get(f'text_{text_provider}_base_url') - processor = ScriptProcessor( - model_name=text_model, - api_key=text_api_key, - prompt=custom_prompt, - video_theme=video_theme - ) + # 根据提供商类型选择合适的处理器 + if text_provider == 'gemini(openai)': + # 使用OpenAI兼容的Gemini代理 + from app.utils.script_generator import GeminiOpenAIGenerator + generator = GeminiOpenAIGenerator( + model_name=text_model, + api_key=text_api_key, + prompt=custom_prompt, + base_url=text_base_url + ) + processor = ScriptProcessor( + model_name=text_model, + api_key=text_api_key, + base_url=text_base_url, + prompt=custom_prompt, + video_theme=video_theme + ) + processor.generator = generator + else: + # 使用标准处理器(包括原生Gemini) + processor = ScriptProcessor( + model_name=text_model, + api_key=text_api_key, + base_url=text_base_url, + prompt=custom_prompt, + video_theme=video_theme + ) return processor.process_frames(frame_content_list) diff --git a/app/services/update_script.py b/app/services/update_script.py index 2eb9663..9dd32d2 100644 --- a/app/services/update_script.py +++ b/app/services/update_script.py @@ -4,7 +4,7 @@ ''' @Project: NarratoAI @File : update_script -@Author : 小林同学 +@Author : Viccy同学 @Date : 2025/5/6 下午11:00 ''' diff --git a/app/utils/gemini_analyzer.py b/app/utils/gemini_analyzer.py index 7236a9e..c3685ab 100644 --- a/app/utils/gemini_analyzer.py +++ b/app/utils/gemini_analyzer.py @@ -5,53 +5,162 @@ from pathlib import Path from loguru import logger from tqdm import tqdm import asyncio -from tenacity import retry, stop_after_attempt, RetryError, retry_if_exception_type, wait_exponential -from google.api_core import exceptions -import google.generativeai as genai +from tenacity import retry, stop_after_attempt, retry_if_exception_type, wait_exponential +import requests import PIL.Image import traceback +import base64 +import io from app.utils import utils class VisionAnalyzer: - """视觉分析器类""" + """原生Gemini视觉分析器类""" - def __init__(self, model_name: str = "gemini-1.5-flash", api_key: str = None): + def __init__(self, model_name: str = "gemini-2.0-flash-exp", api_key: str = None, base_url: str = None): """初始化视觉分析器""" if not api_key: raise ValueError("必须提供API密钥") self.model_name = model_name self.api_key = api_key + self.base_url = base_url or "https://generativelanguage.googleapis.com/v1beta" # 初始化配置 self._configure_client() def _configure_client(self): - """配置API客户端""" - genai.configure(api_key=self.api_key) - # 开放 Gemini 模型安全设置 - from google.generativeai.types import HarmCategory, HarmBlockThreshold - safety_settings = { - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, - } - self.model = genai.GenerativeModel(self.model_name, safety_settings=safety_settings) + """配置原生Gemini API客户端""" + # 使用原生Gemini REST API + self.client = None + logger.info(f"配置原生Gemini API,端点: {self.base_url}, 模型: {self.model_name}") @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type(exceptions.ResourceExhausted) + retry=retry_if_exception_type(requests.exceptions.RequestException) ) async def _generate_content_with_retry(self, prompt, batch): - """使用重试机制的内部方法来调用 generate_content_async""" + """使用重试机制调用原生Gemini API""" try: - return await self.model.generate_content_async([prompt, *batch]) - except exceptions.ResourceExhausted as e: - print(f"API配额限制: {str(e)}") - raise RetryError("API调用失败") + return await self._generate_with_gemini_api(prompt, batch) + except requests.exceptions.RequestException as e: + logger.warning(f"Gemini API请求异常: {str(e)}") + raise + except Exception as e: + logger.error(f"Gemini API生成内容时发生错误: {str(e)}") + raise + + async def _generate_with_gemini_api(self, prompt, batch): + """使用原生Gemini REST API生成内容""" + # 将PIL图片转换为base64编码 + image_parts = [] + for img in batch: + # 将PIL图片转换为字节流 + img_buffer = io.BytesIO() + img.save(img_buffer, format='JPEG', quality=85) # 优化图片质量 + img_bytes = img_buffer.getvalue() + + # 转换为base64 + img_base64 = base64.b64encode(img_bytes).decode('utf-8') + image_parts.append({ + "inline_data": { + "mime_type": "image/jpeg", + "data": img_base64 + } + }) + + # 构建符合官方文档的请求数据 + request_data = { + "contents": [{ + "parts": [ + {"text": prompt}, + *image_parts + ] + }], + "generationConfig": { + "temperature": 1.0, + "topK": 40, + "topP": 0.95, + "maxOutputTokens": 8192, + "candidateCount": 1, + "stopSequences": [] + }, + "safetySettings": [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_NONE" + } + ] + } + + # 构建请求URL + url = f"{self.base_url}/models/{self.model_name}:generateContent?key={self.api_key}" + + # 发送请求 + response = await asyncio.to_thread( + requests.post, + url, + json=request_data, + headers={ + "Content-Type": "application/json", + "User-Agent": "NarratoAI/1.0" + }, + timeout=120 # 增加超时时间 + ) + + # 处理HTTP错误 + if response.status_code == 429: + raise requests.exceptions.RequestException(f"API配额限制: {response.text}") + elif response.status_code == 400: + raise Exception(f"请求参数错误: {response.text}") + elif response.status_code == 403: + raise Exception(f"API密钥无效或权限不足: {response.text}") + elif response.status_code != 200: + raise Exception(f"Gemini API请求失败: {response.status_code} - {response.text}") + + response_data = response.json() + + # 检查响应格式 + if "candidates" not in response_data or not response_data["candidates"]: + raise Exception("Gemini API返回无效响应,可能触发了安全过滤") + + candidate = response_data["candidates"][0] + + # 检查是否被安全过滤阻止 + if "finishReason" in candidate and candidate["finishReason"] == "SAFETY": + raise Exception("内容被Gemini安全过滤器阻止") + + if "content" not in candidate or "parts" not in candidate["content"]: + raise Exception("Gemini API返回内容格式错误") + + # 提取文本内容 + text_content = "" + for part in candidate["content"]["parts"]: + if "text" in part: + text_content += part["text"] + + if not text_content.strip(): + raise Exception("Gemini API返回空内容") + + # 创建兼容的响应对象 + class CompatibleResponse: + def __init__(self, text): + self.text = text + + return CompatibleResponse(text_content) async def analyze_images(self, images: Union[List[str], List[PIL.Image.Image]], diff --git a/app/utils/gemini_openai_analyzer.py b/app/utils/gemini_openai_analyzer.py new file mode 100644 index 0000000..9d2ca0b --- /dev/null +++ b/app/utils/gemini_openai_analyzer.py @@ -0,0 +1,177 @@ +""" +OpenAI兼容的Gemini视觉分析器 +使用标准OpenAI格式调用Gemini代理服务 +""" + +import json +from typing import List, Union, Dict +import os +from pathlib import Path +from loguru import logger +from tqdm import tqdm +import asyncio +from tenacity import retry, stop_after_attempt, retry_if_exception_type, wait_exponential +import requests +import PIL.Image +import traceback +import base64 +import io +from app.utils import utils + + +class GeminiOpenAIAnalyzer: + """OpenAI兼容的Gemini视觉分析器类""" + + def __init__(self, model_name: str = "gemini-2.0-flash-exp", api_key: str = None, base_url: str = None): + """初始化OpenAI兼容的Gemini分析器""" + if not api_key: + raise ValueError("必须提供API密钥") + + if not base_url: + raise ValueError("必须提供OpenAI兼容的代理端点URL") + + self.model_name = model_name + self.api_key = api_key + self.base_url = base_url.rstrip('/') + + # 初始化OpenAI客户端 + self._configure_client() + + def _configure_client(self): + """配置OpenAI兼容的客户端""" + from openai import OpenAI + self.client = OpenAI( + api_key=self.api_key, + base_url=self.base_url + ) + logger.info(f"配置OpenAI兼容Gemini代理,端点: {self.base_url}, 模型: {self.model_name}") + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((requests.exceptions.RequestException, Exception)) + ) + async def _generate_content_with_retry(self, prompt, batch): + """使用重试机制调用OpenAI兼容的Gemini代理""" + try: + return await self._generate_with_openai_api(prompt, batch) + except Exception as e: + logger.warning(f"OpenAI兼容Gemini代理请求异常: {str(e)}") + raise + + async def _generate_with_openai_api(self, prompt, batch): + """使用OpenAI兼容接口生成内容""" + # 将PIL图片转换为base64编码 + image_contents = [] + for img in batch: + # 将PIL图片转换为字节流 + img_buffer = io.BytesIO() + img.save(img_buffer, format='JPEG', quality=85) + img_bytes = img_buffer.getvalue() + + # 转换为base64 + img_base64 = base64.b64encode(img_bytes).decode('utf-8') + image_contents.append({ + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{img_base64}" + } + }) + + # 构建OpenAI格式的消息 + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + *image_contents + ] + } + ] + + # 调用OpenAI兼容接口 + response = await asyncio.to_thread( + self.client.chat.completions.create, + model=self.model_name, + messages=messages, + max_tokens=4000, + temperature=1.0 + ) + + # 创建兼容的响应对象 + class CompatibleResponse: + def __init__(self, text): + self.text = text + + return CompatibleResponse(response.choices[0].message.content) + + async def analyze_images(self, + images: List[Union[str, Path, PIL.Image.Image]], + prompt: str, + batch_size: int = 10) -> List[str]: + """ + 分析图片并返回结果 + + Args: + images: 图片路径列表或PIL图片对象列表 + prompt: 分析提示词 + batch_size: 批处理大小 + + Returns: + 分析结果列表 + """ + logger.info(f"开始分析 {len(images)} 张图片,使用OpenAI兼容Gemini代理") + + # 加载图片 + loaded_images = [] + for img in images: + if isinstance(img, (str, Path)): + try: + pil_img = PIL.Image.open(img) + # 调整图片大小以优化性能 + if pil_img.size[0] > 1024 or pil_img.size[1] > 1024: + pil_img.thumbnail((1024, 1024), PIL.Image.Resampling.LANCZOS) + loaded_images.append(pil_img) + except Exception as e: + logger.error(f"加载图片失败 {img}: {str(e)}") + continue + elif isinstance(img, PIL.Image.Image): + loaded_images.append(img) + else: + logger.warning(f"不支持的图片类型: {type(img)}") + continue + + if not loaded_images: + raise ValueError("没有有效的图片可以分析") + + # 分批处理 + results = [] + total_batches = (len(loaded_images) + batch_size - 1) // batch_size + + for i in tqdm(range(0, len(loaded_images), batch_size), + desc="分析图片批次", total=total_batches): + batch = loaded_images[i:i + batch_size] + + try: + response = await self._generate_content_with_retry(prompt, batch) + results.append(response.text) + + # 添加延迟以避免API限流 + if i + batch_size < len(loaded_images): + await asyncio.sleep(1) + + except Exception as e: + logger.error(f"分析批次 {i//batch_size + 1} 失败: {str(e)}") + results.append(f"分析失败: {str(e)}") + + logger.info(f"完成图片分析,共处理 {len(results)} 个批次") + return results + + def analyze_images_sync(self, + images: List[Union[str, Path, PIL.Image.Image]], + prompt: str, + batch_size: int = 10) -> List[str]: + """ + 同步版本的图片分析方法 + """ + return asyncio.run(self.analyze_images(images, prompt, batch_size)) diff --git a/app/utils/script_generator.py b/app/utils/script_generator.py index 7020782..e6d7cea 100644 --- a/app/utils/script_generator.py +++ b/app/utils/script_generator.py @@ -6,7 +6,7 @@ from loguru import logger from typing import List, Dict from datetime import datetime from openai import OpenAI -import google.generativeai as genai +import requests import time @@ -134,59 +134,182 @@ class OpenAIGenerator(BaseGenerator): class GeminiGenerator(BaseGenerator): - """Google Gemini API 生成器实现""" - def __init__(self, model_name: str, api_key: str, prompt: str): + """原生Gemini API 生成器实现""" + def __init__(self, model_name: str, api_key: str, prompt: str, base_url: str = None): super().__init__(model_name, api_key, prompt) - genai.configure(api_key=api_key) - self.model = genai.GenerativeModel(model_name) - - # Gemini特定参数 + + self.base_url = base_url or "https://generativelanguage.googleapis.com/v1beta" + self.client = None + + # 原生Gemini API参数 self.default_params = { "temperature": self.default_params["temperature"], - "top_p": self.default_params["top_p"], - "candidate_count": 1, - "stop_sequences": None + "topP": self.default_params["top_p"], + "topK": 40, + "maxOutputTokens": 4000, + "candidateCount": 1, + "stopSequences": [] + } + + +class GeminiOpenAIGenerator(BaseGenerator): + """OpenAI兼容的Gemini代理生成器实现""" + def __init__(self, model_name: str, api_key: str, prompt: str, base_url: str = None): + super().__init__(model_name, api_key, prompt) + + if not base_url: + raise ValueError("OpenAI兼容的Gemini代理必须提供base_url") + + self.base_url = base_url.rstrip('/') + + # 使用OpenAI兼容接口 + from openai import OpenAI + self.client = OpenAI( + api_key=api_key, + base_url=base_url + ) + + # OpenAI兼容接口参数 + self.default_params = { + "temperature": self.default_params["temperature"], + "max_tokens": 4000, + "stream": False } def _generate(self, messages: list, params: dict) -> any: - """实现Gemini特定的生成逻辑""" - while True: + """实现OpenAI兼容Gemini代理的生成逻辑""" + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=messages, + **params + ) + return response + except Exception as e: + logger.error(f"OpenAI兼容Gemini代理生成错误: {str(e)}") + raise + + def _process_response(self, response: any) -> str: + """处理OpenAI兼容接口的响应""" + if not response or not response.choices: + raise ValueError("OpenAI兼容Gemini代理返回无效响应") + return response.choices[0].message.content.strip() + + def _generate(self, messages: list, params: dict) -> any: + """实现原生Gemini API的生成逻辑""" + max_retries = 3 + for attempt in range(max_retries): try: # 转换消息格式为Gemini格式 prompt = "\n".join([m["content"] for m in messages]) - response = self.model.generate_content( - prompt, - generation_config=params + + # 构建请求数据 + request_data = { + "contents": [{ + "parts": [{"text": prompt}] + }], + "generationConfig": params, + "safetySettings": [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_NONE" + } + ] + } + + # 构建请求URL + url = f"{self.base_url}/models/{self.model_name}:generateContent?key={self.api_key}" + + # 发送请求 + response = requests.post( + url, + json=request_data, + headers={ + "Content-Type": "application/json", + "User-Agent": "NarratoAI/1.0" + }, + timeout=120 ) - - # 检查响应是否包含有效内容 - if (hasattr(response, 'result') and - hasattr(response.result, 'candidates') and - response.result.candidates): - - candidate = response.result.candidates[0] - - # 检查是否有内容字段 - if not hasattr(candidate, 'content'): - logger.warning("Gemini API 返回速率限制响应,等待30秒后重试...") - time.sleep(30) # 等待3秒后重试 + + if response.status_code == 429: + # 处理限流 + wait_time = 65 if attempt == 0 else 30 + logger.warning(f"原生Gemini API 触发限流,等待{wait_time}秒后重试...") + time.sleep(wait_time) + continue + + if response.status_code == 400: + raise Exception(f"请求参数错误: {response.text}") + elif response.status_code == 403: + raise Exception(f"API密钥无效或权限不足: {response.text}") + elif response.status_code != 200: + raise Exception(f"原生Gemini API请求失败: {response.status_code} - {response.text}") + + response_data = response.json() + + # 检查响应格式 + if "candidates" not in response_data or not response_data["candidates"]: + if attempt < max_retries - 1: + logger.warning("原生Gemini API 返回无效响应,等待30秒后重试...") + time.sleep(30) continue - return response - - except Exception as e: - error_str = str(e) - if "429" in error_str: - logger.warning("Gemini API 触发限流,等待65秒后重试...") - time.sleep(65) # 等待65秒后重试 + else: + raise Exception("原生Gemini API返回无效响应,可能触发了安全过滤") + + candidate = response_data["candidates"][0] + + # 检查是否被安全过滤阻止 + if "finishReason" in candidate and candidate["finishReason"] == "SAFETY": + raise Exception("内容被Gemini安全过滤器阻止") + + # 创建兼容的响应对象 + class CompatibleResponse: + def __init__(self, data): + self.data = data + candidate = data["candidates"][0] + if "content" in candidate and "parts" in candidate["content"]: + self.text = "" + for part in candidate["content"]["parts"]: + if "text" in part: + self.text += part["text"] + else: + self.text = "" + + return CompatibleResponse(response_data) + + except requests.exceptions.RequestException as e: + if attempt < max_retries - 1: + logger.warning(f"网络请求失败,等待30秒后重试: {str(e)}") + time.sleep(30) continue else: - logger.error(f"Gemini 生成文案错误: \n{error_str}") + logger.error(f"原生Gemini API请求失败: {str(e)}") + raise + except Exception as e: + if attempt < max_retries - 1 and "429" in str(e): + logger.warning("原生Gemini API 触发限流,等待65秒后重试...") + time.sleep(65) + continue + else: + logger.error(f"原生Gemini 生成文案错误: {str(e)}") raise def _process_response(self, response: any) -> str: - """处理Gemini的响应""" + """处理原生Gemini API的响应""" if not response or not response.text: - raise ValueError("Invalid response from Gemini API") + raise ValueError("原生Gemini API返回无效响应") return response.text.strip() @@ -318,7 +441,7 @@ class ScriptProcessor: # 根据模型名称选择对应的生成器 logger.info(f"文本 LLM 提供商: {model_name}") if 'gemini' in model_name.lower(): - self.generator = GeminiGenerator(model_name, self.api_key, self.prompt) + self.generator = GeminiGenerator(model_name, self.api_key, self.prompt, self.base_url) elif 'qwen' in model_name.lower(): self.generator = QwenGenerator(model_name, self.api_key, self.prompt, self.base_url) elif 'moonshot' in model_name.lower(): diff --git a/webui/components/basic_settings.py b/webui/components/basic_settings.py index a5f3c62..a887246 100644 --- a/webui/components/basic_settings.py +++ b/webui/components/basic_settings.py @@ -7,6 +7,45 @@ from app.utils import utils from loguru import logger +def validate_api_key(api_key: str, provider: str) -> tuple[bool, str]: + """验证API密钥格式""" + if not api_key or not api_key.strip(): + return False, f"{provider} API密钥不能为空" + + # 基本长度检查 + if len(api_key.strip()) < 10: + return False, f"{provider} API密钥长度过短,请检查是否正确" + + return True, "" + + +def validate_base_url(base_url: str, provider: str) -> tuple[bool, str]: + """验证Base URL格式""" + if not base_url or not base_url.strip(): + return True, "" # base_url可以为空 + + base_url = base_url.strip() + if not (base_url.startswith('http://') or base_url.startswith('https://')): + return False, f"{provider} Base URL必须以http://或https://开头" + + return True, "" + + +def validate_model_name(model_name: str, provider: str) -> tuple[bool, str]: + """验证模型名称""" + if not model_name or not model_name.strip(): + return False, f"{provider} 模型名称不能为空" + + return True, "" + + +def show_config_validation_errors(errors: list): + """显示配置验证错误""" + if errors: + for error in errors: + st.error(error) + + def render_basic_settings(tr): """渲染基础设置面板""" with st.expander(tr("Basic Settings"), expanded=False): @@ -87,29 +126,96 @@ def render_proxy_settings(tr): def test_vision_model_connection(api_key, base_url, model_name, provider, tr): """测试视觉模型连接 - + Args: api_key: API密钥 base_url: 基础URL model_name: 模型名称 provider: 提供商名称 - + Returns: bool: 连接是否成功 str: 测试结果消息 """ + import requests if provider.lower() == 'gemini': - import google.generativeai as genai - + # 原生Gemini API测试 try: - genai.configure(api_key=api_key) - model = genai.GenerativeModel(model_name) - model.generate_content("直接回复我文本'当前网络可用'") - return True, tr("gemini model is available") + # 构建请求数据 + request_data = { + "contents": [{ + "parts": [{"text": "直接回复我文本'当前网络可用'"}] + }], + "generationConfig": { + "temperature": 1.0, + "topK": 40, + "topP": 0.95, + "maxOutputTokens": 100, + }, + "safetySettings": [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_NONE" + } + ] + } + + # 构建请求URL + api_base_url = base_url or "https://generativelanguage.googleapis.com/v1beta" + url = f"{api_base_url}/models/{model_name}:generateContent?key={api_key}" + + # 发送请求 + response = requests.post( + url, + json=request_data, + headers={"Content-Type": "application/json"}, + timeout=30 + ) + + if response.status_code == 200: + return True, tr("原生Gemini模型连接成功") + else: + return False, f"{tr('原生Gemini模型连接失败')}: HTTP {response.status_code}" except Exception as e: - return False, f"{tr('gemini model is not available')}: {str(e)}" + return False, f"{tr('原生Gemini模型连接失败')}: {str(e)}" + + elif provider.lower() == 'gemini(openai)': + # OpenAI兼容的Gemini代理测试 + try: + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + + test_url = f"{base_url.rstrip('/')}/chat/completions" + test_data = { + "model": model_name, + "messages": [ + {"role": "user", "content": "直接回复我文本'当前网络可用'"} + ], + "stream": False + } + + response = requests.post(test_url, headers=headers, json=test_data, timeout=10) + if response.status_code == 200: + return True, tr("OpenAI兼容Gemini代理连接成功") + else: + return False, f"{tr('OpenAI兼容Gemini代理连接失败')}: HTTP {response.status_code}" + except Exception as e: + return False, f"{tr('OpenAI兼容Gemini代理连接失败')}: {str(e)}" elif provider.lower() == 'narratoapi': - import requests try: # 构建测试请求 headers = { @@ -172,7 +278,7 @@ def render_vision_llm_settings(tr): st.subheader(tr("Vision Model Settings")) # 视频分析模型提供商选择 - vision_providers = ['Siliconflow', 'Gemini', 'QwenVL', 'OpenAI'] + vision_providers = ['Siliconflow', 'Gemini', 'Gemini(OpenAI)', 'QwenVL', 'OpenAI'] saved_vision_provider = config.app.get("vision_llm_provider", "Gemini").lower() saved_provider_index = 0 @@ -191,9 +297,15 @@ def render_vision_llm_settings(tr): st.session_state['vision_llm_providers'] = vision_provider # 获取已保存的视觉模型配置 - vision_api_key = config.app.get(f"vision_{vision_provider}_api_key", "") - vision_base_url = config.app.get(f"vision_{vision_provider}_base_url", "") - vision_model_name = config.app.get(f"vision_{vision_provider}_model_name", "") + # 处理特殊的提供商名称映射 + if vision_provider == 'gemini(openai)': + vision_config_key = 'vision_gemini_openai' + else: + vision_config_key = f'vision_{vision_provider}' + + vision_api_key = config.app.get(f"{vision_config_key}_api_key", "") + vision_base_url = config.app.get(f"{vision_config_key}_base_url", "") + vision_model_name = config.app.get(f"{vision_config_key}_model_name", "") # 渲染视觉模型配置输入框 st_vision_api_key = st.text_input(tr("Vision API Key"), value=vision_api_key, type="password") @@ -201,15 +313,25 @@ def render_vision_llm_settings(tr): # 根据不同提供商设置默认值和帮助信息 if vision_provider == 'gemini': st_vision_base_url = st.text_input( - tr("Vision Base URL"), - value=vision_base_url, - disabled=True, - help=tr("Gemini API does not require a base URL") + tr("Vision Base URL"), + value=vision_base_url or "https://generativelanguage.googleapis.com/v1beta", + help=tr("原生Gemini API端点,默认: https://generativelanguage.googleapis.com/v1beta") ) st_vision_model_name = st.text_input( - tr("Vision Model Name"), - value=vision_model_name or "gemini-2.0-flash-lite", - help=tr("Default: gemini-2.0-flash-lite") + tr("Vision Model Name"), + value=vision_model_name or "gemini-2.0-flash-exp", + help=tr("原生Gemini模型,默认: gemini-2.0-flash-exp") + ) + elif vision_provider == 'gemini(openai)': + st_vision_base_url = st.text_input( + tr("Vision Base URL"), + value=vision_base_url or "https://generativelanguage.googleapis.com/v1beta/openai", + help=tr("OpenAI兼容的Gemini代理端点,如: https://your-proxy.com/v1") + ) + st_vision_model_name = st.text_input( + tr("Vision Model Name"), + value=vision_model_name or "gemini-2.0-flash-exp", + help=tr("OpenAI格式的Gemini模型名称,默认: gemini-2.0-flash-exp") ) elif vision_provider == 'qwenvl': st_vision_base_url = st.text_input( @@ -228,30 +350,81 @@ def render_vision_llm_settings(tr): # 在配置输入框后添加测试按钮 if st.button(tr("Test Connection"), key="test_vision_connection"): - with st.spinner(tr("Testing connection...")): - success, message = test_vision_model_connection( - api_key=st_vision_api_key, - base_url=st_vision_base_url, - model_name=st_vision_model_name, - provider=vision_provider, - tr=tr - ) - - if success: - st.success(tr(message)) - else: - st.error(tr(message)) + # 先验证配置 + test_errors = [] + if not st_vision_api_key: + test_errors.append("请先输入API密钥") + if not st_vision_model_name: + test_errors.append("请先输入模型名称") - # 保存视觉模型配置 + if test_errors: + for error in test_errors: + st.error(error) + else: + with st.spinner(tr("Testing connection...")): + try: + success, message = test_vision_model_connection( + api_key=st_vision_api_key, + base_url=st_vision_base_url, + model_name=st_vision_model_name, + provider=vision_provider, + tr=tr + ) + + if success: + st.success(message) + else: + st.error(message) + except Exception as e: + st.error(f"测试连接时发生错误: {str(e)}") + logger.error(f"视频分析模型连接测试失败: {str(e)}") + + # 验证和保存视觉模型配置 + validation_errors = [] + config_changed = False + + # 验证API密钥 if st_vision_api_key: - config.app[f"vision_{vision_provider}_api_key"] = st_vision_api_key - st.session_state[f"vision_{vision_provider}_api_key"] = st_vision_api_key + is_valid, error_msg = validate_api_key(st_vision_api_key, f"视频分析({vision_provider})") + if is_valid: + config.app[f"{vision_config_key}_api_key"] = st_vision_api_key + st.session_state[f"{vision_config_key}_api_key"] = st_vision_api_key + config_changed = True + else: + validation_errors.append(error_msg) + + # 验证Base URL if st_vision_base_url: - config.app[f"vision_{vision_provider}_base_url"] = st_vision_base_url - st.session_state[f"vision_{vision_provider}_base_url"] = st_vision_base_url + is_valid, error_msg = validate_base_url(st_vision_base_url, f"视频分析({vision_provider})") + if is_valid: + config.app[f"{vision_config_key}_base_url"] = st_vision_base_url + st.session_state[f"{vision_config_key}_base_url"] = st_vision_base_url + config_changed = True + else: + validation_errors.append(error_msg) + + # 验证模型名称 if st_vision_model_name: - config.app[f"vision_{vision_provider}_model_name"] = st_vision_model_name - st.session_state[f"vision_{vision_provider}_model_name"] = st_vision_model_name + is_valid, error_msg = validate_model_name(st_vision_model_name, f"视频分析({vision_provider})") + if is_valid: + config.app[f"{vision_config_key}_model_name"] = st_vision_model_name + st.session_state[f"{vision_config_key}_model_name"] = st_vision_model_name + config_changed = True + else: + validation_errors.append(error_msg) + + # 显示验证错误 + show_config_validation_errors(validation_errors) + + # 如果配置有变化且没有验证错误,保存到文件 + if config_changed and not validation_errors: + try: + config.save_config() + if st_vision_api_key or st_vision_base_url or st_vision_model_name: + st.success(f"视频分析模型({vision_provider})配置已保存") + except Exception as e: + st.error(f"保存配置失败: {str(e)}") + logger.error(f"保存视频分析配置失败: {str(e)}") def test_text_model_connection(api_key, base_url, model_name, provider, tr): @@ -278,14 +451,74 @@ def test_text_model_connection(api_key, base_url, model_name, provider, tr): # 特殊处理Gemini if provider.lower() == 'gemini': - import google.generativeai as genai + # 原生Gemini API测试 try: - genai.configure(api_key=api_key) - model = genai.GenerativeModel(model_name) - model.generate_content("直接回复我文本'当前网络可用'") - return True, tr("Gemini model is available") + # 构建请求数据 + request_data = { + "contents": [{ + "parts": [{"text": "直接回复我文本'当前网络可用'"}] + }], + "generationConfig": { + "temperature": 1.0, + "topK": 40, + "topP": 0.95, + "maxOutputTokens": 100, + }, + "safetySettings": [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_NONE" + } + ] + } + + # 构建请求URL + api_base_url = base_url or "https://generativelanguage.googleapis.com/v1beta" + url = f"{api_base_url}/models/{model_name}:generateContent?key={api_key}" + + # 发送请求 + response = requests.post( + url, + json=request_data, + headers={"Content-Type": "application/json"}, + timeout=30 + ) + + if response.status_code == 200: + return True, tr("原生Gemini模型连接成功") + else: + return False, f"{tr('原生Gemini模型连接失败')}: HTTP {response.status_code}" except Exception as e: - return False, f"{tr('Gemini model is not available')}: {str(e)}" + return False, f"{tr('原生Gemini模型连接失败')}: {str(e)}" + + elif provider.lower() == 'gemini(openai)': + # OpenAI兼容的Gemini代理测试 + test_url = f"{base_url.rstrip('/')}/chat/completions" + test_data = { + "model": model_name, + "messages": [ + {"role": "user", "content": "直接回复我文本'当前网络可用'"} + ], + "stream": False + } + + response = requests.post(test_url, headers=headers, json=test_data, timeout=10) + if response.status_code == 200: + return True, tr("OpenAI兼容Gemini代理连接成功") + else: + return False, f"{tr('OpenAI兼容Gemini代理连接失败')}: HTTP {response.status_code}" else: test_url = f"{base_url.rstrip('/')}/chat/completions" @@ -322,7 +555,7 @@ def render_text_llm_settings(tr): st.subheader(tr("Text Generation Model Settings")) # 文案生成模型提供商选择 - text_providers = ['OpenAI', 'Siliconflow', 'DeepSeek', 'Gemini', 'Qwen', 'Moonshot'] + text_providers = ['OpenAI', 'Siliconflow', 'DeepSeek', 'Gemini', 'Gemini(OpenAI)', 'Qwen', 'Moonshot'] saved_text_provider = config.app.get("text_llm_provider", "OpenAI").lower() saved_provider_index = 0 @@ -346,32 +579,108 @@ def render_text_llm_settings(tr): # 渲染文本模型配置输入框 st_text_api_key = st.text_input(tr("Text API Key"), value=text_api_key, type="password") - st_text_base_url = st.text_input(tr("Text Base URL"), value=text_base_url) - st_text_model_name = st.text_input(tr("Text Model Name"), value=text_model_name) + + # 根据不同提供商设置默认值和帮助信息 + if text_provider == 'gemini': + st_text_base_url = st.text_input( + tr("Text Base URL"), + value=text_base_url or "https://generativelanguage.googleapis.com/v1beta", + help=tr("原生Gemini API端点,默认: https://generativelanguage.googleapis.com/v1beta") + ) + st_text_model_name = st.text_input( + tr("Text Model Name"), + value=text_model_name or "gemini-2.0-flash-exp", + help=tr("原生Gemini模型,默认: gemini-2.0-flash-exp") + ) + elif text_provider == 'gemini(openai)': + st_text_base_url = st.text_input( + tr("Text Base URL"), + value=text_base_url or "https://generativelanguage.googleapis.com/v1beta/openai", + help=tr("OpenAI兼容的Gemini代理端点,如: https://your-proxy.com/v1") + ) + st_text_model_name = st.text_input( + tr("Text Model Name"), + value=text_model_name or "gemini-2.0-flash-exp", + help=tr("OpenAI格式的Gemini模型名称,默认: gemini-2.0-flash-exp") + ) + else: + st_text_base_url = st.text_input(tr("Text Base URL"), value=text_base_url) + st_text_model_name = st.text_input(tr("Text Model Name"), value=text_model_name) # 添加测试按钮 if st.button(tr("Test Connection"), key="test_text_connection"): - with st.spinner(tr("Testing connection...")): - success, message = test_text_model_connection( - api_key=st_text_api_key, - base_url=st_text_base_url, - model_name=st_text_model_name, - provider=text_provider, - tr=tr - ) - - if success: - st.success(message) - else: - st.error(message) + # 先验证配置 + test_errors = [] + if not st_text_api_key: + test_errors.append("请先输入API密钥") + if not st_text_model_name: + test_errors.append("请先输入模型名称") - # 保存文本模型配置 + if test_errors: + for error in test_errors: + st.error(error) + else: + with st.spinner(tr("Testing connection...")): + try: + success, message = test_text_model_connection( + api_key=st_text_api_key, + base_url=st_text_base_url, + model_name=st_text_model_name, + provider=text_provider, + tr=tr + ) + + if success: + st.success(message) + else: + st.error(message) + except Exception as e: + st.error(f"测试连接时发生错误: {str(e)}") + logger.error(f"文案生成模型连接测试失败: {str(e)}") + + # 验证和保存文本模型配置 + text_validation_errors = [] + text_config_changed = False + + # 验证API密钥 if st_text_api_key: - config.app[f"text_{text_provider}_api_key"] = st_text_api_key + is_valid, error_msg = validate_api_key(st_text_api_key, f"文案生成({text_provider})") + if is_valid: + config.app[f"text_{text_provider}_api_key"] = st_text_api_key + text_config_changed = True + else: + text_validation_errors.append(error_msg) + + # 验证Base URL if st_text_base_url: - config.app[f"text_{text_provider}_base_url"] = st_text_base_url + is_valid, error_msg = validate_base_url(st_text_base_url, f"文案生成({text_provider})") + if is_valid: + config.app[f"text_{text_provider}_base_url"] = st_text_base_url + text_config_changed = True + else: + text_validation_errors.append(error_msg) + + # 验证模型名称 if st_text_model_name: - config.app[f"text_{text_provider}_model_name"] = st_text_model_name + is_valid, error_msg = validate_model_name(st_text_model_name, f"文案生成({text_provider})") + if is_valid: + config.app[f"text_{text_provider}_model_name"] = st_text_model_name + text_config_changed = True + else: + text_validation_errors.append(error_msg) + + # 显示验证错误 + show_config_validation_errors(text_validation_errors) + + # 如果配置有变化且没有验证错误,保存到文件 + if text_config_changed and not text_validation_errors: + try: + config.save_config() + if st_text_api_key or st_text_base_url or st_text_model_name: + st.success(f"文案生成模型({text_provider})配置已保存") + except Exception as e: + st.error(f"保存配置失败: {str(e)}") + logger.error(f"保存文案生成配置失败: {str(e)}") # # Cloudflare 特殊配置 # if text_provider == 'cloudflare': diff --git a/webui/tools/base.py b/webui/tools/base.py index b8aff6a..465c27e 100644 --- a/webui/tools/base.py +++ b/webui/tools/base.py @@ -23,11 +23,14 @@ def create_vision_analyzer(provider, api_key, model, base_url): VisionAnalyzer 或 QwenAnalyzer 实例 """ if provider == 'gemini': - return gemini_analyzer.VisionAnalyzer(model_name=model, api_key=api_key) + return gemini_analyzer.VisionAnalyzer(model_name=model, api_key=api_key, base_url=base_url) + elif provider == 'gemini(openai)': + from app.utils.gemini_openai_analyzer import GeminiOpenAIAnalyzer + return GeminiOpenAIAnalyzer(model_name=model, api_key=api_key, base_url=base_url) else: # 只传入必要的参数 return qwenvl_analyzer.QwenAnalyzer( - model_name=model, + model_name=model, api_key=api_key, base_url=base_url ) diff --git a/webui/tools/generate_short_summary.py b/webui/tools/generate_short_summary.py index eb2a6f4..73fcd5b 100644 --- a/webui/tools/generate_short_summary.py +++ b/webui/tools/generate_short_summary.py @@ -16,6 +16,89 @@ from loguru import logger from app.config import config from app.services.SDE.short_drama_explanation import analyze_subtitle, generate_narration_script +import re + + +def parse_and_fix_json(json_string): + """ + 解析并修复JSON字符串 + + Args: + json_string: 待解析的JSON字符串 + + Returns: + dict: 解析后的字典,如果解析失败返回None + """ + if not json_string or not json_string.strip(): + logger.error("JSON字符串为空") + return None + + # 清理字符串 + json_string = json_string.strip() + + # 尝试直接解析 + try: + return json.loads(json_string) + except json.JSONDecodeError as e: + logger.warning(f"直接JSON解析失败: {e}") + + # 尝试提取JSON部分 + try: + # 查找JSON代码块 + json_match = re.search(r'```json\s*(.*?)\s*```', json_string, re.DOTALL) + if json_match: + json_content = json_match.group(1).strip() + logger.info("从代码块中提取JSON内容") + return json.loads(json_content) + except json.JSONDecodeError: + pass + + # 尝试查找大括号包围的内容 + try: + # 查找第一个 { 到最后一个 } 的内容 + start_idx = json_string.find('{') + end_idx = json_string.rfind('}') + if start_idx != -1 and end_idx != -1 and end_idx > start_idx: + json_content = json_string[start_idx:end_idx+1] + logger.info("提取大括号包围的JSON内容") + return json.loads(json_content) + except json.JSONDecodeError: + pass + + # 尝试修复常见的JSON格式问题 + try: + # 移除注释 + json_string = re.sub(r'#.*', '', json_string) + # 移除多余的逗号 + json_string = re.sub(r',\s*}', '}', json_string) + json_string = re.sub(r',\s*]', ']', json_string) + # 修复单引号 + json_string = re.sub(r"'([^']*)':", r'"\1":', json_string) + + logger.info("尝试修复JSON格式问题后解析") + return json.loads(json_string) + except json.JSONDecodeError: + pass + + # 如果所有方法都失败,尝试创建一个基本的结构 + logger.error(f"所有JSON解析方法都失败,原始内容: {json_string[:200]}...") + + # 尝试从文本中提取关键信息创建基本结构 + try: + # 这是一个简单的回退方案 + return { + "items": [ + { + "_id": 1, + "timestamp": "00:00:00,000-00:00:10,000", + "picture": "解析失败,使用默认内容", + "narration": json_string[:100] + "..." if len(json_string) > 100 else json_string, + "OST": 0 + } + ] + } + except Exception: + return None def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperature): @@ -61,7 +144,8 @@ def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperatu model=text_model, base_url=text_base_url, save_result=True, - temperature=temperature + temperature=temperature, + provider=text_provider ) """ 3. 根据剧情生成解说文案 @@ -78,7 +162,8 @@ def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperatu model=text_model, base_url=text_base_url, save_result=True, - temperature=temperature + temperature=temperature, + provider=text_provider ) if narration_result["status"] == "success": @@ -100,7 +185,20 @@ def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperatu # 结果转换为JSON字符串 narration_script = narration_result["narration_script"] - narration_dict = json.loads(narration_script) + + # 增强JSON解析,包含错误处理和修复 + narration_dict = parse_and_fix_json(narration_script) + if narration_dict is None: + st.error("生成的解说文案格式错误,无法解析为JSON") + logger.error(f"JSON解析失败,原始内容: {narration_script}") + st.stop() + + # 验证JSON结构 + if 'items' not in narration_dict: + st.error("生成的解说文案缺少必要的'items'字段") + logger.error(f"JSON结构错误,缺少items字段: {narration_dict}") + st.stop() + script = json.dumps(narration_dict['items'], ensure_ascii=False, indent=2) if script is None: