feat(script_generator): 为 OpenAI、Qwen 和 Moonshot 生成器添加 base_url 参数

-为 OpenAIGenerator、QwenGenerator 和 MoonshotGenerator 类添加 base_url 参数
- 更新 ScriptProcessor 类以支持 base_url 参数
-调整 OpenAI 生成器的最大 token 数量从 7000 减少到 5000
- 移动 seconds_to_time 函数以减少代码重复
This commit is contained in:
linyq 2024-11-15 14:35:33 +08:00
parent 177304aec0
commit af9e7fa279

View File

@ -79,10 +79,11 @@ class BaseGenerator:
class OpenAIGenerator(BaseGenerator): class OpenAIGenerator(BaseGenerator):
"""OpenAI API 生成器实现""" """OpenAI API 生成器实现"""
def __init__(self, model_name: str, api_key: str, prompt: str): def __init__(self, model_name: str, api_key: str, prompt: str, base_url: str):
super().__init__(model_name, api_key, prompt) super().__init__(model_name, api_key, prompt)
self.client = OpenAI(api_key=api_key) base_url = base_url or f"https://api.openai.com/v1"
self.max_tokens = 7000 self.client = OpenAI(api_key=api_key, base_url=base_url)
self.max_tokens = 5000
# OpenAI特定参数 # OpenAI特定参数
self.default_params = { self.default_params = {
@ -168,11 +169,11 @@ class GeminiGenerator(BaseGenerator):
class QwenGenerator(BaseGenerator): class QwenGenerator(BaseGenerator):
"""阿里云千问 API 生成器实现""" """阿里云千问 API 生成器实现"""
def __init__(self, model_name: str, api_key: str, prompt: str): def __init__(self, model_name: str, api_key: str, prompt: str, base_url: str):
super().__init__(model_name, api_key, prompt) super().__init__(model_name, api_key, prompt)
self.client = OpenAI( self.client = OpenAI(
api_key=api_key, api_key=api_key,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1" base_url=base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1"
) )
# Qwen特定参数 # Qwen特定参数
@ -204,11 +205,11 @@ class QwenGenerator(BaseGenerator):
class MoonshotGenerator(BaseGenerator): class MoonshotGenerator(BaseGenerator):
"""Moonshot API 生成器实现""" """Moonshot API 生成器实现"""
def __init__(self, model_name: str, api_key: str, prompt: str): def __init__(self, model_name: str, api_key: str, prompt: str, base_url: str):
super().__init__(model_name, api_key, prompt) super().__init__(model_name, api_key, prompt)
self.client = OpenAI( self.client = OpenAI(
api_key=api_key, api_key=api_key,
base_url="https://api.moonshot.cn/v1" base_url=base_url or "https://api.moonshot.cn/v1"
) )
# Moonshot特定参数 # Moonshot特定参数
@ -241,9 +242,10 @@ class MoonshotGenerator(BaseGenerator):
class ScriptProcessor: class ScriptProcessor:
def __init__(self, model_name: str, api_key: str = None, prompt: str = None, video_theme: str = ""): def __init__(self, model_name: str, api_key: str = None, base_url: str = None, prompt: str = None, video_theme: str = ""):
self.model_name = model_name self.model_name = model_name
self.api_key = api_key self.api_key = api_key
self.base_url = base_url
self.video_theme = video_theme self.video_theme = video_theme
self.prompt = prompt or self._get_default_prompt() self.prompt = prompt or self._get_default_prompt()
@ -251,11 +253,11 @@ class ScriptProcessor:
if 'gemini' in model_name.lower(): if 'gemini' in model_name.lower():
self.generator = GeminiGenerator(model_name, self.api_key, self.prompt) self.generator = GeminiGenerator(model_name, self.api_key, self.prompt)
elif 'qwen' in model_name.lower(): elif 'qwen' in model_name.lower():
self.generator = QwenGenerator(model_name, self.api_key, self.prompt) self.generator = QwenGenerator(model_name, self.api_key, self.prompt, self.base_url)
elif 'moonshot' in model_name.lower(): elif 'moonshot' in model_name.lower():
self.generator = MoonshotGenerator(model_name, self.api_key, self.prompt) self.generator = MoonshotGenerator(model_name, self.api_key, self.prompt, self.base_url)
else: else:
self.generator = OpenAIGenerator(model_name, self.api_key, self.prompt) self.generator = OpenAIGenerator(model_name, self.api_key, self.prompt, self.base_url)
def _get_default_prompt(self) -> str: def _get_default_prompt(self) -> str:
return f"""你是一位极具幽默感的短视频脚本创作大师,擅长用"温和的违反"制造笑点,让{self.video_theme}视频既有趣又富有传播力。你的任务是将视频画面描述转化为能在社交平台疯狂传播的爆款口播文案。 return f"""你是一位极具幽默感的短视频脚本创作大师,擅长用"温和的违反"制造笑点,让{self.video_theme}视频既有趣又富有传播力。你的任务是将视频画面描述转化为能在社交平台疯狂传播的爆款口播文案。
@ -334,6 +336,12 @@ class ScriptProcessor:
def _save_results(self, frame_content_list: List[Dict]): def _save_results(self, frame_content_list: List[Dict]):
"""保存处理结果,并添加新的时间戳""" """保存处理结果,并添加新的时间戳"""
try: try:
# 转换秒数为 MM:SS 格式
def seconds_to_time(seconds):
minutes = seconds // 60
remaining_seconds = seconds % 60
return f"{minutes:02d}:{remaining_seconds:02d}"
# 计算新的时间戳 # 计算新的时间戳
current_time = 0 # 当前时间点(秒) current_time = 0 # 当前时间点(秒)
@ -350,12 +358,6 @@ class ScriptProcessor:
end_seconds = time_to_seconds(end_str) end_seconds = time_to_seconds(end_str)
duration = end_seconds - start_seconds duration = end_seconds - start_seconds
# 转换秒数为 MM:SS 格式
def seconds_to_time(seconds):
minutes = seconds // 60
remaining_seconds = seconds % 60
return f"{minutes:02d}:{remaining_seconds:02d}"
# 设置新的时间戳 # 设置新的时间戳
new_start = seconds_to_time(current_time) new_start = seconds_to_time(current_time)
new_end = seconds_to_time(current_time + duration) new_end = seconds_to_time(current_time + duration)