From 4c57fe0fa9f6204a4581fc7280bdb4200a4236aa Mon Sep 17 00:00:00 2001 From: linyq Date: Fri, 15 Nov 2024 16:03:59 +0800 Subject: [PATCH] =?UTF-8?q?feat(webui):=20=E9=9B=86=E6=88=90=20DeepSeek=20?= =?UTF-8?q?=E6=96=87=E6=9C=AC=E7=94=9F=E6=88=90=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在文本生成模型提供商列表中添加 DeepSeek - 实现 DeepSeek API 的生成器类 - 在脚本生成器中支持 DeepSeek 模型 - 优化脚本处理过程中的错误提示 --- app/utils/script_generator.py | 41 ++++++++++++++++++++++++++++- webui/components/basic_settings.py | 2 +- webui/components/script_settings.py | 4 +-- 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/app/utils/script_generator.py b/app/utils/script_generator.py index f643aac..f336ab2 100644 --- a/app/utils/script_generator.py +++ b/app/utils/script_generator.py @@ -223,7 +223,7 @@ class MoonshotGenerator(BaseGenerator): } def _generate(self, messages: list, params: dict) -> any: - """实现Moonshot特定的生成逻辑,包含429错误重试机制""" + """实现Moonshot特定的生成逻辑,包含429���误重试机制""" while True: try: response = self.client.chat.completions.create( @@ -249,6 +249,42 @@ class MoonshotGenerator(BaseGenerator): return response.choices[0].message.content.strip() +class DeepSeekGenerator(BaseGenerator): + """DeepSeek API 生成器实现""" + def __init__(self, model_name: str, api_key: str, prompt: str, base_url: str): + super().__init__(model_name, api_key, prompt) + self.client = OpenAI( + api_key=api_key, + base_url=base_url or "https://api.deepseek.com" + ) + + # DeepSeek特定参数 + self.default_params = { + **self.default_params, + "stream": False, + "user": "script_generator" + } + + def _generate(self, messages: list, params: dict) -> any: + """实现DeepSeek特定的生成逻辑""" + try: + response = self.client.chat.completions.create( + model=self.model_name, # deepseek-chat 或 deepseek-coder + messages=messages, + **params + ) + return response + except Exception as e: + logger.error(f"DeepSeek generation error: {str(e)}") + raise + + def _process_response(self, response: any) -> str: + """处理DeepSeek的响应""" + if not response or not response.choices: + raise ValueError("Invalid response from DeepSeek API") + return response.choices[0].message.content.strip() + + class ScriptProcessor: def __init__(self, model_name: str, api_key: str = None, base_url: str = None, prompt: str = None, video_theme: str = ""): self.model_name = model_name @@ -258,12 +294,15 @@ class ScriptProcessor: self.prompt = prompt or self._get_default_prompt() # 根据模型名称选择对应的生成器 + logger.info(f"文本 LLM 提供商: {model_name}") if 'gemini' in model_name.lower(): self.generator = GeminiGenerator(model_name, self.api_key, self.prompt) elif 'qwen' in model_name.lower(): self.generator = QwenGenerator(model_name, self.api_key, self.prompt, self.base_url) elif 'moonshot' in model_name.lower(): self.generator = MoonshotGenerator(model_name, self.api_key, self.prompt, self.base_url) + elif 'deepseek' in model_name.lower(): + self.generator = DeepSeekGenerator(model_name, self.api_key, self.prompt, self.base_url) else: self.generator = OpenAIGenerator(model_name, self.api_key, self.prompt, self.base_url) diff --git a/webui/components/basic_settings.py b/webui/components/basic_settings.py index 855db83..e5fa8f7 100644 --- a/webui/components/basic_settings.py +++ b/webui/components/basic_settings.py @@ -191,7 +191,7 @@ def render_text_llm_settings(tr): st.subheader(tr("Text Generation Model Settings")) # 文案生成模型提供商选择 - text_providers = ['OpenAI', 'Qwen', 'Moonshot', 'Gemini'] + text_providers = ['OpenAI', 'Qwen', 'Moonshot', 'DeepSeek', 'Gemini'] saved_text_provider = config.app.get("text_llm_provider", "OpenAI").lower() saved_provider_index = 0 diff --git a/webui/components/script_settings.py b/webui/components/script_settings.py index 0b3d5d1..bfbe297 100644 --- a/webui/components/script_settings.py +++ b/webui/components/script_settings.py @@ -501,8 +501,8 @@ def generate_script(tr, params): script = json.dumps(script_result, ensure_ascii=False, indent=2) except Exception as e: - logger.exception(f"Gemini 处理过程中发生错误\n{traceback.format_exc()}") - raise Exception(f"视觉分析失败: {str(e)}") + logger.exception(f"大模型处理过程中发生错误\n{traceback.format_exc()}") + raise Exception(f"分析失败: {str(e)}") elif vision_llm_provider == 'narratoapi': # NarratoAPI try: