refactor(script_generator): 重构脚本生成器

- 优化了基本设置中的文本生成模型提供商选择
- 重新设计了脚本生成器的架构,提高了可扩展性和维护性
- 为 OpenAI、Gemini、Qwen 和 Moonshot 生成器实现了统一的接口和流程
- 移除了字数控制要求,简化了生成逻辑
This commit is contained in:
linyq 2024-11-11 17:22:01 +08:00
parent ee52600ae2
commit eaa8ceb7e3
2 changed files with 166 additions and 179 deletions

View File

@ -17,23 +17,108 @@ class BaseGenerator:
self.conversation_history = []
self.chunk_overlap = 50
self.last_chunk_ending = ""
self.default_params = {
"temperature": 0.7,
"max_tokens": 500,
"top_p": 0.9,
"frequency_penalty": 0.3,
"presence_penalty": 0.5
}
def _try_generate(self, messages: list, params: dict = None) -> str:
max_attempts = 3
tolerance = 5
for attempt in range(max_attempts):
try:
response = self._generate(messages, params or self.default_params)
return self._process_response(response)
except Exception as e:
if attempt == max_attempts - 1:
raise
logger.warning(f"Generation attempt {attempt + 1} failed: {str(e)}")
continue
return ""
def _generate(self, messages: list, params: dict) -> any:
raise NotImplementedError
def _process_response(self, response: any) -> str:
return response
def generate_script(self, scene_description: str, word_count: int) -> str:
raise NotImplementedError("Subclasses must implement generate_script method")
"""生成脚本的通用方法"""
prompt = f"""{self.base_prompt}
上一段文案的结尾{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"}
当前画面描述{scene_description}
请确保新生成的文案与上文自然衔接保持叙事的连贯性和趣味性
严格字数要求{word_count}允许误差±5"""
messages = [
{"role": "system", "content": self.base_prompt},
{"role": "user", "content": prompt}
]
try:
generated_script = self._try_generate(messages, self.default_params)
# 更新上下文
if generated_script:
self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len(
generated_script) > self.chunk_overlap else generated_script
return generated_script
except Exception as e:
logger.error(f"Script generation failed: {str(e)}")
raise
class OpenAIGenerator(BaseGenerator):
"""OpenAI API 生成器实现"""
def __init__(self, model_name: str, api_key: str, prompt: str):
super().__init__(model_name, api_key, prompt)
self.client = OpenAI(api_key=api_key)
self.max_tokens = 7000
# OpenAI特定参数
self.default_params = {
**self.default_params,
"stream": False,
"user": "script_generator"
}
# 初始化token计数器
try:
self.encoding = tiktoken.encoding_for_model(self.model_name)
except KeyError:
logger.info(f"警告:未找到模型 {self.model_name} 的专用编码器,使用认编码器")
logger.warning(f"未找到模型 {self.model_name} 的专用编码器,使用认编码器")
self.encoding = tiktoken.get_encoding("cl100k_base")
def _generate(self, messages: list, params: dict) -> any:
"""实现OpenAI特定的生成逻辑"""
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
**params
)
return response
except Exception as e:
logger.error(f"OpenAI generation error: {str(e)}")
raise
def _process_response(self, response: any) -> str:
"""处理OpenAI的响应"""
if not response or not response.choices:
raise ValueError("Invalid response from OpenAI API")
return response.choices[0].message.content.strip()
def _count_tokens(self, messages: list) -> int:
"""计算token数量"""
num_tokens = 0
for message in messages:
num_tokens += 3
@ -44,206 +129,116 @@ class OpenAIGenerator(BaseGenerator):
num_tokens += 3
return num_tokens
def _trim_conversation_history(self, system_prompt: str, new_user_prompt: str) -> None:
base_messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": new_user_prompt}
]
base_tokens = self._count_tokens(base_messages)
temp_history = []
current_tokens = base_tokens
for message in reversed(self.conversation_history):
message_tokens = self._count_tokens([message])
if current_tokens + message_tokens > self.max_tokens:
break
temp_history.insert(0, message)
current_tokens += message_tokens
self.conversation_history = temp_history
def generate_script(self, scene_description: str, word_count: int) -> str:
max_attempts = 3
tolerance = 5
for attempt in range(max_attempts):
system_prompt, user_prompt = self._create_prompt(scene_description, word_count)
self._trim_conversation_history(system_prompt, user_prompt)
messages = [
{"role": "system", "content": system_prompt},
*self.conversation_history,
{"role": "user", "content": user_prompt}
]
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
temperature=0.7,
max_tokens=500,
top_p=0.9,
frequency_penalty=0.3,
presence_penalty=0.5
)
generated_script = response.choices[0].message.content.strip().strip('"').strip("'").replace('\"',
'').replace(
'\n', '')
current_length = len(generated_script)
if abs(current_length - word_count) <= tolerance:
self.conversation_history.append({"role": "user", "content": user_prompt})
self.conversation_history.append({"role": "assistant", "content": generated_script})
self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len(
generated_script) > self.chunk_overlap else generated_script
return generated_script
return generated_script
def _create_prompt(self, scene_description: str, word_count: int) -> tuple:
system_prompt = self.base_prompt.format(word_count=word_count)
user_prompt = f"""上一段文案的结尾:{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"}
当前画面描述{scene_description}
请确保新生成的文案与上文自然衔接保持叙事的连贯性和趣味性
严格字数要求{word_count}允许误差±5"""
return system_prompt, user_prompt
class GeminiGenerator(BaseGenerator):
"""Google Gemini API 生成器实现"""
def __init__(self, model_name: str, api_key: str, prompt: str):
super().__init__(model_name, api_key, prompt)
genai.configure(api_key=api_key)
self.model = genai.GenerativeModel(model_name)
# Gemini特定参数
self.default_params = {
"temperature": self.default_params["temperature"],
"top_p": self.default_params["top_p"],
"candidate_count": 1,
"stop_sequences": None
}
def generate_script(self, scene_description: str, word_count: int) -> str:
max_attempts = 3
tolerance = 5
def _generate(self, messages: list, params: dict) -> any:
"""实现Gemini特定的生成逻辑"""
try:
# 转换消息格式为Gemini格式
prompt = "\n".join([m["content"] for m in messages])
response = self.model.generate_content(
prompt,
generation_config=params
)
return response
except Exception as e:
logger.error(f"Gemini generation error: {str(e)}")
raise
for attempt in range(max_attempts):
prompt = f"""{self.base_prompt}
上一段文案的结尾{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"}
当前画面描述{scene_description}
请确保新生成的文案与上文自然衔接保持叙事的连贯性和趣味性
严格字数要求{word_count}允许误差±5"""
response = self.model.generate_content(prompt)
generated_script = response.text.strip().strip('"').strip("'").replace('\"', '').replace('\n', '')
current_length = len(generated_script)
if abs(current_length - word_count) <= tolerance:
self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len(
generated_script) > self.chunk_overlap else generated_script
return generated_script
return generated_script
def _process_response(self, response: any) -> str:
"""处理Gemini的响应"""
if not response or not response.text:
raise ValueError("Invalid response from Gemini API")
return response.text.strip()
class QwenGenerator(BaseGenerator):
"""阿里云千问 API 生成器实现"""
def __init__(self, model_name: str, api_key: str, prompt: str):
super().__init__(model_name, api_key, prompt)
self.client = OpenAI(
api_key=api_key,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
)
# Qwen特定参数
self.default_params = {
**self.default_params,
"stream": False,
"user": "script_generator",
"enable_search": True
}
def generate_script(self, scene_description: str, word_count: int) -> str:
max_attempts = 3
tolerance = 5
for attempt in range(max_attempts):
prompt = f"""{self.base_prompt}
上一段文案的结尾{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"}
当前画面描述{scene_description}
请确保新生成的文案与上文自然衔接保持叙事的连贯性和趣味性
严格字数要求{word_count}允许误差±5"""
messages = [
{"role": "system", "content": self.base_prompt},
{"role": "user", "content": prompt}
]
def _generate(self, messages: list, params: dict) -> any:
"""实现千问特定的生成逻辑"""
try:
response = self.client.chat.completions.create(
model=self.model_name, # 如 "qwen-plus"
model=self.model_name,
messages=messages,
temperature=0.7,
max_tokens=500,
top_p=0.9,
frequency_penalty=0.3,
presence_penalty=0.5
**params
)
return response
except Exception as e:
logger.error(f"Qwen generation error: {str(e)}")
raise
generated_script = response.choices[0].message.content.strip().strip('"').strip("'").replace('\"',
'').replace(
'\n', '')
current_length = len(generated_script)
if abs(current_length - word_count) <= tolerance:
self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len(
generated_script) > self.chunk_overlap else generated_script
return generated_script
return generated_script
def _process_response(self, response: any) -> str:
"""处理千问的响应"""
if not response or not response.choices:
raise ValueError("Invalid response from Qwen API")
return response.choices[0].message.content.strip()
class MoonshotGenerator(BaseGenerator):
"""Moonshot API 生成器实现"""
def __init__(self, model_name: str, api_key: str, prompt: str):
super().__init__(model_name, api_key, prompt)
self.client = OpenAI(
api_key=api_key,
base_url="https://api.moonshot.cn/v1"
)
# Moonshot特定参数
self.default_params = {
**self.default_params,
"stream": False,
"stop": None,
"user": "script_generator",
"tools": None
}
def generate_script(self, scene_description: str, word_count: int) -> str:
max_attempts = 3
tolerance = 5
for attempt in range(max_attempts):
prompt = f"""{self.base_prompt}
上一段文案的结尾{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"}
当前画面描述{scene_description}
请确保新生成的文案与上文自然衔接保持叙事的连贯性和趣味性
严格字数要求{word_count}允许误差±5"""
messages = [
{"role": "system", "content": self.base_prompt},
{"role": "user", "content": prompt}
]
def _generate(self, messages: list, params: dict) -> any:
"""实现Moonshot特定的生成逻辑"""
try:
response = self.client.chat.completions.create(
model=self.model_name, # 如 "moonshot-v1-8k"
model=self.model_name,
messages=messages,
temperature=0.7,
max_tokens=500,
top_p=0.9,
frequency_penalty=0.3,
presence_penalty=0.5
**params
)
return response
except Exception as e:
logger.error(f"Moonshot generation error: {str(e)}")
raise
generated_script = response.choices[0].message.content.strip().strip('"').strip("'").replace('\"',
'').replace(
'\n', '')
current_length = len(generated_script)
if abs(current_length - word_count) <= tolerance:
self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len(
generated_script) > self.chunk_overlap else generated_script
return generated_script
return generated_script
def _process_response(self, response: any) -> str:
"""处理Moonshot的响应"""
if not response or not response.choices:
raise ValueError("Invalid response from Moonshot API")
return response.choices[0].message.content.strip()
class ScriptProcessor:
@ -263,7 +258,7 @@ class ScriptProcessor:
else:
self.generator = OpenAIGenerator(model_name, self.api_key, self.prompt)
def _get_default_prompt(self, word_count=None) -> str:
def _get_default_prompt(self) -> str:
return f"""你是一位极具幽默感的短视频脚本创作大师,擅长用"温和的违反"制造笑点,让{self.video_theme}视频既有趣又富有传播力。你的任务是将视频画面描述转化为能在社交平台疯狂传播的爆款口播文案。
目标受众热爱生活追求独特体验的18-35岁年轻人
@ -293,10 +288,9 @@ class ScriptProcessor:
3. 用接地气的表达方式拉近与观众距离
节奏控制 Rhythm
1. 严格控制文案字数在{word_count}字左右允许误差不超过5字
2. 像讲段子一样注意铺垫和包袱的节奏
3. 确保每段都有笑点但不强求
4. 段落结尾干净利落不拖泥带水
1. 像讲段子一样注意铺垫和包袱的节奏
2. 确保每段都有笑点但不强求
3. 段落结尾干净利落不拖泥带水
连贯性要求
1. 新生成的内容必须自然衔接上一段文案的结尾
@ -305,13 +299,6 @@ class ScriptProcessor:
4. 避免重复上一段已经提到的信息
5. 确保情节和建造过程的逻辑连续性
字数控制要求
1. 严格控制文案字数在{word_count}字左右允许误差不超过5字
2. 如果内容过长优先精简修饰性词语
3. 如果内容过短可以适当增加细节描写
4. 保持文案结构完整不因字数限制而牺牲内容质量
5. 确保每个笑点和包袱都得到完整表达
我会按顺序提供多段视频画面描述请创作既搞笑又能火爆全网的口播文案
记住要敢于用"温和的违反"制造笑点但要把握好尺度让观众在轻松愉快中感受野外建造的乐趣"""

View File

@ -191,7 +191,7 @@ def render_text_llm_settings(tr):
st.subheader(tr("Text Generation Model Settings"))
# 文案生成模型提供商选择
text_providers = ['OpenAI', 'Gemini', 'Moonshot', 'Azure', 'Qwen', 'Ollama', 'G4f', 'OneAPI', 'Cloudflare']
text_providers = ['OpenAI', 'Qwen', 'Moonshot', 'Gemini']
saved_text_provider = config.app.get("text_llm_provider", "OpenAI").lower()
saved_provider_index = 0