mirror of
https://github.com/linyqh/NarratoAI.git
synced 2026-02-21 16:00:28 +00:00
feat(webui): 集成 DeepSeek 文本生成模型
- 在文本生成模型提供商列表中添加 DeepSeek - 实现 DeepSeek API 的生成器类 - 在脚本生成器中支持 DeepSeek 模型 - 优化脚本处理过程中的错误提示
This commit is contained in:
parent
1a16d2b655
commit
4c57fe0fa9
@ -223,7 +223,7 @@ class MoonshotGenerator(BaseGenerator):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def _generate(self, messages: list, params: dict) -> any:
|
def _generate(self, messages: list, params: dict) -> any:
|
||||||
"""实现Moonshot特定的生成逻辑,包含429错误重试机制"""
|
"""实现Moonshot特定的生成逻辑,包含429<EFBFBD><EFBFBD><EFBFBD>误重试机制"""
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
@ -249,6 +249,42 @@ class MoonshotGenerator(BaseGenerator):
|
|||||||
return response.choices[0].message.content.strip()
|
return response.choices[0].message.content.strip()
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSeekGenerator(BaseGenerator):
|
||||||
|
"""DeepSeek API 生成器实现"""
|
||||||
|
def __init__(self, model_name: str, api_key: str, prompt: str, base_url: str):
|
||||||
|
super().__init__(model_name, api_key, prompt)
|
||||||
|
self.client = OpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url or "https://api.deepseek.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
# DeepSeek特定参数
|
||||||
|
self.default_params = {
|
||||||
|
**self.default_params,
|
||||||
|
"stream": False,
|
||||||
|
"user": "script_generator"
|
||||||
|
}
|
||||||
|
|
||||||
|
def _generate(self, messages: list, params: dict) -> any:
|
||||||
|
"""实现DeepSeek特定的生成逻辑"""
|
||||||
|
try:
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model_name, # deepseek-chat 或 deepseek-coder
|
||||||
|
messages=messages,
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"DeepSeek generation error: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _process_response(self, response: any) -> str:
|
||||||
|
"""处理DeepSeek的响应"""
|
||||||
|
if not response or not response.choices:
|
||||||
|
raise ValueError("Invalid response from DeepSeek API")
|
||||||
|
return response.choices[0].message.content.strip()
|
||||||
|
|
||||||
|
|
||||||
class ScriptProcessor:
|
class ScriptProcessor:
|
||||||
def __init__(self, model_name: str, api_key: str = None, base_url: str = None, prompt: str = None, video_theme: str = ""):
|
def __init__(self, model_name: str, api_key: str = None, base_url: str = None, prompt: str = None, video_theme: str = ""):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
@ -258,12 +294,15 @@ class ScriptProcessor:
|
|||||||
self.prompt = prompt or self._get_default_prompt()
|
self.prompt = prompt or self._get_default_prompt()
|
||||||
|
|
||||||
# 根据模型名称选择对应的生成器
|
# 根据模型名称选择对应的生成器
|
||||||
|
logger.info(f"文本 LLM 提供商: {model_name}")
|
||||||
if 'gemini' in model_name.lower():
|
if 'gemini' in model_name.lower():
|
||||||
self.generator = GeminiGenerator(model_name, self.api_key, self.prompt)
|
self.generator = GeminiGenerator(model_name, self.api_key, self.prompt)
|
||||||
elif 'qwen' in model_name.lower():
|
elif 'qwen' in model_name.lower():
|
||||||
self.generator = QwenGenerator(model_name, self.api_key, self.prompt, self.base_url)
|
self.generator = QwenGenerator(model_name, self.api_key, self.prompt, self.base_url)
|
||||||
elif 'moonshot' in model_name.lower():
|
elif 'moonshot' in model_name.lower():
|
||||||
self.generator = MoonshotGenerator(model_name, self.api_key, self.prompt, self.base_url)
|
self.generator = MoonshotGenerator(model_name, self.api_key, self.prompt, self.base_url)
|
||||||
|
elif 'deepseek' in model_name.lower():
|
||||||
|
self.generator = DeepSeekGenerator(model_name, self.api_key, self.prompt, self.base_url)
|
||||||
else:
|
else:
|
||||||
self.generator = OpenAIGenerator(model_name, self.api_key, self.prompt, self.base_url)
|
self.generator = OpenAIGenerator(model_name, self.api_key, self.prompt, self.base_url)
|
||||||
|
|
||||||
|
|||||||
@ -191,7 +191,7 @@ def render_text_llm_settings(tr):
|
|||||||
st.subheader(tr("Text Generation Model Settings"))
|
st.subheader(tr("Text Generation Model Settings"))
|
||||||
|
|
||||||
# 文案生成模型提供商选择
|
# 文案生成模型提供商选择
|
||||||
text_providers = ['OpenAI', 'Qwen', 'Moonshot', 'Gemini']
|
text_providers = ['OpenAI', 'Qwen', 'Moonshot', 'DeepSeek', 'Gemini']
|
||||||
saved_text_provider = config.app.get("text_llm_provider", "OpenAI").lower()
|
saved_text_provider = config.app.get("text_llm_provider", "OpenAI").lower()
|
||||||
saved_provider_index = 0
|
saved_provider_index = 0
|
||||||
|
|
||||||
|
|||||||
@ -501,8 +501,8 @@ def generate_script(tr, params):
|
|||||||
script = json.dumps(script_result, ensure_ascii=False, indent=2)
|
script = json.dumps(script_result, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Gemini 处理过程中发生错误\n{traceback.format_exc()}")
|
logger.exception(f"大模型处理过程中发生错误\n{traceback.format_exc()}")
|
||||||
raise Exception(f"视觉分析失败: {str(e)}")
|
raise Exception(f"分析失败: {str(e)}")
|
||||||
|
|
||||||
elif vision_llm_provider == 'narratoapi': # NarratoAPI
|
elif vision_llm_provider == 'narratoapi': # NarratoAPI
|
||||||
try:
|
try:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user