diff --git a/app/config/audio_config.py b/app/config/audio_config.py index da4cf47..1e2a18a 100644 --- a/app/config/audio_config.py +++ b/app/config/audio_config.py @@ -4,7 +4,7 @@ ''' @Project: NarratoAI @File : audio_config -@Author : 小林同学 +@Author : Viccy同学 @Date : 2025/1/7 @Description: 音频配置管理 ''' diff --git a/app/services/SDE/short_drama_explanation.py b/app/services/SDE/short_drama_explanation.py index 56a460d..439f63c 100644 --- a/app/services/SDE/short_drama_explanation.py +++ b/app/services/SDE/short_drama_explanation.py @@ -15,41 +15,60 @@ from typing import Dict, Any, Optional from loguru import logger from app.config import config from app.utils.utils import get_uuid, storage_dir -from app.services.SDE.prompt import subtitle_plot_analysis_v1, plot_writing +# 导入新的提示词管理系统 +from app.services.prompts import PromptManager class SubtitleAnalyzer: """字幕剧情分析器,负责分析字幕内容并提取关键剧情段落""" def __init__( - self, + self, api_key: Optional[str] = None, model: Optional[str] = None, base_url: Optional[str] = None, custom_prompt: Optional[str] = None, temperature: Optional[float] = 1.0, + provider: Optional[str] = None, ): """ 初始化字幕分析器 - + Args: api_key: API密钥,如果不提供则从配置中读取 model: 模型名称,如果不提供则从配置中读取 base_url: API基础URL,如果不提供则从配置中读取或使用默认值 custom_prompt: 自定义提示词,如果不提供则使用默认值 temperature: 模型温度 + provider: 提供商类型,用于确定API调用格式 """ # 使用传入的参数或从配置中获取 self.api_key = api_key self.model = model self.base_url = base_url self.temperature = temperature - + self.provider = provider or self._detect_provider() + # 设置提示词模板 - self.prompt_template = custom_prompt or subtitle_plot_analysis_v1 - + if custom_prompt: + self.prompt_template = custom_prompt + else: + # 使用新的提示词管理系统 + self.prompt_template = PromptManager.get_prompt( + category="short_drama_narration", + name="plot_analysis", + parameters={} + ) + + # 根据提供商类型确定是否为原生Gemini + self.is_native_gemini = self.provider.lower() == 'gemini' + # 初始化HTTP请求所需的头信息 self._init_headers() + + def _detect_provider(self): + """根据配置自动检测提供商类型""" + return config.app.get('text_llm_provider', 'gemini').lower() def _init_headers(self): """初始化HTTP请求头""" @@ -67,18 +86,152 @@ class SubtitleAnalyzer: def analyze_subtitle(self, subtitle_content: str) -> Dict[str, Any]: """ 分析字幕内容 - + Args: subtitle_content: 字幕内容文本 - + Returns: Dict[str, Any]: 包含分析结果的字典 """ try: # 构建完整提示词 prompt = f"{self.prompt_template}\n\n{subtitle_content}" - - # 构建请求体数据 + + if self.is_native_gemini: + # 使用原生Gemini API格式 + return self._call_native_gemini_api(prompt) + else: + # 使用OpenAI兼容格式 + return self._call_openai_compatible_api(prompt) + + except Exception as e: + logger.error(f"字幕分析过程中发生错误: {str(e)}") + return { + "status": "error", + "message": str(e), + "temperature": self.temperature + } + + def _call_native_gemini_api(self, prompt: str) -> Dict[str, Any]: + """调用原生Gemini API""" + try: + # 构建原生Gemini API请求数据 + payload = { + "systemInstruction": { + "parts": [{"text": "你是一位专业的剧本分析师和剧情概括助手。请严格按照要求的格式输出分析结果。"}] + }, + "contents": [{ + "parts": [{"text": prompt}] + }], + "generationConfig": { + "temperature": self.temperature, + "topK": 40, + "topP": 0.95, + "maxOutputTokens": 4000, + "candidateCount": 1 + }, + "safetySettings": [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_NONE" + } + ] + } + + # 构建请求URL + url = f"{self.base_url}/models/{self.model}:generateContent?key={self.api_key}" + + # 发送请求 + response = requests.post( + url, + json=payload, + headers={"Content-Type": "application/json", "User-Agent": "NarratoAI/1.0"}, + timeout=120 + ) + + if response.status_code == 200: + response_data = response.json() + + # 检查响应格式 + if "candidates" not in response_data or not response_data["candidates"]: + return { + "status": "error", + "message": "原生Gemini API返回无效响应,可能触发了安全过滤", + "temperature": self.temperature + } + + candidate = response_data["candidates"][0] + + # 检查是否被安全过滤阻止 + if "finishReason" in candidate and candidate["finishReason"] == "SAFETY": + return { + "status": "error", + "message": "内容被Gemini安全过滤器阻止", + "temperature": self.temperature + } + + if "content" not in candidate or "parts" not in candidate["content"]: + return { + "status": "error", + "message": "原生Gemini API返回内容格式错误", + "temperature": self.temperature + } + + # 提取文本内容 + analysis_result = "" + for part in candidate["content"]["parts"]: + if "text" in part: + analysis_result += part["text"] + + if not analysis_result.strip(): + return { + "status": "error", + "message": "原生Gemini API返回空内容", + "temperature": self.temperature + } + + logger.debug(f"原生Gemini字幕分析完成") + + return { + "status": "success", + "analysis": analysis_result, + "tokens_used": response_data.get("usage", {}).get("total_tokens", 0), + "model": self.model, + "temperature": self.temperature + } + else: + error_msg = f"原生Gemini API请求失败,状态码: {response.status_code}, 响应: {response.text}" + logger.error(error_msg) + return { + "status": "error", + "message": error_msg, + "temperature": self.temperature + } + + except Exception as e: + logger.error(f"原生Gemini API调用失败: {str(e)}") + return { + "status": "error", + "message": f"原生Gemini API调用失败: {str(e)}", + "temperature": self.temperature + } + + def _call_openai_compatible_api(self, prompt: str) -> Dict[str, Any]: + """调用OpenAI兼容的API""" + try: + # 构建OpenAI格式的请求数据 payload = { "model": self.model, "messages": [ @@ -87,22 +240,22 @@ class SubtitleAnalyzer: ], "temperature": self.temperature } - + # 构建请求地址 url = f"{self.base_url}/chat/completions" - + # 发送HTTP请求 - response = requests.post(url, headers=self.headers, json=payload) - + response = requests.post(url, headers=self.headers, json=payload, timeout=120) + # 解析响应 if response.status_code == 200: response_data = response.json() - + # 提取响应内容 if "choices" in response_data and len(response_data["choices"]) > 0: analysis_result = response_data["choices"][0]["message"]["content"] - logger.debug(f"字幕分析完成,消耗的tokens: {response_data.get('usage', {}).get('total_tokens', 0)}") - + logger.debug(f"OpenAI兼容API字幕分析完成,消耗的tokens: {response_data.get('usage', {}).get('total_tokens', 0)}") + # 返回结果 return { "status": "success", @@ -112,26 +265,26 @@ class SubtitleAnalyzer: "temperature": self.temperature } else: - logger.error("字幕分析失败: 未获取到有效响应") + logger.error("OpenAI兼容API字幕分析失败: 未获取到有效响应") return { "status": "error", "message": "未获取到有效响应", "temperature": self.temperature } else: - error_msg = f"请求失败,状态码: {response.status_code}, 响应: {response.text}" + error_msg = f"OpenAI兼容API请求失败,状态码: {response.status_code}, 响应: {response.text}" logger.error(error_msg) return { "status": "error", "message": error_msg, "temperature": self.temperature } - + except Exception as e: - logger.error(f"字幕分析过程中发生错误: {str(e)}") + logger.error(f"OpenAI兼容API调用失败: {str(e)}") return { "status": "error", - "message": str(e), + "message": f"OpenAI兼容API调用失败: {str(e)}", "temperature": self.temperature } @@ -206,20 +359,165 @@ class SubtitleAnalyzer: def generate_narration_script(self, short_name:str, plot_analysis: str, temperature: float = 0.7) -> Dict[str, Any]: """ 根据剧情分析生成解说文案 - + Args: short_name: 短剧名称 plot_analysis: 剧情分析内容 temperature: 生成温度,控制创造性,默认0.7 - + Returns: Dict[str, Any]: 包含生成结果的字典 """ try: - # 构建完整提示词 - prompt = plot_writing % (short_name, plot_analysis) + # 使用新的提示词管理系统构建提示词 + prompt = PromptManager.get_prompt( + category="short_drama_narration", + name="script_generation", + parameters={ + "drama_name": short_name, + "plot_analysis": plot_analysis + } + ) - # 构建请求体数据 + if self.is_native_gemini: + # 使用原生Gemini API格式 + return self._generate_narration_with_native_gemini(prompt, temperature) + else: + # 使用OpenAI兼容格式 + return self._generate_narration_with_openai_compatible(prompt, temperature) + + except Exception as e: + logger.error(f"解说文案生成过程中发生错误: {str(e)}") + return { + "status": "error", + "message": str(e), + "temperature": self.temperature + } + + def _generate_narration_with_native_gemini(self, prompt: str, temperature: float) -> Dict[str, Any]: + """使用原生Gemini API生成解说文案""" + try: + # 构建原生Gemini API请求数据 + # 为了确保JSON输出,在提示词中添加更强的约束 + enhanced_prompt = f"{prompt}\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。" + + payload = { + "systemInstruction": { + "parts": [{"text": "你是一位专业的短视频解说脚本撰写专家。你必须严格按照JSON格式输出,不能包含任何其他文字、说明或代码块标记。"}] + }, + "contents": [{ + "parts": [{"text": enhanced_prompt}] + }], + "generationConfig": { + "temperature": temperature, + "topK": 40, + "topP": 0.95, + "maxOutputTokens": 4000, + "candidateCount": 1, + "stopSequences": ["```", "注意", "说明"] + }, + "safetySettings": [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_NONE" + } + ] + } + + # 构建请求URL + url = f"{self.base_url}/models/{self.model}:generateContent?key={self.api_key}" + + # 发送请求 + response = requests.post( + url, + json=payload, + headers={"Content-Type": "application/json", "User-Agent": "NarratoAI/1.0"}, + timeout=120 + ) + + if response.status_code == 200: + response_data = response.json() + + # 检查响应格式 + if "candidates" not in response_data or not response_data["candidates"]: + return { + "status": "error", + "message": "原生Gemini API返回无效响应,可能触发了安全过滤", + "temperature": temperature + } + + candidate = response_data["candidates"][0] + + # 检查是否被安全过滤阻止 + if "finishReason" in candidate and candidate["finishReason"] == "SAFETY": + return { + "status": "error", + "message": "内容被Gemini安全过滤器阻止", + "temperature": temperature + } + + if "content" not in candidate or "parts" not in candidate["content"]: + return { + "status": "error", + "message": "原生Gemini API返回内容格式错误", + "temperature": temperature + } + + # 提取文本内容 + narration_script = "" + for part in candidate["content"]["parts"]: + if "text" in part: + narration_script += part["text"] + + if not narration_script.strip(): + return { + "status": "error", + "message": "原生Gemini API返回空内容", + "temperature": temperature + } + + logger.debug(f"原生Gemini解说文案生成完成") + + return { + "status": "success", + "narration_script": narration_script, + "tokens_used": response_data.get("usage", {}).get("total_tokens", 0), + "model": self.model, + "temperature": temperature + } + else: + error_msg = f"原生Gemini API请求失败,状态码: {response.status_code}, 响应: {response.text}" + logger.error(error_msg) + return { + "status": "error", + "message": error_msg, + "temperature": temperature + } + + except Exception as e: + logger.error(f"原生Gemini API解说文案生成失败: {str(e)}") + return { + "status": "error", + "message": f"原生Gemini API解说文案生成失败: {str(e)}", + "temperature": temperature + } + + def _generate_narration_with_openai_compatible(self, prompt: str, temperature: float) -> Dict[str, Any]: + """使用OpenAI兼容API生成解说文案""" + try: + # 构建OpenAI格式的请求数据 payload = { "model": self.model, "messages": [ @@ -228,56 +526,56 @@ class SubtitleAnalyzer: ], "temperature": temperature } - + # 对特定模型添加响应格式设置 if self.model not in ["deepseek-reasoner"]: payload["response_format"] = {"type": "json_object"} - + # 构建请求地址 url = f"{self.base_url}/chat/completions" - + # 发送HTTP请求 - response = requests.post(url, headers=self.headers, json=payload) - + response = requests.post(url, headers=self.headers, json=payload, timeout=120) + # 解析响应 if response.status_code == 200: response_data = response.json() - + # 提取响应内容 if "choices" in response_data and len(response_data["choices"]) > 0: narration_script = response_data["choices"][0]["message"]["content"] - logger.debug(f"解说文案生成完成,消耗的tokens: {response_data.get('usage', {}).get('total_tokens', 0)}") - + logger.debug(f"OpenAI兼容API解说文案生成完成,消耗的tokens: {response_data.get('usage', {}).get('total_tokens', 0)}") + # 返回结果 return { "status": "success", "narration_script": narration_script, "tokens_used": response_data.get("usage", {}).get("total_tokens", 0), "model": self.model, - "temperature": self.temperature + "temperature": temperature } else: - logger.error("解说文案生成失败: 未获取到有效响应") + logger.error("OpenAI兼容API解说文案生成失败: 未获取到有效响应") return { "status": "error", "message": "未获取到有效响应", - "temperature": self.temperature + "temperature": temperature } else: - error_msg = f"请求失败,状态码: {response.status_code}, 响应: {response.text}" + error_msg = f"OpenAI兼容API请求失败,状态码: {response.status_code}, 响应: {response.text}" logger.error(error_msg) return { "status": "error", "message": error_msg, - "temperature": self.temperature + "temperature": temperature } - + except Exception as e: - logger.error(f"解说文案生成过程中发生错误: {str(e)}") + logger.error(f"OpenAI兼容API解说文案生成失败: {str(e)}") return { "status": "error", - "message": str(e), - "temperature": self.temperature + "message": f"OpenAI兼容API解说文案生成失败: {str(e)}", + "temperature": temperature } def save_narration_script(self, narration_result: Dict[str, Any], output_path: Optional[str] = None) -> str: @@ -324,11 +622,12 @@ def analyze_subtitle( custom_prompt: Optional[str] = None, temperature: float = 1.0, save_result: bool = False, - output_path: Optional[str] = None + output_path: Optional[str] = None, + provider: Optional[str] = None ) -> Dict[str, Any]: """ 分析字幕内容的便捷函数 - + Args: subtitle_content: 字幕内容文本 subtitle_file_path: 字幕文件路径 @@ -339,7 +638,8 @@ def analyze_subtitle( temperature: 模型温度 save_result: 是否保存结果到文件 output_path: 输出文件路径 - + provider: 提供商类型 + Returns: Dict[str, Any]: 包含分析结果的字典 """ @@ -349,7 +649,8 @@ def analyze_subtitle( api_key=api_key, model=model, base_url=base_url, - custom_prompt=custom_prompt + custom_prompt=custom_prompt, + provider=provider ) logger.debug(f"使用模型: {analyzer.model} 开始分析, 温度: {analyzer.temperature}") # 分析字幕 @@ -379,11 +680,12 @@ def generate_narration_script( base_url: Optional[str] = None, temperature: float = 1.0, save_result: bool = False, - output_path: Optional[str] = None + output_path: Optional[str] = None, + provider: Optional[str] = None ) -> Dict[str, Any]: """ 根据剧情分析生成解说文案的便捷函数 - + Args: short_name: 短剧名称 plot_analysis: 剧情分析内容,直接提供 @@ -393,7 +695,8 @@ def generate_narration_script( temperature: 生成温度,控制创造性 save_result: 是否保存结果到文件 output_path: 输出文件路径 - + provider: 提供商类型 + Returns: Dict[str, Any]: 包含生成结果的字典 """ @@ -402,7 +705,8 @@ def generate_narration_script( temperature=temperature, api_key=api_key, model=model, - base_url=base_url + base_url=base_url, + provider=provider ) # 生成解说文案 diff --git a/app/services/SDP/generate_script_short.py b/app/services/SDP/generate_script_short.py index caaad93..713d26c 100644 --- a/app/services/SDP/generate_script_short.py +++ b/app/services/SDP/generate_script_short.py @@ -6,12 +6,17 @@ from .utils.step1_subtitle_analyzer_openai import analyze_subtitle from .utils.step5_merge_script import merge_script -def generate_script(srt_path: str, api_key: str, model_name: str, output_path: str, base_url: str = None, custom_clips: int = 5): +def generate_script(srt_path: str, api_key: str, model_name: str, output_path: str, base_url: str = None, custom_clips: int = 5, provider: str = None): """生成视频混剪脚本 Args: srt_path: 字幕文件路径 + api_key: API密钥 + model_name: 模型名称 output_path: 输出文件路径,可选 + base_url: API基础URL + custom_clips: 自定义片段数量 + provider: LLM服务提供商 Returns: str: 生成的脚本内容 @@ -27,7 +32,8 @@ def generate_script(srt_path: str, api_key: str, model_name: str, output_path: s api_key=api_key, model_name=model_name, base_url=base_url, - custom_clips=custom_clips + custom_clips=custom_clips, + provider=provider ) # 合并生成最终脚本 diff --git a/app/services/SDP/utils/step1_subtitle_analyzer_openai.py b/app/services/SDP/utils/step1_subtitle_analyzer_openai.py index 59ea3b0..8752d38 100644 --- a/app/services/SDP/utils/step1_subtitle_analyzer_openai.py +++ b/app/services/SDP/utils/step1_subtitle_analyzer_openai.py @@ -1,12 +1,18 @@ """ -使用OpenAI API,分析字幕文件,返回剧情梗概和爆点 +使用统一LLM服务,分析字幕文件,返回剧情梗概和爆点 """ import traceback -from openai import OpenAI, BadRequestError -import os import json +import asyncio +from loguru import logger from .utils import load_srt +# 导入新的提示词管理系统 +from app.services.prompts import PromptManager +# 导入统一LLM服务 +from app.services.llm.unified_service import UnifiedLLMService +# 导入安全的异步执行函数 +from app.services.llm.migration_adapter import _run_async_safely def analyze_subtitle( @@ -14,15 +20,18 @@ def analyze_subtitle( model_name: str, api_key: str = None, base_url: str = None, - custom_clips: int = 5 + custom_clips: int = 5, + provider: str = None ) -> dict: """分析字幕内容,返回完整的分析结果 Args: srt_path (str): SRT字幕文件路径 + model_name (str): 大模型名称 api_key (str, optional): 大模型API密钥. Defaults to None. - model_name (str, optional): 大模型名称. Defaults to "gpt-4o-2024-11-20". base_url (str, optional): 大模型API基础URL. Defaults to None. + custom_clips (int): 需要提取的片段数量. Defaults to 5. + provider (str, optional): LLM服务提供商. Defaults to None. Returns: dict: 包含剧情梗概和结构化的时间段分析的字典 @@ -32,126 +41,103 @@ def analyze_subtitle( subtitles = load_srt(srt_path) subtitle_content = "\n".join([f"{sub['timestamp']}\n{sub['text']}" for sub in subtitles]) - # 初始化客户端 - global client - if "deepseek" in model_name.lower(): - client = OpenAI( - api_key=api_key or os.getenv('DeepSeek_API_KEY'), - base_url="https://api.siliconflow.cn/v1" # 使用第三方 硅基流动 API - ) - else: - client = OpenAI( - api_key=api_key or os.getenv('OPENAI_API_KEY'), - base_url=base_url - ) + # 初始化统一LLM服务 + llm_service = UnifiedLLMService() - messages = [ - { - "role": "system", - "content": """你是一名经验丰富的短剧编剧,擅长根据字幕内容按照先后顺序分析关键剧情,并找出 %s 个关键片段。 - 请返回一个JSON对象,包含以下字段: - { - "summary": "整体剧情梗概", - "plot_titles": [ - "关键剧情1", - "关键剧情2", - "关键剧情3", - "关键剧情4", - "关键剧情5", - "..." - ] - } - 请确保返回的是合法的JSON格式, 请确保返回的是 %s 个片段。 - """ % (custom_clips, custom_clips) - }, - { - "role": "user", - "content": f"srt字幕如下:{subtitle_content}" + # 如果没有指定provider,根据model_name推断 + if not provider: + if "deepseek" in model_name.lower(): + provider = "deepseek" + elif "gpt" in model_name.lower(): + provider = "openai" + elif "gemini" in model_name.lower(): + provider = "gemini" + else: + provider = "openai" # 默认使用openai + + logger.info(f"使用LLM服务分析字幕,提供商: {provider}, 模型: {model_name}") + + # 使用新的提示词管理系统 + subtitle_analysis_prompt = PromptManager.get_prompt( + category="short_drama_editing", + name="subtitle_analysis", + parameters={ + "subtitle_content": subtitle_content, + "custom_clips": custom_clips } - ] - # DeepSeek R1 和 V3 不支持 response_format=json_object - try: - completion = client.chat.completions.create( - model=model_name, - messages=messages, - response_format={"type": "json_object"} - ) - summary_data = json.loads(completion.choices[0].message.content) - except BadRequestError as e: - completion = client.chat.completions.create( - model=model_name, - messages=messages - ) - # 去除 completion 字符串前的 ```json 和 结尾的 ``` - completion = completion.choices[0].message.content.replace("```json", "").replace("```", "") - summary_data = json.loads(completion) - except Exception as e: - raise Exception(f"大模型解析发生错误:{str(e)}\n{traceback.format_exc()}") + ) + # 使用统一LLM服务生成文本 + logger.info("开始分析字幕内容...") + response = _run_async_safely( + UnifiedLLMService.generate_text, + prompt=subtitle_analysis_prompt, + provider=provider, + model=model_name, + api_key=api_key, + base_url=base_url, + temperature=0.1, # 使用较低的温度以获得更稳定的结果 + max_tokens=4000 + ) + + # 解析JSON响应 + from webui.tools.generate_short_summary import parse_and_fix_json + summary_data = parse_and_fix_json(response) + + if not summary_data: + raise Exception("无法解析LLM返回的JSON数据") + + logger.info(f"字幕分析完成,找到 {len(summary_data.get('plot_titles', []))} 个关键情节") print(json.dumps(summary_data, indent=4, ensure_ascii=False)) - # 获取爆点时间段分析 - prompt = f"""剧情梗概: - {summary_data['summary']} - - 需要定位的爆点内容: - """ + # 构建爆点标题列表 + plot_titles_text = "" print(f"找到 {len(summary_data['plot_titles'])} 个片段") for i, point in enumerate(summary_data['plot_titles'], 1): - prompt += f"{i}. {point}\n" + plot_titles_text += f"{i}. {point}\n" - messages = [ - { - "role": "system", - "content": """你是一名短剧编剧,非常擅长根据字幕中分析视频中关键剧情出现的具体时间段。 - 请仔细阅读剧情梗概和爆点内容,然后在字幕中找出每个爆点发生的具体时间段和爆点前后的详细剧情。 - - 请返回一个JSON对象,包含一个名为"plot_points"的数组,数组中包含多个对象,每个对象都要包含以下字段: - { - "plot_points": [ - { - "timestamp": "时间段,格式为xx:xx:xx,xxx-xx:xx:xx,xxx", - "title": "关键剧情的主题", - "picture": "关键剧情前后的详细剧情描述" - } - ] - } - 请确保返回的是合法的JSON格式。""" - }, - { - "role": "user", - "content": f"""字幕内容: -{subtitle_content} - -{prompt}""" + # 使用新的提示词管理系统 + plot_extraction_prompt = PromptManager.get_prompt( + category="short_drama_editing", + name="plot_extraction", + parameters={ + "subtitle_content": subtitle_content, + "plot_summary": summary_data['summary'], + "plot_titles": plot_titles_text } - ] - # DeepSeek R1 和 V3 不支持 response_format=json_object - try: - completion = client.chat.completions.create( - model=model_name, - messages=messages, - response_format={"type": "json_object"} - ) - plot_points_data = json.loads(completion.choices[0].message.content) - except BadRequestError as e: - completion = client.chat.completions.create( - model=model_name, - messages=messages - ) - # 去除 completion 字符串前的 ```json 和 结尾的 ``` - completion = completion.choices[0].message.content.replace("```json", "").replace("```", "") - plot_points_data = json.loads(completion) - except Exception as e: - raise Exception(f"大模型解析错误:{str(e)}\n{traceback.format_exc()}") + ) - print(json.dumps(plot_points_data, indent=4, ensure_ascii=False)) + # 使用统一LLM服务进行爆点时间段分析 + logger.info("开始分析爆点时间段...") + response = _run_async_safely( + UnifiedLLMService.generate_text, + prompt=plot_extraction_prompt, + provider=provider, + model=model_name, + api_key=api_key, + base_url=base_url, + temperature=0.1, + max_tokens=4000 + ) + + # 解析JSON响应 + plot_data = parse_and_fix_json(response) + + if not plot_data: + raise Exception("无法解析爆点分析的JSON数据") + + logger.info(f"爆点分析完成,找到 {len(plot_data.get('plot_points', []))} 个时间段") # 合并结果 - return { - "plot_summary": summary_data, - "plot_points": plot_points_data["plot_points"] + result = { + "summary": summary_data.get("summary", ""), + "plot_titles": summary_data.get("plot_titles", []), + "plot_points": plot_data.get("plot_points", []) } + return result + except Exception as e: + logger.error(f"分析字幕时发生错误: {str(e)}") raise Exception(f"分析字幕时发生错误:{str(e)}\n{traceback.format_exc()}") + diff --git a/app/services/audio_normalizer.py b/app/services/audio_normalizer.py index 25ba4ee..b0796b8 100644 --- a/app/services/audio_normalizer.py +++ b/app/services/audio_normalizer.py @@ -4,7 +4,7 @@ ''' @Project: NarratoAI @File : audio_normalizer -@Author : 小林同学 +@Author : Viccy同学 @Date : 2025/1/7 @Description: 音频响度分析和标准化工具 ''' diff --git a/app/services/clip_video.py b/app/services/clip_video.py index 81794e4..65b97ea 100644 --- a/app/services/clip_video.py +++ b/app/services/clip_video.py @@ -4,7 +4,7 @@ ''' @Project: NarratoAI @File : clip_video -@Author : 小林同学 +@Author : Viccy同学 @Date : 2025/5/6 下午6:14 ''' diff --git a/app/services/generate_narration_script.py b/app/services/generate_narration_script.py index f6640db..80fcf1a 100644 --- a/app/services/generate_narration_script.py +++ b/app/services/generate_narration_script.py @@ -4,16 +4,23 @@ ''' @Project: NarratoAI @File : 生成介绍文案 -@Author : 小林同学 +@Author : Viccy同学 @Date : 2025/5/8 上午11:33 ''' import json import os import traceback +import asyncio from openai import OpenAI from loguru import logger +# 导入新的LLM服务模块 - 确保提供商被注册 +import app.services.llm # 这会触发提供商注册 +from app.services.llm.migration_adapter import generate_narration as generate_narration_new +# 导入新的提示词管理系统 +from app.services.prompts import PromptManager + def parse_frame_analysis_to_markdown(json_file_path): """ @@ -79,104 +86,52 @@ def parse_frame_analysis_to_markdown(json_file_path): def generate_narration(markdown_content, api_key, base_url, model): """ - 调用OpenAI API根据视频帧分析的Markdown内容生成解说文案 - + 调用大模型API根据视频帧分析的Markdown内容生成解说文案 - 已重构为使用新的LLM服务架构 + :param markdown_content: Markdown格式的视频帧分析内容 - :param api_key: OpenAI API密钥 - :param base_url: API基础URL,如果使用非官方API + :param api_key: API密钥 + :param base_url: API基础URL :param model: 使用的模型名称 :return: 生成的解说文案 """ try: - # 构建提示词 - prompt = """ -我是一名荒野建造解说的博主,以下是一些同行的对标文案,请你深度学习并总结这些文案的风格特点跟内容特点: + # 优先使用新的LLM服务架构 + logger.info("使用新的LLM服务架构生成解说文案") + result = generate_narration_new(markdown_content, api_key, base_url, model) + return result - -解压助眠的天花板就是荒野建造,沉浸丝滑的搭建过程可以说每一帧都是极致享受,我保证强迫症来了都找不出一丁点毛病。更别说全屋严丝合缝的拼接工艺,还能轻松抵御零下二十度气温,让你居住的每一天都温暖如春。 -在家闲不住的西姆今天也打算来一次野外建造,行走没多久他就发现许多倒塌的树,任由它们自生自灭不如将其利用起来。想到这他就开始挥舞铲子要把地基挖掘出来,虽然每次只能挖一点点,但架不住他体能惊人。没多长时间一个 2x3 的深坑就赫然出现,这深度住他一人绰绰有余。 -随后他去附近收集来原木,这些都是搭建墙壁的最好材料。而在投入使用前自然要把表皮刮掉,防止森林中的白蚁蛀虫。处理好一大堆后西姆还在两端打孔,使用木钉固定在一起。这可不是用来做墙壁的,而是做庇护所的承重柱。只要木头间的缝隙足够紧密,那搭建出的木屋就能足够坚固。 -每向上搭建一层,他都会在中间塞入苔藓防寒,保证不会泄露一丝热量。其他几面也是用相同方法,很快西姆就做好了三面墙壁,每一根木头都极其工整,保证强迫症来了都要点个赞再走。 -在继续搭建墙壁前西姆决定将壁炉制作出来,毕竟森林夜晚的气温会很低,保暖措施可是重中之重。完成后他找来一块大树皮用来充当庇护所的大门,而上面刮掉的木屑还能作为壁炉的引火物,可以说再完美不过。 -测试了排烟没问题后他才开始搭建最后一面墙壁,这一面要预留门和窗,所以在搭建到一半后还需要在原木中间开出卡口,让自己劈砍时能轻松许多。此时只需将另外一根如法炮制,两端拼接在一起后就是一扇大小适中的窗户。而随着随后一层苔藓铺好,最后一根原木落位,这个庇护所的雏形就算完成。 -大门的安装他没选择用合页,而是在底端雕刻出榫头,门框上则雕刻出榫眼,只能说西姆的眼就是一把尺,这完全就是严丝合缝。此时他才开始搭建屋顶。这里西姆用的方法不同,他先把最外围的原木固定好,随后将原木平铺在上面,就能得到完美的斜面屋顶。等他将四周的围栏也装好后,工整的屋顶看起来十分舒服,西姆躺上去都不想动。 -稍作休息后,他利用剩余的苔藓,对屋顶的缝隙处密封。可这样西姆觉得不够保险,于是他找来一些黏土,再次对原本的缝隙二次加工,保管这庇护所冬天也暖和。最后只需要平铺上枯叶,以及挖掘出的泥土,整个屋顶就算完成。 -考虑到庇护所的美观性,自然少不了覆盖上苔藓,翠绿的颜色看起来十分舒服。就连门口的庭院旁,他都移植了许多小树做点缀,让这木屋与周边环境融为一体。西姆才刚完成好这件事,一场大雨就骤然降临。好在此时的他已经不用淋雨,更别说这屋顶防水十分不错,室内没一点雨水渗透进来。 -等待温度回升的过程,西姆利用墙壁本身的凹槽,把床框镶嵌在上面,只需要铺上苔藓,以及自带的床单枕头,一张完美的单人床就做好。辛苦劳作一整天,西姆可不会亏待自己。他将自带的牛肉腌制好后,直接放到壁炉中烤,只需要等待三十分钟,就能享受这美味的一顿。 -在辛苦建造一星期后,他终于可以在自己搭建的庇护所中,享受最纯正的野外露营。后面西姆回家补给了一堆物资,再次回来时森林已经大雪纷飞,让他原本翠绿的小屋,更换上了冬季限定皮肤。好在内部设施没受什么影响,和他离开时一样整洁。 -就是房间中已经没多少柴火,让西姆今天又得劈柴。寒冷干燥的天气,让木头劈起来十分轻松。没多久他就收集到一大堆,这些足够燃烧好几天。虽然此时外面大雪纷飞,但小屋中却开始逐渐温暖。这次他除了带来一些食物外,还有几瓶调味料,以及一整套被褥,让自己的居住舒适度提高一大截。 -而秋天他有收集干草的缘故,只需要塞入枕套中密封起来,就能作为靠垫用。就这居住条件,比一般人在家过的还要奢侈。趁着壁炉木头变木炭的过程,西姆则开始不紧不慢的处理食物。他取出一块牛排,改好花刀以后,撒上一堆调料腌制起来。接着用锡纸包裹好,放到壁炉中直接炭烤,搭配上自带的红酒,是一个非常好的选择。 -随着时间来到第二天,外面的积雪融化了不少,西姆简单做顿煎蛋补充体力后,决定制作一个室外篝火堆,用来晚上驱散周边野兽。搭建这玩意没什么技巧,只需要找到一大堆木棍,利用大树的夹缝将其掰弯,然后将其堆积在一起,就是一个简易版的篝火堆。看这外形有点像帐篷,好在西姆没想那么多。 -等待天色暗淡下来后,他才来到室外将其点燃,顺便处理下多余的废料。只可惜这场景没朋友陪在身边,对西姆来说可能是个遗憾。而哪怕森林只有他一个人,都依旧做了好几个小时。等到里面的篝火彻底燃尽后,西姆还找来雪球,覆盖到上面将火熄灭,这防火意识可谓十分好。最后在室内二十五度的高温下,裹着被子睡觉。 - + except Exception as e: + logger.warning(f"使用新LLM服务失败,回退到旧实现: {str(e)}") + + # 回退到旧的实现以确保兼容性 + return _generate_narration_legacy(markdown_content, api_key, base_url, model) + + +def _generate_narration_legacy(markdown_content, api_key, base_url, model): + """ + 旧的解说文案生成实现 - 保留作为备用方案 + + :param markdown_content: Markdown格式的视频帧分析内容 + :param api_key: API密钥 + :param base_url: API基础URL + :param model: 使用的模型名称 + :return: 生成的解说文案 + """ + try: + # 使用新的提示词管理系统构建提示词 + prompt = PromptManager.get_prompt( + category="documentary", + name="narration_generation", + parameters={ + "video_frame_description": markdown_content + } + ) - -解压助眠的天花板就是荒野建造,沉浸丝滑的搭建过程每一帧都是极致享受,全屋严丝合缝的拼接工艺,能轻松抵御零下二十度气温,居住体验温暖如春。 -在家闲不住的西姆开启野外建造。他发现倒塌的树,决定加以利用。先挖掘出 2x3 的深坑作为地基,接着收集原木,刮掉表皮防白蚁蛀虫,打孔用木钉固定制作承重柱。搭建墙壁时,每一层都塞入苔藓防寒,很快做好三面墙。 -为应对森林夜晚低温,西姆制作壁炉,用大树皮当大门,刮下的木屑做引火物。搭建最后一面墙时预留门窗,通过在原木中间开口拼接做出窗户。大门采用榫卯结构安装,严丝合缝。 -搭建屋顶时,先固定外围原木,再平铺原木形成斜面屋顶,之后用苔藓、黏土密封缝隙,铺上枯叶和泥土。为美观,在木屋覆盖苔藓,移植小树点缀。完工时遇大雨,木屋防水良好。 -西姆利用墙壁凹槽镶嵌床框,铺上苔藓、床单枕头做成床。劳作一天后,他用壁炉烤牛肉享用。建造一星期后,他开始野外露营。 -后来西姆回家补给物资,回来时森林大雪纷飞。他劈柴储备,带回食物、调味料和被褥,提高居住舒适度,还用干草做靠垫。他用壁炉烤牛排,搭配红酒。 -第二天,积雪融化,西姆制作室外篝火堆防野兽。用大树夹缝掰弯木棍堆积而成,晚上点燃处理废料,结束后用雪球灭火,最后在室内二十五度的环境中裹被入睡。 - - -如果战争到来,这个深埋地下十几米的庇护所绝对是 bug 般的存在。即使被敌人发现,还能通过快速通道一秒逃出。里面不仅有竹子、地暖、地下水井,还自制抽水机。在解决用水问题的同时,甚至自研无土栽培技术,过上完全自给自足的生活。 -阿伟的老婆美如花,但阿伟从来不回家,来到野外他乐哈哈,一言不合就开挖。众所周知当战争来临时,地下堡垒的安全性是最高的。阿伟苦苦研习两载半,只为练就一身挖洞本领。在这双逆天麒麟臂的加持下,如此坚硬的泥土都只能当做炮灰。 -得到了充足的空间后,他便开始对这些边缘进行打磨。随后阿伟将细线捆在木棍上,以此描绘出圆柱的轮廓。接着再一点点铲掉多余的部分。虽然是由泥土一体式打造,但这样的桌子保准用上千年都不成问题。 -考虑到十几米的深度进出非常不方便,于是阿伟找来两根长达 66.6 米的木头,打算为庇护所打造一条快速通道。只见他将木桩牢牢地插入地下,并顺着洞口的方向延伸出去,直到贯穿整个山洞。接着在每个木桩的连接处钉入铁钉,确保轨道不能有一毫米的偏差。完成后再制作一个木质框架,从而达到前后滑动的效果。 -不得不说阿伟这手艺简直就是大钢管子杵青蛙。在上面放上一个木制的车斗,还能加快搬运泥土的速度。没多久庇护所的内部就已经初见雏形。为了住起来更加舒适,还需要为自己打造一张床。虽然深处的泥土同样很坚固,但好处就是不用担心垮塌的风险。 -阿伟不仅设计了更加符合人体工学的拱形,并且还在一旁雕刻处壁龛。就是这氛围怎么看着有点不太吉利。别看阿伟一身腱子肉,但这身体里的艺术细菌可不少。每个边缘的地方他都做了精雕细琢,瞬间让整个卧室的颜值提升一大截。 -住在地下的好处就是房子面积全靠挖,每平方消耗两个半馒头。不仅没有了房贷的压力,就连买墓地的钱也省了。阿伟将中间的墙壁挖空,从而得到取暖的壁炉。当然最重要的还有排烟问题,要想从上往下打通十几米的山体是件极其困难的事。好在阿伟年轻时报过忆坤年的古墓派补习班,这打洞技术堪比隔壁学校的土拨鼠专业。虽然深度长达十几米,但排烟效果却一点不受影响,一个字专业! -随后阿伟继续对壁炉底部雕刻,打通了底部放柴火的空间,并制作出放锅的灶头。完成后阿伟从侧面将壁炉打通,并制作出一条导热的通道,以此连接到床铺的位置。毕竟住在这么一个风湿宝地,不注意保暖除湿很容易得老寒腿。 -阿伟在床面上挖出一条条管道,以便于温度能传输到床的每个角落。接下来就可以根据这些通道的长度裁切出同样长短的竹子,根据竹筒的大小凿出相互连接的孔洞,最后再将竹筒内部打通,以达到温度传送的效果。 -而后阿伟将这些管道安装到凹槽内,在他严谨的制作工艺下,每根竹子刚好都能镶嵌进去。在铺设床面之前还需要用木塞把圆孔堵住,防止泥土掉落进管道。泥土虽然不能隔绝湿气,但却是十分优良的导热材料。等他把床面都压平后就可以小心的将这些木塞拔出来,最后再用黏土把剩余的管道也遮盖起来,直到整个墙面恢复原样。 -接下来还需要测试一下加热效果,当他把火点起来后,温度很快就传送到了管道内,把火力一点点加大,直到热气流淌到更远的床面。随着小孔里的青烟冒出,也预示着阿伟的地暖可以投入使用。而后阿伟制作了一些竹条,并用细绳将它们喜结连理。 -千里之行始于足下,美好的家园要靠自己双手打造。明明可以靠才艺吃饭的阿伟偏偏要用八块腹肌征服大家,就问这样的男人哪个野生婆娘不喜欢?完成后阿伟还用自己 35 码的大腚感受了一下,真烫! -随后阿伟来到野区找到一根上好的雷击木,他当即就把木头咔嚓成两段,并取下两节较为完整的带了回去,刚好能和圆桌配套。另外一个在里面凿出凹槽,并插入木棍连接,得到一个夯土的木锤。住过农村的小伙伴都知道,这样夯出来的地面堪比水泥地,不仅坚硬耐磨,还不用担心脚底打滑。忙碌了一天的阿伟已经饥渴难耐,拿出野生小烤肠,安安心心住新房,光脚爬上大热炕,一觉能睡到天亮。 -第二天阿伟打算将房间扩宽,毕竟吃住的地方有了,还要解决个人卫生的问题。阿伟在另一侧增加了一个房间,他打算将这里打造成洗澡的地方。为了防止泥土垮塌,他将顶部做成圆弧形,等挖出足够的空间后,旁边的泥土已经堆成了小山。 -为了方便清理这些泥土,阿伟在之前的轨道增加了转弯,交接处依然是用铁钉固定,一直延伸到房间的最里面。有了运输车的帮助,这些成吨的泥土也能轻松的运送出去,并且还能体验过山车的感觉。很快他就完成了清理工作。 -为了更方便的在里面洗澡,他将底部一点点挖空,这么大的浴缸,看来阿伟并不打算一个人住。完成后他将墙面雕刻的凹凸有致,让这里看起来更加豪华。接着用洛阳铲挖出排水口,并用一根相同大小的竹筒作为开关。 -由于四周都是泥土还不能防水,阿伟特意找了一些白蚁巢,用来制作可以防水的野生水泥。现在就可以将里里外外,能接触到水的地方都涂抹一遍。细心的阿伟还找来这种 500 克一斤的鹅卵石,对池子表面进行装饰。 -没错,水源问题阿伟早已经考虑在内,他打算直接在旁边挖个水井,毕竟已经挖了这么深,再向下挖一挖,应该就能到达地下水的深度。经过几日的奋战,能看得出阿伟已经消瘦了不少,但一想到马上就能拥有的豪宅,他直接化身为无情的挖土机器,很快就挖到了好几米的深度。 -考虑到自己的弹跳力有限,阿伟在一旁定入木桩,然后通过绳子爬上爬下。随着深度越来越深,井底已经开始渗出水来,这也预示着打井成功。没多久这里面将渗满泉水,仅凭一次就能挖到水源,看来这里还真是块风湿宝地。 -随后阿伟在井口四周挖出凹槽,以便于井盖的安置。这一量才知道,井的深度已经达到了足足的 5 米。阿伟把木板组合在一起,再沿着标记切掉多余部分,他甚至还给井盖做了把手。可是如何从这么深的井里打水还是个问题,但从阿伟坚定的眼神来看,他应该想到了解决办法。 -只见他将树桩锯成两半,然后用凿子把里面一点点掏空,另外一半也是如法炮制。接着还要在底部挖出圆孔,要想成功将水从 5 米深的地方抽上来,那就不得不提到大家熟知的勾股定理。没错,这跟勾股定理没什么关系。 -阿伟给竹筒做了一个木塞,并在里面打上安装连接轴的孔。为了增加密闭性,阿伟不得不牺牲了自己的 AJ,剪出与木塞相同的大小后,再用木钉固定住。随后他收集了一些树胶,并放到火上加热融化。接下来就可以涂在木塞上增加使用寿命。 -现在将竹筒组装完成,就可以利用虹吸原理将水抽上来。完成后就可以把井盖盖上去,再用泥土在上面覆盖,现在就不用担心失足掉下去了。 -接下来阿伟去采集了一些大漆,将它涂抹在木桶接缝处,就能将其二合为一。完了再接入旁边浴缸的入水口,每个连接的地方都要做好密封,不然后面很容易漏水。随后就可以安装上活塞,并用一根木桩作为省力杠杆,根据空气压强的原理将井水抽上来。 -经过半小时的来回拉扯,硕大的浴缸终于被灌满,阿伟也是忍不住洗了把脸。接下来还需要解决排水的问题,阿伟在地上挖出沟渠,一直贯穿到屋外,然后再用竹筒从出水口连接,每个接口处都要抹上胶水,就连门外的出水口他都做了隐藏。 -在野外最重要的就是庇护所、水源还有食物。既然已经完成了前二者,那么阿伟还需要拥有可持续发展的食物来源。他先是在地上挖了两排地洞,然后在每根竹筒的表面都打上无数孔洞,这就是他打算用来种植的载体。在此之前,还需要用大火对竹筒进行杀菌消毒。 -趁着这时候,他去搬了一麻袋的木屑,先用芭蕉叶覆盖在上面,再铺上厚厚的黏土隔绝温度。在火焰的温度下,能让里面的木屑达到生长条件。 -等到第二天所有材料都晾凉后,阿伟才将竹筒内部掏空,并将木屑一点点地塞入竹筒。一切准备就绪,就可以将竹筒插入提前挖好的地洞。最后再往竹筒里塞入种子,依靠房间内的湿度和温度,就能达到大棚种植的效果。稍加时日,这些种子就会慢慢发芽。 -虽然暂时还吃不上自己培养的食物,但好在阿伟从表哥贺强那里学到不少钓鱼本领,哪怕只有一根小小的竹竿,也能让他钓上两斤半的大鲶鱼。新鲜的食材,那肯定是少不了高温消毒的过程。趁着鱼没熟,阿伟直接爬进浴缸,冰凉的井水瞬间洗去了身上的疲惫。这一刻的阿伟是无比的享受。 -不久后鱼也烤得差不多了,阿伟的生活现在可以说是有滋有味。住在十几米的地下,不仅能安全感满满,哪怕遇到危险,还能通过轨道快速逃生。 - - -%s - -我正在尝试做这个内容的解说纪录片视频,我需要你以 中的内容为解说目标,根据我刚才提供给你的对标文案 特点,以及你总结的特点,帮我生成一段关于荒野建造的解说文案,文案需要符合平台受欢迎的解说风格,请使用 json 格式进行输出;使用 中的输出格式: - -{ - "items": [ - { - "_id": 1, # 唯一递增id - "timestamp": "00:00:05,390-00:00:10,430", - "picture": "画面描述", - "narration": "解说文案", - } -} - - -1. 只输出 json 内容,不要输出其他任何说明性的文字 -2. 解说文案的语言使用 简体中文 -3. 严禁虚构画面,所有画面只能从 中摘取 - -""" % (markdown_content) # 使用OpenAI SDK初始化客户端 client = OpenAI( diff --git a/app/services/generate_video.py b/app/services/generate_video.py index 9b03b52..8395aef 100644 --- a/app/services/generate_video.py +++ b/app/services/generate_video.py @@ -4,7 +4,7 @@ ''' @Project: NarratoAI @File : generate_video -@Author : 小林同学 +@Author : Viccy同学 @Date : 2025/5/7 上午11:55 ''' diff --git a/app/services/llm/__init__.py b/app/services/llm/__init__.py new file mode 100644 index 0000000..d05b43c --- /dev/null +++ b/app/services/llm/__init__.py @@ -0,0 +1,52 @@ +""" +NarratoAI 大模型服务模块 + +统一的大模型服务抽象层,支持多供应商切换和严格的输出格式验证 +包含视觉模型和文本生成模型的统一接口 + +主要组件: +- BaseLLMProvider: 大模型服务提供商基类 +- VisionModelProvider: 视觉模型提供商基类 +- TextModelProvider: 文本模型提供商基类 +- LLMServiceManager: 大模型服务管理器 +- OutputValidator: 输出格式验证器 + +支持的供应商: +视觉模型: Gemini, QwenVL, Siliconflow +文本模型: OpenAI, DeepSeek, Gemini, Qwen, Moonshot, Siliconflow +""" + +from .manager import LLMServiceManager +from .base import BaseLLMProvider, VisionModelProvider, TextModelProvider +from .validators import OutputValidator, ValidationError +from .exceptions import LLMServiceError, ProviderNotFoundError, ConfigurationError + +# 确保提供商在模块导入时被注册 +def _ensure_providers_registered(): + """确保所有提供商都已注册""" + try: + # 导入providers模块会自动执行注册 + from . import providers + from loguru import logger + logger.debug("LLM服务提供商注册完成") + except Exception as e: + from loguru import logger + logger.error(f"LLM服务提供商注册失败: {str(e)}") + +# 自动注册提供商 +_ensure_providers_registered() + +__all__ = [ + 'LLMServiceManager', + 'BaseLLMProvider', + 'VisionModelProvider', + 'TextModelProvider', + 'OutputValidator', + 'ValidationError', + 'LLMServiceError', + 'ProviderNotFoundError', + 'ConfigurationError' +] + +# 版本信息 +__version__ = '1.0.0' diff --git a/app/services/llm/base.py b/app/services/llm/base.py new file mode 100644 index 0000000..91f6c33 --- /dev/null +++ b/app/services/llm/base.py @@ -0,0 +1,175 @@ +""" +大模型服务提供商基类定义 + +定义了统一的大模型服务接口,包括视觉模型和文本生成模型的抽象基类 +""" + +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional, Union +from pathlib import Path +import PIL.Image +from loguru import logger + +from .exceptions import LLMServiceError, ConfigurationError + + +class BaseLLMProvider(ABC): + """大模型服务提供商基类""" + + def __init__(self, + api_key: str, + model_name: str, + base_url: Optional[str] = None, + **kwargs): + """ + 初始化大模型服务提供商 + + Args: + api_key: API密钥 + model_name: 模型名称 + base_url: API基础URL + **kwargs: 其他配置参数 + """ + self.api_key = api_key + self.model_name = model_name + self.base_url = base_url + self.config = kwargs + + # 验证必要配置 + self._validate_config() + + # 初始化提供商特定设置 + self._initialize() + + @property + @abstractmethod + def provider_name(self) -> str: + """供应商名称""" + pass + + @property + @abstractmethod + def supported_models(self) -> List[str]: + """支持的模型列表""" + pass + + def _validate_config(self): + """验证配置参数""" + if not self.api_key: + raise ConfigurationError("API密钥不能为空", "api_key") + + if not self.model_name: + raise ConfigurationError("模型名称不能为空", "model_name") + + if self.model_name not in self.supported_models: + from .exceptions import ModelNotSupportedError + raise ModelNotSupportedError(self.model_name, self.provider_name) + + def _initialize(self): + """初始化提供商特定设置,子类可重写""" + pass + + @abstractmethod + async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """执行API调用,子类必须实现""" + pass + + def _handle_api_error(self, status_code: int, response_text: str) -> LLMServiceError: + """处理API错误,返回适当的异常""" + from .exceptions import APICallError, RateLimitError, AuthenticationError + + if status_code == 401: + return AuthenticationError() + elif status_code == 429: + return RateLimitError() + else: + return APICallError(f"HTTP {status_code}", status_code, response_text) + + +class VisionModelProvider(BaseLLMProvider): + """视觉模型提供商基类""" + + @abstractmethod + async def analyze_images(self, + images: List[Union[str, Path, PIL.Image.Image]], + prompt: str, + batch_size: int = 10, + **kwargs) -> List[str]: + """ + 分析图片并返回结果 + + Args: + images: 图片路径列表或PIL图片对象列表 + prompt: 分析提示词 + batch_size: 批处理大小 + **kwargs: 其他参数 + + Returns: + 分析结果列表 + """ + pass + + def _prepare_images(self, images: List[Union[str, Path, PIL.Image.Image]]) -> List[PIL.Image.Image]: + """预处理图片,统一转换为PIL.Image对象""" + processed_images = [] + + for img in images: + try: + if isinstance(img, (str, Path)): + pil_img = PIL.Image.open(img) + elif isinstance(img, PIL.Image.Image): + pil_img = img + else: + logger.warning(f"不支持的图片类型: {type(img)}") + continue + + # 调整图片大小以优化性能 + if pil_img.size[0] > 1024 or pil_img.size[1] > 1024: + pil_img.thumbnail((1024, 1024), PIL.Image.Resampling.LANCZOS) + + processed_images.append(pil_img) + + except Exception as e: + logger.error(f"加载图片失败 {img}: {str(e)}") + continue + + return processed_images + + +class TextModelProvider(BaseLLMProvider): + """文本生成模型提供商基类""" + + @abstractmethod + async def generate_text(self, + prompt: str, + system_prompt: Optional[str] = None, + temperature: float = 1.0, + max_tokens: Optional[int] = None, + response_format: Optional[str] = None, + **kwargs) -> str: + """ + 生成文本内容 + + Args: + prompt: 用户提示词 + system_prompt: 系统提示词 + temperature: 生成温度 + max_tokens: 最大token数 + response_format: 响应格式 ('json' 或 None) + **kwargs: 其他参数 + + Returns: + 生成的文本内容 + """ + pass + + def _build_messages(self, prompt: str, system_prompt: Optional[str] = None) -> List[Dict[str, str]]: + """构建消息列表""" + messages = [] + + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + + messages.append({"role": "user", "content": prompt}) + + return messages diff --git a/app/services/llm/config_validator.py b/app/services/llm/config_validator.py new file mode 100644 index 0000000..0bfe287 --- /dev/null +++ b/app/services/llm/config_validator.py @@ -0,0 +1,307 @@ +""" +LLM服务配置验证器 + +验证大模型服务的配置是否正确,并提供配置建议 +""" + +from typing import Dict, List, Any, Optional +from loguru import logger + +from app.config import config +from .manager import LLMServiceManager +from .exceptions import ConfigurationError + + +class LLMConfigValidator: + """LLM服务配置验证器""" + + @staticmethod + def validate_all_configs() -> Dict[str, Any]: + """ + 验证所有LLM服务配置 + + Returns: + 验证结果字典 + """ + results = { + "vision_providers": {}, + "text_providers": {}, + "summary": { + "total_vision_providers": 0, + "valid_vision_providers": 0, + "total_text_providers": 0, + "valid_text_providers": 0, + "errors": [], + "warnings": [] + } + } + + # 验证视觉模型提供商 + vision_providers = LLMServiceManager.list_vision_providers() + results["summary"]["total_vision_providers"] = len(vision_providers) + + for provider in vision_providers: + try: + validation_result = LLMConfigValidator.validate_vision_provider(provider) + results["vision_providers"][provider] = validation_result + + if validation_result["is_valid"]: + results["summary"]["valid_vision_providers"] += 1 + else: + results["summary"]["errors"].extend(validation_result["errors"]) + + except Exception as e: + error_msg = f"验证视觉模型提供商 {provider} 时发生错误: {str(e)}" + results["vision_providers"][provider] = { + "is_valid": False, + "errors": [error_msg], + "warnings": [] + } + results["summary"]["errors"].append(error_msg) + + # 验证文本模型提供商 + text_providers = LLMServiceManager.list_text_providers() + results["summary"]["total_text_providers"] = len(text_providers) + + for provider in text_providers: + try: + validation_result = LLMConfigValidator.validate_text_provider(provider) + results["text_providers"][provider] = validation_result + + if validation_result["is_valid"]: + results["summary"]["valid_text_providers"] += 1 + else: + results["summary"]["errors"].extend(validation_result["errors"]) + + except Exception as e: + error_msg = f"验证文本模型提供商 {provider} 时发生错误: {str(e)}" + results["text_providers"][provider] = { + "is_valid": False, + "errors": [error_msg], + "warnings": [] + } + results["summary"]["errors"].append(error_msg) + + return results + + @staticmethod + def validate_vision_provider(provider_name: str) -> Dict[str, Any]: + """ + 验证视觉模型提供商配置 + + Args: + provider_name: 提供商名称 + + Returns: + 验证结果字典 + """ + result = { + "is_valid": False, + "errors": [], + "warnings": [], + "config": {} + } + + try: + # 获取配置 + config_prefix = f"vision_{provider_name}" + api_key = config.app.get(f'{config_prefix}_api_key') + model_name = config.app.get(f'{config_prefix}_model_name') + base_url = config.app.get(f'{config_prefix}_base_url') + + result["config"] = { + "api_key": "***" if api_key else None, + "model_name": model_name, + "base_url": base_url + } + + # 验证必需配置 + if not api_key: + result["errors"].append(f"缺少API密钥配置: {config_prefix}_api_key") + + if not model_name: + result["errors"].append(f"缺少模型名称配置: {config_prefix}_model_name") + + # 尝试创建提供商实例 + if api_key and model_name: + try: + provider_instance = LLMServiceManager.get_vision_provider(provider_name) + result["is_valid"] = True + logger.debug(f"视觉模型提供商 {provider_name} 配置验证成功") + + except Exception as e: + result["errors"].append(f"创建提供商实例失败: {str(e)}") + + # 添加警告 + if not base_url: + result["warnings"].append(f"未配置base_url,将使用默认值") + + except Exception as e: + result["errors"].append(f"配置验证过程中发生错误: {str(e)}") + + return result + + @staticmethod + def validate_text_provider(provider_name: str) -> Dict[str, Any]: + """ + 验证文本模型提供商配置 + + Args: + provider_name: 提供商名称 + + Returns: + 验证结果字典 + """ + result = { + "is_valid": False, + "errors": [], + "warnings": [], + "config": {} + } + + try: + # 获取配置 + config_prefix = f"text_{provider_name}" + api_key = config.app.get(f'{config_prefix}_api_key') + model_name = config.app.get(f'{config_prefix}_model_name') + base_url = config.app.get(f'{config_prefix}_base_url') + + result["config"] = { + "api_key": "***" if api_key else None, + "model_name": model_name, + "base_url": base_url + } + + # 验证必需配置 + if not api_key: + result["errors"].append(f"缺少API密钥配置: {config_prefix}_api_key") + + if not model_name: + result["errors"].append(f"缺少模型名称配置: {config_prefix}_model_name") + + # 尝试创建提供商实例 + if api_key and model_name: + try: + provider_instance = LLMServiceManager.get_text_provider(provider_name) + result["is_valid"] = True + logger.debug(f"文本模型提供商 {provider_name} 配置验证成功") + + except Exception as e: + result["errors"].append(f"创建提供商实例失败: {str(e)}") + + # 添加警告 + if not base_url: + result["warnings"].append(f"未配置base_url,将使用默认值") + + except Exception as e: + result["errors"].append(f"配置验证过程中发生错误: {str(e)}") + + return result + + @staticmethod + def get_config_suggestions() -> Dict[str, Any]: + """ + 获取配置建议 + + Returns: + 配置建议字典 + """ + suggestions = { + "vision_providers": {}, + "text_providers": {}, + "general_tips": [ + "确保所有API密钥都已正确配置", + "建议为每个提供商配置base_url以提高稳定性", + "定期检查模型名称是否为最新版本", + "建议配置多个提供商作为备用方案" + ] + } + + # 为每个视觉模型提供商提供建议 + vision_providers = LLMServiceManager.list_vision_providers() + for provider in vision_providers: + suggestions["vision_providers"][provider] = { + "required_configs": [ + f"vision_{provider}_api_key", + f"vision_{provider}_model_name" + ], + "optional_configs": [ + f"vision_{provider}_base_url" + ], + "example_models": LLMConfigValidator._get_example_models(provider, "vision") + } + + # 为每个文本模型提供商提供建议 + text_providers = LLMServiceManager.list_text_providers() + for provider in text_providers: + suggestions["text_providers"][provider] = { + "required_configs": [ + f"text_{provider}_api_key", + f"text_{provider}_model_name" + ], + "optional_configs": [ + f"text_{provider}_base_url" + ], + "example_models": LLMConfigValidator._get_example_models(provider, "text") + } + + return suggestions + + @staticmethod + def _get_example_models(provider_name: str, model_type: str) -> List[str]: + """获取示例模型名称""" + examples = { + "gemini": { + "vision": ["gemini-2.0-flash-lite", "gemini-2.0-flash"], + "text": ["gemini-2.0-flash", "gemini-1.5-pro"] + }, + "openai": { + "vision": [], + "text": ["gpt-4o", "gpt-4o-mini", "gpt-4-turbo"] + }, + "qwen": { + "vision": ["qwen2.5-vl-32b-instruct"], + "text": ["qwen-plus-1127", "qwen-turbo"] + }, + "deepseek": { + "vision": [], + "text": ["deepseek-chat", "deepseek-reasoner"] + }, + "siliconflow": { + "vision": ["Qwen/Qwen2.5-VL-32B-Instruct"], + "text": ["deepseek-ai/DeepSeek-R1", "Qwen/Qwen2.5-72B-Instruct"] + } + } + + return examples.get(provider_name, {}).get(model_type, []) + + @staticmethod + def print_validation_report(validation_results: Dict[str, Any]): + """ + 打印验证报告 + + Args: + validation_results: 验证结果 + """ + summary = validation_results["summary"] + + print("\n" + "="*60) + print("LLM服务配置验证报告") + print("="*60) + + print(f"\n📊 总体统计:") + print(f" 视觉模型提供商: {summary['valid_vision_providers']}/{summary['total_vision_providers']} 有效") + print(f" 文本模型提供商: {summary['valid_text_providers']}/{summary['total_text_providers']} 有效") + + if summary["errors"]: + print(f"\n❌ 错误 ({len(summary['errors'])}):") + for error in summary["errors"]: + print(f" - {error}") + + if summary["warnings"]: + print(f"\n⚠️ 警告 ({len(summary['warnings'])}):") + for warning in summary["warnings"]: + print(f" - {warning}") + + print(f"\n✅ 配置验证完成") + print("="*60) diff --git a/app/services/llm/exceptions.py b/app/services/llm/exceptions.py new file mode 100644 index 0000000..545bacd --- /dev/null +++ b/app/services/llm/exceptions.py @@ -0,0 +1,118 @@ +""" +大模型服务异常类定义 + +定义了大模型服务中可能出现的各种异常类型, +提供统一的错误处理机制 +""" + +from typing import Optional, Dict, Any + + +class LLMServiceError(Exception): + """大模型服务基础异常类""" + + def __init__(self, message: str, error_code: Optional[str] = None, details: Optional[Dict[str, Any]] = None): + super().__init__(message) + self.message = message + self.error_code = error_code + self.details = details or {} + + def __str__(self): + if self.error_code: + return f"[{self.error_code}] {self.message}" + return self.message + + +class ProviderNotFoundError(LLMServiceError): + """供应商未找到异常""" + + def __init__(self, provider_name: str): + super().__init__( + message=f"未找到大模型供应商: {provider_name}", + error_code="PROVIDER_NOT_FOUND", + details={"provider_name": provider_name} + ) + + +class ConfigurationError(LLMServiceError): + """配置错误异常""" + + def __init__(self, message: str, config_key: Optional[str] = None): + super().__init__( + message=f"配置错误: {message}", + error_code="CONFIGURATION_ERROR", + details={"config_key": config_key} if config_key else {} + ) + + +class APICallError(LLMServiceError): + """API调用错误异常""" + + def __init__(self, message: str, status_code: Optional[int] = None, response_text: Optional[str] = None): + super().__init__( + message=f"API调用失败: {message}", + error_code="API_CALL_ERROR", + details={ + "status_code": status_code, + "response_text": response_text + } + ) + + +class ValidationError(LLMServiceError): + """输出验证错误异常""" + + def __init__(self, message: str, validation_type: Optional[str] = None, invalid_data: Optional[Any] = None): + super().__init__( + message=f"输出验证失败: {message}", + error_code="VALIDATION_ERROR", + details={ + "validation_type": validation_type, + "invalid_data": str(invalid_data) if invalid_data else None + } + ) + + +class ModelNotSupportedError(LLMServiceError): + """模型不支持异常""" + + def __init__(self, model_name: str, provider_name: str): + super().__init__( + message=f"供应商 {provider_name} 不支持模型 {model_name}", + error_code="MODEL_NOT_SUPPORTED", + details={ + "model_name": model_name, + "provider_name": provider_name + } + ) + + +class RateLimitError(LLMServiceError): + """API速率限制异常""" + + def __init__(self, message: str = "API调用频率超限", retry_after: Optional[int] = None): + super().__init__( + message=message, + error_code="RATE_LIMIT_ERROR", + details={"retry_after": retry_after} + ) + + +class AuthenticationError(LLMServiceError): + """认证错误异常""" + + def __init__(self, message: str = "API密钥无效或权限不足"): + super().__init__( + message=message, + error_code="AUTHENTICATION_ERROR" + ) + + +class ContentFilterError(LLMServiceError): + """内容过滤异常""" + + def __init__(self, message: str = "内容被安全过滤器阻止"): + super().__init__( + message=message, + error_code="CONTENT_FILTER_ERROR" + ) diff --git a/app/services/llm/manager.py b/app/services/llm/manager.py new file mode 100644 index 0000000..ac32932 --- /dev/null +++ b/app/services/llm/manager.py @@ -0,0 +1,214 @@ +""" +大模型服务管理器 + +统一管理所有大模型服务提供商,提供简单的工厂方法来创建和获取服务实例 +""" + +from typing import Dict, Type, Optional +from loguru import logger + +from app.config import config +from .base import VisionModelProvider, TextModelProvider +from .exceptions import ProviderNotFoundError, ConfigurationError + + +class LLMServiceManager: + """大模型服务管理器""" + + # 注册的视觉模型提供商 + _vision_providers: Dict[str, Type[VisionModelProvider]] = {} + + # 注册的文本模型提供商 + _text_providers: Dict[str, Type[TextModelProvider]] = {} + + # 缓存的提供商实例 + _vision_instance_cache: Dict[str, VisionModelProvider] = {} + _text_instance_cache: Dict[str, TextModelProvider] = {} + + @classmethod + def register_vision_provider(cls, name: str, provider_class: Type[VisionModelProvider]): + """注册视觉模型提供商""" + cls._vision_providers[name.lower()] = provider_class + logger.debug(f"注册视觉模型提供商: {name}") + + @classmethod + def register_text_provider(cls, name: str, provider_class: Type[TextModelProvider]): + """注册文本模型提供商""" + cls._text_providers[name.lower()] = provider_class + logger.debug(f"注册文本模型提供商: {name}") + + @classmethod + def _ensure_providers_registered(cls): + """确保提供商已注册""" + try: + # 如果没有注册的提供商,强制导入providers模块 + if not cls._vision_providers or not cls._text_providers: + from . import providers + logger.debug("LLMServiceManager强制注册提供商") + except Exception as e: + logger.error(f"LLMServiceManager确保提供商注册时发生错误: {str(e)}") + + @classmethod + def get_vision_provider(cls, provider_name: Optional[str] = None) -> VisionModelProvider: + """ + 获取视觉模型提供商实例 + + Args: + provider_name: 提供商名称,如果不指定则从配置中获取 + + Returns: + 视觉模型提供商实例 + + Raises: + ProviderNotFoundError: 提供商未找到 + ConfigurationError: 配置错误 + """ + # 确保提供商已注册 + cls._ensure_providers_registered() + + # 确定提供商名称 + if not provider_name: + provider_name = config.app.get('vision_llm_provider', 'gemini').lower() + else: + provider_name = provider_name.lower() + + # 检查缓存 + cache_key = f"vision_{provider_name}" + if cache_key in cls._vision_instance_cache: + return cls._vision_instance_cache[cache_key] + + # 检查提供商是否已注册 + if provider_name not in cls._vision_providers: + raise ProviderNotFoundError(provider_name) + + # 获取配置 + config_prefix = f"vision_{provider_name}" + api_key = config.app.get(f'{config_prefix}_api_key') + model_name = config.app.get(f'{config_prefix}_model_name') + base_url = config.app.get(f'{config_prefix}_base_url') + + if not api_key: + raise ConfigurationError(f"缺少API密钥配置: {config_prefix}_api_key") + + if not model_name: + raise ConfigurationError(f"缺少模型名称配置: {config_prefix}_model_name") + + # 创建提供商实例 + provider_class = cls._vision_providers[provider_name] + try: + instance = provider_class( + api_key=api_key, + model_name=model_name, + base_url=base_url + ) + + # 缓存实例 + cls._vision_instance_cache[cache_key] = instance + + logger.info(f"创建视觉模型提供商实例: {provider_name} - {model_name}") + return instance + + except Exception as e: + logger.error(f"创建视觉模型提供商实例失败: {provider_name} - {str(e)}") + raise ConfigurationError(f"创建提供商实例失败: {str(e)}") + + @classmethod + def get_text_provider(cls, provider_name: Optional[str] = None) -> TextModelProvider: + """ + 获取文本模型提供商实例 + + Args: + provider_name: 提供商名称,如果不指定则从配置中获取 + + Returns: + 文本模型提供商实例 + + Raises: + ProviderNotFoundError: 提供商未找到 + ConfigurationError: 配置错误 + """ + # 确保提供商已注册 + cls._ensure_providers_registered() + + # 确定提供商名称 + if not provider_name: + provider_name = config.app.get('text_llm_provider', 'openai').lower() + else: + provider_name = provider_name.lower() + + # 检查缓存 + cache_key = f"text_{provider_name}" + if cache_key in cls._text_instance_cache: + return cls._text_instance_cache[cache_key] + + # 检查提供商是否已注册 + if provider_name not in cls._text_providers: + raise ProviderNotFoundError(provider_name) + + # 获取配置 + config_prefix = f"text_{provider_name}" + api_key = config.app.get(f'{config_prefix}_api_key') + model_name = config.app.get(f'{config_prefix}_model_name') + base_url = config.app.get(f'{config_prefix}_base_url') + + if not api_key: + raise ConfigurationError(f"缺少API密钥配置: {config_prefix}_api_key") + + if not model_name: + raise ConfigurationError(f"缺少模型名称配置: {config_prefix}_model_name") + + # 创建提供商实例 + provider_class = cls._text_providers[provider_name] + try: + instance = provider_class( + api_key=api_key, + model_name=model_name, + base_url=base_url + ) + + # 缓存实例 + cls._text_instance_cache[cache_key] = instance + + logger.info(f"创建文本模型提供商实例: {provider_name} - {model_name}") + return instance + + except Exception as e: + logger.error(f"创建文本模型提供商实例失败: {provider_name} - {str(e)}") + raise ConfigurationError(f"创建提供商实例失败: {str(e)}") + + @classmethod + def clear_cache(cls): + """清空提供商实例缓存""" + cls._vision_instance_cache.clear() + cls._text_instance_cache.clear() + logger.info("已清空提供商实例缓存") + + @classmethod + def list_vision_providers(cls) -> list[str]: + """列出所有已注册的视觉模型提供商""" + return list(cls._vision_providers.keys()) + + @classmethod + def list_text_providers(cls) -> list[str]: + """列出所有已注册的文本模型提供商""" + return list(cls._text_providers.keys()) + + @classmethod + def get_provider_info(cls) -> Dict[str, Dict[str, any]]: + """获取所有提供商信息""" + return { + "vision_providers": { + name: { + "class": provider_class.__name__, + "module": provider_class.__module__ + } + for name, provider_class in cls._vision_providers.items() + }, + "text_providers": { + name: { + "class": provider_class.__name__, + "module": provider_class.__module__ + } + for name, provider_class in cls._text_providers.items() + } + } diff --git a/app/services/llm/migration_adapter.py b/app/services/llm/migration_adapter.py new file mode 100644 index 0000000..b991910 --- /dev/null +++ b/app/services/llm/migration_adapter.py @@ -0,0 +1,348 @@ +""" +迁移适配器 + +为现有代码提供向后兼容的接口,方便逐步迁移到新的LLM服务架构 +""" + +import asyncio +import json +from typing import List, Dict, Any, Optional, Union +from pathlib import Path +import PIL.Image +from loguru import logger + +from .unified_service import UnifiedLLMService +from .exceptions import LLMServiceError +# 导入新的提示词管理系统 +from app.services.prompts import PromptManager + +# 确保提供商已注册 +def _ensure_providers_registered(): + """确保所有提供商都已注册""" + try: + from .manager import LLMServiceManager + # 检查是否有已注册的提供商 + if not LLMServiceManager.list_text_providers() or not LLMServiceManager.list_vision_providers(): + # 如果没有注册的提供商,强制导入providers模块 + from . import providers + logger.debug("迁移适配器强制注册LLM服务提供商") + except Exception as e: + logger.error(f"迁移适配器确保LLM服务提供商注册时发生错误: {str(e)}") + +# 在模块加载时确保提供商已注册 +_ensure_providers_registered() + + +def _run_async_safely(coro_func, *args, **kwargs): + """ + 安全地运行异步协程,处理各种事件循环情况 + + Args: + coro_func: 协程函数(不是协程对象) + *args: 协程函数的位置参数 + **kwargs: 协程函数的关键字参数 + + Returns: + 协程的执行结果 + """ + def run_in_new_loop(): + """在新的事件循环中运行协程""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete(coro_func(*args, **kwargs)) + finally: + loop.close() + asyncio.set_event_loop(None) + + try: + # 尝试获取当前事件循环 + try: + loop = asyncio.get_running_loop() + # 如果有运行中的事件循环,使用线程池执行 + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(run_in_new_loop) + return future.result() + except RuntimeError: + # 没有运行中的事件循环,直接运行 + return run_in_new_loop() + except Exception as e: + logger.error(f"异步执行失败: {str(e)}") + raise LLMServiceError(f"异步执行失败: {str(e)}") + + +class LegacyLLMAdapter: + """传统LLM接口适配器""" + + @staticmethod + def create_vision_analyzer(provider: str, api_key: str, model: str, base_url: str = None): + """ + 创建视觉分析器实例 - 兼容原有接口 + + Args: + provider: 提供商名称 + api_key: API密钥 + model: 模型名称 + base_url: API基础URL + + Returns: + 适配器实例 + """ + return VisionAnalyzerAdapter(provider, api_key, model, base_url) + + @staticmethod + def generate_narration(markdown_content: str, api_key: str, base_url: str, model: str) -> str: + """ + 生成解说文案 - 兼容原有接口 + + Args: + markdown_content: Markdown格式的视频帧分析内容 + api_key: API密钥 + base_url: API基础URL + model: 模型名称 + + Returns: + 生成的解说文案JSON字符串 + """ + try: + # 使用新的提示词管理系统 + prompt = PromptManager.get_prompt( + category="documentary", + name="narration_generation", + parameters={ + "video_frame_description": markdown_content + } + ) + + # 使用统一服务生成文案 + result = _run_async_safely( + UnifiedLLMService.generate_text, + prompt=prompt, + system_prompt="你是一名专业的短视频解说文案撰写专家。", + temperature=1.5, + response_format="json" + ) + + # 使用增强的JSON解析器 + from webui.tools.generate_short_summary import parse_and_fix_json + parsed_result = parse_and_fix_json(result) + + if not parsed_result: + logger.error("无法解析LLM返回的JSON数据") + # 返回一个基本的JSON结构而不是错误字符串 + return json.dumps({ + "items": [ + { + "_id": 1, + "timestamp": "00:00:00-00:00:10", + "picture": "解析失败,请检查LLM输出", + "narration": "解说文案生成失败,请重试" + } + ] + }, ensure_ascii=False) + + # 确保返回的是JSON字符串 + return json.dumps(parsed_result, ensure_ascii=False) + + except Exception as e: + logger.error(f"生成解说文案失败: {str(e)}") + # 返回一个基本的JSON结构而不是错误字符串 + return json.dumps({ + "items": [ + { + "_id": 1, + "timestamp": "00:00:00-00:00:10", + "picture": "生成失败", + "narration": f"解说文案生成失败: {str(e)}" + } + ] + }, ensure_ascii=False) + + +class VisionAnalyzerAdapter: + """视觉分析器适配器""" + + def __init__(self, provider: str, api_key: str, model: str, base_url: str = None): + self.provider = provider + self.api_key = api_key + self.model = model + self.base_url = base_url + + async def analyze_images(self, + images: List[Union[str, Path, PIL.Image.Image]], + prompt: str, + batch_size: int = 10) -> List[Dict[str, Any]]: + """ + 分析图片 - 兼容原有接口 + + Args: + images: 图片列表 + prompt: 分析提示词 + batch_size: 批处理大小 + + Returns: + 分析结果列表,格式与旧实现兼容 + """ + try: + # 使用统一服务分析图片 + results = await UnifiedLLMService.analyze_images( + images=images, + prompt=prompt, + provider=self.provider, + batch_size=batch_size + ) + + # 转换为旧格式以保持向后兼容性 + # 新实现返回 List[str],需要转换为 List[Dict] + compatible_results = [] + for i, result in enumerate(results): + # 计算这个批次处理的图片数量 + start_idx = i * batch_size + end_idx = min(start_idx + batch_size, len(images)) + images_processed = end_idx - start_idx + + compatible_results.append({ + 'batch_index': i, + 'images_processed': images_processed, + 'response': result, + 'model_used': self.model + }) + + logger.info(f"图片分析完成,共处理 {len(images)} 张图片,生成 {len(compatible_results)} 个批次结果") + return compatible_results + + except Exception as e: + logger.error(f"图片分析失败: {str(e)}") + raise + + +class SubtitleAnalyzerAdapter: + """字幕分析器适配器""" + + def __init__(self, api_key: str, model: str, base_url: str, provider: str = None): + self.api_key = api_key + self.model = model + self.base_url = base_url + self.provider = provider or "openai" + + def _run_async_safely(self, coro_func, *args, **kwargs): + """安全地运行异步协程""" + return _run_async_safely(coro_func, *args, **kwargs) + + def _clean_json_output(self, output: str) -> str: + """清理JSON输出,移除markdown标记等""" + import re + + # 移除可能的markdown代码块标记 + output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE) + output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE) + output = re.sub(r'^```.*$', '', output, flags=re.MULTILINE) + + # 移除开头和结尾的```标记 + output = re.sub(r'^```', '', output) + output = re.sub(r'```$', '', output) + + # 移除前后空白字符 + output = output.strip() + + return output + + def analyze_subtitle(self, subtitle_content: str) -> Dict[str, Any]: + """ + 分析字幕内容 - 兼容原有接口 + + Args: + subtitle_content: 字幕内容 + + Returns: + 分析结果字典 + """ + try: + # 使用统一服务分析字幕 + result = self._run_async_safely( + UnifiedLLMService.analyze_subtitle, + subtitle_content=subtitle_content, + provider=self.provider, + temperature=1.0 + ) + + return { + "status": "success", + "analysis": result, + "model": self.model, + "temperature": 1.0 + } + + except Exception as e: + logger.error(f"字幕分析失败: {str(e)}") + return { + "status": "error", + "message": str(e), + "temperature": 1.0 + } + + def generate_narration_script(self, short_name: str, plot_analysis: str, temperature: float = 0.7) -> Dict[str, Any]: + """ + 生成解说文案 - 兼容原有接口 + + Args: + short_name: 短剧名称 + plot_analysis: 剧情分析内容 + temperature: 生成温度 + + Returns: + 生成结果字典 + """ + try: + # 使用新的提示词管理系统构建提示词 + prompt = PromptManager.get_prompt( + category="short_drama_narration", + name="script_generation", + parameters={ + "drama_name": short_name, + "plot_analysis": plot_analysis + } + ) + + # 使用统一服务生成文案 + result = self._run_async_safely( + UnifiedLLMService.generate_text, + prompt=prompt, + system_prompt="你是一位专业的短视频解说脚本撰写专家。", + provider=self.provider, + temperature=temperature, + response_format="json" + ) + + # 清理JSON输出 + cleaned_result = self._clean_json_output(result) + + # 新的提示词系统返回的是包含items数组的JSON格式 + # 为了保持向后兼容,我们需要直接返回这个JSON字符串 + # 调用方会期望这是一个包含items数组的JSON字符串 + return { + "status": "success", + "narration_script": cleaned_result, + "model": self.model, + "temperature": temperature + } + + except Exception as e: + logger.error(f"解说文案生成失败: {str(e)}") + return { + "status": "error", + "message": str(e), + "temperature": temperature + } + + +# 为了向后兼容,提供一些全局函数 +def create_vision_analyzer(provider: str, api_key: str, model: str, base_url: str = None): + """创建视觉分析器 - 全局函数""" + return LegacyLLMAdapter.create_vision_analyzer(provider, api_key, model, base_url) + + +def generate_narration(markdown_content: str, api_key: str, base_url: str, model: str) -> str: + """生成解说文案 - 全局函数""" + return LegacyLLMAdapter.generate_narration(markdown_content, api_key, base_url, model) diff --git a/app/services/llm/providers/__init__.py b/app/services/llm/providers/__init__.py new file mode 100644 index 0000000..ea1509d --- /dev/null +++ b/app/services/llm/providers/__init__.py @@ -0,0 +1,47 @@ +""" +大模型服务提供商实现 + +包含各种大模型服务提供商的具体实现 +""" + +from .gemini_provider import GeminiVisionProvider, GeminiTextProvider +from .gemini_openai_provider import GeminiOpenAIVisionProvider, GeminiOpenAITextProvider +from .openai_provider import OpenAITextProvider +from .qwen_provider import QwenVisionProvider, QwenTextProvider +from .deepseek_provider import DeepSeekTextProvider +from .siliconflow_provider import SiliconflowVisionProvider, SiliconflowTextProvider + +# 自动注册所有提供商 +from ..manager import LLMServiceManager + +def register_all_providers(): + """注册所有提供商""" + # 注册视觉模型提供商 + LLMServiceManager.register_vision_provider('gemini', GeminiVisionProvider) + LLMServiceManager.register_vision_provider('gemini(openai)', GeminiOpenAIVisionProvider) + LLMServiceManager.register_vision_provider('qwenvl', QwenVisionProvider) + LLMServiceManager.register_vision_provider('siliconflow', SiliconflowVisionProvider) + + # 注册文本模型提供商 + LLMServiceManager.register_text_provider('gemini', GeminiTextProvider) + LLMServiceManager.register_text_provider('gemini(openai)', GeminiOpenAITextProvider) + LLMServiceManager.register_text_provider('openai', OpenAITextProvider) + LLMServiceManager.register_text_provider('qwen', QwenTextProvider) + LLMServiceManager.register_text_provider('deepseek', DeepSeekTextProvider) + LLMServiceManager.register_text_provider('siliconflow', SiliconflowTextProvider) + +# 自动注册 +register_all_providers() + +__all__ = [ + 'GeminiVisionProvider', + 'GeminiTextProvider', + 'GeminiOpenAIVisionProvider', + 'GeminiOpenAITextProvider', + 'OpenAITextProvider', + 'QwenVisionProvider', + 'QwenTextProvider', + 'DeepSeekTextProvider', + 'SiliconflowVisionProvider', + 'SiliconflowTextProvider' +] diff --git a/app/services/llm/providers/deepseek_provider.py b/app/services/llm/providers/deepseek_provider.py new file mode 100644 index 0000000..1a4836f --- /dev/null +++ b/app/services/llm/providers/deepseek_provider.py @@ -0,0 +1,157 @@ +""" +DeepSeek API提供商实现 + +支持DeepSeek的文本生成模型 +""" + +import asyncio +from typing import List, Dict, Any, Optional +from openai import OpenAI, BadRequestError +from loguru import logger + +from ..base import TextModelProvider +from ..exceptions import APICallError + + +class DeepSeekTextProvider(TextModelProvider): + """DeepSeek文本生成提供商""" + + @property + def provider_name(self) -> str: + return "deepseek" + + @property + def supported_models(self) -> List[str]: + return [ + "deepseek-chat", + "deepseek-reasoner", + "deepseek-r1", + "deepseek-v3" + ] + + def _initialize(self): + """初始化DeepSeek客户端""" + if not self.base_url: + self.base_url = "https://api.deepseek.com" + + self.client = OpenAI( + api_key=self.api_key, + base_url=self.base_url + ) + + async def generate_text(self, + prompt: str, + system_prompt: Optional[str] = None, + temperature: float = 1.0, + max_tokens: Optional[int] = None, + response_format: Optional[str] = None, + **kwargs) -> str: + """ + 使用DeepSeek API生成文本 + + Args: + prompt: 用户提示词 + system_prompt: 系统提示词 + temperature: 生成温度 + max_tokens: 最大token数 + response_format: 响应格式 ('json' 或 None) + **kwargs: 其他参数 + + Returns: + 生成的文本内容 + """ + # 构建消息列表 + messages = self._build_messages(prompt, system_prompt) + + # 构建请求参数 + request_params = { + "model": self.model_name, + "messages": messages, + "temperature": temperature + } + + if max_tokens: + request_params["max_tokens"] = max_tokens + + # 处理JSON格式输出 + # DeepSeek R1 和 V3 不支持 response_format=json_object + if response_format == "json": + if self._supports_response_format(): + request_params["response_format"] = {"type": "json_object"} + else: + # 对于不支持response_format的模型,在提示词中添加约束 + messages[-1]["content"] += "\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。" + + try: + # 发送API请求 + response = await asyncio.to_thread( + self.client.chat.completions.create, + **request_params + ) + + # 提取生成的内容 + if response.choices and len(response.choices) > 0: + content = response.choices[0].message.content + + # 对于不支持response_format的模型,清理输出 + if response_format == "json" and not self._supports_response_format(): + content = self._clean_json_output(content) + + logger.debug(f"DeepSeek API调用成功,消耗tokens: {response.usage.total_tokens if response.usage else 'N/A'}") + return content + else: + raise APICallError("DeepSeek API返回空响应") + + except BadRequestError as e: + # 处理不支持response_format的情况 + if "response_format" in str(e) and response_format == "json": + logger.warning(f"DeepSeek模型 {self.model_name} 不支持response_format,重试不带格式约束的请求") + request_params.pop("response_format", None) + messages[-1]["content"] += "\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。" + + response = await asyncio.to_thread( + self.client.chat.completions.create, + **request_params + ) + + if response.choices and len(response.choices) > 0: + content = response.choices[0].message.content + content = self._clean_json_output(content) + return content + else: + raise APICallError("DeepSeek API返回空响应") + else: + raise APICallError(f"DeepSeek API请求失败: {str(e)}") + + except Exception as e: + logger.error(f"DeepSeek API调用失败: {str(e)}") + raise APICallError(f"DeepSeek API调用失败: {str(e)}") + + def _supports_response_format(self) -> bool: + """检查模型是否支持response_format参数""" + # DeepSeek R1 和 V3 不支持 response_format=json_object + unsupported_models = [ + "deepseek-reasoner", + "deepseek-r1", + "deepseek-v3" + ] + + return not any(unsupported in self.model_name.lower() for unsupported in unsupported_models) + + def _clean_json_output(self, output: str) -> str: + """清理JSON输出,移除markdown标记等""" + import re + + # 移除可能的markdown代码块标记 + output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE) + output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE) + output = re.sub(r'^```.*$', '', output, flags=re.MULTILINE) + + # 移除前后空白字符 + output = output.strip() + + return output + + async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """执行API调用 - 由于使用OpenAI SDK,这个方法主要用于兼容基类""" + pass diff --git a/app/services/llm/providers/gemini_openai_provider.py b/app/services/llm/providers/gemini_openai_provider.py new file mode 100644 index 0000000..45c30cb --- /dev/null +++ b/app/services/llm/providers/gemini_openai_provider.py @@ -0,0 +1,235 @@ +""" +OpenAI兼容的Gemini API提供商实现 + +使用OpenAI兼容接口调用Gemini服务,支持视觉分析和文本生成 +""" + +import asyncio +import base64 +import io +from typing import List, Dict, Any, Optional, Union +from pathlib import Path +import PIL.Image +from openai import OpenAI +from loguru import logger + +from ..base import VisionModelProvider, TextModelProvider +from ..exceptions import APICallError + + +class GeminiOpenAIVisionProvider(VisionModelProvider): + """OpenAI兼容的Gemini视觉模型提供商""" + + @property + def provider_name(self) -> str: + return "gemini(openai)" + + @property + def supported_models(self) -> List[str]: + return [ + "gemini-2.0-flash-lite", + "gemini-2.0-flash", + "gemini-1.5-pro", + "gemini-1.5-flash" + ] + + def _initialize(self): + """初始化OpenAI兼容的Gemini客户端""" + if not self.base_url: + self.base_url = "https://generativelanguage.googleapis.com/v1beta/openai" + + self.client = OpenAI( + api_key=self.api_key, + base_url=self.base_url + ) + + async def analyze_images(self, + images: List[Union[str, Path, PIL.Image.Image]], + prompt: str, + batch_size: int = 10, + **kwargs) -> List[str]: + """ + 使用OpenAI兼容的Gemini API分析图片 + + Args: + images: 图片列表 + prompt: 分析提示词 + batch_size: 批处理大小 + **kwargs: 其他参数 + + Returns: + 分析结果列表 + """ + logger.info(f"开始分析 {len(images)} 张图片,使用OpenAI兼容Gemini代理") + + # 预处理图片 + processed_images = self._prepare_images(images) + + # 分批处理 + results = [] + for i in range(0, len(processed_images), batch_size): + batch = processed_images[i:i + batch_size] + logger.info(f"处理第 {i//batch_size + 1} 批,共 {len(batch)} 张图片") + + try: + result = await self._analyze_batch(batch, prompt) + results.append(result) + except Exception as e: + logger.error(f"批次 {i//batch_size + 1} 处理失败: {str(e)}") + results.append(f"批次处理失败: {str(e)}") + + return results + + async def _analyze_batch(self, batch: List[PIL.Image.Image], prompt: str) -> str: + """分析一批图片""" + # 构建OpenAI格式的消息内容 + content = [{"type": "text", "text": prompt}] + + # 添加图片 + for img in batch: + base64_image = self._image_to_base64(img) + content.append({ + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}" + } + }) + + # 构建消息 + messages = [{ + "role": "user", + "content": content + }] + + # 调用API + response = await asyncio.to_thread( + self.client.chat.completions.create, + model=self.model_name, + messages=messages, + max_tokens=4000, + temperature=1.0 + ) + + if response.choices and len(response.choices) > 0: + return response.choices[0].message.content + else: + raise APICallError("OpenAI兼容Gemini API返回空响应") + + def _image_to_base64(self, img: PIL.Image.Image) -> str: + """将PIL图片转换为base64编码""" + img_buffer = io.BytesIO() + img.save(img_buffer, format='JPEG', quality=85) + img_bytes = img_buffer.getvalue() + return base64.b64encode(img_bytes).decode('utf-8') + + async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """执行API调用 - 由于使用OpenAI SDK,这个方法主要用于兼容基类""" + pass + + +class GeminiOpenAITextProvider(TextModelProvider): + """OpenAI兼容的Gemini文本生成提供商""" + + @property + def provider_name(self) -> str: + return "gemini(openai)" + + @property + def supported_models(self) -> List[str]: + return [ + "gemini-2.0-flash-lite", + "gemini-2.0-flash", + "gemini-1.5-pro", + "gemini-1.5-flash" + ] + + def _initialize(self): + """初始化OpenAI兼容的Gemini客户端""" + if not self.base_url: + self.base_url = "https://generativelanguage.googleapis.com/v1beta/openai" + + self.client = OpenAI( + api_key=self.api_key, + base_url=self.base_url + ) + + async def generate_text(self, + prompt: str, + system_prompt: Optional[str] = None, + temperature: float = 1.0, + max_tokens: Optional[int] = None, + response_format: Optional[str] = None, + **kwargs) -> str: + """ + 使用OpenAI兼容的Gemini API生成文本 + + Args: + prompt: 用户提示词 + system_prompt: 系统提示词 + temperature: 生成温度 + max_tokens: 最大token数 + response_format: 响应格式 ('json' 或 None) + **kwargs: 其他参数 + + Returns: + 生成的文本内容 + """ + # 构建消息列表 + messages = self._build_messages(prompt, system_prompt) + + # 构建请求参数 + request_params = { + "model": self.model_name, + "messages": messages, + "temperature": temperature + } + + if max_tokens: + request_params["max_tokens"] = max_tokens + + # 处理JSON格式输出 - Gemini通过OpenAI接口可能不完全支持response_format + if response_format == "json": + # 在提示词中添加JSON格式约束 + messages[-1]["content"] += "\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。" + + try: + # 发送API请求 + response = await asyncio.to_thread( + self.client.chat.completions.create, + **request_params + ) + + # 提取生成的内容 + if response.choices and len(response.choices) > 0: + content = response.choices[0].message.content + + # 对于JSON格式,清理输出 + if response_format == "json": + content = self._clean_json_output(content) + + logger.debug(f"OpenAI兼容Gemini API调用成功,消耗tokens: {response.usage.total_tokens if response.usage else 'N/A'}") + return content + else: + raise APICallError("OpenAI兼容Gemini API返回空响应") + + except Exception as e: + logger.error(f"OpenAI兼容Gemini API调用失败: {str(e)}") + raise APICallError(f"OpenAI兼容Gemini API调用失败: {str(e)}") + + def _clean_json_output(self, output: str) -> str: + """清理JSON输出,移除markdown标记等""" + import re + + # 移除可能的markdown代码块标记 + output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE) + output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE) + output = re.sub(r'^```.*$', '', output, flags=re.MULTILINE) + + # 移除前后空白字符 + output = output.strip() + + return output + + async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """执行API调用 - 由于使用OpenAI SDK,这个方法主要用于兼容基类""" + pass diff --git a/app/services/llm/providers/gemini_provider.py b/app/services/llm/providers/gemini_provider.py new file mode 100644 index 0000000..bba4fce --- /dev/null +++ b/app/services/llm/providers/gemini_provider.py @@ -0,0 +1,346 @@ +""" +原生Gemini API提供商实现 + +使用Google原生Gemini API进行视觉分析和文本生成 +""" + +import asyncio +import base64 +import io +import requests +from typing import List, Dict, Any, Optional, Union +from pathlib import Path +import PIL.Image +from loguru import logger + +from ..base import VisionModelProvider, TextModelProvider +from ..exceptions import APICallError, ContentFilterError + + +class GeminiVisionProvider(VisionModelProvider): + """原生Gemini视觉模型提供商""" + + @property + def provider_name(self) -> str: + return "gemini" + + @property + def supported_models(self) -> List[str]: + return [ + "gemini-2.0-flash-lite", + "gemini-2.0-flash", + "gemini-1.5-pro", + "gemini-1.5-flash" + ] + + def _initialize(self): + """初始化Gemini特定设置""" + if not self.base_url: + self.base_url = "https://generativelanguage.googleapis.com/v1beta" + + async def analyze_images(self, + images: List[Union[str, Path, PIL.Image.Image]], + prompt: str, + batch_size: int = 10, + **kwargs) -> List[str]: + """ + 使用原生Gemini API分析图片 + + Args: + images: 图片列表 + prompt: 分析提示词 + batch_size: 批处理大小 + **kwargs: 其他参数 + + Returns: + 分析结果列表 + """ + logger.info(f"开始分析 {len(images)} 张图片,使用原生Gemini API") + + # 预处理图片 + processed_images = self._prepare_images(images) + + # 分批处理 + results = [] + for i in range(0, len(processed_images), batch_size): + batch = processed_images[i:i + batch_size] + logger.info(f"处理第 {i//batch_size + 1} 批,共 {len(batch)} 张图片") + + try: + result = await self._analyze_batch(batch, prompt) + results.append(result) + except Exception as e: + logger.error(f"批次 {i//batch_size + 1} 处理失败: {str(e)}") + results.append(f"批次处理失败: {str(e)}") + + return results + + async def _analyze_batch(self, batch: List[PIL.Image.Image], prompt: str) -> str: + """分析一批图片""" + # 构建请求数据 + parts = [{"text": prompt}] + + # 添加图片数据 + for img in batch: + img_data = self._image_to_base64(img) + parts.append({ + "inline_data": { + "mime_type": "image/jpeg", + "data": img_data + } + }) + + payload = { + "systemInstruction": { + "parts": [{"text": "你是一位专业的视觉内容分析师,请仔细分析图片内容并提供详细描述。"}] + }, + "contents": [{"parts": parts}], + "generationConfig": { + "temperature": 1.0, + "topK": 40, + "topP": 0.95, + "maxOutputTokens": 4000, + "candidateCount": 1 + }, + "safetySettings": [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_NONE" + } + ] + } + + # 发送API请求 + response_data = await self._make_api_call(payload) + + # 解析响应 + return self._parse_vision_response(response_data) + + def _image_to_base64(self, img: PIL.Image.Image) -> str: + """将PIL图片转换为base64编码""" + img_buffer = io.BytesIO() + img.save(img_buffer, format='JPEG', quality=85) + img_bytes = img_buffer.getvalue() + return base64.b64encode(img_bytes).decode('utf-8') + + async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """执行原生Gemini API调用""" + url = f"{self.base_url}/models/{self.model_name}:generateContent?key={self.api_key}" + + response = await asyncio.to_thread( + requests.post, + url, + json=payload, + headers={ + "Content-Type": "application/json", + "User-Agent": "NarratoAI/1.0" + }, + timeout=120 + ) + + if response.status_code != 200: + error = self._handle_api_error(response.status_code, response.text) + raise error + + return response.json() + + def _parse_vision_response(self, response_data: Dict[str, Any]) -> str: + """解析视觉分析响应""" + if "candidates" not in response_data or not response_data["candidates"]: + raise APICallError("原生Gemini API返回无效响应") + + candidate = response_data["candidates"][0] + + # 检查是否被安全过滤阻止 + if "finishReason" in candidate and candidate["finishReason"] == "SAFETY": + raise ContentFilterError("内容被Gemini安全过滤器阻止") + + if "content" not in candidate or "parts" not in candidate["content"]: + raise APICallError("原生Gemini API返回内容格式错误") + + # 提取文本内容 + result = "" + for part in candidate["content"]["parts"]: + if "text" in part: + result += part["text"] + + if not result.strip(): + raise APICallError("原生Gemini API返回空内容") + + return result + + +class GeminiTextProvider(TextModelProvider): + """原生Gemini文本生成提供商""" + + @property + def provider_name(self) -> str: + return "gemini" + + @property + def supported_models(self) -> List[str]: + return [ + "gemini-2.0-flash-lite", + "gemini-2.0-flash", + "gemini-1.5-pro", + "gemini-1.5-flash" + ] + + def _initialize(self): + """初始化Gemini特定设置""" + if not self.base_url: + self.base_url = "https://generativelanguage.googleapis.com/v1beta" + + async def generate_text(self, + prompt: str, + system_prompt: Optional[str] = None, + temperature: float = 1.0, + max_tokens: Optional[int] = None, + response_format: Optional[str] = None, + **kwargs) -> str: + """ + 使用原生Gemini API生成文本 + + Args: + prompt: 用户提示词 + system_prompt: 系统提示词 + temperature: 生成温度 + max_tokens: 最大token数 + response_format: 响应格式 + **kwargs: 其他参数 + + Returns: + 生成的文本内容 + """ + # 构建请求数据 + payload = { + "contents": [{"parts": [{"text": prompt}]}], + "generationConfig": { + "temperature": temperature, + "topK": 40, + "topP": 0.95, + "maxOutputTokens": max_tokens or 4000, + "candidateCount": 1 + }, + "safetySettings": [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_NONE" + } + ] + } + + # 添加系统提示词 + if system_prompt: + payload["systemInstruction"] = { + "parts": [{"text": system_prompt}] + } + + # 如果需要JSON格式,调整提示词和配置 + if response_format == "json": + # 使用更温和的JSON格式约束 + enhanced_prompt = f"{prompt}\n\n请以JSON格式输出结果。" + payload["contents"][0]["parts"][0]["text"] = enhanced_prompt + # 移除可能导致问题的stopSequences + # payload["generationConfig"]["stopSequences"] = ["```", "注意", "说明"] + + # 记录请求信息 + logger.debug(f"Gemini文本生成请求: {payload}") + + # 发送API请求 + response_data = await self._make_api_call(payload) + + # 解析响应 + return self._parse_text_response(response_data) + + async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """执行原生Gemini API调用""" + url = f"{self.base_url}/models/{self.model_name}:generateContent?key={self.api_key}" + + response = await asyncio.to_thread( + requests.post, + url, + json=payload, + headers={ + "Content-Type": "application/json", + "User-Agent": "NarratoAI/1.0" + }, + timeout=120 + ) + + if response.status_code != 200: + error = self._handle_api_error(response.status_code, response.text) + raise error + + return response.json() + + def _parse_text_response(self, response_data: Dict[str, Any]) -> str: + """解析文本生成响应""" + logger.debug(f"Gemini API响应数据: {response_data}") + + if "candidates" not in response_data or not response_data["candidates"]: + logger.error(f"Gemini API返回无效响应结构: {response_data}") + raise APICallError("原生Gemini API返回无效响应") + + candidate = response_data["candidates"][0] + logger.debug(f"Gemini候选响应: {candidate}") + + # 检查完成原因 + finish_reason = candidate.get("finishReason", "UNKNOWN") + logger.debug(f"Gemini完成原因: {finish_reason}") + + # 检查是否被安全过滤阻止 + if finish_reason == "SAFETY": + safety_ratings = candidate.get("safetyRatings", []) + logger.warning(f"内容被Gemini安全过滤器阻止,安全评级: {safety_ratings}") + raise ContentFilterError("内容被Gemini安全过滤器阻止") + + # 检查是否因为其他原因停止 + if finish_reason in ["RECITATION", "OTHER"]: + logger.warning(f"Gemini因为{finish_reason}原因停止生成") + raise APICallError(f"Gemini因为{finish_reason}原因停止生成") + + if "content" not in candidate: + logger.error(f"Gemini候选响应中缺少content字段: {candidate}") + raise APICallError("原生Gemini API返回内容格式错误") + + if "parts" not in candidate["content"]: + logger.error(f"Gemini内容中缺少parts字段: {candidate['content']}") + raise APICallError("原生Gemini API返回内容格式错误") + + # 提取文本内容 + result = "" + for part in candidate["content"]["parts"]: + if "text" in part: + result += part["text"] + + if not result.strip(): + logger.error(f"Gemini API返回空文本内容,完整响应: {response_data}") + raise APICallError("原生Gemini API返回空内容") + + logger.debug(f"Gemini成功生成内容,长度: {len(result)}") + return result diff --git a/app/services/llm/providers/openai_provider.py b/app/services/llm/providers/openai_provider.py new file mode 100644 index 0000000..f700f83 --- /dev/null +++ b/app/services/llm/providers/openai_provider.py @@ -0,0 +1,168 @@ +""" +OpenAI API提供商实现 + +使用OpenAI API进行文本生成,也支持OpenAI兼容的其他服务 +""" + +import asyncio +from typing import List, Dict, Any, Optional +from openai import OpenAI, BadRequestError +from loguru import logger + +from ..base import TextModelProvider +from ..exceptions import APICallError, RateLimitError, AuthenticationError + + +class OpenAITextProvider(TextModelProvider): + """OpenAI文本生成提供商""" + + @property + def provider_name(self) -> str: + return "openai" + + @property + def supported_models(self) -> List[str]: + return [ + "gpt-4o", + "gpt-4o-mini", + "gpt-4-turbo", + "gpt-4", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + # 支持其他OpenAI兼容模型 + "deepseek-chat", + "deepseek-reasoner", + "qwen-plus", + "qwen-turbo", + "moonshot-v1-8k", + "moonshot-v1-32k", + "moonshot-v1-128k" + ] + + def _initialize(self): + """初始化OpenAI客户端""" + if not self.base_url: + self.base_url = "https://api.openai.com/v1" + + self.client = OpenAI( + api_key=self.api_key, + base_url=self.base_url + ) + + async def generate_text(self, + prompt: str, + system_prompt: Optional[str] = None, + temperature: float = 1.0, + max_tokens: Optional[int] = None, + response_format: Optional[str] = None, + **kwargs) -> str: + """ + 使用OpenAI API生成文本 + + Args: + prompt: 用户提示词 + system_prompt: 系统提示词 + temperature: 生成温度 + max_tokens: 最大token数 + response_format: 响应格式 ('json' 或 None) + **kwargs: 其他参数 + + Returns: + 生成的文本内容 + """ + # 构建消息列表 + messages = self._build_messages(prompt, system_prompt) + + # 构建请求参数 + request_params = { + "model": self.model_name, + "messages": messages, + "temperature": temperature + } + + if max_tokens: + request_params["max_tokens"] = max_tokens + + # 处理JSON格式输出 + if response_format == "json": + # 检查模型是否支持response_format + if self._supports_response_format(): + request_params["response_format"] = {"type": "json_object"} + else: + # 对于不支持response_format的模型,在提示词中添加约束 + messages[-1]["content"] += "\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。" + + try: + # 发送API请求 + response = await asyncio.to_thread( + self.client.chat.completions.create, + **request_params + ) + + # 提取生成的内容 + if response.choices and len(response.choices) > 0: + content = response.choices[0].message.content + + # 对于不支持response_format的模型,清理输出 + if response_format == "json" and not self._supports_response_format(): + content = self._clean_json_output(content) + + logger.debug(f"OpenAI API调用成功,消耗tokens: {response.usage.total_tokens if response.usage else 'N/A'}") + return content + else: + raise APICallError("OpenAI API返回空响应") + + except BadRequestError as e: + # 处理不支持response_format的情况 + if "response_format" in str(e) and response_format == "json": + logger.warning(f"模型 {self.model_name} 不支持response_format,重试不带格式约束的请求") + request_params.pop("response_format", None) + messages[-1]["content"] += "\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。" + + response = await asyncio.to_thread( + self.client.chat.completions.create, + **request_params + ) + + if response.choices and len(response.choices) > 0: + content = response.choices[0].message.content + content = self._clean_json_output(content) + return content + else: + raise APICallError("OpenAI API返回空响应") + else: + raise APICallError(f"OpenAI API请求失败: {str(e)}") + + except Exception as e: + logger.error(f"OpenAI API调用失败: {str(e)}") + raise APICallError(f"OpenAI API调用失败: {str(e)}") + + def _supports_response_format(self) -> bool: + """检查模型是否支持response_format参数""" + # 已知不支持response_format的模型 + unsupported_models = [ + "deepseek-reasoner", + "deepseek-r1" + ] + + return not any(unsupported in self.model_name.lower() for unsupported in unsupported_models) + + def _clean_json_output(self, output: str) -> str: + """清理JSON输出,移除markdown标记等""" + import re + + # 移除可能的markdown代码块标记 + output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE) + output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE) + output = re.sub(r'^```.*$', '', output, flags=re.MULTILINE) + + # 移除前后空白字符 + output = output.strip() + + return output + + async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """执行API调用 - 由于使用OpenAI SDK,这个方法主要用于兼容基类""" + # 这个方法在OpenAI提供商中不直接使用,因为我们使用OpenAI SDK + # 但为了兼容基类接口,保留此方法 + pass diff --git a/app/services/llm/providers/qwen_provider.py b/app/services/llm/providers/qwen_provider.py new file mode 100644 index 0000000..7a71f97 --- /dev/null +++ b/app/services/llm/providers/qwen_provider.py @@ -0,0 +1,247 @@ +""" +通义千问API提供商实现 + +支持通义千问的视觉模型和文本生成模型 +""" + +import asyncio +import base64 +import io +from typing import List, Dict, Any, Optional, Union +from pathlib import Path +import PIL.Image +from openai import OpenAI +from loguru import logger + +from ..base import VisionModelProvider, TextModelProvider +from ..exceptions import APICallError + + +class QwenVisionProvider(VisionModelProvider): + """通义千问视觉模型提供商""" + + @property + def provider_name(self) -> str: + return "qwenvl" + + @property + def supported_models(self) -> List[str]: + return [ + "qwen2.5-vl-32b-instruct", + "qwen2-vl-72b-instruct", + "qwen-vl-max", + "qwen-vl-plus" + ] + + def _initialize(self): + """初始化通义千问客户端""" + if not self.base_url: + self.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1" + + self.client = OpenAI( + api_key=self.api_key, + base_url=self.base_url + ) + + async def analyze_images(self, + images: List[Union[str, Path, PIL.Image.Image]], + prompt: str, + batch_size: int = 10, + **kwargs) -> List[str]: + """ + 使用通义千问VL分析图片 + + Args: + images: 图片列表 + prompt: 分析提示词 + batch_size: 批处理大小 + **kwargs: 其他参数 + + Returns: + 分析结果列表 + """ + logger.info(f"开始分析 {len(images)} 张图片,使用通义千问VL") + + # 预处理图片 + processed_images = self._prepare_images(images) + + # 分批处理 + results = [] + for i in range(0, len(processed_images), batch_size): + batch = processed_images[i:i + batch_size] + logger.info(f"处理第 {i//batch_size + 1} 批,共 {len(batch)} 张图片") + + try: + result = await self._analyze_batch(batch, prompt) + results.append(result) + except Exception as e: + logger.error(f"批次 {i//batch_size + 1} 处理失败: {str(e)}") + results.append(f"批次处理失败: {str(e)}") + + return results + + async def _analyze_batch(self, batch: List[PIL.Image.Image], prompt: str) -> str: + """分析一批图片""" + # 构建消息内容 + content = [] + + # 添加图片 + for img in batch: + base64_image = self._image_to_base64(img) + content.append({ + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}" + } + }) + + # 添加文本提示,使用占位符来引用图片数量 + content.append({ + "type": "text", + "text": prompt % (len(batch), len(batch), len(batch)) + }) + + # 构建消息 + messages = [{ + "role": "user", + "content": content + }] + + # 调用API + response = await asyncio.to_thread( + self.client.chat.completions.create, + model=self.model_name, + messages=messages + ) + + if response.choices and len(response.choices) > 0: + return response.choices[0].message.content + else: + raise APICallError("通义千问VL API返回空响应") + + def _image_to_base64(self, img: PIL.Image.Image) -> str: + """将PIL图片转换为base64编码""" + img_buffer = io.BytesIO() + img.save(img_buffer, format='JPEG', quality=85) + img_bytes = img_buffer.getvalue() + return base64.b64encode(img_bytes).decode('utf-8') + + async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """执行API调用 - 由于使用OpenAI SDK,这个方法主要用于兼容基类""" + pass + + +class QwenTextProvider(TextModelProvider): + """通义千问文本生成提供商""" + + @property + def provider_name(self) -> str: + return "qwen" + + @property + def supported_models(self) -> List[str]: + return [ + "qwen-plus-1127", + "qwen-plus", + "qwen-turbo", + "qwen-max", + "qwen2.5-72b-instruct", + "qwen2.5-32b-instruct", + "qwen2.5-14b-instruct", + "qwen2.5-7b-instruct" + ] + + def _initialize(self): + """初始化通义千问客户端""" + if not self.base_url: + self.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1" + + self.client = OpenAI( + api_key=self.api_key, + base_url=self.base_url + ) + + async def generate_text(self, + prompt: str, + system_prompt: Optional[str] = None, + temperature: float = 1.0, + max_tokens: Optional[int] = None, + response_format: Optional[str] = None, + **kwargs) -> str: + """ + 使用通义千问API生成文本 + + Args: + prompt: 用户提示词 + system_prompt: 系统提示词 + temperature: 生成温度 + max_tokens: 最大token数 + response_format: 响应格式 ('json' 或 None) + **kwargs: 其他参数 + + Returns: + 生成的文本内容 + """ + # 构建消息列表 + messages = self._build_messages(prompt, system_prompt) + + # 构建请求参数 + request_params = { + "model": self.model_name, + "messages": messages, + "temperature": temperature + } + + if max_tokens: + request_params["max_tokens"] = max_tokens + + # 处理JSON格式输出 + if response_format == "json": + # 通义千问支持response_format + try: + request_params["response_format"] = {"type": "json_object"} + except: + # 如果不支持,在提示词中添加约束 + messages[-1]["content"] += "\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。" + + try: + # 发送API请求 + response = await asyncio.to_thread( + self.client.chat.completions.create, + **request_params + ) + + # 提取生成的内容 + if response.choices and len(response.choices) > 0: + content = response.choices[0].message.content + + # 对于JSON格式,清理输出 + if response_format == "json" and "response_format" not in request_params: + content = self._clean_json_output(content) + + logger.debug(f"通义千问API调用成功,消耗tokens: {response.usage.total_tokens if response.usage else 'N/A'}") + return content + else: + raise APICallError("通义千问API返回空响应") + + except Exception as e: + logger.error(f"通义千问API调用失败: {str(e)}") + raise APICallError(f"通义千问API调用失败: {str(e)}") + + def _clean_json_output(self, output: str) -> str: + """清理JSON输出,移除markdown标记等""" + import re + + # 移除可能的markdown代码块标记 + output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE) + output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE) + output = re.sub(r'^```.*$', '', output, flags=re.MULTILINE) + + # 移除前后空白字符 + output = output.strip() + + return output + + async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """执行API调用 - 由于使用OpenAI SDK,这个方法主要用于兼容基类""" + pass diff --git a/app/services/llm/providers/siliconflow_provider.py b/app/services/llm/providers/siliconflow_provider.py new file mode 100644 index 0000000..948be3a --- /dev/null +++ b/app/services/llm/providers/siliconflow_provider.py @@ -0,0 +1,251 @@ +""" +硅基流动API提供商实现 + +支持硅基流动的视觉模型和文本生成模型 +""" + +import asyncio +import base64 +import io +from typing import List, Dict, Any, Optional, Union +from pathlib import Path +import PIL.Image +from openai import OpenAI +from loguru import logger + +from ..base import VisionModelProvider, TextModelProvider +from ..exceptions import APICallError + + +class SiliconflowVisionProvider(VisionModelProvider): + """硅基流动视觉模型提供商""" + + @property + def provider_name(self) -> str: + return "siliconflow" + + @property + def supported_models(self) -> List[str]: + return [ + "Qwen/Qwen2.5-VL-32B-Instruct", + "Qwen/Qwen2-VL-72B-Instruct", + "deepseek-ai/deepseek-vl2", + "OpenGVLab/InternVL2-26B" + ] + + def _initialize(self): + """初始化硅基流动客户端""" + if not self.base_url: + self.base_url = "https://api.siliconflow.cn/v1" + + self.client = OpenAI( + api_key=self.api_key, + base_url=self.base_url + ) + + async def analyze_images(self, + images: List[Union[str, Path, PIL.Image.Image]], + prompt: str, + batch_size: int = 10, + **kwargs) -> List[str]: + """ + 使用硅基流动API分析图片 + + Args: + images: 图片列表 + prompt: 分析提示词 + batch_size: 批处理大小 + **kwargs: 其他参数 + + Returns: + 分析结果列表 + """ + logger.info(f"开始分析 {len(images)} 张图片,使用硅基流动") + + # 预处理图片 + processed_images = self._prepare_images(images) + + # 分批处理 + results = [] + for i in range(0, len(processed_images), batch_size): + batch = processed_images[i:i + batch_size] + logger.info(f"处理第 {i//batch_size + 1} 批,共 {len(batch)} 张图片") + + try: + result = await self._analyze_batch(batch, prompt) + results.append(result) + except Exception as e: + logger.error(f"批次 {i//batch_size + 1} 处理失败: {str(e)}") + results.append(f"批次处理失败: {str(e)}") + + return results + + async def _analyze_batch(self, batch: List[PIL.Image.Image], prompt: str) -> str: + """分析一批图片""" + # 构建消息内容 + content = [{"type": "text", "text": prompt}] + + # 添加图片 + for img in batch: + base64_image = self._image_to_base64(img) + content.append({ + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}" + } + }) + + # 构建消息 + messages = [{ + "role": "user", + "content": content + }] + + # 调用API + response = await asyncio.to_thread( + self.client.chat.completions.create, + model=self.model_name, + messages=messages, + max_tokens=4000, + temperature=1.0 + ) + + if response.choices and len(response.choices) > 0: + return response.choices[0].message.content + else: + raise APICallError("硅基流动API返回空响应") + + def _image_to_base64(self, img: PIL.Image.Image) -> str: + """将PIL图片转换为base64编码""" + img_buffer = io.BytesIO() + img.save(img_buffer, format='JPEG', quality=85) + img_bytes = img_buffer.getvalue() + return base64.b64encode(img_bytes).decode('utf-8') + + async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """执行API调用 - 由于使用OpenAI SDK,这个方法主要用于兼容基类""" + pass + + +class SiliconflowTextProvider(TextModelProvider): + """硅基流动文本生成提供商""" + + @property + def provider_name(self) -> str: + return "siliconflow" + + @property + def supported_models(self) -> List[str]: + return [ + "deepseek-ai/DeepSeek-R1", + "deepseek-ai/DeepSeek-V3", + "Qwen/Qwen2.5-72B-Instruct", + "Qwen/Qwen2.5-32B-Instruct", + "meta-llama/Llama-3.1-70B-Instruct", + "meta-llama/Llama-3.1-8B-Instruct", + "01-ai/Yi-1.5-34B-Chat" + ] + + def _initialize(self): + """初始化硅基流动客户端""" + if not self.base_url: + self.base_url = "https://api.siliconflow.cn/v1" + + self.client = OpenAI( + api_key=self.api_key, + base_url=self.base_url + ) + + async def generate_text(self, + prompt: str, + system_prompt: Optional[str] = None, + temperature: float = 1.0, + max_tokens: Optional[int] = None, + response_format: Optional[str] = None, + **kwargs) -> str: + """ + 使用硅基流动API生成文本 + + Args: + prompt: 用户提示词 + system_prompt: 系统提示词 + temperature: 生成温度 + max_tokens: 最大token数 + response_format: 响应格式 ('json' 或 None) + **kwargs: 其他参数 + + Returns: + 生成的文本内容 + """ + # 构建消息列表 + messages = self._build_messages(prompt, system_prompt) + + # 构建请求参数 + request_params = { + "model": self.model_name, + "messages": messages, + "temperature": temperature + } + + if max_tokens: + request_params["max_tokens"] = max_tokens + + # 处理JSON格式输出 + if response_format == "json": + if self._supports_response_format(): + request_params["response_format"] = {"type": "json_object"} + else: + # 对于不支持response_format的模型,在提示词中添加约束 + messages[-1]["content"] += "\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。" + + try: + # 发送API请求 + response = await asyncio.to_thread( + self.client.chat.completions.create, + **request_params + ) + + # 提取生成的内容 + if response.choices and len(response.choices) > 0: + content = response.choices[0].message.content + + # 对于不支持response_format的模型,清理输出 + if response_format == "json" and not self._supports_response_format(): + content = self._clean_json_output(content) + + logger.debug(f"硅基流动API调用成功,消耗tokens: {response.usage.total_tokens if response.usage else 'N/A'}") + return content + else: + raise APICallError("硅基流动API返回空响应") + + except Exception as e: + logger.error(f"硅基流动API调用失败: {str(e)}") + raise APICallError(f"硅基流动API调用失败: {str(e)}") + + def _supports_response_format(self) -> bool: + """检查模型是否支持response_format参数""" + # DeepSeek R1 和 V3 不支持 response_format=json_object + unsupported_models = [ + "deepseek-ai/deepseek-r1", + "deepseek-ai/deepseek-v3" + ] + + return not any(unsupported in self.model_name.lower() for unsupported in unsupported_models) + + def _clean_json_output(self, output: str) -> str: + """清理JSON输出,移除markdown标记等""" + import re + + # 移除可能的markdown代码块标记 + output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE) + output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE) + output = re.sub(r'^```.*$', '', output, flags=re.MULTILINE) + + # 移除前后空白字符 + output = output.strip() + + return output + + async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """执行API调用 - 由于使用OpenAI SDK,这个方法主要用于兼容基类""" + pass diff --git a/app/services/llm/test_llm_service.py b/app/services/llm/test_llm_service.py new file mode 100644 index 0000000..fe03ca5 --- /dev/null +++ b/app/services/llm/test_llm_service.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +""" +LLM服务测试脚本 + +测试新的LLM服务架构是否正常工作 +""" + +import asyncio +import sys +import os +from pathlib import Path +from loguru import logger + +# 添加项目根目录到Python路径 +project_root = Path(__file__).parent.parent.parent.parent +sys.path.insert(0, str(project_root)) + +from app.services.llm.config_validator import LLMConfigValidator +from app.services.llm.unified_service import UnifiedLLMService +from app.services.llm.exceptions import LLMServiceError + + +async def test_text_generation(): + """测试文本生成功能""" + print("\n🔤 测试文本生成功能...") + + try: + # 简单的文本生成测试 + prompt = "请用一句话介绍人工智能。" + + result = await UnifiedLLMService.generate_text( + prompt=prompt, + system_prompt="你是一个专业的AI助手。", + temperature=0.7 + ) + + print(f"✅ 文本生成成功:") + print(f" 提示词: {prompt}") + print(f" 生成结果: {result[:100]}...") + + return True + + except Exception as e: + print(f"❌ 文本生成失败: {str(e)}") + return False + + +async def test_json_generation(): + """测试JSON格式生成功能""" + print("\n📄 测试JSON格式生成功能...") + + try: + prompt = """ +请生成一个简单的解说文案示例,包含以下字段: +- title: 标题 +- content: 内容 +- duration: 时长(秒) + +输出JSON格式。 +""" + + result = await UnifiedLLMService.generate_text( + prompt=prompt, + system_prompt="你是一个专业的文案撰写专家。", + temperature=0.7, + response_format="json" + ) + + # 尝试解析JSON + import json + parsed_result = json.loads(result) + + print(f"✅ JSON生成成功:") + print(f" 生成结果: {json.dumps(parsed_result, ensure_ascii=False, indent=2)}") + + return True + + except json.JSONDecodeError as e: + print(f"❌ JSON解析失败: {str(e)}") + print(f" 原始结果: {result}") + return False + except Exception as e: + print(f"❌ JSON生成失败: {str(e)}") + return False + + +async def test_narration_script_generation(): + """测试解说文案生成功能""" + print("\n🎬 测试解说文案生成功能...") + + try: + prompt = """ +根据以下视频描述生成解说文案: + +视频内容:一个人在森林中建造木屋,首先挖掘地基,然后搭建墙壁,最后安装屋顶。 + +请生成JSON格式的解说文案,包含items数组,每个item包含: +- _id: 序号 +- timestamp: 时间戳(格式:HH:MM:SS,mmm-HH:MM:SS,mmm) +- picture: 画面描述 +- narration: 解说文案 +""" + + result = await UnifiedLLMService.generate_narration_script( + prompt=prompt, + temperature=0.8, + validate_output=True + ) + + print(f"✅ 解说文案生成成功:") + print(f" 生成了 {len(result)} 个片段") + for item in result[:2]: # 只显示前2个 + print(f" - {item.get('timestamp', 'N/A')}: {item.get('narration', 'N/A')[:50]}...") + + return True + + except Exception as e: + print(f"❌ 解说文案生成失败: {str(e)}") + return False + + +async def test_subtitle_analysis(): + """测试字幕分析功能""" + print("\n📝 测试字幕分析功能...") + + try: + subtitle_content = """ +1 +00:00:01,000 --> 00:00:05,000 +大家好,欢迎来到我的频道。 + +2 +00:00:05,000 --> 00:00:10,000 +今天我们要学习如何使用人工智能。 + +3 +00:00:10,000 --> 00:00:15,000 +人工智能是一项非常有趣的技术。 +""" + + result = await UnifiedLLMService.analyze_subtitle( + subtitle_content=subtitle_content, + temperature=0.7, + validate_output=True + ) + + print(f"✅ 字幕分析成功:") + print(f" 分析结果: {result[:100]}...") + + return True + + except Exception as e: + print(f"❌ 字幕分析失败: {str(e)}") + return False + + +def test_config_validation(): + """测试配置验证功能""" + print("\n⚙️ 测试配置验证功能...") + + try: + # 验证所有配置 + validation_results = LLMConfigValidator.validate_all_configs() + + summary = validation_results["summary"] + print(f"✅ 配置验证完成:") + print(f" 视觉模型提供商: {summary['valid_vision_providers']}/{summary['total_vision_providers']} 有效") + print(f" 文本模型提供商: {summary['valid_text_providers']}/{summary['total_text_providers']} 有效") + + if summary["errors"]: + print(f" 发现 {len(summary['errors'])} 个错误") + for error in summary["errors"][:3]: # 只显示前3个错误 + print(f" - {error}") + + return summary['valid_text_providers'] > 0 + + except Exception as e: + print(f"❌ 配置验证失败: {str(e)}") + return False + + +def test_provider_info(): + """测试提供商信息获取""" + print("\n📋 测试提供商信息获取...") + + try: + provider_info = UnifiedLLMService.get_provider_info() + + vision_providers = list(provider_info["vision_providers"].keys()) + text_providers = list(provider_info["text_providers"].keys()) + + print(f"✅ 提供商信息获取成功:") + print(f" 视觉模型提供商: {', '.join(vision_providers)}") + print(f" 文本模型提供商: {', '.join(text_providers)}") + + return True + + except Exception as e: + print(f"❌ 提供商信息获取失败: {str(e)}") + return False + + +async def run_all_tests(): + """运行所有测试""" + print("🚀 开始LLM服务测试...") + print("="*60) + + # 测试结果统计 + test_results = [] + + # 1. 测试配置验证 + test_results.append(("配置验证", test_config_validation())) + + # 2. 测试提供商信息 + test_results.append(("提供商信息", test_provider_info())) + + # 3. 测试文本生成 + test_results.append(("文本生成", await test_text_generation())) + + # 4. 测试JSON生成 + test_results.append(("JSON生成", await test_json_generation())) + + # 5. 测试字幕分析 + test_results.append(("字幕分析", await test_subtitle_analysis())) + + # 6. 测试解说文案生成 + test_results.append(("解说文案生成", await test_narration_script_generation())) + + # 输出测试结果 + print("\n" + "="*60) + print("📊 测试结果汇总:") + print("="*60) + + passed = 0 + total = len(test_results) + + for test_name, result in test_results: + status = "✅ 通过" if result else "❌ 失败" + print(f" {test_name:<15} {status}") + if result: + passed += 1 + + print(f"\n总计: {passed}/{total} 个测试通过") + + if passed == total: + print("🎉 所有测试通过!LLM服务工作正常。") + elif passed > 0: + print("⚠️ 部分测试通过,请检查失败的测试项。") + else: + print("💥 所有测试失败,请检查配置和网络连接。") + + print("="*60) + + +if __name__ == "__main__": + # 设置日志级别 + logger.remove() + logger.add(sys.stderr, level="INFO") + + # 运行测试 + asyncio.run(run_all_tests()) diff --git a/app/services/llm/unified_service.py b/app/services/llm/unified_service.py new file mode 100644 index 0000000..0d04ee0 --- /dev/null +++ b/app/services/llm/unified_service.py @@ -0,0 +1,274 @@ +""" +统一的大模型服务接口 + +提供简化的API接口,方便现有代码迁移到新的架构 +""" + +from typing import List, Dict, Any, Optional, Union +from pathlib import Path +import PIL.Image +from loguru import logger + +from .manager import LLMServiceManager +from .validators import OutputValidator +from .exceptions import LLMServiceError + +# 确保提供商已注册 +def _ensure_providers_registered(): + """确保所有提供商都已注册""" + try: + # 检查是否有已注册的提供商 + if not LLMServiceManager.list_text_providers() or not LLMServiceManager.list_vision_providers(): + # 如果没有注册的提供商,强制导入providers模块 + from . import providers + logger.debug("强制注册LLM服务提供商") + except Exception as e: + logger.error(f"确保LLM服务提供商注册时发生错误: {str(e)}") + +# 在模块加载时确保提供商已注册 +_ensure_providers_registered() + + +class UnifiedLLMService: + """统一的大模型服务接口""" + + @staticmethod + async def analyze_images(images: List[Union[str, Path, PIL.Image.Image]], + prompt: str, + provider: Optional[str] = None, + batch_size: int = 10, + **kwargs) -> List[str]: + """ + 分析图片内容 + + Args: + images: 图片路径列表或PIL图片对象列表 + prompt: 分析提示词 + provider: 视觉模型提供商名称,如果不指定则使用配置中的默认值 + batch_size: 批处理大小 + **kwargs: 其他参数 + + Returns: + 分析结果列表 + + Raises: + LLMServiceError: 服务调用失败时抛出 + """ + try: + # 获取视觉模型提供商 + vision_provider = LLMServiceManager.get_vision_provider(provider) + + # 执行图片分析 + results = await vision_provider.analyze_images( + images=images, + prompt=prompt, + batch_size=batch_size, + **kwargs + ) + + logger.info(f"图片分析完成,共处理 {len(images)} 张图片,生成 {len(results)} 个结果") + return results + + except Exception as e: + logger.error(f"图片分析失败: {str(e)}") + raise LLMServiceError(f"图片分析失败: {str(e)}") + + @staticmethod + async def generate_text(prompt: str, + system_prompt: Optional[str] = None, + provider: Optional[str] = None, + temperature: float = 1.0, + max_tokens: Optional[int] = None, + response_format: Optional[str] = None, + **kwargs) -> str: + """ + 生成文本内容 + + Args: + prompt: 用户提示词 + system_prompt: 系统提示词 + provider: 文本模型提供商名称,如果不指定则使用配置中的默认值 + temperature: 生成温度 + max_tokens: 最大token数 + response_format: 响应格式 ('json' 或 None) + **kwargs: 其他参数 + + Returns: + 生成的文本内容 + + Raises: + LLMServiceError: 服务调用失败时抛出 + """ + try: + # 获取文本模型提供商 + text_provider = LLMServiceManager.get_text_provider(provider) + + # 执行文本生成 + result = await text_provider.generate_text( + prompt=prompt, + system_prompt=system_prompt, + temperature=temperature, + max_tokens=max_tokens, + response_format=response_format, + **kwargs + ) + + logger.info(f"文本生成完成,生成内容长度: {len(result)} 字符") + return result + + except Exception as e: + logger.error(f"文本生成失败: {str(e)}") + raise LLMServiceError(f"文本生成失败: {str(e)}") + + @staticmethod + async def generate_narration_script(prompt: str, + provider: Optional[str] = None, + temperature: float = 1.0, + validate_output: bool = True, + **kwargs) -> List[Dict[str, Any]]: + """ + 生成解说文案 + + Args: + prompt: 提示词 + provider: 文本模型提供商名称 + temperature: 生成温度 + validate_output: 是否验证输出格式 + **kwargs: 其他参数 + + Returns: + 解说文案列表 + + Raises: + LLMServiceError: 服务调用失败时抛出 + """ + try: + # 生成文本 + result = await UnifiedLLMService.generate_text( + prompt=prompt, + provider=provider, + temperature=temperature, + response_format="json", + **kwargs + ) + + # 验证输出格式 + if validate_output: + narration_items = OutputValidator.validate_narration_script(result) + logger.info(f"解说文案生成并验证完成,共 {len(narration_items)} 个片段") + return narration_items + else: + # 简单的JSON解析 + import json + parsed_result = json.loads(result) + if "items" in parsed_result: + return parsed_result["items"] + else: + return parsed_result + + except Exception as e: + logger.error(f"解说文案生成失败: {str(e)}") + raise LLMServiceError(f"解说文案生成失败: {str(e)}") + + @staticmethod + async def analyze_subtitle(subtitle_content: str, + provider: Optional[str] = None, + temperature: float = 1.0, + validate_output: bool = True, + **kwargs) -> str: + """ + 分析字幕内容 + + Args: + subtitle_content: 字幕内容 + provider: 文本模型提供商名称 + temperature: 生成温度 + validate_output: 是否验证输出格式 + **kwargs: 其他参数 + + Returns: + 分析结果 + + Raises: + LLMServiceError: 服务调用失败时抛出 + """ + try: + # 构建分析提示词 + system_prompt = "你是一位专业的剧本分析师和剧情概括助手。请仔细分析字幕内容,提取关键剧情信息。" + + # 生成分析结果 + result = await UnifiedLLMService.generate_text( + prompt=subtitle_content, + system_prompt=system_prompt, + provider=provider, + temperature=temperature, + **kwargs + ) + + # 验证输出格式 + if validate_output: + validated_result = OutputValidator.validate_subtitle_analysis(result) + logger.info("字幕分析完成并验证通过") + return validated_result + else: + return result + + except Exception as e: + logger.error(f"字幕分析失败: {str(e)}") + raise LLMServiceError(f"字幕分析失败: {str(e)}") + + @staticmethod + def get_provider_info() -> Dict[str, Any]: + """ + 获取所有提供商信息 + + Returns: + 提供商信息字典 + """ + return LLMServiceManager.get_provider_info() + + @staticmethod + def list_vision_providers() -> List[str]: + """ + 列出所有视觉模型提供商 + + Returns: + 提供商名称列表 + """ + return LLMServiceManager.list_vision_providers() + + @staticmethod + def list_text_providers() -> List[str]: + """ + 列出所有文本模型提供商 + + Returns: + 提供商名称列表 + """ + return LLMServiceManager.list_text_providers() + + @staticmethod + def clear_cache(): + """清空提供商实例缓存""" + LLMServiceManager.clear_cache() + logger.info("已清空大模型服务缓存") + + +# 为了向后兼容,提供一些便捷函数 +async def analyze_images_unified(images: List[Union[str, Path, PIL.Image.Image]], + prompt: str, + provider: Optional[str] = None, + batch_size: int = 10) -> List[str]: + """便捷的图片分析函数""" + return await UnifiedLLMService.analyze_images(images, prompt, provider, batch_size) + + +async def generate_text_unified(prompt: str, + system_prompt: Optional[str] = None, + provider: Optional[str] = None, + temperature: float = 1.0, + response_format: Optional[str] = None) -> str: + """便捷的文本生成函数""" + return await UnifiedLLMService.generate_text( + prompt, system_prompt, provider, temperature, response_format=response_format + ) diff --git a/app/services/llm/validators.py b/app/services/llm/validators.py new file mode 100644 index 0000000..1614e14 --- /dev/null +++ b/app/services/llm/validators.py @@ -0,0 +1,200 @@ +""" +输出格式验证器 + +提供严格的输出格式验证机制,确保大模型输出符合预期格式 +""" + +import json +import re +from typing import Any, Dict, List, Optional, Union +from loguru import logger + +from .exceptions import ValidationError + + +class OutputValidator: + """输出格式验证器""" + + @staticmethod + def validate_json_output(output: str, schema: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + """ + 验证JSON输出格式 + + Args: + output: 待验证的输出字符串 + schema: JSON Schema (可选) + + Returns: + 解析后的JSON对象 + + Raises: + ValidationError: 验证失败时抛出 + """ + try: + # 清理输出字符串,移除可能的markdown代码块标记 + cleaned_output = OutputValidator._clean_json_output(output) + + # 解析JSON + parsed_json = json.loads(cleaned_output) + + # 如果提供了schema,进行schema验证 + if schema: + OutputValidator._validate_json_schema(parsed_json, schema) + + return parsed_json + + except json.JSONDecodeError as e: + logger.error(f"JSON解析失败: {str(e)}") + logger.error(f"原始输出: {output}") + raise ValidationError(f"JSON格式无效: {str(e)}", "json_parse", output) + except Exception as e: + logger.error(f"JSON验证失败: {str(e)}") + raise ValidationError(f"JSON验证失败: {str(e)}", "json_validation", output) + + @staticmethod + def _clean_json_output(output: str) -> str: + """清理JSON输出,移除markdown标记等""" + # 移除可能的markdown代码块标记 + output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE) + output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE) + output = re.sub(r'^```.*$', '', output, flags=re.MULTILINE) + + # 移除开头和结尾的```标记 + output = re.sub(r'^```', '', output) + output = re.sub(r'```$', '', output) + + # 移除前后空白字符 + output = output.strip() + + return output + + @staticmethod + def _validate_json_schema(data: Dict[str, Any], schema: Dict[str, Any]): + """验证JSON Schema (简化版本)""" + # 这里可以集成jsonschema库进行更严格的验证 + # 目前实现基础的类型检查 + + if "type" in schema: + expected_type = schema["type"] + if expected_type == "object" and not isinstance(data, dict): + raise ValidationError(f"期望对象类型,实际为 {type(data)}", "schema_type") + elif expected_type == "array" and not isinstance(data, list): + raise ValidationError(f"期望数组类型,实际为 {type(data)}", "schema_type") + + if "required" in schema and isinstance(data, dict): + for required_field in schema["required"]: + if required_field not in data: + raise ValidationError(f"缺少必需字段: {required_field}", "schema_required") + + @staticmethod + def validate_narration_script(output: str) -> List[Dict[str, Any]]: + """ + 验证解说文案输出格式 + + Args: + output: 待验证的解说文案输出 + + Returns: + 解析后的解说文案列表 + + Raises: + ValidationError: 验证失败时抛出 + """ + try: + # 定义解说文案的JSON Schema + narration_schema = { + "type": "object", + "required": ["items"], + "properties": { + "items": { + "type": "array", + "items": { + "type": "object", + "required": ["_id", "timestamp", "picture", "narration"], + "properties": { + "_id": {"type": "number"}, + "timestamp": {"type": "string"}, + "picture": {"type": "string"}, + "narration": {"type": "string"}, + "OST": {"type": "number"} + } + } + } + } + } + + # 验证JSON格式 + parsed_data = OutputValidator.validate_json_output(output, narration_schema) + + # 提取items数组 + items = parsed_data.get("items", []) + + # 验证每个item的具体内容 + for i, item in enumerate(items): + OutputValidator._validate_narration_item(item, i) + + logger.info(f"解说文案验证成功,共 {len(items)} 个片段") + return items + + except ValidationError: + raise + except Exception as e: + logger.error(f"解说文案验证失败: {str(e)}") + raise ValidationError(f"解说文案验证失败: {str(e)}", "narration_validation", output) + + @staticmethod + def _validate_narration_item(item: Dict[str, Any], index: int): + """验证单个解说文案项目""" + # 验证时间戳格式 + timestamp = item.get("timestamp", "") + if not re.match(r'\d{2}:\d{2}:\d{2},\d{3}-\d{2}:\d{2}:\d{2},\d{3}', timestamp): + raise ValidationError(f"第{index+1}项时间戳格式无效: {timestamp}", "timestamp_format") + + # 验证内容不为空 + if not item.get("picture", "").strip(): + raise ValidationError(f"第{index+1}项画面描述不能为空", "empty_picture") + + if not item.get("narration", "").strip(): + raise ValidationError(f"第{index+1}项解说文案不能为空", "empty_narration") + + # 验证ID为正整数 + item_id = item.get("_id") + if not isinstance(item_id, (int, float)) or item_id <= 0: + raise ValidationError(f"第{index+1}项ID必须为正整数: {item_id}", "invalid_id") + + @staticmethod + def validate_subtitle_analysis(output: str) -> str: + """ + 验证字幕分析输出格式 + + Args: + output: 待验证的字幕分析输出 + + Returns: + 验证后的分析内容 + + Raises: + ValidationError: 验证失败时抛出 + """ + try: + # 基础验证:内容不能为空 + if not output or not output.strip(): + raise ValidationError("字幕分析结果不能为空", "empty_analysis") + + # 验证内容长度合理 + if len(output.strip()) < 50: + raise ValidationError("字幕分析结果过短,可能不完整", "analysis_too_short") + + # 验证是否包含基本的分析要素(可根据需要调整) + analysis_keywords = ["剧情", "情节", "角色", "故事", "内容"] + if not any(keyword in output for keyword in analysis_keywords): + logger.warning("字幕分析结果可能缺少关键分析要素") + + logger.info("字幕分析验证成功") + return output.strip() + + except ValidationError: + raise + except Exception as e: + logger.error(f"字幕分析验证失败: {str(e)}") + raise ValidationError(f"字幕分析验证失败: {str(e)}", "analysis_validation", output) diff --git a/app/services/merger_video.py b/app/services/merger_video.py index 75c686b..82ed20e 100644 --- a/app/services/merger_video.py +++ b/app/services/merger_video.py @@ -4,7 +4,7 @@ ''' @Project: NarratoAI @File : merger_video -@Author : 小林同学 +@Author : Viccy同学 @Date : 2025/5/6 下午7:38 ''' diff --git a/app/services/prompts/__init__.py b/app/services/prompts/__init__.py new file mode 100644 index 0000000..3338673 --- /dev/null +++ b/app/services/prompts/__init__.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +""" +@Project: NarratoAI +@File : __init__.py +@Author : viccy同学 +@Date : 2025/1/7 +@Description: 统一提示词管理模块 +""" + +from .manager import PromptManager +from .base import BasePrompt, VisionPrompt, TextPrompt, ParameterizedPrompt +from .registry import PromptRegistry +from .template import TemplateRenderer +from .validators import PromptOutputValidator +from .exceptions import ( + PromptError, + PromptNotFoundError, + PromptValidationError, + TemplateRenderError +) + +# 版本信息 +__version__ = "1.0.0" +__author__ = "viccy同学" + +# 导出的公共接口 +__all__ = [ + # 核心管理器 + "PromptManager", + + # 基础类 + "BasePrompt", + "VisionPrompt", + "TextPrompt", + "ParameterizedPrompt", + + # 工具类 + "PromptRegistry", + "TemplateRenderer", + "PromptOutputValidator", + + # 异常类 + "PromptError", + "PromptNotFoundError", + "PromptValidationError", + "TemplateRenderError", + + # 版本信息 + "__version__", + "__author__" +] + +# 模块初始化 +def initialize_prompts(): + """初始化提示词模块,注册所有提示词""" + from . import documentary + from . import short_drama_editing + from . import short_drama_narration + + # 注册各模块的提示词 + documentary.register_prompts() + short_drama_editing.register_prompts() + short_drama_narration.register_prompts() + +# 自动初始化 +initialize_prompts() diff --git a/app/services/prompts/base.py b/app/services/prompts/base.py new file mode 100644 index 0000000..d19a5a0 --- /dev/null +++ b/app/services/prompts/base.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +""" +@Project: NarratoAI +@File : base.py +@Author : viccy同学 +@Date : 2025/1/7 +@Description: 提示词基础类定义 +""" + +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional, List +from enum import Enum +from dataclasses import dataclass, field +from datetime import datetime + + +class ModelType(Enum): + """模型类型枚举""" + TEXT = "text" # 文本模型 + VISION = "vision" # 视觉模型 + MULTIMODAL = "multimodal" # 多模态模型 + + +class OutputFormat(Enum): + """输出格式枚举""" + TEXT = "text" # 纯文本 + JSON = "json" # JSON格式 + MARKDOWN = "markdown" # Markdown格式 + STRUCTURED = "structured" # 结构化数据 + + +@dataclass +class PromptMetadata: + """提示词元数据""" + name: str # 提示词名称 + category: str # 分类 + version: str # 版本 + description: str # 描述 + model_type: ModelType # 适用的模型类型 + output_format: OutputFormat # 输出格式 + author: str = "viccy同学" # 作者 + created_at: datetime = field(default_factory=datetime.now) # 创建时间 + updated_at: datetime = field(default_factory=datetime.now) # 更新时间 + tags: List[str] = field(default_factory=list) # 标签 + parameters: List[str] = field(default_factory=list) # 支持的参数列表 + + +class BasePrompt(ABC): + """提示词基础类""" + + def __init__(self, metadata: PromptMetadata): + self.metadata = metadata + self._template = None + self._system_prompt = None + self._examples = [] + + @property + def name(self) -> str: + """获取提示词名称""" + return self.metadata.name + + @property + def category(self) -> str: + """获取提示词分类""" + return self.metadata.category + + @property + def version(self) -> str: + """获取提示词版本""" + return self.metadata.version + + @property + def model_type(self) -> ModelType: + """获取适用的模型类型""" + return self.metadata.model_type + + @property + def output_format(self) -> OutputFormat: + """获取输出格式""" + return self.metadata.output_format + + @abstractmethod + def get_template(self) -> str: + """获取提示词模板""" + pass + + def get_system_prompt(self) -> Optional[str]: + """获取系统提示词""" + return self._system_prompt + + def get_examples(self) -> List[str]: + """获取示例""" + return self._examples.copy() + + def validate_parameters(self, parameters: Dict[str, Any]) -> bool: + """验证参数""" + required_params = set(self.metadata.parameters) + provided_params = set(parameters.keys()) + + missing_params = required_params - provided_params + if missing_params: + from .exceptions import TemplateRenderError + raise TemplateRenderError( + template_name=self.name, + error_message="缺少必需参数", + missing_params=list(missing_params) + ) + return True + + def render(self, parameters: Dict[str, Any] = None) -> str: + """渲染提示词""" + parameters = parameters or {} + + # 验证参数 + if self.metadata.parameters: + self.validate_parameters(parameters) + + # 渲染模板 - 使用自定义的模板渲染器 + template = self.get_template() + try: + from .template import get_renderer + renderer = get_renderer() + return renderer.render(template, parameters) + except Exception as e: + from .exceptions import TemplateRenderError + raise TemplateRenderError( + template_name=self.name, + error_message=f"模板渲染错误: {str(e)}", + missing_params=[] + ) + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return { + "metadata": { + "name": self.metadata.name, + "category": self.metadata.category, + "version": self.metadata.version, + "description": self.metadata.description, + "model_type": self.metadata.model_type.value, + "output_format": self.metadata.output_format.value, + "author": self.metadata.author, + "created_at": self.metadata.created_at.isoformat(), + "updated_at": self.metadata.updated_at.isoformat(), + "tags": self.metadata.tags, + "parameters": self.metadata.parameters + }, + "template": self.get_template(), + "system_prompt": self.get_system_prompt(), + "examples": self.get_examples() + } + + +class TextPrompt(BasePrompt): + """文本模型专用提示词""" + + def __init__(self, metadata: PromptMetadata): + if metadata.model_type not in [ModelType.TEXT, ModelType.MULTIMODAL]: + raise ValueError(f"TextPrompt只支持TEXT或MULTIMODAL模型类型,当前: {metadata.model_type}") + super().__init__(metadata) + + +class VisionPrompt(BasePrompt): + """视觉模型专用提示词""" + + def __init__(self, metadata: PromptMetadata): + if metadata.model_type not in [ModelType.VISION, ModelType.MULTIMODAL]: + raise ValueError(f"VisionPrompt只支持VISION或MULTIMODAL模型类型,当前: {metadata.model_type}") + super().__init__(metadata) + + +class ParameterizedPrompt(BasePrompt): + """支持参数化的提示词""" + + def __init__(self, metadata: PromptMetadata, required_parameters: List[str] = None): + super().__init__(metadata) + if required_parameters: + self.metadata.parameters.extend(required_parameters) + # 去重 + self.metadata.parameters = list(set(self.metadata.parameters)) diff --git a/app/services/prompts/documentary/__init__.py b/app/services/prompts/documentary/__init__.py new file mode 100644 index 0000000..0c5455c --- /dev/null +++ b/app/services/prompts/documentary/__init__.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +""" +@Project: NarratoAI +@File : __init__.py +@Author : viccy同学 +@Date : 2025/1/7 +@Description: 纪录片解说提示词模块 +""" + +from .frame_analysis import FrameAnalysisPrompt +from .narration_generation import NarrationGenerationPrompt +from ..manager import PromptManager + + +def register_prompts(): + """注册纪录片解说相关的提示词""" + + # 注册视频帧分析提示词 + frame_analysis_prompt = FrameAnalysisPrompt() + PromptManager.register_prompt(frame_analysis_prompt, is_default=True) + + # 注册解说文案生成提示词 + narration_prompt = NarrationGenerationPrompt() + PromptManager.register_prompt(narration_prompt, is_default=True) + + +__all__ = [ + "FrameAnalysisPrompt", + "NarrationGenerationPrompt", + "register_prompts" +] diff --git a/app/services/prompts/documentary/frame_analysis.py b/app/services/prompts/documentary/frame_analysis.py new file mode 100644 index 0000000..ad69986 --- /dev/null +++ b/app/services/prompts/documentary/frame_analysis.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +""" +@Project: NarratoAI +@File : frame_analysis.py +@Author : viccy同学 +@Date : 2025/1/7 +@Description: 纪录片视频帧分析提示词 +""" + +from ..base import VisionPrompt, PromptMetadata, ModelType, OutputFormat + + +class FrameAnalysisPrompt(VisionPrompt): + """纪录片视频帧分析提示词""" + + def __init__(self): + metadata = PromptMetadata( + name="frame_analysis", + category="documentary", + version="v1.0", + description="分析纪录片视频关键帧,提取画面内容和场景描述", + model_type=ModelType.VISION, + output_format=OutputFormat.JSON, + tags=["纪录片", "视频分析", "关键帧", "画面描述"], + parameters=["video_theme", "custom_instructions"] + ) + super().__init__(metadata) + + self._system_prompt = "你是一名专业的视频内容分析师,擅长分析纪录片视频帧内容,提取关键信息和场景描述。" + + def get_template(self) -> str: + return """请仔细分析这些视频关键帧图片,我需要你提供详细的画面分析。 + +视频主题:${video_theme} + +分析要求: +1. 按时间顺序分析每一帧画面 +2. 详细描述画面中的主要内容、人物、物体、环境 +3. 注意画面的构图、色彩、光线等视觉元素 +4. 识别画面中的关键动作或变化 +5. 提供准确的时间戳信息 + +${custom_instructions} + +请按照以下JSON格式输出分析结果: + +{ + "analysis": [ + { + "timestamp": "00:00:05,390", + "picture": "详细的画面描述,包括场景、人物、物体、动作等", + "scene_type": "场景类型(如:建造、准备、完成等)", + "key_elements": ["关键元素1", "关键元素2"], + "visual_quality": "画面质量描述(构图、光线、色彩等)" + } + ], + "summary": "整体视频内容概述", + "total_frames": "分析的帧数" +} + +重要要求: +1. 只输出JSON格式,不要添加任何其他文字或代码块标记 +2. 画面描述要详细准确,为后续解说文案生成提供充分信息 +3. 时间戳必须准确对应视频帧 +4. 严禁虚构不存在的内容""" diff --git a/app/services/prompts/documentary/narration_generation.py b/app/services/prompts/documentary/narration_generation.py new file mode 100644 index 0000000..f60af4b --- /dev/null +++ b/app/services/prompts/documentary/narration_generation.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +""" +@Project: NarratoAI +@File : narration_generation.py +@Author : viccy同学 +@Date : 2025/1/7 +@Description: 纪录片解说文案生成提示词 +""" + +from ..base import TextPrompt, PromptMetadata, ModelType, OutputFormat + + +class NarrationGenerationPrompt(TextPrompt): + """纪录片解说文案生成提示词""" + + def __init__(self): + metadata = PromptMetadata( + name="narration_generation", + category="documentary", + version="v1.0", + description="根据视频帧分析结果生成纪录片解说文案,特别适用于荒野建造类内容", + model_type=ModelType.TEXT, + output_format=OutputFormat.JSON, + tags=["纪录片", "解说文案", "荒野建造", "文案生成"], + parameters=["video_frame_description"] + ) + super().__init__(metadata) + + self._system_prompt = "你是一名专业的短视频解说文案撰写专家,擅长创作引人入胜的纪录片解说内容。" + + def get_template(self) -> str: + return """我是一名荒野建造解说的博主,以下是一些同行的对标文案,请你深度学习并总结这些文案的风格特点跟内容特点: + + +解压助眠的天花板就是荒野建造,沉浸丝滑的搭建过程可以说每一帧都是极致享受,我保证强迫症来了都找不出一丁点毛病。更别说全屋严丝合缝的拼接工艺,还能轻松抵御零下二十度气温,让你居住的每一天都温暖如春。 +在家闲不住的西姆今天也打算来一次野外建造,行走没多久他就发现许多倒塌的树,任由它们自生自灭不如将其利用起来。想到这他就开始挥舞铲子要把地基挖掘出来,虽然每次只能挖一点点,但架不住他体能惊人。没多长时间一个 2x3 的深坑就赫然出现,这深度住他一人绰绰有余。 +随后他去附近收集来原木,这些都是搭建墙壁的最好材料。而在投入使用前自然要把表皮刮掉,防止森林中的白蚁蛀虫。处理好一大堆后西姆还在两端打孔,使用木钉固定在一起。这可不是用来做墙壁的,而是做庇护所的承重柱。只要木头间的缝隙足够紧密,那搭建出的木屋就能足够坚固。 +每向上搭建一层,他都会在中间塞入苔藓防寒,保证不会泄露一丝热量。其他几面也是用相同方法,很快西姆就做好了三面墙壁,每一根木头都极其工整,保证强迫症来了都要点个赞再走。 +在继续搭建墙壁前西姆决定将壁炉制作出来,毕竟森林夜晚的气温会很低,保暖措施可是重中之重。完成后他找来一块大树皮用来充当庇护所的大门,而上面刮掉的木屑还能作为壁炉的引火物,可以说再完美不过。 +测试了排烟没问题后他才开始搭建最后一面墙壁,这一面要预留门和窗,所以在搭建到一半后还需要在原木中间开出卡口,让自己劈砍时能轻松许多。此时只需将另外一根如法炮制,两端拼接在一起后就是一扇大小适中的窗户。而随着随后一层苔藓铺好,最后一根原木落位,这个庇护所的雏形就算完成。 + + + +解压助眠的天花板就是荒野建造,沉浸丝滑的搭建过程每一帧都是极致享受,全屋严丝合缝的拼接工艺,能轻松抵御零下二十度气温,居住体验温暖如春。 +在家闲不住的西姆开启野外建造。他发现倒塌的树,决定加以利用。先挖掘出 2x3 的深坑作为地基,接着收集原木,刮掉表皮防白蚁蛀虫,打孔用木钉固定制作承重柱。搭建墙壁时,每一层都塞入苔藓防寒,很快做好三面墙。 +为应对森林夜晚低温,西姆制作壁炉,用大树皮当大门,刮下的木屑做引火物。搭建最后一面墙时预留门窗,通过在原木中间开口拼接做出窗户。大门采用榫卯结构安装,严丝合缝。 +搭建屋顶时,先固定外围原木,再平铺原木形成斜面屋顶,之后用苔藓、黏土密封缝隙,铺上枯叶和泥土。为美观,在木屋覆盖苔藓,移植小树点缀。完工时遇大雨,木屋防水良好。 +西姆利用墙壁凹槽镶嵌床框,铺上苔藓、床单枕头做成床。劳作一天后,他用壁炉烤牛肉享用。建造一星期后,他开始野外露营。 +后来西姆回家补给物资,回来时森林大雪纷飞。他劈柴储备,带回食物、调味料和被褥,提高居住舒适度,还用干草做靠垫。他用壁炉烤牛排,搭配红酒。 +第二天,积雪融化,西姆制作室外篝火堆防野兽。用大树夹缝掰弯木棍堆积而成,晚上点燃处理废料,结束后用雪球灭火,最后在室内二十五度的环境中裹被入睡。 + + + +${video_frame_description} + + +我正在尝试做这个内容的解说纪录片视频,我需要你以 中的内容为解说目标,根据我刚才提供给你的对标文案特点,以及你总结的特点,帮我生成一段关于荒野建造的解说文案,文案需要符合平台受欢迎的解说风格,请使用 json 格式进行输出;使用 中的输出格式: + + +{ + "items": [ + { + "_id": 1, + "timestamp": "00:00:05,390-00:00:10,430", + "picture": "画面描述", + "narration": "解说文案" + } + ] +} + + + +1. 只输出 json 内容,不要输出其他任何说明性的文字 +2. 解说文案的语言使用 简体中文 +3. 严禁虚构画面,所有画面只能从 中摘取 +4. 严禁虚构时间戳,所有时间戳只能从 中摘取 +5. 解说文案要生动有趣,符合荒野建造解说的风格特点 +6. 每个片段的解说文案要与画面内容高度匹配 +7. 保持解说的连贯性和故事性 +""" diff --git a/app/services/prompts/exceptions.py b/app/services/prompts/exceptions.py new file mode 100644 index 0000000..1c6a361 --- /dev/null +++ b/app/services/prompts/exceptions.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +""" +@Project: NarratoAI +@File : exceptions.py +@Author : viccy同学 +@Date : 2025/1/7 +@Description: 提示词管理模块异常定义 +""" + + +class PromptError(Exception): + """提示词模块基础异常类""" + pass + + +class PromptNotFoundError(PromptError): + """提示词未找到异常""" + + def __init__(self, category: str, name: str, version: str = None): + self.category = category + self.name = name + self.version = version + + if version: + message = f"提示词未找到: {category}.{name} (版本: {version})" + else: + message = f"提示词未找到: {category}.{name}" + + super().__init__(message) + + +class PromptValidationError(PromptError): + """提示词验证异常""" + + def __init__(self, message: str, validation_errors: list = None): + self.validation_errors = validation_errors or [] + super().__init__(message) + + +class TemplateRenderError(PromptError): + """模板渲染异常""" + + def __init__(self, template_name: str, error_message: str, missing_params: list = None): + self.template_name = template_name + self.error_message = error_message + self.missing_params = missing_params or [] + + message = f"模板渲染失败 '{template_name}': {error_message}" + if missing_params: + message += f" (缺少参数: {', '.join(missing_params)})" + + super().__init__(message) + + +class PromptRegistrationError(PromptError): + """提示词注册异常""" + + def __init__(self, category: str, name: str, reason: str): + self.category = category + self.name = name + self.reason = reason + + message = f"提示词注册失败 {category}.{name}: {reason}" + super().__init__(message) + + +class PromptVersionError(PromptError): + """提示词版本异常""" + + def __init__(self, category: str, name: str, version: str, reason: str): + self.category = category + self.name = name + self.version = version + self.reason = reason + + message = f"提示词版本错误 {category}.{name} v{version}: {reason}" + super().__init__(message) diff --git a/app/services/prompts/manager.py b/app/services/prompts/manager.py new file mode 100644 index 0000000..5dd65f2 --- /dev/null +++ b/app/services/prompts/manager.py @@ -0,0 +1,287 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +""" +@Project: NarratoAI +@File : manager.py +@Author : viccy同学 +@Date : 2025/1/7 +@Description: 提示词管理器 +""" + +from typing import Dict, Any, List, Optional, Union +from loguru import logger + +from .base import BasePrompt, ModelType, OutputFormat +from .registry import get_registry +from .template import get_renderer +from .validators import PromptOutputValidator +from .exceptions import ( + PromptNotFoundError, + PromptValidationError, + TemplateRenderError +) + + +class PromptManager: + """提示词管理器 - 统一的提示词管理接口""" + + def __init__(self): + self._registry = get_registry() + self._renderer = get_renderer() + + @classmethod + def get_prompt(cls, + category: str, + name: str, + version: Optional[str] = None, + parameters: Optional[Dict[str, Any]] = None) -> str: + """ + 获取渲染后的提示词 + + Args: + category: 分类 + name: 名称 + version: 版本(可选,默认使用最新版本) + parameters: 模板参数(可选) + + Returns: + 渲染后的提示词字符串 + """ + instance = cls() + prompt_obj = instance._registry.get(category, name, version) + + try: + rendered = prompt_obj.render(parameters) + logger.debug(f"提示词渲染成功: {category}.{name} v{prompt_obj.version}") + return rendered + except Exception as e: + logger.error(f"提示词渲染失败: {category}.{name} - {str(e)}") + raise + + @classmethod + def get_prompt_object(cls, + category: str, + name: str, + version: Optional[str] = None) -> BasePrompt: + """ + 获取提示词对象 + + Args: + category: 分类 + name: 名称 + version: 版本(可选) + + Returns: + 提示词对象 + """ + instance = cls() + return instance._registry.get(category, name, version) + + @classmethod + def register_prompt(cls, prompt: BasePrompt, is_default: bool = True) -> None: + """ + 注册提示词 + + Args: + prompt: 提示词对象 + is_default: 是否设为默认版本 + """ + instance = cls() + instance._registry.register(prompt, is_default) + + @classmethod + def list_categories(cls) -> List[str]: + """列出所有分类""" + instance = cls() + return instance._registry.list_categories() + + @classmethod + def list_prompts(cls, category: str) -> List[str]: + """列出指定分类下的所有提示词""" + instance = cls() + return instance._registry.list_prompts(category) + + @classmethod + def list_versions(cls, category: str, name: str) -> List[str]: + """列出指定提示词的所有版本""" + instance = cls() + return instance._registry.list_versions(category, name) + + @classmethod + def exists(cls, category: str, name: str, version: Optional[str] = None) -> bool: + """检查提示词是否存在""" + instance = cls() + return instance._registry.exists(category, name, version) + + @classmethod + def search_prompts(cls, + keyword: str = None, + category: str = None, + model_type: ModelType = None, + output_format: OutputFormat = None) -> List[Dict[str, str]]: + """ + 搜索提示词 + + Args: + keyword: 关键词 + category: 分类过滤 + model_type: 模型类型过滤 + output_format: 输出格式过滤 + + Returns: + 匹配的提示词列表 + """ + instance = cls() + results = instance._registry.search(keyword, category, model_type, output_format) + + return [ + { + "category": cat, + "name": name, + "version": ver, + "full_name": f"{cat}.{name}", + "identifier": f"{cat}.{name}@{ver}" + } + for cat, name, ver in results + ] + + @classmethod + def get_stats(cls) -> Dict[str, Any]: + """获取统计信息""" + instance = cls() + registry_stats = instance._registry.get_stats() + + return { + "registry": registry_stats, + "categories": cls.list_categories(), + "total_categories": registry_stats["categories"], + "total_prompts": registry_stats["prompts"], + "total_versions": registry_stats["versions"] + } + + @classmethod + def validate_output(cls, + output: Union[str, Dict], + category: str, + name: str, + version: Optional[str] = None) -> Any: + """ + 验证提示词输出 + + Args: + output: 输出内容 + category: 提示词分类 + name: 提示词名称 + version: 提示词版本 + + Returns: + 验证后的数据 + """ + instance = cls() + prompt_obj = instance._registry.get(category, name, version) + + # 根据输出格式进行验证 + output_format = prompt_obj.metadata.output_format + + try: + if output_format == OutputFormat.JSON: + # 特殊处理解说文案和剧情分析 + if "narration" in name.lower() or "script" in name.lower(): + return PromptOutputValidator.validate_narration_script(output) + elif "plot" in name.lower() or "analysis" in name.lower(): + return PromptOutputValidator.validate_plot_analysis(output) + else: + return PromptOutputValidator.validate_json(output) + else: + return PromptOutputValidator.validate_by_format(output, output_format) + + except Exception as e: + logger.error(f"输出验证失败 {category}.{name}: {str(e)}") + raise PromptValidationError(f"输出验证失败: {str(e)}") + + @classmethod + def get_prompt_info(cls, category: str, name: str, version: Optional[str] = None) -> Dict[str, Any]: + """ + 获取提示词详细信息 + + Args: + category: 分类 + name: 名称 + version: 版本 + + Returns: + 提示词详细信息 + """ + instance = cls() + prompt_obj = instance._registry.get(category, name, version) + + return { + "metadata": { + "name": prompt_obj.metadata.name, + "category": prompt_obj.metadata.category, + "version": prompt_obj.metadata.version, + "description": prompt_obj.metadata.description, + "model_type": prompt_obj.metadata.model_type.value, + "output_format": prompt_obj.metadata.output_format.value, + "author": prompt_obj.metadata.author, + "created_at": prompt_obj.metadata.created_at.isoformat(), + "updated_at": prompt_obj.metadata.updated_at.isoformat(), + "tags": prompt_obj.metadata.tags, + "parameters": prompt_obj.metadata.parameters + }, + "template_preview": prompt_obj.get_template()[:500] + "..." if len(prompt_obj.get_template()) > 500 else prompt_obj.get_template(), + "system_prompt": prompt_obj.get_system_prompt(), + "examples_count": len(prompt_obj.get_examples()), + "has_parameters": bool(prompt_obj.metadata.parameters) + } + + @classmethod + def export_prompts(cls, category: Optional[str] = None) -> Dict[str, Any]: + """ + 导出提示词配置 + + Args: + category: 分类过滤(可选) + + Returns: + 提示词配置数据 + """ + instance = cls() + categories = [category] if category else instance._registry.list_categories() + + export_data = { + "version": "1.0.0", + "exported_at": instance._get_current_time(), + "categories": {} + } + + for cat in categories: + export_data["categories"][cat] = {} + prompts = instance._registry.list_prompts(cat) + + for prompt_name in prompts: + versions = instance._registry.list_versions(cat, prompt_name) + export_data["categories"][cat][prompt_name] = {} + + for ver in versions: + prompt_obj = instance._registry.get(cat, prompt_name, ver) + export_data["categories"][cat][prompt_name][ver] = prompt_obj.to_dict() + + return export_data + + def _get_current_time(self) -> str: + """获取当前时间字符串""" + from datetime import datetime + return datetime.now().isoformat() + + +# 便捷函数 +def get_prompt(category: str, name: str, version: str = None, **parameters) -> str: + """获取提示词的便捷函数""" + return PromptManager.get_prompt(category, name, version, parameters) + + +def validate_prompt_output(output: Union[str, Dict], category: str, name: str, version: str = None) -> Any: + """验证提示词输出的便捷函数""" + return PromptManager.validate_output(output, category, name, version) diff --git a/app/services/prompts/registry.py b/app/services/prompts/registry.py new file mode 100644 index 0000000..2720522 --- /dev/null +++ b/app/services/prompts/registry.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +""" +@Project: NarratoAI +@File : registry.py +@Author : viccy同学 +@Date : 2025/1/7 +@Description: 提示词注册机制 +""" + +from typing import Dict, List, Optional, Tuple +from collections import defaultdict +from loguru import logger + +from .base import BasePrompt, ModelType, OutputFormat +from .exceptions import ( + PromptNotFoundError, + PromptRegistrationError, + PromptVersionError +) + + +class PromptRegistry: + """提示词注册表""" + + def __init__(self): + # 存储结构: {category: {name: {version: prompt}}} + self._prompts: Dict[str, Dict[str, Dict[str, BasePrompt]]] = defaultdict( + lambda: defaultdict(dict) + ) + # 默认版本映射: {category: {name: default_version}} + self._default_versions: Dict[str, Dict[str, str]] = defaultdict(dict) + + def register(self, prompt: BasePrompt, is_default: bool = True) -> None: + """ + 注册提示词 + + Args: + prompt: 提示词实例 + is_default: 是否设为默认版本 + """ + category = prompt.category + name = prompt.name + version = prompt.version + + # 检查是否已存在相同版本 + if version in self._prompts[category][name]: + raise PromptRegistrationError( + category=category, + name=name, + reason=f"版本 {version} 已存在" + ) + + # 注册提示词 + self._prompts[category][name][version] = prompt + + # 设置默认版本 + if is_default or name not in self._default_versions[category]: + self._default_versions[category][name] = version + + logger.info(f"已注册提示词: {category}.{name} v{version}") + + def get(self, category: str, name: str, version: Optional[str] = None) -> BasePrompt: + """ + 获取提示词 + + Args: + category: 分类 + name: 名称 + version: 版本,为None时使用默认版本 + + Returns: + 提示词实例 + """ + if category not in self._prompts: + raise PromptNotFoundError(category, name, version) + + if name not in self._prompts[category]: + raise PromptNotFoundError(category, name, version) + + # 确定版本 + if version is None: + if name not in self._default_versions[category]: + raise PromptNotFoundError(category, name, version) + version = self._default_versions[category][name] + + if version not in self._prompts[category][name]: + raise PromptNotFoundError(category, name, version) + + return self._prompts[category][name][version] + + def list_categories(self) -> List[str]: + """列出所有分类""" + return list(self._prompts.keys()) + + def list_prompts(self, category: str) -> List[str]: + """列出指定分类下的所有提示词名称""" + if category not in self._prompts: + return [] + return list(self._prompts[category].keys()) + + def list_versions(self, category: str, name: str) -> List[str]: + """列出指定提示词的所有版本""" + if category not in self._prompts or name not in self._prompts[category]: + return [] + return list(self._prompts[category][name].keys()) + + def get_default_version(self, category: str, name: str) -> Optional[str]: + """获取默认版本""" + return self._default_versions.get(category, {}).get(name) + + def set_default_version(self, category: str, name: str, version: str) -> None: + """设置默认版本""" + if (category not in self._prompts or + name not in self._prompts[category] or + version not in self._prompts[category][name]): + raise PromptVersionError(category, name, version, "版本不存在") + + self._default_versions[category][name] = version + logger.info(f"已设置默认版本: {category}.{name} -> v{version}") + + def exists(self, category: str, name: str, version: Optional[str] = None) -> bool: + """检查提示词是否存在""" + try: + self.get(category, name, version) + return True + except PromptNotFoundError: + return False + + def remove(self, category: str, name: str, version: Optional[str] = None) -> None: + """移除提示词""" + if version is None: + # 移除所有版本 + if category in self._prompts and name in self._prompts[category]: + del self._prompts[category][name] + if name in self._default_versions.get(category, {}): + del self._default_versions[category][name] + logger.info(f"已移除提示词所有版本: {category}.{name}") + else: + # 移除指定版本 + if (category in self._prompts and + name in self._prompts[category] and + version in self._prompts[category][name]): + del self._prompts[category][name][version] + + # 如果移除的是默认版本,需要重新设置默认版本 + if (self._default_versions.get(category, {}).get(name) == version and + self._prompts[category][name]): + # 选择最新版本作为默认版本 + new_default = max(self._prompts[category][name].keys()) + self._default_versions[category][name] = new_default + logger.info(f"默认版本已更新: {category}.{name} -> v{new_default}") + + logger.info(f"已移除提示词版本: {category}.{name} v{version}") + + def search(self, + keyword: str = None, + category: str = None, + model_type: ModelType = None, + output_format: OutputFormat = None) -> List[Tuple[str, str, str]]: + """ + 搜索提示词 + + Args: + keyword: 关键词(在名称和描述中搜索) + category: 分类过滤 + model_type: 模型类型过滤 + output_format: 输出格式过滤 + + Returns: + 匹配的提示词列表 [(category, name, version), ...] + """ + results = [] + + categories = [category] if category else self._prompts.keys() + + for cat in categories: + for name in self._prompts[cat]: + for version, prompt in self._prompts[cat][name].items(): + # 关键词过滤 + if keyword: + if (keyword.lower() not in name.lower() and + keyword.lower() not in prompt.metadata.description.lower()): + continue + + # 模型类型过滤 + if model_type and prompt.metadata.model_type != model_type: + continue + + # 输出格式过滤 + if output_format and prompt.metadata.output_format != output_format: + continue + + results.append((cat, name, version)) + + return results + + def get_stats(self) -> Dict[str, int]: + """获取注册表统计信息""" + total_prompts = 0 + total_versions = 0 + + for category in self._prompts: + for name in self._prompts[category]: + total_prompts += 1 + total_versions += len(self._prompts[category][name]) + + return { + "categories": len(self._prompts), + "prompts": total_prompts, + "versions": total_versions + } + + +# 全局注册表实例 +_global_registry = PromptRegistry() + + +def get_registry() -> PromptRegistry: + """获取全局注册表实例""" + return _global_registry diff --git a/app/services/prompts/short_drama_editing/__init__.py b/app/services/prompts/short_drama_editing/__init__.py new file mode 100644 index 0000000..0f3bd04 --- /dev/null +++ b/app/services/prompts/short_drama_editing/__init__.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +""" +@Project: NarratoAI +@File : __init__.py +@Author : viccy同学 +@Date : 2025/1/7 +@Description: 短剧混剪提示词模块 +""" + +from .subtitle_analysis import SubtitleAnalysisPrompt +from .plot_extraction import PlotExtractionPrompt +from ..manager import PromptManager + + +def register_prompts(): + """注册短剧混剪相关的提示词""" + + # 注册字幕分析提示词 + subtitle_analysis_prompt = SubtitleAnalysisPrompt() + PromptManager.register_prompt(subtitle_analysis_prompt, is_default=True) + + # 注册爆点提取提示词 + plot_extraction_prompt = PlotExtractionPrompt() + PromptManager.register_prompt(plot_extraction_prompt, is_default=True) + + +__all__ = [ + "SubtitleAnalysisPrompt", + "PlotExtractionPrompt", + "register_prompts" +] diff --git a/app/services/prompts/short_drama_editing/plot_extraction.py b/app/services/prompts/short_drama_editing/plot_extraction.py new file mode 100644 index 0000000..49c2ab3 --- /dev/null +++ b/app/services/prompts/short_drama_editing/plot_extraction.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +""" +@Project: NarratoAI +@File : plot_extraction.py +@Author : viccy同学 +@Date : 2025/1/7 +@Description: 短剧爆点提取提示词 +""" + +from ..base import TextPrompt, PromptMetadata, ModelType, OutputFormat + + +class PlotExtractionPrompt(TextPrompt): + """短剧爆点提取提示词""" + + def __init__(self): + metadata = PromptMetadata( + name="plot_extraction", + category="short_drama_editing", + version="v1.0", + description="根据剧情梗概和字幕内容,精确定位关键剧情的时间段", + model_type=ModelType.TEXT, + output_format=OutputFormat.JSON, + tags=["短剧", "爆点定位", "时间戳", "剧情提取"], + parameters=["subtitle_content", "plot_summary", "plot_titles"] + ) + super().__init__(metadata) + + self._system_prompt = "你是一名短剧编剧,非常擅长根据字幕中分析视频中关键剧情出现的具体时间段。" + + def get_template(self) -> str: + return """请仔细阅读剧情梗概和爆点内容,然后在字幕中找出每个爆点发生的具体时间段和爆点前后的详细剧情。 + +剧情梗概: +${plot_summary} + +需要定位的爆点内容: +${plot_titles} + +字幕内容: +${subtitle_content} + +分析要求: +1. 为每个爆点找到对应的具体时间段 +2. 时间段要准确反映该爆点的完整发展过程 +3. 提供爆点前后的详细剧情描述 +4. 确保时间戳格式正确且存在于字幕中 +5. 选择最具戏剧张力的时间段 + +请返回一个JSON对象,包含一个名为"plot_points"的数组,数组中包含多个对象,每个对象都要包含以下字段: + +{ + "plot_points": [ + { + "timestamp": "时间段,格式为xx:xx:xx,xxx-xx:xx:xx,xxx", + "title": "关键剧情的主题", + "picture": "关键剧情前后的详细剧情描述,包括人物对话、动作、情感变化等" + } + ] +} + +重要要求: +1. 请确保返回的是合法的JSON格式 +2. 时间戳必须严格按照字幕中的格式 +3. 剧情描述要详细具体,包含关键对话和动作 +4. 每个爆点的时间段要合理,不能过短或过长 +5. 严禁虚构不存在的时间戳或剧情内容 +6. 只输出JSON内容,不要添加任何说明文字""" diff --git a/app/services/prompts/short_drama_editing/subtitle_analysis.py b/app/services/prompts/short_drama_editing/subtitle_analysis.py new file mode 100644 index 0000000..9237e74 --- /dev/null +++ b/app/services/prompts/short_drama_editing/subtitle_analysis.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +""" +@Project: NarratoAI +@File : subtitle_analysis.py +@Author : viccy同学 +@Date : 2025/1/7 +@Description: 短剧字幕分析提示词 +""" + +from ..base import TextPrompt, PromptMetadata, ModelType, OutputFormat + + +class SubtitleAnalysisPrompt(TextPrompt): + """短剧字幕分析提示词""" + + def __init__(self): + metadata = PromptMetadata( + name="subtitle_analysis", + category="short_drama_editing", + version="v1.0", + description="分析短剧字幕内容,提取剧情梗概和关键情节点", + model_type=ModelType.TEXT, + output_format=OutputFormat.JSON, + tags=["短剧", "字幕分析", "剧情梗概", "情节提取"], + parameters=["subtitle_content", "custom_clips"] + ) + super().__init__(metadata) + + self._system_prompt = "你是一名短剧编剧和内容分析师,擅长从字幕中提取剧情要点和关键情节。" + + def get_template(self) -> str: + return """请仔细分析以下短剧字幕内容,提取剧情梗概和关键情节点。 + +字幕内容: +${subtitle_content} + +分析要求: +1. 提取整体剧情梗概,概括主要故事线和核心冲突 +2. 识别 ${custom_clips} 个最具吸引力的关键情节点(爆点) +3. 每个情节点要包含具体的时间段和详细描述 +4. 关注剧情的转折点、冲突高潮、情感爆发等关键时刻 +5. 确保选择的情节点具有强烈的戏剧张力和观看价值 + +请按照以下JSON格式输出分析结果: + +{ + "summary": "整体剧情梗概,简要概括主要故事线、角色关系和核心冲突", + "plot_titles": [ + "情节点1标题", + "情节点2标题", + "情节点3标题" + ], + "analysis_details": { + "main_characters": ["主要角色1", "主要角色2"], + "story_theme": "故事主题", + "conflict_type": "冲突类型(如:爱情、复仇、家庭等)", + "emotional_peaks": ["情感高潮点1", "情感高潮点2"] + } +} + +重要要求: +1. 必须输出有效的JSON格式,不能包含注释或其他文字 +2. 剧情梗概要简洁明了,突出核心看点 +3. 情节点标题要吸引人,体现戏剧冲突 +4. 严禁虚构不存在的剧情内容 +5. 分析要客观准确,基于字幕实际内容""" diff --git a/app/services/prompts/short_drama_narration/__init__.py b/app/services/prompts/short_drama_narration/__init__.py new file mode 100644 index 0000000..dfa0171 --- /dev/null +++ b/app/services/prompts/short_drama_narration/__init__.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +""" +@Project: NarratoAI +@File : __init__.py +@Author : viccy同学 +@Date : 2025/1/7 +@Description: 短剧解说提示词模块 +""" + +from .plot_analysis import PlotAnalysisPrompt +from .script_generation import ScriptGenerationPrompt +from ..manager import PromptManager + + +def register_prompts(): + """注册短剧解说相关的提示词""" + + # 注册剧情分析提示词 + plot_analysis_prompt = PlotAnalysisPrompt() + PromptManager.register_prompt(plot_analysis_prompt, is_default=True) + + # 注册解说脚本生成提示词 + script_generation_prompt = ScriptGenerationPrompt() + PromptManager.register_prompt(script_generation_prompt, is_default=True) + + +__all__ = [ + "PlotAnalysisPrompt", + "ScriptGenerationPrompt", + "register_prompts" +] diff --git a/app/services/SDE/prompt.py b/app/services/prompts/short_drama_narration/plot_analysis.py similarity index 64% rename from app/services/SDE/prompt.py rename to app/services/prompts/short_drama_narration/plot_analysis.py index 78385cc..6431754 100644 --- a/app/services/SDE/prompt.py +++ b/app/services/prompts/short_drama_narration/plot_analysis.py @@ -1,15 +1,37 @@ #!/usr/bin/env python # -*- coding: UTF-8 -*- -''' +""" @Project: NarratoAI -@File : prompt -@Author : 小林同学 -@Date : 2025/5/9 上午12:57 -''' -# 字幕剧情分析提示词 -subtitle_plot_analysis_v1 = """ -# 角色 +@File : plot_analysis.py +@Author : viccy同学 +@Date : 2025/1/7 +@Description: 短剧剧情分析提示词 +""" + +from ..base import TextPrompt, PromptMetadata, ModelType, OutputFormat + + +class PlotAnalysisPrompt(TextPrompt): + """短剧剧情分析提示词""" + + def __init__(self): + metadata = PromptMetadata( + name="plot_analysis", + category="short_drama_narration", + version="v1.0", + description="分析短剧字幕内容,提供详细的剧情分析和分段解析", + model_type=ModelType.TEXT, + output_format=OutputFormat.TEXT, + tags=["短剧", "剧情分析", "字幕解析", "分段分析"], + parameters=["subtitle_content"] + ) + super().__init__(metadata) + + self._system_prompt = "你是一位专业的剧本分析师和剧情概括助手。" + + def get_template(self) -> str: + return """# 角色 你是一位专业的剧本分析师和剧情概括助手。 # 任务 @@ -62,36 +84,7 @@ subtitle_plot_analysis_v1 = """ # 限制 1. 严禁输出与分析结果无关的内容 -2. +2. 时间戳必须严格按照字幕中的实际时间 # 请处理以下字幕: -""" - -plot_writing = """ -我是一个影视解说up主,需要为我的粉丝讲解短剧《%s》的剧情,目前正在解说剧情,希望能让粉丝通过我的解说了解剧情,并且产生 继续观看的兴趣,请生成一篇解说脚本,包含解说文案,以及穿插原声的片段,下面中的内容是短剧的剧情概述: - - -%s - - -请使用 json 格式进行输出;使用 中的输出格式: - -{ - "items": [ - { - "_id": 1, # 唯一递增id - "timestamp": "00:00:05,390-00:00:10,430", - "picture": "剧情描述或者备注", - "narration": "解说文案,如果片段为穿插的原片片段,可以直接使用 ‘播放原片+_id‘ 进行占位", - "OST": "值为 0 表示当前片段为解说片段,值为 1 表示当前片段为穿插的原片" - } -} - - - -1. 只输出 json 内容,不要输出其他任何说明性的文字 -2. 解说文案的语言使用 简体中文 -3. 严禁虚构剧情,所有画面只能从 中摘取 -4. 严禁虚构时间戳,所有时间戳范围只能从 中摘取 - -""" \ No newline at end of file +${subtitle_content}""" diff --git a/app/services/prompts/short_drama_narration/script_generation.py b/app/services/prompts/short_drama_narration/script_generation.py new file mode 100644 index 0000000..9fd105f --- /dev/null +++ b/app/services/prompts/short_drama_narration/script_generation.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +""" +@Project: NarratoAI +@File : script_generation.py +@Author : viccy同学 +@Date : 2025/1/7 +@Description: 短剧解说脚本生成提示词 +""" + +from ..base import ParameterizedPrompt, PromptMetadata, ModelType, OutputFormat + + +class ScriptGenerationPrompt(ParameterizedPrompt): + """短剧解说脚本生成提示词""" + + def __init__(self): + metadata = PromptMetadata( + name="script_generation", + category="short_drama_narration", + version="v1.0", + description="根据剧情分析生成短剧解说脚本,包含解说文案和原声片段", + model_type=ModelType.TEXT, + output_format=OutputFormat.JSON, + tags=["短剧", "解说脚本", "文案生成", "原声片段"], + parameters=["drama_name", "plot_analysis"] + ) + super().__init__(metadata, required_parameters=["drama_name", "plot_analysis"]) + + self._system_prompt = "你是一位专业的短视频解说脚本撰写专家。你必须严格按照JSON格式输出,不能包含任何其他文字、说明或代码块标记。" + + def get_template(self) -> str: + return """我是一个影视解说up主,需要为我的粉丝讲解短剧《${drama_name}》的剧情,目前正在解说剧情,希望能让粉丝通过我的解说了解剧情,并且产生继续观看的兴趣,请生成一篇解说脚本,包含解说文案,以及穿插原声的片段,下面中的内容是短剧的剧情概述: + + +${plot_analysis} + + +请严格按照以下JSON格式输出,不要添加任何其他文字、说明或代码块标记: + +{ + "items": [ + { + "_id": 1, + "timestamp": "00:00:05,390-00:00:10,430", + "picture": "剧情描述或者备注", + "narration": "解说文案,如果片段为穿插的原片片段,可以直接使用 '播放原片+_id' 进行占位", + "OST": 0 + } + ] +} + +重要要求: +1. 只输出 json 内容,不要输出其他任何说明性的文字 +2. 解说文案必须遵循“起-承-转-合”的线性时间链 +3. 解说文案需包含角色微表情、动作细节、场景氛围的描写,每段80-150字 +4. 通过细节关联普遍情感(如遗憾、和解、成长),避免直白抒情 +5. 所有细节严格源自,可对角色行为进行合理心理推导但不虚构剧情 +6. 时间戳从摘取,可根据解说内容拆分原时间片段(如将10秒拆分为两个5秒) +7. 解说与原片穿插比例控制在7:3,关键情绪点保留原片原声 +8. 严禁跳脱剧情发展顺序,所有描述必须符合“先发生A,再发生B,A导致B”的逻辑 +9. 强化流程感,让观众清晰感知剧情推进的先后顺序""" diff --git a/app/services/prompts/template.py b/app/services/prompts/template.py new file mode 100644 index 0000000..bfebcb5 --- /dev/null +++ b/app/services/prompts/template.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +""" +@Project: NarratoAI +@File : template.py +@Author : viccy同学 +@Date : 2025/1/7 +@Description: 模板渲染引擎 +""" + +import re +from typing import Dict, Any, List, Optional +from string import Template +from loguru import logger + +from .exceptions import TemplateRenderError + + +class TemplateRenderer: + """模板渲染器""" + + def __init__(self): + self._custom_filters = {} + + def register_filter(self, name: str, func: callable) -> None: + """注册自定义过滤器""" + self._custom_filters[name] = func + logger.debug(f"已注册模板过滤器: {name}") + + def render(self, template: str, parameters: Dict[str, Any] = None) -> str: + """ + 渲染模板 + + Args: + template: 模板字符串 + parameters: 参数字典 + + Returns: + 渲染后的字符串 + """ + parameters = parameters or {} + + try: + # 使用简单的字符串替换进行参数替换 + rendered = template + + for key, value in parameters.items(): + # 替换 ${key} 格式的参数 + rendered = rendered.replace(f"${{{key}}}", str(value)) + # 也替换 $key 格式的参数(为了兼容性) + rendered = rendered.replace(f"${key}", str(value)) + + # 处理自定义过滤器 + rendered = self._apply_filters(rendered, parameters) + + return rendered + + except Exception as e: + raise TemplateRenderError( + template_name="unknown", + error_message=f"模板渲染失败: {str(e)}" + ) + + def _apply_filters(self, text: str, parameters: Dict[str, Any]) -> str: + """应用自定义过滤器""" + # 查找过滤器模式: ${variable|filter_name} + filter_pattern = r'\$\{([^}]+)\|([^}]+)\}' + + def replace_filter(match): + var_name = match.group(1).strip() + filter_name = match.group(2).strip() + + if filter_name not in self._custom_filters: + logger.warning(f"未知的过滤器: {filter_name}") + return match.group(0) # 返回原始文本 + + if var_name not in parameters: + logger.warning(f"参数不存在: {var_name}") + return match.group(0) # 返回原始文本 + + try: + filter_func = self._custom_filters[filter_name] + filtered_value = filter_func(parameters[var_name]) + return str(filtered_value) + except Exception as e: + logger.error(f"过滤器执行失败 {filter_name}: {str(e)}") + return match.group(0) # 返回原始文本 + + return re.sub(filter_pattern, replace_filter, text) + + def extract_variables(self, template: str) -> List[str]: + """提取模板中的变量名""" + # 匹配 ${variable} 和 ${variable|filter} 模式 + pattern = r'\$\{([^}|]+)(?:\|[^}]+)?\}' + matches = re.findall(pattern, template) + return list(set(match.strip() for match in matches)) + + def validate_template(self, template: str, required_params: List[str] = None) -> bool: + """验证模板""" + try: + # 提取模板变量 + template_vars = self.extract_variables(template) + + # 检查必需参数 + if required_params: + missing_params = set(required_params) - set(template_vars) + if missing_params: + raise TemplateRenderError( + template_name="validation", + error_message="模板缺少必需参数", + missing_params=list(missing_params) + ) + + # 尝试渲染测试 + test_params = {var: f"test_{var}" for var in template_vars} + self.render(template, test_params) + + return True + + except Exception as e: + logger.error(f"模板验证失败: {str(e)}") + return False + + +# 内置过滤器 +def _upper_filter(value: Any) -> str: + """转换为大写""" + return str(value).upper() + + +def _lower_filter(value: Any) -> str: + """转换为小写""" + return str(value).lower() + + +def _title_filter(value: Any) -> str: + """转换为标题格式""" + return str(value).title() + + +def _strip_filter(value: Any) -> str: + """去除首尾空白""" + return str(value).strip() + + +def _truncate_filter(value: Any, length: int = 100) -> str: + """截断文本""" + text = str(value) + if len(text) <= length: + return text + return text[:length] + "..." + + +def _json_filter(value: Any) -> str: + """转换为JSON字符串""" + import json + return json.dumps(value, ensure_ascii=False, indent=2) + + +# 全局渲染器实例 +_global_renderer = TemplateRenderer() + +# 注册内置过滤器 +_global_renderer.register_filter("upper", _upper_filter) +_global_renderer.register_filter("lower", _lower_filter) +_global_renderer.register_filter("title", _title_filter) +_global_renderer.register_filter("strip", _strip_filter) +_global_renderer.register_filter("truncate", _truncate_filter) +_global_renderer.register_filter("json", _json_filter) + + +def get_renderer() -> TemplateRenderer: + """获取全局渲染器实例""" + return _global_renderer + + +def render_template(template: str, parameters: Dict[str, Any] = None) -> str: + """便捷的模板渲染函数""" + return _global_renderer.render(template, parameters) diff --git a/app/services/prompts/validators.py b/app/services/prompts/validators.py new file mode 100644 index 0000000..e991e2b --- /dev/null +++ b/app/services/prompts/validators.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +""" +@Project: NarratoAI +@File : validators.py +@Author : viccy同学 +@Date : 2025/1/7 +@Description: 提示词输出验证器 +""" + +import json +import re +from typing import Dict, Any, List, Optional, Union +from loguru import logger + +from .base import OutputFormat +from .exceptions import PromptValidationError + + +class PromptOutputValidator: + """提示词输出验证器""" + + @staticmethod + def validate_json(output: str, schema: Dict[str, Any] = None) -> Dict[str, Any]: + """ + 验证JSON输出 + + Args: + output: 输出字符串 + schema: JSON schema(可选) + + Returns: + 解析后的JSON对象 + """ + try: + # 清理输出(移除可能的代码块标记) + cleaned_output = PromptOutputValidator._clean_json_output(output) + + # 解析JSON + parsed = json.loads(cleaned_output) + + # Schema验证(如果提供) + if schema: + PromptOutputValidator._validate_json_schema(parsed, schema) + + return parsed + + except json.JSONDecodeError as e: + raise PromptValidationError(f"JSON格式错误: {str(e)}") + except Exception as e: + raise PromptValidationError(f"JSON验证失败: {str(e)}") + + @staticmethod + def validate_narration_script(output: Union[str, Dict]) -> Dict[str, Any]: + """ + 验证解说文案输出格式 + + Args: + output: 输出内容(字符串或字典) + + Returns: + 验证后的解说文案数据 + """ + # 如果是字符串,先解析为JSON + if isinstance(output, str): + data = PromptOutputValidator.validate_json(output) + else: + data = output + + # 验证必需字段 + if "items" not in data: + raise PromptValidationError("解说文案缺少 'items' 字段") + + items = data["items"] + if not isinstance(items, list): + raise PromptValidationError("'items' 字段必须是数组") + + if not items: + raise PromptValidationError("解说文案不能为空") + + # 验证每个item + for i, item in enumerate(items): + PromptOutputValidator._validate_narration_item(item, i) + + logger.debug(f"解说文案验证通过,包含 {len(items)} 个片段") + return data + + @staticmethod + def validate_plot_analysis(output: Union[str, Dict]) -> Dict[str, Any]: + """ + 验证剧情分析输出格式 + + Args: + output: 输出内容 + + Returns: + 验证后的剧情分析数据 + """ + if isinstance(output, str): + data = PromptOutputValidator.validate_json(output) + else: + data = output + + # 验证剧情分析必需字段 + required_fields = ["summary", "plot_points"] + for field in required_fields: + if field not in data: + raise PromptValidationError(f"剧情分析缺少 '{field}' 字段") + + # 验证plot_points + plot_points = data["plot_points"] + if not isinstance(plot_points, list): + raise PromptValidationError("'plot_points' 字段必须是数组") + + for i, point in enumerate(plot_points): + PromptOutputValidator._validate_plot_point(point, i) + + logger.debug(f"剧情分析验证通过,包含 {len(plot_points)} 个情节点") + return data + + @staticmethod + def _clean_json_output(output: str) -> str: + """清理JSON输出""" + # 移除可能的代码块标记 + output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE) + output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE) + + # 移除前后空白 + output = output.strip() + + # 尝试提取JSON部分(如果有其他文本) + json_match = re.search(r'\{.*\}', output, re.DOTALL) + if json_match: + output = json_match.group(0) + + return output + + @staticmethod + def _validate_json_schema(data: Dict[str, Any], schema: Dict[str, Any]) -> None: + """验证JSON Schema""" + # 简单的schema验证实现 + for field, field_type in schema.items(): + if field not in data: + raise PromptValidationError(f"缺少必需字段: {field}") + + if not isinstance(data[field], field_type): + raise PromptValidationError( + f"字段 '{field}' 类型错误,期望: {field_type.__name__},实际: {type(data[field]).__name__}" + ) + + @staticmethod + def _validate_narration_item(item: Dict[str, Any], index: int) -> None: + """验证解说文案项目""" + required_fields = ["_id", "timestamp", "picture", "narration"] + + for field in required_fields: + if field not in item: + raise PromptValidationError(f"第 {index + 1} 个片段缺少 '{field}' 字段") + + # 验证_id + if not isinstance(item["_id"], int) or item["_id"] <= 0: + raise PromptValidationError(f"第 {index + 1} 个片段的 '_id' 必须是正整数") + + # 验证timestamp格式 + timestamp = item["timestamp"] + if not isinstance(timestamp, str): + raise PromptValidationError(f"第 {index + 1} 个片段的 'timestamp' 必须是字符串") + + # 验证时间戳格式 (HH:MM:SS,mmm-HH:MM:SS,mmm) + timestamp_pattern = r'^\d{2}:\d{2}:\d{2},\d{3}-\d{2}:\d{2}:\d{2},\d{3}$' + if not re.match(timestamp_pattern, timestamp): + raise PromptValidationError( + f"第 {index + 1} 个片段的时间戳格式错误,应为 'HH:MM:SS,mmm-HH:MM:SS,mmm'" + ) + + # 验证文本字段不为空 + for field in ["picture", "narration"]: + if not isinstance(item[field], str) or not item[field].strip(): + raise PromptValidationError(f"第 {index + 1} 个片段的 '{field}' 不能为空") + + # 验证OST字段(如果存在) + if "OST" in item: + if not isinstance(item["OST"], int) or item["OST"] not in [0, 1, 2]: + raise PromptValidationError( + f"第 {index + 1} 个片段的 'OST' 必须是 0、1 或 2" + ) + + @staticmethod + def _validate_plot_point(point: Dict[str, Any], index: int) -> None: + """验证剧情点""" + required_fields = ["timestamp", "title", "picture"] + + for field in required_fields: + if field not in point: + raise PromptValidationError(f"第 {index + 1} 个剧情点缺少 '{field}' 字段") + + # 验证字段类型和内容 + for field in required_fields: + if not isinstance(point[field], str) or not point[field].strip(): + raise PromptValidationError(f"第 {index + 1} 个剧情点的 '{field}' 不能为空") + + # 验证时间戳格式 + timestamp = point["timestamp"] + # 支持多种时间戳格式 + patterns = [ + r'^\d{2}:\d{2}:\d{2},\d{3}-\d{2}:\d{2}:\d{2},\d{3}$', # HH:MM:SS,mmm-HH:MM:SS,mmm + r'^\d{2}:\d{2}:\d{2}-\d{2}:\d{2}:\d{2}$', # HH:MM:SS-HH:MM:SS + ] + + if not any(re.match(pattern, timestamp) for pattern in patterns): + raise PromptValidationError( + f"第 {index + 1} 个剧情点的时间戳格式错误" + ) + + @staticmethod + def validate_by_format(output: str, format_type: OutputFormat, schema: Dict[str, Any] = None) -> Any: + """ + 根据格式类型验证输出 + + Args: + output: 输出内容 + format_type: 输出格式类型 + schema: 验证schema(可选) + + Returns: + 验证后的数据 + """ + if format_type == OutputFormat.JSON: + return PromptOutputValidator.validate_json(output, schema) + elif format_type == OutputFormat.TEXT: + return output.strip() + elif format_type == OutputFormat.MARKDOWN: + return output.strip() + elif format_type == OutputFormat.STRUCTURED: + # 结构化数据需要根据具体类型处理 + return PromptOutputValidator.validate_json(output, schema) + else: + raise PromptValidationError(f"不支持的输出格式: {format_type}") + + +# 便捷函数 +def validate_json_output(output: str, schema: Dict[str, Any] = None) -> Dict[str, Any]: + """验证JSON输出的便捷函数""" + return PromptOutputValidator.validate_json(output, schema) + + +def validate_narration_output(output: Union[str, Dict]) -> Dict[str, Any]: + """验证解说文案输出的便捷函数""" + return PromptOutputValidator.validate_narration_script(output) diff --git a/app/services/script_service.py b/app/services/script_service.py index 461978b..e9ff042 100644 --- a/app/services/script_service.py +++ b/app/services/script_service.py @@ -140,14 +140,27 @@ class ScriptGenerator: # 获取Gemini配置 vision_api_key = config.app.get("vision_gemini_api_key") vision_model = config.app.get("vision_gemini_model_name") - + vision_base_url = config.app.get("vision_gemini_base_url") + if not vision_api_key or not vision_model: raise ValueError("未配置 Gemini API Key 或者模型") - analyzer = gemini_analyzer.VisionAnalyzer( - model_name=vision_model, - api_key=vision_api_key, - ) + # 根据提供商类型选择合适的分析器 + if vision_provider == 'gemini(openai)': + # 使用OpenAI兼容的Gemini代理 + from app.utils.gemini_openai_analyzer import GeminiOpenAIAnalyzer + analyzer = GeminiOpenAIAnalyzer( + model_name=vision_model, + api_key=vision_api_key, + base_url=vision_base_url + ) + else: + # 使用原生Gemini分析器 + analyzer = gemini_analyzer.VisionAnalyzer( + model_name=vision_model, + api_key=vision_api_key, + base_url=vision_base_url + ) progress_callback(40, "正在分析关键帧...") @@ -213,13 +226,35 @@ class ScriptGenerator: text_provider = config.app.get('text_llm_provider', 'gemini').lower() text_api_key = config.app.get(f'text_{text_provider}_api_key') text_model = config.app.get(f'text_{text_provider}_model_name') + text_base_url = config.app.get(f'text_{text_provider}_base_url') - processor = ScriptProcessor( - model_name=text_model, - api_key=text_api_key, - prompt=custom_prompt, - video_theme=video_theme - ) + # 根据提供商类型选择合适的处理器 + if text_provider == 'gemini(openai)': + # 使用OpenAI兼容的Gemini代理 + from app.utils.script_generator import GeminiOpenAIGenerator + generator = GeminiOpenAIGenerator( + model_name=text_model, + api_key=text_api_key, + prompt=custom_prompt, + base_url=text_base_url + ) + processor = ScriptProcessor( + model_name=text_model, + api_key=text_api_key, + base_url=text_base_url, + prompt=custom_prompt, + video_theme=video_theme + ) + processor.generator = generator + else: + # 使用标准处理器(包括原生Gemini) + processor = ScriptProcessor( + model_name=text_model, + api_key=text_api_key, + base_url=text_base_url, + prompt=custom_prompt, + video_theme=video_theme + ) return processor.process_frames(frame_content_list) diff --git a/app/services/update_script.py b/app/services/update_script.py index 2eb9663..9dd32d2 100644 --- a/app/services/update_script.py +++ b/app/services/update_script.py @@ -4,7 +4,7 @@ ''' @Project: NarratoAI @File : update_script -@Author : 小林同学 +@Author : Viccy同学 @Date : 2025/5/6 下午11:00 ''' diff --git a/app/utils/gemini_analyzer.py b/app/utils/gemini_analyzer.py index 7236a9e..c3685ab 100644 --- a/app/utils/gemini_analyzer.py +++ b/app/utils/gemini_analyzer.py @@ -5,53 +5,162 @@ from pathlib import Path from loguru import logger from tqdm import tqdm import asyncio -from tenacity import retry, stop_after_attempt, RetryError, retry_if_exception_type, wait_exponential -from google.api_core import exceptions -import google.generativeai as genai +from tenacity import retry, stop_after_attempt, retry_if_exception_type, wait_exponential +import requests import PIL.Image import traceback +import base64 +import io from app.utils import utils class VisionAnalyzer: - """视觉分析器类""" + """原生Gemini视觉分析器类""" - def __init__(self, model_name: str = "gemini-1.5-flash", api_key: str = None): + def __init__(self, model_name: str = "gemini-2.0-flash-exp", api_key: str = None, base_url: str = None): """初始化视觉分析器""" if not api_key: raise ValueError("必须提供API密钥") self.model_name = model_name self.api_key = api_key + self.base_url = base_url or "https://generativelanguage.googleapis.com/v1beta" # 初始化配置 self._configure_client() def _configure_client(self): - """配置API客户端""" - genai.configure(api_key=self.api_key) - # 开放 Gemini 模型安全设置 - from google.generativeai.types import HarmCategory, HarmBlockThreshold - safety_settings = { - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, - } - self.model = genai.GenerativeModel(self.model_name, safety_settings=safety_settings) + """配置原生Gemini API客户端""" + # 使用原生Gemini REST API + self.client = None + logger.info(f"配置原生Gemini API,端点: {self.base_url}, 模型: {self.model_name}") @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type(exceptions.ResourceExhausted) + retry=retry_if_exception_type(requests.exceptions.RequestException) ) async def _generate_content_with_retry(self, prompt, batch): - """使用重试机制的内部方法来调用 generate_content_async""" + """使用重试机制调用原生Gemini API""" try: - return await self.model.generate_content_async([prompt, *batch]) - except exceptions.ResourceExhausted as e: - print(f"API配额限制: {str(e)}") - raise RetryError("API调用失败") + return await self._generate_with_gemini_api(prompt, batch) + except requests.exceptions.RequestException as e: + logger.warning(f"Gemini API请求异常: {str(e)}") + raise + except Exception as e: + logger.error(f"Gemini API生成内容时发生错误: {str(e)}") + raise + + async def _generate_with_gemini_api(self, prompt, batch): + """使用原生Gemini REST API生成内容""" + # 将PIL图片转换为base64编码 + image_parts = [] + for img in batch: + # 将PIL图片转换为字节流 + img_buffer = io.BytesIO() + img.save(img_buffer, format='JPEG', quality=85) # 优化图片质量 + img_bytes = img_buffer.getvalue() + + # 转换为base64 + img_base64 = base64.b64encode(img_bytes).decode('utf-8') + image_parts.append({ + "inline_data": { + "mime_type": "image/jpeg", + "data": img_base64 + } + }) + + # 构建符合官方文档的请求数据 + request_data = { + "contents": [{ + "parts": [ + {"text": prompt}, + *image_parts + ] + }], + "generationConfig": { + "temperature": 1.0, + "topK": 40, + "topP": 0.95, + "maxOutputTokens": 8192, + "candidateCount": 1, + "stopSequences": [] + }, + "safetySettings": [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_NONE" + } + ] + } + + # 构建请求URL + url = f"{self.base_url}/models/{self.model_name}:generateContent?key={self.api_key}" + + # 发送请求 + response = await asyncio.to_thread( + requests.post, + url, + json=request_data, + headers={ + "Content-Type": "application/json", + "User-Agent": "NarratoAI/1.0" + }, + timeout=120 # 增加超时时间 + ) + + # 处理HTTP错误 + if response.status_code == 429: + raise requests.exceptions.RequestException(f"API配额限制: {response.text}") + elif response.status_code == 400: + raise Exception(f"请求参数错误: {response.text}") + elif response.status_code == 403: + raise Exception(f"API密钥无效或权限不足: {response.text}") + elif response.status_code != 200: + raise Exception(f"Gemini API请求失败: {response.status_code} - {response.text}") + + response_data = response.json() + + # 检查响应格式 + if "candidates" not in response_data or not response_data["candidates"]: + raise Exception("Gemini API返回无效响应,可能触发了安全过滤") + + candidate = response_data["candidates"][0] + + # 检查是否被安全过滤阻止 + if "finishReason" in candidate and candidate["finishReason"] == "SAFETY": + raise Exception("内容被Gemini安全过滤器阻止") + + if "content" not in candidate or "parts" not in candidate["content"]: + raise Exception("Gemini API返回内容格式错误") + + # 提取文本内容 + text_content = "" + for part in candidate["content"]["parts"]: + if "text" in part: + text_content += part["text"] + + if not text_content.strip(): + raise Exception("Gemini API返回空内容") + + # 创建兼容的响应对象 + class CompatibleResponse: + def __init__(self, text): + self.text = text + + return CompatibleResponse(text_content) async def analyze_images(self, images: Union[List[str], List[PIL.Image.Image]], diff --git a/app/utils/gemini_openai_analyzer.py b/app/utils/gemini_openai_analyzer.py new file mode 100644 index 0000000..9d2ca0b --- /dev/null +++ b/app/utils/gemini_openai_analyzer.py @@ -0,0 +1,177 @@ +""" +OpenAI兼容的Gemini视觉分析器 +使用标准OpenAI格式调用Gemini代理服务 +""" + +import json +from typing import List, Union, Dict +import os +from pathlib import Path +from loguru import logger +from tqdm import tqdm +import asyncio +from tenacity import retry, stop_after_attempt, retry_if_exception_type, wait_exponential +import requests +import PIL.Image +import traceback +import base64 +import io +from app.utils import utils + + +class GeminiOpenAIAnalyzer: + """OpenAI兼容的Gemini视觉分析器类""" + + def __init__(self, model_name: str = "gemini-2.0-flash-exp", api_key: str = None, base_url: str = None): + """初始化OpenAI兼容的Gemini分析器""" + if not api_key: + raise ValueError("必须提供API密钥") + + if not base_url: + raise ValueError("必须提供OpenAI兼容的代理端点URL") + + self.model_name = model_name + self.api_key = api_key + self.base_url = base_url.rstrip('/') + + # 初始化OpenAI客户端 + self._configure_client() + + def _configure_client(self): + """配置OpenAI兼容的客户端""" + from openai import OpenAI + self.client = OpenAI( + api_key=self.api_key, + base_url=self.base_url + ) + logger.info(f"配置OpenAI兼容Gemini代理,端点: {self.base_url}, 模型: {self.model_name}") + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((requests.exceptions.RequestException, Exception)) + ) + async def _generate_content_with_retry(self, prompt, batch): + """使用重试机制调用OpenAI兼容的Gemini代理""" + try: + return await self._generate_with_openai_api(prompt, batch) + except Exception as e: + logger.warning(f"OpenAI兼容Gemini代理请求异常: {str(e)}") + raise + + async def _generate_with_openai_api(self, prompt, batch): + """使用OpenAI兼容接口生成内容""" + # 将PIL图片转换为base64编码 + image_contents = [] + for img in batch: + # 将PIL图片转换为字节流 + img_buffer = io.BytesIO() + img.save(img_buffer, format='JPEG', quality=85) + img_bytes = img_buffer.getvalue() + + # 转换为base64 + img_base64 = base64.b64encode(img_bytes).decode('utf-8') + image_contents.append({ + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{img_base64}" + } + }) + + # 构建OpenAI格式的消息 + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + *image_contents + ] + } + ] + + # 调用OpenAI兼容接口 + response = await asyncio.to_thread( + self.client.chat.completions.create, + model=self.model_name, + messages=messages, + max_tokens=4000, + temperature=1.0 + ) + + # 创建兼容的响应对象 + class CompatibleResponse: + def __init__(self, text): + self.text = text + + return CompatibleResponse(response.choices[0].message.content) + + async def analyze_images(self, + images: List[Union[str, Path, PIL.Image.Image]], + prompt: str, + batch_size: int = 10) -> List[str]: + """ + 分析图片并返回结果 + + Args: + images: 图片路径列表或PIL图片对象列表 + prompt: 分析提示词 + batch_size: 批处理大小 + + Returns: + 分析结果列表 + """ + logger.info(f"开始分析 {len(images)} 张图片,使用OpenAI兼容Gemini代理") + + # 加载图片 + loaded_images = [] + for img in images: + if isinstance(img, (str, Path)): + try: + pil_img = PIL.Image.open(img) + # 调整图片大小以优化性能 + if pil_img.size[0] > 1024 or pil_img.size[1] > 1024: + pil_img.thumbnail((1024, 1024), PIL.Image.Resampling.LANCZOS) + loaded_images.append(pil_img) + except Exception as e: + logger.error(f"加载图片失败 {img}: {str(e)}") + continue + elif isinstance(img, PIL.Image.Image): + loaded_images.append(img) + else: + logger.warning(f"不支持的图片类型: {type(img)}") + continue + + if not loaded_images: + raise ValueError("没有有效的图片可以分析") + + # 分批处理 + results = [] + total_batches = (len(loaded_images) + batch_size - 1) // batch_size + + for i in tqdm(range(0, len(loaded_images), batch_size), + desc="分析图片批次", total=total_batches): + batch = loaded_images[i:i + batch_size] + + try: + response = await self._generate_content_with_retry(prompt, batch) + results.append(response.text) + + # 添加延迟以避免API限流 + if i + batch_size < len(loaded_images): + await asyncio.sleep(1) + + except Exception as e: + logger.error(f"分析批次 {i//batch_size + 1} 失败: {str(e)}") + results.append(f"分析失败: {str(e)}") + + logger.info(f"完成图片分析,共处理 {len(results)} 个批次") + return results + + def analyze_images_sync(self, + images: List[Union[str, Path, PIL.Image.Image]], + prompt: str, + batch_size: int = 10) -> List[str]: + """ + 同步版本的图片分析方法 + """ + return asyncio.run(self.analyze_images(images, prompt, batch_size)) diff --git a/app/utils/script_generator.py b/app/utils/script_generator.py index 7020782..e6d7cea 100644 --- a/app/utils/script_generator.py +++ b/app/utils/script_generator.py @@ -6,7 +6,7 @@ from loguru import logger from typing import List, Dict from datetime import datetime from openai import OpenAI -import google.generativeai as genai +import requests import time @@ -134,59 +134,182 @@ class OpenAIGenerator(BaseGenerator): class GeminiGenerator(BaseGenerator): - """Google Gemini API 生成器实现""" - def __init__(self, model_name: str, api_key: str, prompt: str): + """原生Gemini API 生成器实现""" + def __init__(self, model_name: str, api_key: str, prompt: str, base_url: str = None): super().__init__(model_name, api_key, prompt) - genai.configure(api_key=api_key) - self.model = genai.GenerativeModel(model_name) - - # Gemini特定参数 + + self.base_url = base_url or "https://generativelanguage.googleapis.com/v1beta" + self.client = None + + # 原生Gemini API参数 self.default_params = { "temperature": self.default_params["temperature"], - "top_p": self.default_params["top_p"], - "candidate_count": 1, - "stop_sequences": None + "topP": self.default_params["top_p"], + "topK": 40, + "maxOutputTokens": 4000, + "candidateCount": 1, + "stopSequences": [] + } + + +class GeminiOpenAIGenerator(BaseGenerator): + """OpenAI兼容的Gemini代理生成器实现""" + def __init__(self, model_name: str, api_key: str, prompt: str, base_url: str = None): + super().__init__(model_name, api_key, prompt) + + if not base_url: + raise ValueError("OpenAI兼容的Gemini代理必须提供base_url") + + self.base_url = base_url.rstrip('/') + + # 使用OpenAI兼容接口 + from openai import OpenAI + self.client = OpenAI( + api_key=api_key, + base_url=base_url + ) + + # OpenAI兼容接口参数 + self.default_params = { + "temperature": self.default_params["temperature"], + "max_tokens": 4000, + "stream": False } def _generate(self, messages: list, params: dict) -> any: - """实现Gemini特定的生成逻辑""" - while True: + """实现OpenAI兼容Gemini代理的生成逻辑""" + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=messages, + **params + ) + return response + except Exception as e: + logger.error(f"OpenAI兼容Gemini代理生成错误: {str(e)}") + raise + + def _process_response(self, response: any) -> str: + """处理OpenAI兼容接口的响应""" + if not response or not response.choices: + raise ValueError("OpenAI兼容Gemini代理返回无效响应") + return response.choices[0].message.content.strip() + + def _generate(self, messages: list, params: dict) -> any: + """实现原生Gemini API的生成逻辑""" + max_retries = 3 + for attempt in range(max_retries): try: # 转换消息格式为Gemini格式 prompt = "\n".join([m["content"] for m in messages]) - response = self.model.generate_content( - prompt, - generation_config=params + + # 构建请求数据 + request_data = { + "contents": [{ + "parts": [{"text": prompt}] + }], + "generationConfig": params, + "safetySettings": [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_NONE" + } + ] + } + + # 构建请求URL + url = f"{self.base_url}/models/{self.model_name}:generateContent?key={self.api_key}" + + # 发送请求 + response = requests.post( + url, + json=request_data, + headers={ + "Content-Type": "application/json", + "User-Agent": "NarratoAI/1.0" + }, + timeout=120 ) - - # 检查响应是否包含有效内容 - if (hasattr(response, 'result') and - hasattr(response.result, 'candidates') and - response.result.candidates): - - candidate = response.result.candidates[0] - - # 检查是否有内容字段 - if not hasattr(candidate, 'content'): - logger.warning("Gemini API 返回速率限制响应,等待30秒后重试...") - time.sleep(30) # 等待3秒后重试 + + if response.status_code == 429: + # 处理限流 + wait_time = 65 if attempt == 0 else 30 + logger.warning(f"原生Gemini API 触发限流,等待{wait_time}秒后重试...") + time.sleep(wait_time) + continue + + if response.status_code == 400: + raise Exception(f"请求参数错误: {response.text}") + elif response.status_code == 403: + raise Exception(f"API密钥无效或权限不足: {response.text}") + elif response.status_code != 200: + raise Exception(f"原生Gemini API请求失败: {response.status_code} - {response.text}") + + response_data = response.json() + + # 检查响应格式 + if "candidates" not in response_data or not response_data["candidates"]: + if attempt < max_retries - 1: + logger.warning("原生Gemini API 返回无效响应,等待30秒后重试...") + time.sleep(30) continue - return response - - except Exception as e: - error_str = str(e) - if "429" in error_str: - logger.warning("Gemini API 触发限流,等待65秒后重试...") - time.sleep(65) # 等待65秒后重试 + else: + raise Exception("原生Gemini API返回无效响应,可能触发了安全过滤") + + candidate = response_data["candidates"][0] + + # 检查是否被安全过滤阻止 + if "finishReason" in candidate and candidate["finishReason"] == "SAFETY": + raise Exception("内容被Gemini安全过滤器阻止") + + # 创建兼容的响应对象 + class CompatibleResponse: + def __init__(self, data): + self.data = data + candidate = data["candidates"][0] + if "content" in candidate and "parts" in candidate["content"]: + self.text = "" + for part in candidate["content"]["parts"]: + if "text" in part: + self.text += part["text"] + else: + self.text = "" + + return CompatibleResponse(response_data) + + except requests.exceptions.RequestException as e: + if attempt < max_retries - 1: + logger.warning(f"网络请求失败,等待30秒后重试: {str(e)}") + time.sleep(30) continue else: - logger.error(f"Gemini 生成文案错误: \n{error_str}") + logger.error(f"原生Gemini API请求失败: {str(e)}") + raise + except Exception as e: + if attempt < max_retries - 1 and "429" in str(e): + logger.warning("原生Gemini API 触发限流,等待65秒后重试...") + time.sleep(65) + continue + else: + logger.error(f"原生Gemini 生成文案错误: {str(e)}") raise def _process_response(self, response: any) -> str: - """处理Gemini的响应""" + """处理原生Gemini API的响应""" if not response or not response.text: - raise ValueError("Invalid response from Gemini API") + raise ValueError("原生Gemini API返回无效响应") return response.text.strip() @@ -318,7 +441,7 @@ class ScriptProcessor: # 根据模型名称选择对应的生成器 logger.info(f"文本 LLM 提供商: {model_name}") if 'gemini' in model_name.lower(): - self.generator = GeminiGenerator(model_name, self.api_key, self.prompt) + self.generator = GeminiGenerator(model_name, self.api_key, self.prompt, self.base_url) elif 'qwen' in model_name.lower(): self.generator = QwenGenerator(model_name, self.api_key, self.prompt, self.base_url) elif 'moonshot' in model_name.lower(): diff --git a/config.example.toml b/config.example.toml index 270ed0d..562c454 100644 --- a/config.example.toml +++ b/config.example.toml @@ -1,5 +1,5 @@ [app] - project_version="0.6.6" + project_version="0.6.7" # 支持视频理解的大模型提供商 # gemini (谷歌, 需要 VPN) # siliconflow (硅基流动) diff --git a/docs/LLM_MIGRATION_GUIDE.md b/docs/LLM_MIGRATION_GUIDE.md new file mode 100644 index 0000000..9cc606b --- /dev/null +++ b/docs/LLM_MIGRATION_GUIDE.md @@ -0,0 +1,367 @@ +# NarratoAI 大模型服务迁移指南 + +## 📋 概述 + +本指南帮助开发者将现有代码从旧的大模型调用方式迁移到新的统一LLM服务架构。新架构提供了更好的模块化、错误处理和配置管理。 + +## 🔄 迁移对比 + +### 旧的调用方式 vs 新的调用方式 + +#### 1. 视觉分析器创建 + +**旧方式:** +```python +from app.utils import gemini_analyzer, qwenvl_analyzer + +if provider == 'gemini': + analyzer = gemini_analyzer.VisionAnalyzer( + model_name=model, + api_key=api_key, + base_url=base_url + ) +elif provider == 'qwenvl': + analyzer = qwenvl_analyzer.QwenAnalyzer( + model_name=model, + api_key=api_key, + base_url=base_url + ) +``` + +**新方式:** +```python +from app.services.llm.unified_service import UnifiedLLMService + +# 方式1: 直接使用统一服务 +results = await UnifiedLLMService.analyze_images( + images=images, + prompt=prompt, + provider=provider # 可选,使用配置中的默认值 +) + +# 方式2: 使用迁移适配器(向后兼容) +from app.services.llm.migration_adapter import create_vision_analyzer +analyzer = create_vision_analyzer(provider, api_key, model, base_url) +results = await analyzer.analyze_images(images, prompt) +``` + +#### 2. 文本生成 + +**旧方式:** +```python +from openai import OpenAI + +client = OpenAI(api_key=api_key, base_url=base_url) +response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt} + ], + temperature=temperature, + response_format={"type": "json_object"} +) +result = response.choices[0].message.content +``` + +**新方式:** +```python +from app.services.llm.unified_service import UnifiedLLMService + +result = await UnifiedLLMService.generate_text( + prompt=prompt, + system_prompt=system_prompt, + temperature=temperature, + response_format="json" +) +``` + +#### 3. 解说文案生成 + +**旧方式:** +```python +from app.services.generate_narration_script import generate_narration + +narration = generate_narration( + markdown_content, + api_key, + base_url=base_url, + model=model +) +# 手动解析JSON和验证格式 +import json +narration_dict = json.loads(narration)['items'] +``` + +**新方式:** +```python +from app.services.llm.unified_service import UnifiedLLMService + +# 自动验证输出格式 +narration_items = await UnifiedLLMService.generate_narration_script( + prompt=prompt, + validate_output=True # 自动验证JSON格式和字段 +) +``` + +## 📝 具体迁移步骤 + +### 步骤1: 更新配置文件 + +**旧配置格式:** +```toml +[app] + llm_provider = "openai" + openai_api_key = "sk-xxx" + openai_model_name = "gpt-4" + + vision_llm_provider = "gemini" + gemini_api_key = "xxx" + gemini_model_name = "gemini-1.5-pro" +``` + +**新配置格式:** +```toml +[app] + # 视觉模型配置 + vision_llm_provider = "gemini" + vision_gemini_api_key = "xxx" + vision_gemini_model_name = "gemini-2.0-flash-lite" + vision_gemini_base_url = "https://generativelanguage.googleapis.com/v1beta" + + # 文本模型配置 + text_llm_provider = "openai" + text_openai_api_key = "sk-xxx" + text_openai_model_name = "gpt-4o-mini" + text_openai_base_url = "https://api.openai.com/v1" +``` + +### 步骤2: 更新导入语句 + +**旧导入:** +```python +from app.utils import gemini_analyzer, qwenvl_analyzer +from app.services.generate_narration_script import generate_narration +from app.services.SDE.short_drama_explanation import analyze_subtitle +``` + +**新导入:** +```python +from app.services.llm.unified_service import UnifiedLLMService +from app.services.llm.migration_adapter import ( + create_vision_analyzer, + SubtitleAnalyzerAdapter +) +``` + +### 步骤3: 更新函数调用 + +#### 图片分析迁移 + +**旧代码:** +```python +def analyze_images_old(provider, api_key, model, base_url, images, prompt): + if provider == 'gemini': + analyzer = gemini_analyzer.VisionAnalyzer( + model_name=model, + api_key=api_key, + base_url=base_url + ) + else: + analyzer = qwenvl_analyzer.QwenAnalyzer( + model_name=model, + api_key=api_key, + base_url=base_url + ) + + # 同步调用 + results = [] + for batch in batches: + result = analyzer.analyze_batch(batch, prompt) + results.append(result) + return results +``` + +**新代码:** +```python +async def analyze_images_new(images, prompt, provider=None): + # 异步调用,自动批处理 + results = await UnifiedLLMService.analyze_images( + images=images, + prompt=prompt, + provider=provider, + batch_size=10 + ) + return results +``` + +#### 字幕分析迁移 + +**旧代码:** +```python +from app.services.SDE.short_drama_explanation import analyze_subtitle + +result = analyze_subtitle( + subtitle_file_path=subtitle_path, + api_key=api_key, + model=model, + base_url=base_url, + provider=provider +) +``` + +**新代码:** +```python +# 方式1: 使用统一服务 +with open(subtitle_path, 'r', encoding='utf-8') as f: + subtitle_content = f.read() + +result = await UnifiedLLMService.analyze_subtitle( + subtitle_content=subtitle_content, + provider=provider, + validate_output=True +) + +# 方式2: 使用适配器 +from app.services.llm.migration_adapter import SubtitleAnalyzerAdapter + +analyzer = SubtitleAnalyzerAdapter(api_key, model, base_url, provider) +result = analyzer.analyze_subtitle(subtitle_content) +``` + +## 🔧 常见迁移问题 + +### 1. 同步 vs 异步调用 + +**问题:** 新架构使用异步调用,旧代码是同步的。 + +**解决方案:** +```python +# 在同步函数中调用异步函数 +import asyncio + +def sync_function(): + result = asyncio.run(UnifiedLLMService.generate_text(prompt)) + return result + +# 或者将整个函数改为异步 +async def async_function(): + result = await UnifiedLLMService.generate_text(prompt) + return result +``` + +### 2. 配置获取方式变化 + +**问题:** 配置键名发生变化。 + +**解决方案:** +```python +# 旧方式 +api_key = config.app.get('openai_api_key') +model = config.app.get('openai_model_name') + +# 新方式 +provider = config.app.get('text_llm_provider', 'openai') +api_key = config.app.get(f'text_{provider}_api_key') +model = config.app.get(f'text_{provider}_model_name') +``` + +### 3. 错误处理更新 + +**旧方式:** +```python +try: + result = some_llm_call() +except Exception as e: + print(f"Error: {e}") +``` + +**新方式:** +```python +from app.services.llm.exceptions import LLMServiceError, ValidationError + +try: + result = await UnifiedLLMService.generate_text(prompt) +except ValidationError as e: + print(f"输出验证失败: {e.message}") +except LLMServiceError as e: + print(f"LLM服务错误: {e.message}") +except Exception as e: + print(f"未知错误: {e}") +``` + +## ✅ 迁移检查清单 + +### 配置迁移 +- [ ] 更新配置文件格式 +- [ ] 验证所有API密钥配置正确 +- [ ] 运行配置验证器检查 + +### 代码迁移 +- [ ] 更新导入语句 +- [ ] 将同步调用改为异步调用 +- [ ] 更新错误处理机制 +- [ ] 使用新的统一接口 + +### 测试验证 +- [ ] 运行LLM服务测试脚本 +- [ ] 测试所有功能模块 +- [ ] 验证输出格式正确 +- [ ] 检查性能和稳定性 + +### 清理工作 +- [ ] 移除未使用的旧代码 +- [ ] 更新文档和注释 +- [ ] 清理过时的依赖 + +## 🚀 迁移最佳实践 + +### 1. 渐进式迁移 +- 先迁移一个模块,测试通过后再迁移其他模块 +- 保留旧代码作为备用方案 +- 使用迁移适配器确保向后兼容 + +### 2. 充分测试 +- 在每个迁移步骤后运行测试 +- 比较新旧实现的输出结果 +- 测试边界情况和错误处理 + +### 3. 监控和日志 +- 启用详细日志记录 +- 监控API调用成功率 +- 跟踪性能指标 + +### 4. 文档更新 +- 更新代码注释 +- 更新API文档 +- 记录迁移过程中的问题和解决方案 + +## 📞 获取帮助 + +如果在迁移过程中遇到问题: + +1. **查看测试脚本输出**: + ```bash + python app/services/llm/test_llm_service.py + ``` + +2. **验证配置**: + ```python + from app.services.llm.config_validator import LLMConfigValidator + results = LLMConfigValidator.validate_all_configs() + LLMConfigValidator.print_validation_report(results) + ``` + +3. **查看详细日志**: + ```python + from loguru import logger + logger.add("migration.log", level="DEBUG") + ``` + +4. **参考示例代码**: + - 查看 `app/services/llm/test_llm_service.py` 中的使用示例 + - 参考已迁移的文件如 `webui/tools/base.py` + +--- + +*最后更新: 2025-01-07* diff --git a/docs/LLM_SERVICE_GUIDE.md b/docs/LLM_SERVICE_GUIDE.md new file mode 100644 index 0000000..5b5fb10 --- /dev/null +++ b/docs/LLM_SERVICE_GUIDE.md @@ -0,0 +1,294 @@ +# NarratoAI 大模型服务使用指南 + +## 📖 概述 + +NarratoAI 项目已完成大模型服务的全面重构,提供了统一、模块化、可扩展的大模型集成架构。新架构支持多种大模型供应商,具有严格的输出格式验证和完善的错误处理机制。 + +## 🏗️ 架构概览 + +### 核心组件 + +``` +app/services/llm/ +├── __init__.py # 模块入口 +├── base.py # 抽象基类 +├── manager.py # 服务管理器 +├── unified_service.py # 统一服务接口 +├── validators.py # 输出格式验证器 +├── exceptions.py # 异常类定义 +├── migration_adapter.py # 迁移适配器 +├── config_validator.py # 配置验证器 +├── test_llm_service.py # 测试脚本 +└── providers/ # 提供商实现 + ├── __init__.py + ├── gemini_provider.py + ├── gemini_openai_provider.py + ├── openai_provider.py + ├── qwen_provider.py + ├── deepseek_provider.py + └── siliconflow_provider.py +``` + +### 支持的供应商 + +#### 视觉模型供应商 +- **Gemini** (原生API + OpenAI兼容) +- **QwenVL** (通义千问视觉) +- **Siliconflow** (硅基流动) + +#### 文本生成模型供应商 +- **OpenAI** (标准OpenAI API) +- **Gemini** (原生API + OpenAI兼容) +- **DeepSeek** (深度求索) +- **Qwen** (通义千问) +- **Siliconflow** (硅基流动) + +## ⚙️ 配置说明 + +### 配置文件格式 + +在 `config.toml` 中配置大模型服务: + +```toml +[app] + # 视觉模型提供商配置 + vision_llm_provider = "gemini" + + # Gemini 视觉模型 + vision_gemini_api_key = "your_gemini_api_key" + vision_gemini_model_name = "gemini-2.0-flash-lite" + vision_gemini_base_url = "https://generativelanguage.googleapis.com/v1beta" + + # QwenVL 视觉模型 + vision_qwenvl_api_key = "your_qwen_api_key" + vision_qwenvl_model_name = "qwen2.5-vl-32b-instruct" + vision_qwenvl_base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1" + + # 文本模型提供商配置 + text_llm_provider = "openai" + + # OpenAI 文本模型 + text_openai_api_key = "your_openai_api_key" + text_openai_model_name = "gpt-4o-mini" + text_openai_base_url = "https://api.openai.com/v1" + + # DeepSeek 文本模型 + text_deepseek_api_key = "your_deepseek_api_key" + text_deepseek_model_name = "deepseek-chat" + text_deepseek_base_url = "https://api.deepseek.com" +``` + +### 配置验证 + +使用配置验证器检查配置是否正确: + +```python +from app.services.llm.config_validator import LLMConfigValidator + +# 验证所有配置 +results = LLMConfigValidator.validate_all_configs() + +# 打印验证报告 +LLMConfigValidator.print_validation_report(results) + +# 获取配置建议 +suggestions = LLMConfigValidator.get_config_suggestions() +``` + +## 🚀 使用方法 + +### 1. 统一服务接口(推荐) + +```python +from app.services.llm.unified_service import UnifiedLLMService + +# 图片分析 +results = await UnifiedLLMService.analyze_images( + images=["path/to/image1.jpg", "path/to/image2.jpg"], + prompt="请描述这些图片的内容", + provider="gemini", # 可选,不指定则使用配置中的默认值 + batch_size=10 +) + +# 文本生成 +text = await UnifiedLLMService.generate_text( + prompt="请介绍人工智能的发展历史", + system_prompt="你是一个专业的AI专家", + provider="openai", # 可选 + temperature=0.7, + response_format="json" # 可选,支持JSON格式输出 +) + +# 解说文案生成(带验证) +narration_items = await UnifiedLLMService.generate_narration_script( + prompt="根据视频内容生成解说文案...", + validate_output=True # 自动验证输出格式 +) + +# 字幕分析 +analysis = await UnifiedLLMService.analyze_subtitle( + subtitle_content="字幕内容...", + validate_output=True +) +``` + +### 2. 直接使用服务管理器 + +```python +from app.services.llm.manager import LLMServiceManager + +# 获取视觉模型提供商 +vision_provider = LLMServiceManager.get_vision_provider("gemini") +results = await vision_provider.analyze_images(images, prompt) + +# 获取文本模型提供商 +text_provider = LLMServiceManager.get_text_provider("openai") +text = await text_provider.generate_text(prompt) +``` + +### 3. 迁移适配器(向后兼容) + +```python +from app.services.llm.migration_adapter import create_vision_analyzer + +# 兼容旧的接口 +analyzer = create_vision_analyzer("gemini", api_key, model, base_url) +results = await analyzer.analyze_images(images, prompt) +``` + +## 🔍 输出格式验证 + +### 解说文案验证 + +```python +from app.services.llm.validators import OutputValidator + +# 验证解说文案格式 +try: + narration_items = OutputValidator.validate_narration_script(output) + print(f"验证成功,共 {len(narration_items)} 个片段") +except ValidationError as e: + print(f"验证失败: {e.message}") +``` + +### JSON输出验证 + +```python +# 验证JSON格式 +try: + data = OutputValidator.validate_json_output(output) + print("JSON格式验证成功") +except ValidationError as e: + print(f"JSON验证失败: {e.message}") +``` + +## 🧪 测试和调试 + +### 运行测试脚本 + +```bash +# 运行完整的LLM服务测试 +python app/services/llm/test_llm_service.py +``` + +测试脚本会验证: +- 配置有效性 +- 提供商信息获取 +- 文本生成功能 +- JSON格式生成 +- 字幕分析功能 +- 解说文案生成功能 + +### 调试技巧 + +1. **启用详细日志**: +```python +from loguru import logger +logger.add("llm_service.log", level="DEBUG") +``` + +2. **清空提供商缓存**: +```python +UnifiedLLMService.clear_cache() +``` + +3. **检查提供商信息**: +```python +info = UnifiedLLMService.get_provider_info() +print(info) +``` + +## ⚠️ 注意事项 + +### 1. API密钥安全 +- 不要在代码中硬编码API密钥 +- 使用环境变量或配置文件管理密钥 +- 定期轮换API密钥 + +### 2. 错误处理 +- 所有LLM服务调用都应该包装在try-catch中 +- 使用适当的异常类型进行错误处理 +- 实现重试机制处理临时性错误 + +### 3. 性能优化 +- 合理设置批处理大小 +- 使用缓存避免重复调用 +- 监控API调用频率和成本 + +### 4. 模型选择 +- 根据任务类型选择合适的模型 +- 考虑成本和性能的平衡 +- 定期更新到最新的模型版本 + +## 🔧 扩展新供应商 + +### 1. 创建提供商类 + +```python +# app/services/llm/providers/new_provider.py +from ..base import TextModelProvider + +class NewTextProvider(TextModelProvider): + @property + def provider_name(self) -> str: + return "new_provider" + + @property + def supported_models(self) -> List[str]: + return ["model-1", "model-2"] + + async def generate_text(self, prompt: str, **kwargs) -> str: + # 实现具体的API调用逻辑 + pass +``` + +### 2. 注册提供商 + +```python +# app/services/llm/providers/__init__.py +from .new_provider import NewTextProvider + +LLMServiceManager.register_text_provider('new_provider', NewTextProvider) +``` + +### 3. 添加配置支持 + +```toml +# config.toml +text_new_provider_api_key = "your_api_key" +text_new_provider_model_name = "model-1" +text_new_provider_base_url = "https://api.newprovider.com/v1" +``` + +## 📞 技术支持 + +如果在使用过程中遇到问题: + +1. 首先运行测试脚本检查配置 +2. 查看日志文件了解详细错误信息 +3. 检查API密钥和网络连接 +4. 参考本文档的故障排除部分 + +--- + +*最后更新: 2025-01-07* diff --git a/docs/prompt_management_system.md b/docs/prompt_management_system.md new file mode 100644 index 0000000..42662f8 --- /dev/null +++ b/docs/prompt_management_system.md @@ -0,0 +1,267 @@ +# 提示词管理系统文档 + +## 概述 + +本项目实现了统一的提示词管理系统,用于集中管理三个核心功能的提示词: +- **纪录片解说** - 视频帧分析和解说文案生成 +- **短剧混剪** - 字幕分析和爆点提取 +- **短剧解说** - 剧情分析和解说脚本生成 + +## 系统架构 + +``` +app/services/prompts/ +├── __init__.py # 模块初始化 +├── base.py # 基础提示词类 +├── manager.py # 提示词管理器 +├── registry.py # 提示词注册机制 +├── template.py # 模板渲染引擎 +├── validators.py # 输出验证器 +├── exceptions.py # 异常定义 +├── documentary/ # 纪录片解说提示词 +│ ├── __init__.py +│ ├── frame_analysis.py # 视频帧分析 +│ └── narration_generation.py # 解说文案生成 +├── short_drama_editing/ # 短剧混剪提示词 +│ ├── __init__.py +│ ├── subtitle_analysis.py # 字幕分析 +│ └── plot_extraction.py # 爆点提取 +└── short_drama_narration/ # 短剧解说提示词 + ├── __init__.py + ├── plot_analysis.py # 剧情分析 + └── script_generation.py # 解说脚本生成 +``` + +## 核心特性 + +### 1. 统一管理 +- 所有提示词集中在 `app/services/prompts/` 模块中 +- 按功能模块分类组织 +- 支持版本控制和回滚 + +### 2. 模型类型适配 +- **TextPrompt**: 文本模型专用 +- **VisionPrompt**: 视觉模型专用 +- **ParameterizedPrompt**: 支持参数化 + +### 3. 参数化支持 +- 动态参数替换 +- 参数验证 +- 模板渲染 + +### 4. 输出验证 +- 严格的JSON格式验证 +- 特定业务场景验证(解说文案、剧情分析等) +- 自定义验证规则 + +## 使用方法 + +### 基本用法 + +```python +from app.services.prompts import PromptManager + +# 获取纪录片解说的视频帧分析提示词 +prompt = PromptManager.get_prompt( + category="documentary", + name="frame_analysis", + parameters={ + "video_theme": "荒野建造", + "custom_instructions": "请特别关注建造过程的细节" + } +) + +# 获取短剧解说的剧情分析提示词 +prompt = PromptManager.get_prompt( + category="short_drama_narration", + name="plot_analysis", + parameters={"subtitle_content": "字幕内容..."} +) +``` + +### 高级功能 + +```python +# 搜索提示词 +results = PromptManager.search_prompts( + keyword="分析", + model_type=ModelType.TEXT +) + +# 获取提示词详细信息 +info = PromptManager.get_prompt_info( + category="documentary", + name="narration_generation" +) + +# 验证输出 +validated_data = PromptManager.validate_output( + output=llm_response, + category="documentary", + name="narration_generation" +) +``` + +## 已注册的提示词 + +### 纪录片解说 (documentary) +- `frame_analysis` - 视频帧分析提示词 +- `narration_generation` - 解说文案生成提示词 + +### 短剧混剪 (short_drama_editing) +- `subtitle_analysis` - 字幕分析提示词 +- `plot_extraction` - 爆点提取提示词 + +### 短剧解说 (short_drama_narration) +- `plot_analysis` - 剧情分析提示词 +- `script_generation` - 解说脚本生成提示词 + +## 迁移指南 + +### 旧代码迁移 + +**之前的用法:** +```python +from app.services.SDE.prompt import subtitle_plot_analysis_v1 +prompt = subtitle_plot_analysis_v1 +``` + +**新的用法:** +```python +from app.services.prompts import PromptManager +prompt = PromptManager.get_prompt( + category="short_drama_narration", + name="plot_analysis", + parameters={"subtitle_content": content} +) +``` + +### 已更新的文件 +- `app/services/SDE/short_drama_explanation.py` +- `app/services/SDP/utils/step1_subtitle_analyzer_openai.py` +- `app/services/generate_narration_script.py` + +## 扩展指南 + +### 添加新提示词 + +1. 在相应分类目录下创建新的提示词类: + +```python +from ..base import TextPrompt, PromptMetadata, ModelType, OutputFormat + +class NewPrompt(TextPrompt): + def __init__(self): + metadata = PromptMetadata( + name="new_prompt", + category="your_category", + version="v1.0", + description="提示词描述", + model_type=ModelType.TEXT, + output_format=OutputFormat.JSON, + parameters=["param1", "param2"] + ) + super().__init__(metadata) + + def get_template(self) -> str: + return "您的提示词模板内容..." +``` + +2. 在 `__init__.py` 中注册: + +```python +def register_prompts(): + new_prompt = NewPrompt() + PromptManager.register_prompt(new_prompt, is_default=True) +``` + +### 添加新分类 + +1. 创建新的分类目录 +2. 实现提示词类 +3. 在主模块的 `__init__.py` 中导入并注册 + +## 测试 + +运行测试脚本验证系统功能: + +```bash +python test_prompt_system.py +``` + +## 注意事项 + +1. **模板参数**: 使用 `${parameter_name}` 格式 +2. **JSON格式**: 模板中的JSON示例使用标准格式 `{` 和 `}`,不要使用双大括号 +3. **参数验证**: 必需参数会自动验证 +4. **版本管理**: 支持多版本共存,默认使用最新版本 +5. **输出验证**: 建议对LLM输出进行验证以确保格式正确 +6. **JSON解析**: 系统提供强大的JSON解析兼容性,自动处理各种格式问题 + +## JSON解析优化 + +系统提供了强大的JSON解析兼容性,能够处理LLM生成的各种格式问题: + +### 支持的格式修复 + +1. **双大括号修复**: 自动将 `{{` 和 `}}` 转换为标准的 `{` 和 `}` +2. **代码块提取**: 自动从 ````json` 代码块中提取JSON内容 +3. **额外文本处理**: 自动提取大括号包围的JSON内容,忽略前后的额外文本 +4. **尾随逗号修复**: 自动移除对象和数组末尾的多余逗号 +5. **注释移除**: 自动移除 `//` 和 `#` 注释 +6. **引号修复**: 自动修复单引号和缺失的属性名引号 + +### 解析策略 + +系统采用多重解析策略,按优先级依次尝试: + +```python +strategies = [ + ("直接解析", lambda s: json.loads(s)), + ("修复双大括号", _fix_double_braces), + ("提取代码块", _extract_code_block), + ("提取大括号内容", _extract_braces_content), + ("修复常见格式问题", _fix_common_json_issues), + ("修复引号问题", _fix_quote_issues), + ("修复尾随逗号", _fix_trailing_commas), + ("强制修复", _force_fix_json), +] +``` + +### 使用示例 + +```python +from webui.tools.generate_short_summary import parse_and_fix_json + +# 处理双大括号JSON +json_str = '{{ "items": [{{ "_id": 1, "name": "test" }}] }}' +result = parse_and_fix_json(json_str) # 自动修复并解析 + +# 处理有额外文本的JSON +json_str = '这是一些文本\n{"items": []}\n更多文本' +result = parse_and_fix_json(json_str) # 自动提取JSON部分 +``` + +## 性能优化 + +- 提示词模板会被缓存 +- 支持批量操作 +- 异步渲染支持(未来版本) +- JSON解析采用多策略优化,确保高成功率 + +## 故障排除 + +### 常见问题 + +1. **模板渲染错误**: 检查参数名称和格式 +2. **提示词未找到**: 确认分类、名称和版本正确 +3. **输出验证失败**: 检查LLM输出格式是否符合要求 + +### 日志调试 + +系统使用 loguru 记录详细日志,可通过日志排查问题: + +```python +from loguru import logger +logger.debug("调试信息") +``` diff --git a/project_version b/project_version index bf21f52..8b707c6 100644 --- a/project_version +++ b/project_version @@ -1 +1 @@ -0.6.6 \ No newline at end of file +0.6.7 \ No newline at end of file diff --git a/webui/components/basic_settings.py b/webui/components/basic_settings.py index a5f3c62..a887246 100644 --- a/webui/components/basic_settings.py +++ b/webui/components/basic_settings.py @@ -7,6 +7,45 @@ from app.utils import utils from loguru import logger +def validate_api_key(api_key: str, provider: str) -> tuple[bool, str]: + """验证API密钥格式""" + if not api_key or not api_key.strip(): + return False, f"{provider} API密钥不能为空" + + # 基本长度检查 + if len(api_key.strip()) < 10: + return False, f"{provider} API密钥长度过短,请检查是否正确" + + return True, "" + + +def validate_base_url(base_url: str, provider: str) -> tuple[bool, str]: + """验证Base URL格式""" + if not base_url or not base_url.strip(): + return True, "" # base_url可以为空 + + base_url = base_url.strip() + if not (base_url.startswith('http://') or base_url.startswith('https://')): + return False, f"{provider} Base URL必须以http://或https://开头" + + return True, "" + + +def validate_model_name(model_name: str, provider: str) -> tuple[bool, str]: + """验证模型名称""" + if not model_name or not model_name.strip(): + return False, f"{provider} 模型名称不能为空" + + return True, "" + + +def show_config_validation_errors(errors: list): + """显示配置验证错误""" + if errors: + for error in errors: + st.error(error) + + def render_basic_settings(tr): """渲染基础设置面板""" with st.expander(tr("Basic Settings"), expanded=False): @@ -87,29 +126,96 @@ def render_proxy_settings(tr): def test_vision_model_connection(api_key, base_url, model_name, provider, tr): """测试视觉模型连接 - + Args: api_key: API密钥 base_url: 基础URL model_name: 模型名称 provider: 提供商名称 - + Returns: bool: 连接是否成功 str: 测试结果消息 """ + import requests if provider.lower() == 'gemini': - import google.generativeai as genai - + # 原生Gemini API测试 try: - genai.configure(api_key=api_key) - model = genai.GenerativeModel(model_name) - model.generate_content("直接回复我文本'当前网络可用'") - return True, tr("gemini model is available") + # 构建请求数据 + request_data = { + "contents": [{ + "parts": [{"text": "直接回复我文本'当前网络可用'"}] + }], + "generationConfig": { + "temperature": 1.0, + "topK": 40, + "topP": 0.95, + "maxOutputTokens": 100, + }, + "safetySettings": [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_NONE" + } + ] + } + + # 构建请求URL + api_base_url = base_url or "https://generativelanguage.googleapis.com/v1beta" + url = f"{api_base_url}/models/{model_name}:generateContent?key={api_key}" + + # 发送请求 + response = requests.post( + url, + json=request_data, + headers={"Content-Type": "application/json"}, + timeout=30 + ) + + if response.status_code == 200: + return True, tr("原生Gemini模型连接成功") + else: + return False, f"{tr('原生Gemini模型连接失败')}: HTTP {response.status_code}" except Exception as e: - return False, f"{tr('gemini model is not available')}: {str(e)}" + return False, f"{tr('原生Gemini模型连接失败')}: {str(e)}" + + elif provider.lower() == 'gemini(openai)': + # OpenAI兼容的Gemini代理测试 + try: + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + + test_url = f"{base_url.rstrip('/')}/chat/completions" + test_data = { + "model": model_name, + "messages": [ + {"role": "user", "content": "直接回复我文本'当前网络可用'"} + ], + "stream": False + } + + response = requests.post(test_url, headers=headers, json=test_data, timeout=10) + if response.status_code == 200: + return True, tr("OpenAI兼容Gemini代理连接成功") + else: + return False, f"{tr('OpenAI兼容Gemini代理连接失败')}: HTTP {response.status_code}" + except Exception as e: + return False, f"{tr('OpenAI兼容Gemini代理连接失败')}: {str(e)}" elif provider.lower() == 'narratoapi': - import requests try: # 构建测试请求 headers = { @@ -172,7 +278,7 @@ def render_vision_llm_settings(tr): st.subheader(tr("Vision Model Settings")) # 视频分析模型提供商选择 - vision_providers = ['Siliconflow', 'Gemini', 'QwenVL', 'OpenAI'] + vision_providers = ['Siliconflow', 'Gemini', 'Gemini(OpenAI)', 'QwenVL', 'OpenAI'] saved_vision_provider = config.app.get("vision_llm_provider", "Gemini").lower() saved_provider_index = 0 @@ -191,9 +297,15 @@ def render_vision_llm_settings(tr): st.session_state['vision_llm_providers'] = vision_provider # 获取已保存的视觉模型配置 - vision_api_key = config.app.get(f"vision_{vision_provider}_api_key", "") - vision_base_url = config.app.get(f"vision_{vision_provider}_base_url", "") - vision_model_name = config.app.get(f"vision_{vision_provider}_model_name", "") + # 处理特殊的提供商名称映射 + if vision_provider == 'gemini(openai)': + vision_config_key = 'vision_gemini_openai' + else: + vision_config_key = f'vision_{vision_provider}' + + vision_api_key = config.app.get(f"{vision_config_key}_api_key", "") + vision_base_url = config.app.get(f"{vision_config_key}_base_url", "") + vision_model_name = config.app.get(f"{vision_config_key}_model_name", "") # 渲染视觉模型配置输入框 st_vision_api_key = st.text_input(tr("Vision API Key"), value=vision_api_key, type="password") @@ -201,15 +313,25 @@ def render_vision_llm_settings(tr): # 根据不同提供商设置默认值和帮助信息 if vision_provider == 'gemini': st_vision_base_url = st.text_input( - tr("Vision Base URL"), - value=vision_base_url, - disabled=True, - help=tr("Gemini API does not require a base URL") + tr("Vision Base URL"), + value=vision_base_url or "https://generativelanguage.googleapis.com/v1beta", + help=tr("原生Gemini API端点,默认: https://generativelanguage.googleapis.com/v1beta") ) st_vision_model_name = st.text_input( - tr("Vision Model Name"), - value=vision_model_name or "gemini-2.0-flash-lite", - help=tr("Default: gemini-2.0-flash-lite") + tr("Vision Model Name"), + value=vision_model_name or "gemini-2.0-flash-exp", + help=tr("原生Gemini模型,默认: gemini-2.0-flash-exp") + ) + elif vision_provider == 'gemini(openai)': + st_vision_base_url = st.text_input( + tr("Vision Base URL"), + value=vision_base_url or "https://generativelanguage.googleapis.com/v1beta/openai", + help=tr("OpenAI兼容的Gemini代理端点,如: https://your-proxy.com/v1") + ) + st_vision_model_name = st.text_input( + tr("Vision Model Name"), + value=vision_model_name or "gemini-2.0-flash-exp", + help=tr("OpenAI格式的Gemini模型名称,默认: gemini-2.0-flash-exp") ) elif vision_provider == 'qwenvl': st_vision_base_url = st.text_input( @@ -228,30 +350,81 @@ def render_vision_llm_settings(tr): # 在配置输入框后添加测试按钮 if st.button(tr("Test Connection"), key="test_vision_connection"): - with st.spinner(tr("Testing connection...")): - success, message = test_vision_model_connection( - api_key=st_vision_api_key, - base_url=st_vision_base_url, - model_name=st_vision_model_name, - provider=vision_provider, - tr=tr - ) - - if success: - st.success(tr(message)) - else: - st.error(tr(message)) + # 先验证配置 + test_errors = [] + if not st_vision_api_key: + test_errors.append("请先输入API密钥") + if not st_vision_model_name: + test_errors.append("请先输入模型名称") - # 保存视觉模型配置 + if test_errors: + for error in test_errors: + st.error(error) + else: + with st.spinner(tr("Testing connection...")): + try: + success, message = test_vision_model_connection( + api_key=st_vision_api_key, + base_url=st_vision_base_url, + model_name=st_vision_model_name, + provider=vision_provider, + tr=tr + ) + + if success: + st.success(message) + else: + st.error(message) + except Exception as e: + st.error(f"测试连接时发生错误: {str(e)}") + logger.error(f"视频分析模型连接测试失败: {str(e)}") + + # 验证和保存视觉模型配置 + validation_errors = [] + config_changed = False + + # 验证API密钥 if st_vision_api_key: - config.app[f"vision_{vision_provider}_api_key"] = st_vision_api_key - st.session_state[f"vision_{vision_provider}_api_key"] = st_vision_api_key + is_valid, error_msg = validate_api_key(st_vision_api_key, f"视频分析({vision_provider})") + if is_valid: + config.app[f"{vision_config_key}_api_key"] = st_vision_api_key + st.session_state[f"{vision_config_key}_api_key"] = st_vision_api_key + config_changed = True + else: + validation_errors.append(error_msg) + + # 验证Base URL if st_vision_base_url: - config.app[f"vision_{vision_provider}_base_url"] = st_vision_base_url - st.session_state[f"vision_{vision_provider}_base_url"] = st_vision_base_url + is_valid, error_msg = validate_base_url(st_vision_base_url, f"视频分析({vision_provider})") + if is_valid: + config.app[f"{vision_config_key}_base_url"] = st_vision_base_url + st.session_state[f"{vision_config_key}_base_url"] = st_vision_base_url + config_changed = True + else: + validation_errors.append(error_msg) + + # 验证模型名称 if st_vision_model_name: - config.app[f"vision_{vision_provider}_model_name"] = st_vision_model_name - st.session_state[f"vision_{vision_provider}_model_name"] = st_vision_model_name + is_valid, error_msg = validate_model_name(st_vision_model_name, f"视频分析({vision_provider})") + if is_valid: + config.app[f"{vision_config_key}_model_name"] = st_vision_model_name + st.session_state[f"{vision_config_key}_model_name"] = st_vision_model_name + config_changed = True + else: + validation_errors.append(error_msg) + + # 显示验证错误 + show_config_validation_errors(validation_errors) + + # 如果配置有变化且没有验证错误,保存到文件 + if config_changed and not validation_errors: + try: + config.save_config() + if st_vision_api_key or st_vision_base_url or st_vision_model_name: + st.success(f"视频分析模型({vision_provider})配置已保存") + except Exception as e: + st.error(f"保存配置失败: {str(e)}") + logger.error(f"保存视频分析配置失败: {str(e)}") def test_text_model_connection(api_key, base_url, model_name, provider, tr): @@ -278,14 +451,74 @@ def test_text_model_connection(api_key, base_url, model_name, provider, tr): # 特殊处理Gemini if provider.lower() == 'gemini': - import google.generativeai as genai + # 原生Gemini API测试 try: - genai.configure(api_key=api_key) - model = genai.GenerativeModel(model_name) - model.generate_content("直接回复我文本'当前网络可用'") - return True, tr("Gemini model is available") + # 构建请求数据 + request_data = { + "contents": [{ + "parts": [{"text": "直接回复我文本'当前网络可用'"}] + }], + "generationConfig": { + "temperature": 1.0, + "topK": 40, + "topP": 0.95, + "maxOutputTokens": 100, + }, + "safetySettings": [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_NONE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_NONE" + } + ] + } + + # 构建请求URL + api_base_url = base_url or "https://generativelanguage.googleapis.com/v1beta" + url = f"{api_base_url}/models/{model_name}:generateContent?key={api_key}" + + # 发送请求 + response = requests.post( + url, + json=request_data, + headers={"Content-Type": "application/json"}, + timeout=30 + ) + + if response.status_code == 200: + return True, tr("原生Gemini模型连接成功") + else: + return False, f"{tr('原生Gemini模型连接失败')}: HTTP {response.status_code}" except Exception as e: - return False, f"{tr('Gemini model is not available')}: {str(e)}" + return False, f"{tr('原生Gemini模型连接失败')}: {str(e)}" + + elif provider.lower() == 'gemini(openai)': + # OpenAI兼容的Gemini代理测试 + test_url = f"{base_url.rstrip('/')}/chat/completions" + test_data = { + "model": model_name, + "messages": [ + {"role": "user", "content": "直接回复我文本'当前网络可用'"} + ], + "stream": False + } + + response = requests.post(test_url, headers=headers, json=test_data, timeout=10) + if response.status_code == 200: + return True, tr("OpenAI兼容Gemini代理连接成功") + else: + return False, f"{tr('OpenAI兼容Gemini代理连接失败')}: HTTP {response.status_code}" else: test_url = f"{base_url.rstrip('/')}/chat/completions" @@ -322,7 +555,7 @@ def render_text_llm_settings(tr): st.subheader(tr("Text Generation Model Settings")) # 文案生成模型提供商选择 - text_providers = ['OpenAI', 'Siliconflow', 'DeepSeek', 'Gemini', 'Qwen', 'Moonshot'] + text_providers = ['OpenAI', 'Siliconflow', 'DeepSeek', 'Gemini', 'Gemini(OpenAI)', 'Qwen', 'Moonshot'] saved_text_provider = config.app.get("text_llm_provider", "OpenAI").lower() saved_provider_index = 0 @@ -346,32 +579,108 @@ def render_text_llm_settings(tr): # 渲染文本模型配置输入框 st_text_api_key = st.text_input(tr("Text API Key"), value=text_api_key, type="password") - st_text_base_url = st.text_input(tr("Text Base URL"), value=text_base_url) - st_text_model_name = st.text_input(tr("Text Model Name"), value=text_model_name) + + # 根据不同提供商设置默认值和帮助信息 + if text_provider == 'gemini': + st_text_base_url = st.text_input( + tr("Text Base URL"), + value=text_base_url or "https://generativelanguage.googleapis.com/v1beta", + help=tr("原生Gemini API端点,默认: https://generativelanguage.googleapis.com/v1beta") + ) + st_text_model_name = st.text_input( + tr("Text Model Name"), + value=text_model_name or "gemini-2.0-flash-exp", + help=tr("原生Gemini模型,默认: gemini-2.0-flash-exp") + ) + elif text_provider == 'gemini(openai)': + st_text_base_url = st.text_input( + tr("Text Base URL"), + value=text_base_url or "https://generativelanguage.googleapis.com/v1beta/openai", + help=tr("OpenAI兼容的Gemini代理端点,如: https://your-proxy.com/v1") + ) + st_text_model_name = st.text_input( + tr("Text Model Name"), + value=text_model_name or "gemini-2.0-flash-exp", + help=tr("OpenAI格式的Gemini模型名称,默认: gemini-2.0-flash-exp") + ) + else: + st_text_base_url = st.text_input(tr("Text Base URL"), value=text_base_url) + st_text_model_name = st.text_input(tr("Text Model Name"), value=text_model_name) # 添加测试按钮 if st.button(tr("Test Connection"), key="test_text_connection"): - with st.spinner(tr("Testing connection...")): - success, message = test_text_model_connection( - api_key=st_text_api_key, - base_url=st_text_base_url, - model_name=st_text_model_name, - provider=text_provider, - tr=tr - ) - - if success: - st.success(message) - else: - st.error(message) + # 先验证配置 + test_errors = [] + if not st_text_api_key: + test_errors.append("请先输入API密钥") + if not st_text_model_name: + test_errors.append("请先输入模型名称") - # 保存文本模型配置 + if test_errors: + for error in test_errors: + st.error(error) + else: + with st.spinner(tr("Testing connection...")): + try: + success, message = test_text_model_connection( + api_key=st_text_api_key, + base_url=st_text_base_url, + model_name=st_text_model_name, + provider=text_provider, + tr=tr + ) + + if success: + st.success(message) + else: + st.error(message) + except Exception as e: + st.error(f"测试连接时发生错误: {str(e)}") + logger.error(f"文案生成模型连接测试失败: {str(e)}") + + # 验证和保存文本模型配置 + text_validation_errors = [] + text_config_changed = False + + # 验证API密钥 if st_text_api_key: - config.app[f"text_{text_provider}_api_key"] = st_text_api_key + is_valid, error_msg = validate_api_key(st_text_api_key, f"文案生成({text_provider})") + if is_valid: + config.app[f"text_{text_provider}_api_key"] = st_text_api_key + text_config_changed = True + else: + text_validation_errors.append(error_msg) + + # 验证Base URL if st_text_base_url: - config.app[f"text_{text_provider}_base_url"] = st_text_base_url + is_valid, error_msg = validate_base_url(st_text_base_url, f"文案生成({text_provider})") + if is_valid: + config.app[f"text_{text_provider}_base_url"] = st_text_base_url + text_config_changed = True + else: + text_validation_errors.append(error_msg) + + # 验证模型名称 if st_text_model_name: - config.app[f"text_{text_provider}_model_name"] = st_text_model_name + is_valid, error_msg = validate_model_name(st_text_model_name, f"文案生成({text_provider})") + if is_valid: + config.app[f"text_{text_provider}_model_name"] = st_text_model_name + text_config_changed = True + else: + text_validation_errors.append(error_msg) + + # 显示验证错误 + show_config_validation_errors(text_validation_errors) + + # 如果配置有变化且没有验证错误,保存到文件 + if text_config_changed and not text_validation_errors: + try: + config.save_config() + if st_text_api_key or st_text_base_url or st_text_model_name: + st.success(f"文案生成模型({text_provider})配置已保存") + except Exception as e: + st.error(f"保存配置失败: {str(e)}") + logger.error(f"保存文案生成配置失败: {str(e)}") # # Cloudflare 特殊配置 # if text_provider == 'cloudflare': diff --git a/webui/tools/base.py b/webui/tools/base.py index b8aff6a..754d971 100644 --- a/webui/tools/base.py +++ b/webui/tools/base.py @@ -6,31 +6,45 @@ from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry from app.config import config +# 导入新的LLM服务模块 - 确保提供商被注册 +import app.services.llm # 这会触发提供商注册 +from app.services.llm.migration_adapter import create_vision_analyzer as create_vision_analyzer_new +# 保留旧的导入以确保向后兼容 from app.utils import gemini_analyzer, qwenvl_analyzer def create_vision_analyzer(provider, api_key, model, base_url): """ - 创建视觉分析器实例 - + 创建视觉分析器实例 - 已重构为使用新的LLM服务架构 + Args: - provider: 提供商名称 ('gemini' 或 'qwenvl') + provider: 提供商名称 ('gemini', 'gemini(openai)', 'qwenvl', 'siliconflow') api_key: API密钥 model: 模型名称 base_url: API基础URL - + Returns: - VisionAnalyzer 或 QwenAnalyzer 实例 + 视觉分析器实例 """ - if provider == 'gemini': - return gemini_analyzer.VisionAnalyzer(model_name=model, api_key=api_key) - else: - # 只传入必要的参数 - return qwenvl_analyzer.QwenAnalyzer( - model_name=model, - api_key=api_key, - base_url=base_url - ) + try: + # 优先使用新的LLM服务架构 + return create_vision_analyzer_new(provider, api_key, model, base_url) + except Exception as e: + logger.warning(f"使用新LLM服务失败,回退到旧实现: {str(e)}") + + # 回退到旧的实现以确保兼容性 + if provider == 'gemini': + return gemini_analyzer.VisionAnalyzer(model_name=model, api_key=api_key, base_url=base_url) + elif provider == 'gemini(openai)': + from app.utils.gemini_openai_analyzer import GeminiOpenAIAnalyzer + return GeminiOpenAIAnalyzer(model_name=model, api_key=api_key, base_url=base_url) + else: + # 只传入必要的参数 + return qwenvl_analyzer.QwenAnalyzer( + model_name=model, + api_key=api_key, + base_url=base_url + ) def get_batch_timestamps(batch_files, prev_batch_files=None): diff --git a/webui/tools/generate_script_docu.py b/webui/tools/generate_script_docu.py index 189d897..53ce0f3 100644 --- a/webui/tools/generate_script_docu.py +++ b/webui/tools/generate_script_docu.py @@ -368,7 +368,16 @@ def generate_script_docu(params): base_url=text_base_url, model=text_model ) - narration_dict = json.loads(narration)['items'] + + # 使用增强的JSON解析器 + from webui.tools.generate_short_summary import parse_and_fix_json + narration_data = parse_and_fix_json(narration) + + if not narration_data or 'items' not in narration_data: + logger.error(f"解说文案JSON解析失败,原始内容: {narration[:200]}...") + raise Exception("解说文案格式错误,无法解析JSON或缺少items字段") + + narration_dict = narration_data['items'] # 为 narration_dict 中每个 item 新增一个 OST: 2 的字段, 代表保留原声和配音 narration_dict = [{**item, "OST": 2} for item in narration_dict] logger.debug(f"解说文案创作完成:\n{"\n".join([item['narration'] for item in narration_dict])}") diff --git a/webui/tools/generate_script_short.py b/webui/tools/generate_script_short.py index c4508d9..5c4ce9d 100644 --- a/webui/tools/generate_script_short.py +++ b/webui/tools/generate_script_short.py @@ -69,6 +69,7 @@ def generate_script_short(tr, params, custom_clips=5): model_name=text_model, base_url=text_base_url, custom_clips=custom_clips, + provider=text_provider ) if script is None: diff --git a/webui/tools/generate_short_summary.py b/webui/tools/generate_short_summary.py index eb2a6f4..dc972af 100644 --- a/webui/tools/generate_short_summary.py +++ b/webui/tools/generate_short_summary.py @@ -16,6 +16,122 @@ from loguru import logger from app.config import config from app.services.SDE.short_drama_explanation import analyze_subtitle, generate_narration_script +# 导入新的LLM服务模块 - 确保提供商被注册 +import app.services.llm # 这会触发提供商注册 +from app.services.llm.migration_adapter import SubtitleAnalyzerAdapter +import re + + +def parse_and_fix_json(json_string): + """ + 解析并修复JSON字符串 + + Args: + json_string: 待解析的JSON字符串 + + Returns: + dict: 解析后的字典,如果解析失败返回None + """ + if not json_string or not json_string.strip(): + logger.error("JSON字符串为空") + return None + + # 清理字符串 + json_string = json_string.strip() + + # 尝试直接解析 + try: + return json.loads(json_string) + except json.JSONDecodeError as e: + logger.warning(f"直接JSON解析失败: {e}") + + # 尝试修复双大括号问题(LLM生成的常见问题) + try: + # 将双大括号替换为单大括号 + fixed_braces = json_string.replace('{{', '{').replace('}}', '}') + logger.info("修复双大括号格式") + return json.loads(fixed_braces) + except json.JSONDecodeError: + pass + + # 尝试提取JSON部分 + try: + # 查找JSON代码块 + json_match = re.search(r'```json\s*(.*?)\s*```', json_string, re.DOTALL) + if json_match: + json_content = json_match.group(1).strip() + logger.info("从代码块中提取JSON内容") + return json.loads(json_content) + except json.JSONDecodeError: + pass + + # 尝试查找大括号包围的内容 + try: + # 查找第一个 { 到最后一个 } 的内容 + start_idx = json_string.find('{') + end_idx = json_string.rfind('}') + if start_idx != -1 and end_idx != -1 and end_idx > start_idx: + json_content = json_string[start_idx:end_idx+1] + logger.info("提取大括号包围的JSON内容") + return json.loads(json_content) + except json.JSONDecodeError: + pass + + # 尝试综合修复JSON格式问题 + try: + fixed_json = json_string + + # 1. 修复双大括号问题 + fixed_json = fixed_json.replace('{{', '{').replace('}}', '}') + + # 2. 提取JSON内容(如果有其他文本包围) + start_idx = fixed_json.find('{') + end_idx = fixed_json.rfind('}') + if start_idx != -1 and end_idx != -1 and end_idx > start_idx: + fixed_json = fixed_json[start_idx:end_idx+1] + + # 3. 移除注释 + fixed_json = re.sub(r'#.*', '', fixed_json) + fixed_json = re.sub(r'//.*', '', fixed_json) + + # 4. 移除多余的逗号 + fixed_json = re.sub(r',\s*}', '}', fixed_json) + fixed_json = re.sub(r',\s*]', ']', fixed_json) + + # 5. 修复单引号 + fixed_json = re.sub(r"'([^']*)':", r'"\1":', fixed_json) + + # 6. 修复没有引号的属性名 + fixed_json = re.sub(r'(\w+)(\s*):', r'"\1"\2:', fixed_json) + + # 7. 修复重复的引号 + fixed_json = re.sub(r'""([^"]*?)""', r'"\1"', fixed_json) + + logger.info("尝试综合修复JSON格式问题后解析") + return json.loads(fixed_json) + except json.JSONDecodeError as e: + logger.debug(f"综合修复失败: {e}") + pass + + # 如果所有方法都失败,尝试创建一个基本的结构 + logger.error(f"所有JSON解析方法都失败,原始内容: {json_string[:200]}...") + + # 尝试从文本中提取关键信息创建基本结构 + try: + # 这是一个简单的回退方案 + return { + "items": [ + { + "_id": 1, + "timestamp": "00:00:00,000-00:00:10,000", + "picture": "解析失败,使用默认内容", + "narration": json_string[:100] + "..." if len(json_string) > 100 else json_string, + "OST": 0 + } + ] + } + except Exception: + return None def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperature): @@ -49,20 +165,36 @@ def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperatu return """ - 2. 分析字幕总结剧情 + 2. 分析字幕总结剧情 - 使用新的LLM服务架构 """ text_provider = config.app.get('text_llm_provider', 'gemini').lower() text_api_key = config.app.get(f'text_{text_provider}_api_key') text_model = config.app.get(f'text_{text_provider}_model_name') text_base_url = config.app.get(f'text_{text_provider}_base_url') - analysis_result = analyze_subtitle( - subtitle_file_path=subtitle_path, - api_key=text_api_key, - model=text_model, - base_url=text_base_url, - save_result=True, - temperature=temperature - ) + + try: + # 优先使用新的LLM服务架构 + logger.info("使用新的LLM服务架构进行字幕分析") + analyzer = SubtitleAnalyzerAdapter(text_api_key, text_model, text_base_url, text_provider) + + # 读取字幕文件 + with open(subtitle_path, 'r', encoding='utf-8') as f: + subtitle_content = f.read() + + analysis_result = analyzer.analyze_subtitle(subtitle_content) + + except Exception as e: + logger.warning(f"使用新LLM服务失败,回退到旧实现: {str(e)}") + # 回退到旧的实现 + analysis_result = analyze_subtitle( + subtitle_file_path=subtitle_path, + api_key=text_api_key, + model=text_model, + base_url=text_base_url, + save_result=True, + temperature=temperature, + provider=text_provider + ) """ 3. 根据剧情生成解说文案 """ @@ -70,16 +202,28 @@ def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperatu logger.info("字幕分析成功!") update_progress(60, "正在生成文案...") - # 根据剧情生成解说文案 - narration_result = generate_narration_script( - short_name=video_theme, - plot_analysis=analysis_result["analysis"], - api_key=text_api_key, - model=text_model, - base_url=text_base_url, - save_result=True, - temperature=temperature - ) + # 根据剧情生成解说文案 - 使用新的LLM服务架构 + try: + # 优先使用新的LLM服务架构 + logger.info("使用新的LLM服务架构生成解说文案") + narration_result = analyzer.generate_narration_script( + short_name=video_theme, + plot_analysis=analysis_result["analysis"], + temperature=temperature + ) + except Exception as e: + logger.warning(f"使用新LLM服务失败,回退到旧实现: {str(e)}") + # 回退到旧的实现 + narration_result = generate_narration_script( + short_name=video_theme, + plot_analysis=analysis_result["analysis"], + api_key=text_api_key, + model=text_model, + base_url=text_base_url, + save_result=True, + temperature=temperature, + provider=text_provider + ) if narration_result["status"] == "success": logger.info("\n解说文案生成成功!") @@ -100,7 +244,20 @@ def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperatu # 结果转换为JSON字符串 narration_script = narration_result["narration_script"] - narration_dict = json.loads(narration_script) + + # 增强JSON解析,包含错误处理和修复 + narration_dict = parse_and_fix_json(narration_script) + if narration_dict is None: + st.error("生成的解说文案格式错误,无法解析为JSON") + logger.error(f"JSON解析失败,原始内容: {narration_script}") + st.stop() + + # 验证JSON结构 + if 'items' not in narration_dict: + st.error("生成的解说文案缺少必要的'items'字段") + logger.error(f"JSON结构错误,缺少items字段: {narration_dict}") + st.stop() + script = json.dumps(narration_dict['items'], ensure_ascii=False, indent=2) if script is None: