mirror of
https://github.com/linyqh/NarratoAI.git
synced 2025-12-11 18:42:49 +00:00
refactor(script_generator): 重构脚本生成器
- 优化了基本设置中的文本生成模型提供商选择 - 重新设计了脚本生成器的架构,提高了可扩展性和维护性 - 为 OpenAI、Gemini、Qwen 和 Moonshot 生成器实现了统一的接口和流程 - 移除了字数控制要求,简化了生成逻辑
This commit is contained in:
parent
ee52600ae2
commit
eaa8ceb7e3
@ -17,23 +17,108 @@ class BaseGenerator:
|
||||
self.conversation_history = []
|
||||
self.chunk_overlap = 50
|
||||
self.last_chunk_ending = ""
|
||||
self.default_params = {
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 500,
|
||||
"top_p": 0.9,
|
||||
"frequency_penalty": 0.3,
|
||||
"presence_penalty": 0.5
|
||||
}
|
||||
|
||||
def _try_generate(self, messages: list, params: dict = None) -> str:
|
||||
max_attempts = 3
|
||||
tolerance = 5
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
response = self._generate(messages, params or self.default_params)
|
||||
return self._process_response(response)
|
||||
except Exception as e:
|
||||
if attempt == max_attempts - 1:
|
||||
raise
|
||||
logger.warning(f"Generation attempt {attempt + 1} failed: {str(e)}")
|
||||
continue
|
||||
return ""
|
||||
|
||||
def _generate(self, messages: list, params: dict) -> any:
|
||||
raise NotImplementedError
|
||||
|
||||
def _process_response(self, response: any) -> str:
|
||||
return response
|
||||
|
||||
def generate_script(self, scene_description: str, word_count: int) -> str:
|
||||
raise NotImplementedError("Subclasses must implement generate_script method")
|
||||
"""生成脚本的通用方法"""
|
||||
prompt = f"""{self.base_prompt}
|
||||
|
||||
上一段文案的结尾:{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"}
|
||||
|
||||
当前画面描述:{scene_description}
|
||||
|
||||
请确保新生成的文案与上文自然衔接,保持叙事的连贯性和趣味性。
|
||||
严格字数要求:{word_count}字,允许误差±5字。"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": self.base_prompt},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
try:
|
||||
generated_script = self._try_generate(messages, self.default_params)
|
||||
|
||||
# 更新上下文
|
||||
if generated_script:
|
||||
self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len(
|
||||
generated_script) > self.chunk_overlap else generated_script
|
||||
|
||||
return generated_script
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Script generation failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
class OpenAIGenerator(BaseGenerator):
|
||||
"""OpenAI API 生成器实现"""
|
||||
def __init__(self, model_name: str, api_key: str, prompt: str):
|
||||
super().__init__(model_name, api_key, prompt)
|
||||
self.client = OpenAI(api_key=api_key)
|
||||
self.max_tokens = 7000
|
||||
|
||||
# OpenAI特定参数
|
||||
self.default_params = {
|
||||
**self.default_params,
|
||||
"stream": False,
|
||||
"user": "script_generator"
|
||||
}
|
||||
|
||||
# 初始化token计数器
|
||||
try:
|
||||
self.encoding = tiktoken.encoding_for_model(self.model_name)
|
||||
except KeyError:
|
||||
logger.info(f"警告:未找到模型 {self.model_name} 的专用编码器,使用认编码器")
|
||||
logger.warning(f"未找到模型 {self.model_name} 的专用编码器,使用默认编码器")
|
||||
self.encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
def _generate(self, messages: list, params: dict) -> any:
|
||||
"""实现OpenAI特定的生成逻辑"""
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
**params
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI generation error: {str(e)}")
|
||||
raise
|
||||
|
||||
def _process_response(self, response: any) -> str:
|
||||
"""处理OpenAI的响应"""
|
||||
if not response or not response.choices:
|
||||
raise ValueError("Invalid response from OpenAI API")
|
||||
return response.choices[0].message.content.strip()
|
||||
|
||||
def _count_tokens(self, messages: list) -> int:
|
||||
"""计算token数量"""
|
||||
num_tokens = 0
|
||||
for message in messages:
|
||||
num_tokens += 3
|
||||
@ -44,206 +129,116 @@ class OpenAIGenerator(BaseGenerator):
|
||||
num_tokens += 3
|
||||
return num_tokens
|
||||
|
||||
def _trim_conversation_history(self, system_prompt: str, new_user_prompt: str) -> None:
|
||||
base_messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": new_user_prompt}
|
||||
]
|
||||
base_tokens = self._count_tokens(base_messages)
|
||||
|
||||
temp_history = []
|
||||
current_tokens = base_tokens
|
||||
|
||||
for message in reversed(self.conversation_history):
|
||||
message_tokens = self._count_tokens([message])
|
||||
if current_tokens + message_tokens > self.max_tokens:
|
||||
break
|
||||
temp_history.insert(0, message)
|
||||
current_tokens += message_tokens
|
||||
|
||||
self.conversation_history = temp_history
|
||||
|
||||
def generate_script(self, scene_description: str, word_count: int) -> str:
|
||||
max_attempts = 3
|
||||
tolerance = 5
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
system_prompt, user_prompt = self._create_prompt(scene_description, word_count)
|
||||
self._trim_conversation_history(system_prompt, user_prompt)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
*self.conversation_history,
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=0.7,
|
||||
max_tokens=500,
|
||||
top_p=0.9,
|
||||
frequency_penalty=0.3,
|
||||
presence_penalty=0.5
|
||||
)
|
||||
|
||||
generated_script = response.choices[0].message.content.strip().strip('"').strip("'").replace('\"',
|
||||
'').replace(
|
||||
'\n', '')
|
||||
|
||||
current_length = len(generated_script)
|
||||
if abs(current_length - word_count) <= tolerance:
|
||||
self.conversation_history.append({"role": "user", "content": user_prompt})
|
||||
self.conversation_history.append({"role": "assistant", "content": generated_script})
|
||||
self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len(
|
||||
generated_script) > self.chunk_overlap else generated_script
|
||||
return generated_script
|
||||
|
||||
return generated_script
|
||||
|
||||
def _create_prompt(self, scene_description: str, word_count: int) -> tuple:
|
||||
system_prompt = self.base_prompt.format(word_count=word_count)
|
||||
|
||||
user_prompt = f"""上一段文案的结尾:{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"}
|
||||
|
||||
当前画面描述:{scene_description}
|
||||
|
||||
请确保新生成的文案与上文自然衔接,保持叙事的连贯性和趣味性。
|
||||
严格字数要求:{word_count}字,允许误差±5字。"""
|
||||
|
||||
return system_prompt, user_prompt
|
||||
|
||||
|
||||
class GeminiGenerator(BaseGenerator):
|
||||
"""Google Gemini API 生成器实现"""
|
||||
def __init__(self, model_name: str, api_key: str, prompt: str):
|
||||
super().__init__(model_name, api_key, prompt)
|
||||
genai.configure(api_key=api_key)
|
||||
self.model = genai.GenerativeModel(model_name)
|
||||
|
||||
# Gemini特定参数
|
||||
self.default_params = {
|
||||
"temperature": self.default_params["temperature"],
|
||||
"top_p": self.default_params["top_p"],
|
||||
"candidate_count": 1,
|
||||
"stop_sequences": None
|
||||
}
|
||||
|
||||
def generate_script(self, scene_description: str, word_count: int) -> str:
|
||||
max_attempts = 3
|
||||
tolerance = 5
|
||||
def _generate(self, messages: list, params: dict) -> any:
|
||||
"""实现Gemini特定的生成逻辑"""
|
||||
try:
|
||||
# 转换消息格式为Gemini格式
|
||||
prompt = "\n".join([m["content"] for m in messages])
|
||||
response = self.model.generate_content(
|
||||
prompt,
|
||||
generation_config=params
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Gemini generation error: {str(e)}")
|
||||
raise
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
prompt = f"""{self.base_prompt}
|
||||
|
||||
上一段文案的结尾:{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"}
|
||||
|
||||
当前画面描述:{scene_description}
|
||||
|
||||
请确保新生成的文案与上文自然衔接,保持叙事的连贯性和趣味性。
|
||||
严格字数要求:{word_count}字,允许误差±5字。"""
|
||||
|
||||
response = self.model.generate_content(prompt)
|
||||
generated_script = response.text.strip().strip('"').strip("'").replace('\"', '').replace('\n', '')
|
||||
|
||||
current_length = len(generated_script)
|
||||
if abs(current_length - word_count) <= tolerance:
|
||||
self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len(
|
||||
generated_script) > self.chunk_overlap else generated_script
|
||||
return generated_script
|
||||
|
||||
return generated_script
|
||||
def _process_response(self, response: any) -> str:
|
||||
"""处理Gemini的响应"""
|
||||
if not response or not response.text:
|
||||
raise ValueError("Invalid response from Gemini API")
|
||||
return response.text.strip()
|
||||
|
||||
|
||||
class QwenGenerator(BaseGenerator):
|
||||
"""阿里云千问 API 生成器实现"""
|
||||
def __init__(self, model_name: str, api_key: str, prompt: str):
|
||||
super().__init__(model_name, api_key, prompt)
|
||||
self.client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
)
|
||||
|
||||
# Qwen特定参数
|
||||
self.default_params = {
|
||||
**self.default_params,
|
||||
"stream": False,
|
||||
"user": "script_generator",
|
||||
"enable_search": True
|
||||
}
|
||||
|
||||
def generate_script(self, scene_description: str, word_count: int) -> str:
|
||||
max_attempts = 3
|
||||
tolerance = 5
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
prompt = f"""{self.base_prompt}
|
||||
|
||||
上一段文案的结尾:{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"}
|
||||
|
||||
当前画面描述:{scene_description}
|
||||
|
||||
请确保新生成的文案与上文自然衔接,保持叙事的连贯性和趣味性。
|
||||
严格字数要求:{word_count}字,允许误差±5字。"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": self.base_prompt},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
def _generate(self, messages: list, params: dict) -> any:
|
||||
"""实现千问特定的生成逻辑"""
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name, # 如 "qwen-plus"
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=0.7,
|
||||
max_tokens=500,
|
||||
top_p=0.9,
|
||||
frequency_penalty=0.3,
|
||||
presence_penalty=0.5
|
||||
**params
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Qwen generation error: {str(e)}")
|
||||
raise
|
||||
|
||||
generated_script = response.choices[0].message.content.strip().strip('"').strip("'").replace('\"',
|
||||
'').replace(
|
||||
'\n', '')
|
||||
|
||||
current_length = len(generated_script)
|
||||
if abs(current_length - word_count) <= tolerance:
|
||||
self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len(
|
||||
generated_script) > self.chunk_overlap else generated_script
|
||||
return generated_script
|
||||
|
||||
return generated_script
|
||||
def _process_response(self, response: any) -> str:
|
||||
"""处理千问的响应"""
|
||||
if not response or not response.choices:
|
||||
raise ValueError("Invalid response from Qwen API")
|
||||
return response.choices[0].message.content.strip()
|
||||
|
||||
|
||||
class MoonshotGenerator(BaseGenerator):
|
||||
"""Moonshot API 生成器实现"""
|
||||
def __init__(self, model_name: str, api_key: str, prompt: str):
|
||||
super().__init__(model_name, api_key, prompt)
|
||||
self.client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url="https://api.moonshot.cn/v1"
|
||||
)
|
||||
|
||||
# Moonshot特定参数
|
||||
self.default_params = {
|
||||
**self.default_params,
|
||||
"stream": False,
|
||||
"stop": None,
|
||||
"user": "script_generator",
|
||||
"tools": None
|
||||
}
|
||||
|
||||
def generate_script(self, scene_description: str, word_count: int) -> str:
|
||||
max_attempts = 3
|
||||
tolerance = 5
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
prompt = f"""{self.base_prompt}
|
||||
|
||||
上一段文案的结尾:{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"}
|
||||
|
||||
当前画面描述:{scene_description}
|
||||
|
||||
请确保新生成的文案与上文自然衔接,保持叙事的连贯性和趣味性。
|
||||
严格字数要求:{word_count}字,允许误差±5字。"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": self.base_prompt},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
def _generate(self, messages: list, params: dict) -> any:
|
||||
"""实现Moonshot特定的生成逻辑"""
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name, # 如 "moonshot-v1-8k"
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=0.7,
|
||||
max_tokens=500,
|
||||
top_p=0.9,
|
||||
frequency_penalty=0.3,
|
||||
presence_penalty=0.5
|
||||
**params
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Moonshot generation error: {str(e)}")
|
||||
raise
|
||||
|
||||
generated_script = response.choices[0].message.content.strip().strip('"').strip("'").replace('\"',
|
||||
'').replace(
|
||||
'\n', '')
|
||||
|
||||
current_length = len(generated_script)
|
||||
if abs(current_length - word_count) <= tolerance:
|
||||
self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len(
|
||||
generated_script) > self.chunk_overlap else generated_script
|
||||
return generated_script
|
||||
|
||||
return generated_script
|
||||
def _process_response(self, response: any) -> str:
|
||||
"""处理Moonshot的响应"""
|
||||
if not response or not response.choices:
|
||||
raise ValueError("Invalid response from Moonshot API")
|
||||
return response.choices[0].message.content.strip()
|
||||
|
||||
|
||||
class ScriptProcessor:
|
||||
@ -263,7 +258,7 @@ class ScriptProcessor:
|
||||
else:
|
||||
self.generator = OpenAIGenerator(model_name, self.api_key, self.prompt)
|
||||
|
||||
def _get_default_prompt(self, word_count=None) -> str:
|
||||
def _get_default_prompt(self) -> str:
|
||||
return f"""你是一位极具幽默感的短视频脚本创作大师,擅长用"温和的违反"制造笑点,让{self.video_theme}视频既有趣又富有传播力。你的任务是将视频画面描述转化为能在社交平台疯狂传播的爆款口播文案。
|
||||
|
||||
目标受众:热爱生活、追求独特体验的18-35岁年轻人
|
||||
@ -293,10 +288,9 @@ class ScriptProcessor:
|
||||
3. 用接地气的表达方式拉近与观众距离
|
||||
|
||||
【节奏控制 Rhythm】
|
||||
1. 严格控制文案字数在{word_count}字左右,允许误差不超过5字
|
||||
2. 像讲段子一样,注意铺垫和包袱的节奏
|
||||
3. 确保每段都有笑点,但不强求
|
||||
4. 段落结尾干净利落,不拖泥带水
|
||||
1. 像讲段子一样,注意铺垫和包袱的节奏
|
||||
2. 确保每段都有笑点,但不强求
|
||||
3. 段落结尾干净利落,不拖泥带水
|
||||
|
||||
【连贯性要求】
|
||||
1. 新生成的内容必须自然衔接上一段文案的结尾
|
||||
@ -305,13 +299,6 @@ class ScriptProcessor:
|
||||
4. 避免重复上一段已经提到的信息
|
||||
5. 确保情节和建造过程的逻辑连续性
|
||||
|
||||
【字数控制要求】
|
||||
1. 严格控制文案字数在{word_count}字左右,允许误差不超过5字
|
||||
2. 如果内容过长,优先精简修饰性词语
|
||||
3. 如果内容过短,可以适当增加细节描写
|
||||
4. 保持文案结构完整,不因字数限制而牺牲内容质量
|
||||
5. 确保每个笑点和包袱都得到完整表达
|
||||
|
||||
我会按顺序提供多段视频画面描述。请创作既搞笑又能火爆全网的口播文案。
|
||||
记住:要敢于用"温和的违反"制造笑点,但要把握好尺度,让观众在轻松愉快中感受野外建造的乐趣。"""
|
||||
|
||||
|
||||
@ -191,7 +191,7 @@ def render_text_llm_settings(tr):
|
||||
st.subheader(tr("Text Generation Model Settings"))
|
||||
|
||||
# 文案生成模型提供商选择
|
||||
text_providers = ['OpenAI', 'Gemini', 'Moonshot', 'Azure', 'Qwen', 'Ollama', 'G4f', 'OneAPI', 'Cloudflare']
|
||||
text_providers = ['OpenAI', 'Qwen', 'Moonshot', 'Gemini']
|
||||
saved_text_provider = config.app.get("text_llm_provider", "OpenAI").lower()
|
||||
saved_provider_index = 0
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user