diff --git a/app/utils/script_generator.py b/app/utils/script_generator.py index 2bec834..b3e4d58 100644 --- a/app/utils/script_generator.py +++ b/app/utils/script_generator.py @@ -17,23 +17,108 @@ class BaseGenerator: self.conversation_history = [] self.chunk_overlap = 50 self.last_chunk_ending = "" + self.default_params = { + "temperature": 0.7, + "max_tokens": 500, + "top_p": 0.9, + "frequency_penalty": 0.3, + "presence_penalty": 0.5 + } + + def _try_generate(self, messages: list, params: dict = None) -> str: + max_attempts = 3 + tolerance = 5 + + for attempt in range(max_attempts): + try: + response = self._generate(messages, params or self.default_params) + return self._process_response(response) + except Exception as e: + if attempt == max_attempts - 1: + raise + logger.warning(f"Generation attempt {attempt + 1} failed: {str(e)}") + continue + return "" + + def _generate(self, messages: list, params: dict) -> any: + raise NotImplementedError + + def _process_response(self, response: any) -> str: + return response def generate_script(self, scene_description: str, word_count: int) -> str: - raise NotImplementedError("Subclasses must implement generate_script method") + """生成脚本的通用方法""" + prompt = f"""{self.base_prompt} + +上一段文案的结尾:{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"} + +当前画面描述:{scene_description} + +请确保新生成的文案与上文自然衔接,保持叙事的连贯性和趣味性。 +严格字数要求:{word_count}字,允许误差±5字。""" + + messages = [ + {"role": "system", "content": self.base_prompt}, + {"role": "user", "content": prompt} + ] + + try: + generated_script = self._try_generate(messages, self.default_params) + + # 更新上下文 + if generated_script: + self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len( + generated_script) > self.chunk_overlap else generated_script + + return generated_script + + except Exception as e: + logger.error(f"Script generation failed: {str(e)}") + raise class OpenAIGenerator(BaseGenerator): + """OpenAI API 生成器实现""" def __init__(self, model_name: str, api_key: str, prompt: str): super().__init__(model_name, api_key, prompt) self.client = OpenAI(api_key=api_key) self.max_tokens = 7000 + + # OpenAI特定参数 + self.default_params = { + **self.default_params, + "stream": False, + "user": "script_generator" + } + + # 初始化token计数器 try: self.encoding = tiktoken.encoding_for_model(self.model_name) except KeyError: - logger.info(f"警告:未找到模型 {self.model_name} 的专用编码器,使用认编码器") + logger.warning(f"未找到模型 {self.model_name} 的专用编码器,使用默认编码器") self.encoding = tiktoken.get_encoding("cl100k_base") + def _generate(self, messages: list, params: dict) -> any: + """实现OpenAI特定的生成逻辑""" + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=messages, + **params + ) + return response + except Exception as e: + logger.error(f"OpenAI generation error: {str(e)}") + raise + + def _process_response(self, response: any) -> str: + """处理OpenAI的响应""" + if not response or not response.choices: + raise ValueError("Invalid response from OpenAI API") + return response.choices[0].message.content.strip() + def _count_tokens(self, messages: list) -> int: + """计算token数量""" num_tokens = 0 for message in messages: num_tokens += 3 @@ -44,206 +129,116 @@ class OpenAIGenerator(BaseGenerator): num_tokens += 3 return num_tokens - def _trim_conversation_history(self, system_prompt: str, new_user_prompt: str) -> None: - base_messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": new_user_prompt} - ] - base_tokens = self._count_tokens(base_messages) - - temp_history = [] - current_tokens = base_tokens - - for message in reversed(self.conversation_history): - message_tokens = self._count_tokens([message]) - if current_tokens + message_tokens > self.max_tokens: - break - temp_history.insert(0, message) - current_tokens += message_tokens - - self.conversation_history = temp_history - - def generate_script(self, scene_description: str, word_count: int) -> str: - max_attempts = 3 - tolerance = 5 - - for attempt in range(max_attempts): - system_prompt, user_prompt = self._create_prompt(scene_description, word_count) - self._trim_conversation_history(system_prompt, user_prompt) - - messages = [ - {"role": "system", "content": system_prompt}, - *self.conversation_history, - {"role": "user", "content": user_prompt} - ] - - response = self.client.chat.completions.create( - model=self.model_name, - messages=messages, - temperature=0.7, - max_tokens=500, - top_p=0.9, - frequency_penalty=0.3, - presence_penalty=0.5 - ) - - generated_script = response.choices[0].message.content.strip().strip('"').strip("'").replace('\"', - '').replace( - '\n', '') - - current_length = len(generated_script) - if abs(current_length - word_count) <= tolerance: - self.conversation_history.append({"role": "user", "content": user_prompt}) - self.conversation_history.append({"role": "assistant", "content": generated_script}) - self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len( - generated_script) > self.chunk_overlap else generated_script - return generated_script - - return generated_script - - def _create_prompt(self, scene_description: str, word_count: int) -> tuple: - system_prompt = self.base_prompt.format(word_count=word_count) - - user_prompt = f"""上一段文案的结尾:{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"} - -当前画面描述:{scene_description} - -请确保新生成的文案与上文自然衔接,保持叙事的连贯性和趣味性。 -严格字数要求:{word_count}字,允许误差±5字。""" - - return system_prompt, user_prompt - class GeminiGenerator(BaseGenerator): + """Google Gemini API 生成器实现""" def __init__(self, model_name: str, api_key: str, prompt: str): super().__init__(model_name, api_key, prompt) genai.configure(api_key=api_key) self.model = genai.GenerativeModel(model_name) + + # Gemini特定参数 + self.default_params = { + "temperature": self.default_params["temperature"], + "top_p": self.default_params["top_p"], + "candidate_count": 1, + "stop_sequences": None + } - def generate_script(self, scene_description: str, word_count: int) -> str: - max_attempts = 3 - tolerance = 5 + def _generate(self, messages: list, params: dict) -> any: + """实现Gemini特定的生成逻辑""" + try: + # 转换消息格式为Gemini格式 + prompt = "\n".join([m["content"] for m in messages]) + response = self.model.generate_content( + prompt, + generation_config=params + ) + return response + except Exception as e: + logger.error(f"Gemini generation error: {str(e)}") + raise - for attempt in range(max_attempts): - prompt = f"""{self.base_prompt} - -上一段文案的结尾:{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"} - -当前画面描述:{scene_description} - -请确保新生成的文案与上文自然衔接,保持叙事的连贯性和趣味性。 -严格字数要求:{word_count}字,允许误差±5字。""" - - response = self.model.generate_content(prompt) - generated_script = response.text.strip().strip('"').strip("'").replace('\"', '').replace('\n', '') - - current_length = len(generated_script) - if abs(current_length - word_count) <= tolerance: - self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len( - generated_script) > self.chunk_overlap else generated_script - return generated_script - - return generated_script + def _process_response(self, response: any) -> str: + """处理Gemini的响应""" + if not response or not response.text: + raise ValueError("Invalid response from Gemini API") + return response.text.strip() class QwenGenerator(BaseGenerator): + """阿里云千问 API 生成器实现""" def __init__(self, model_name: str, api_key: str, prompt: str): super().__init__(model_name, api_key, prompt) self.client = OpenAI( api_key=api_key, base_url="https://dashscope.aliyuncs.com/compatible-mode/v1" ) + + # Qwen特定参数 + self.default_params = { + **self.default_params, + "stream": False, + "user": "script_generator", + "enable_search": True + } - def generate_script(self, scene_description: str, word_count: int) -> str: - max_attempts = 3 - tolerance = 5 - - for attempt in range(max_attempts): - prompt = f"""{self.base_prompt} - -上一段文案的结尾:{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"} - -当前画面描述:{scene_description} - -请确保新生成的文案与上文自然衔接,保持叙事的连贯性和趣味性。 -严格字数要求:{word_count}字,允许误差±5字。""" - - messages = [ - {"role": "system", "content": self.base_prompt}, - {"role": "user", "content": prompt} - ] - + def _generate(self, messages: list, params: dict) -> any: + """实现千问特定的生成逻辑""" + try: response = self.client.chat.completions.create( - model=self.model_name, # 如 "qwen-plus" + model=self.model_name, messages=messages, - temperature=0.7, - max_tokens=500, - top_p=0.9, - frequency_penalty=0.3, - presence_penalty=0.5 + **params ) + return response + except Exception as e: + logger.error(f"Qwen generation error: {str(e)}") + raise - generated_script = response.choices[0].message.content.strip().strip('"').strip("'").replace('\"', - '').replace( - '\n', '') - - current_length = len(generated_script) - if abs(current_length - word_count) <= tolerance: - self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len( - generated_script) > self.chunk_overlap else generated_script - return generated_script - - return generated_script + def _process_response(self, response: any) -> str: + """处理千问的响应""" + if not response or not response.choices: + raise ValueError("Invalid response from Qwen API") + return response.choices[0].message.content.strip() class MoonshotGenerator(BaseGenerator): + """Moonshot API 生成器实现""" def __init__(self, model_name: str, api_key: str, prompt: str): super().__init__(model_name, api_key, prompt) self.client = OpenAI( api_key=api_key, base_url="https://api.moonshot.cn/v1" ) + + # Moonshot特定参数 + self.default_params = { + **self.default_params, + "stream": False, + "stop": None, + "user": "script_generator", + "tools": None + } - def generate_script(self, scene_description: str, word_count: int) -> str: - max_attempts = 3 - tolerance = 5 - - for attempt in range(max_attempts): - prompt = f"""{self.base_prompt} - -上一段文案的结尾:{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"} - -当前画面描述:{scene_description} - -请确保新生成的文案与上文自然衔接,保持叙事的连贯性和趣味性。 -严格字数要求:{word_count}字,允许误差±5字。""" - - messages = [ - {"role": "system", "content": self.base_prompt}, - {"role": "user", "content": prompt} - ] - + def _generate(self, messages: list, params: dict) -> any: + """实现Moonshot特定的生成逻辑""" + try: response = self.client.chat.completions.create( - model=self.model_name, # 如 "moonshot-v1-8k" + model=self.model_name, messages=messages, - temperature=0.7, - max_tokens=500, - top_p=0.9, - frequency_penalty=0.3, - presence_penalty=0.5 + **params ) + return response + except Exception as e: + logger.error(f"Moonshot generation error: {str(e)}") + raise - generated_script = response.choices[0].message.content.strip().strip('"').strip("'").replace('\"', - '').replace( - '\n', '') - - current_length = len(generated_script) - if abs(current_length - word_count) <= tolerance: - self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len( - generated_script) > self.chunk_overlap else generated_script - return generated_script - - return generated_script + def _process_response(self, response: any) -> str: + """处理Moonshot的响应""" + if not response or not response.choices: + raise ValueError("Invalid response from Moonshot API") + return response.choices[0].message.content.strip() class ScriptProcessor: @@ -263,7 +258,7 @@ class ScriptProcessor: else: self.generator = OpenAIGenerator(model_name, self.api_key, self.prompt) - def _get_default_prompt(self, word_count=None) -> str: + def _get_default_prompt(self) -> str: return f"""你是一位极具幽默感的短视频脚本创作大师,擅长用"温和的违反"制造笑点,让{self.video_theme}视频既有趣又富有传播力。你的任务是将视频画面描述转化为能在社交平台疯狂传播的爆款口播文案。 目标受众:热爱生活、追求独特体验的18-35岁年轻人 @@ -293,10 +288,9 @@ class ScriptProcessor: 3. 用接地气的表达方式拉近与观众距离 【节奏控制 Rhythm】 -1. 严格控制文案字数在{word_count}字左右,允许误差不超过5字 -2. 像讲段子一样,注意铺垫和包袱的节奏 -3. 确保每段都有笑点,但不强求 -4. 段落结尾干净利落,不拖泥带水 +1. 像讲段子一样,注意铺垫和包袱的节奏 +2. 确保每段都有笑点,但不强求 +3. 段落结尾干净利落,不拖泥带水 【连贯性要求】 1. 新生成的内容必须自然衔接上一段文案的结尾 @@ -305,13 +299,6 @@ class ScriptProcessor: 4. 避免重复上一段已经提到的信息 5. 确保情节和建造过程的逻辑连续性 -【字数控制要求】 -1. 严格控制文案字数在{word_count}字左右,允许误差不超过5字 -2. 如果内容过长,优先精简修饰性词语 -3. 如果内容过短,可以适当增加细节描写 -4. 保持文案结构完整,不因字数限制而牺牲内容质量 -5. 确保每个笑点和包袱都得到完整表达 - 我会按顺序提供多段视频画面描述。请创作既搞笑又能火爆全网的口播文案。 记住:要敢于用"温和的违反"制造笑点,但要把握好尺度,让观众在轻松愉快中感受野外建造的乐趣。""" diff --git a/webui/components/basic_settings.py b/webui/components/basic_settings.py index b8e69ad..855db83 100644 --- a/webui/components/basic_settings.py +++ b/webui/components/basic_settings.py @@ -191,7 +191,7 @@ def render_text_llm_settings(tr): st.subheader(tr("Text Generation Model Settings")) # 文案生成模型提供商选择 - text_providers = ['OpenAI', 'Gemini', 'Moonshot', 'Azure', 'Qwen', 'Ollama', 'G4f', 'OneAPI', 'Cloudflare'] + text_providers = ['OpenAI', 'Qwen', 'Moonshot', 'Gemini'] saved_text_provider = config.app.get("text_llm_provider", "OpenAI").lower() saved_provider_index = 0