diff --git a/app/services/llm/litellm_provider.py b/app/services/llm/litellm_provider.py index d3302ee..a774f9f 100644 --- a/app/services/llm/litellm_provider.py +++ b/app/services/llm/litellm_provider.py @@ -187,8 +187,27 @@ class LiteLLMVisionProvider(VisionModelProvider): # 调用 LiteLLM try: # 准备参数 + effective_model_name = self.model_name + + # SiliconFlow 特殊处理 + if self.model_name.lower().startswith("siliconflow/"): + # 替换 provider 为 openai + if "/" in self.model_name: + effective_model_name = f"openai/{self.model_name.split('/', 1)[1]}" + else: + effective_model_name = f"openai/{self.model_name}" + + # 确保设置了 OPENAI_API_KEY (如果尚未设置) + import os + if not os.environ.get("OPENAI_API_KEY") and os.environ.get("SILICONFLOW_API_KEY"): + os.environ["OPENAI_API_KEY"] = os.environ.get("SILICONFLOW_API_KEY") + + # 确保设置了 base_url (如果尚未设置) + if not hasattr(self, '_api_base'): + self._api_base = "https://api.siliconflow.cn/v1" + completion_kwargs = { - "model": self.model_name, + "model": effective_model_name, "messages": messages, "temperature": kwargs.get("temperature", 1.0), "max_tokens": kwargs.get("max_tokens", 4000) @@ -346,8 +365,27 @@ class LiteLLMTextProvider(TextModelProvider): messages = self._build_messages(prompt, system_prompt) # 准备参数 + effective_model_name = self.model_name + + # SiliconFlow 特殊处理 + if self.model_name.lower().startswith("siliconflow/"): + # 替换 provider 为 openai + if "/" in self.model_name: + effective_model_name = f"openai/{self.model_name.split('/', 1)[1]}" + else: + effective_model_name = f"openai/{self.model_name}" + + # 确保设置了 OPENAI_API_KEY (如果尚未设置) + import os + if not os.environ.get("OPENAI_API_KEY") and os.environ.get("SILICONFLOW_API_KEY"): + os.environ["OPENAI_API_KEY"] = os.environ.get("SILICONFLOW_API_KEY") + + # 确保设置了 base_url (如果尚未设置) + if not hasattr(self, '_api_base'): + self._api_base = "https://api.siliconflow.cn/v1" + completion_kwargs = { - "model": self.model_name, + "model": effective_model_name, "messages": messages, "temperature": temperature } diff --git a/webui/components/basic_settings.py b/webui/components/basic_settings.py index f19ffd1..b6cf115 100644 --- a/webui/components/basic_settings.py +++ b/webui/components/basic_settings.py @@ -316,9 +316,26 @@ def test_litellm_vision_model(api_key: str, base_url: str, model_name: str, tr) old_key = os.environ.get(env_var) os.environ[env_var] = api_key + # SiliconFlow 特殊处理:使用 OpenAI 兼容模式 + test_model_name = model_name + if provider.lower() == "siliconflow": + # 替换 provider 为 openai + if "/" in model_name: + test_model_name = f"openai/{model_name.split('/', 1)[1]}" + else: + test_model_name = f"openai/{model_name}" + + # 确保设置了 base_url + if not base_url: + base_url = "https://api.siliconflow.cn/v1" + + # 设置 OPENAI_API_KEY (SiliconFlow 使用 OpenAI 协议) + os.environ["OPENAI_API_KEY"] = api_key + os.environ["OPENAI_API_BASE"] = base_url + try: - # 创建测试图片(1x1 白色像素) - test_image = Image.new('RGB', (1, 1), color='white') + # 创建测试图片(64x64 白色像素,避免某些模型对极小图片的限制) + test_image = Image.new('RGB', (64, 64), color='white') img_buffer = io.BytesIO() test_image.save(img_buffer, format='JPEG') img_bytes = img_buffer.getvalue() @@ -340,7 +357,7 @@ def test_litellm_vision_model(api_key: str, base_url: str, model_name: str, tr) # 准备参数 completion_kwargs = { - "model": model_name, + "model": test_model_name, "messages": messages, "temperature": 0.1, "max_tokens": 50 @@ -363,6 +380,11 @@ def test_litellm_vision_model(api_key: str, base_url: str, model_name: str, tr) os.environ[env_var] = old_key else: os.environ.pop(env_var, None) + + # 清理临时设置的 OpenAI 环境变量 + if provider.lower() == "siliconflow": + os.environ.pop("OPENAI_API_KEY", None) + os.environ.pop("OPENAI_API_BASE", None) except Exception as e: error_msg = str(e) @@ -415,6 +437,23 @@ def test_litellm_text_model(api_key: str, base_url: str, model_name: str, tr) -> old_key = os.environ.get(env_var) os.environ[env_var] = api_key + # SiliconFlow 特殊处理:使用 OpenAI 兼容模式 + test_model_name = model_name + if provider.lower() == "siliconflow": + # 替换 provider 为 openai + if "/" in model_name: + test_model_name = f"openai/{model_name.split('/', 1)[1]}" + else: + test_model_name = f"openai/{model_name}" + + # 确保设置了 base_url + if not base_url: + base_url = "https://api.siliconflow.cn/v1" + + # 设置 OPENAI_API_KEY (SiliconFlow 使用 OpenAI 协议) + os.environ["OPENAI_API_KEY"] = api_key + os.environ["OPENAI_API_BASE"] = base_url + try: # 构建测试请求 messages = [ @@ -423,7 +462,7 @@ def test_litellm_text_model(api_key: str, base_url: str, model_name: str, tr) -> # 准备参数 completion_kwargs = { - "model": model_name, + "model": test_model_name, "messages": messages, "temperature": 0.1, "max_tokens": 20 @@ -446,6 +485,11 @@ def test_litellm_text_model(api_key: str, base_url: str, model_name: str, tr) -> os.environ[env_var] = old_key else: os.environ.pop(env_var, None) + + # 清理临时设置的 OpenAI 环境变量 + if provider.lower() == "siliconflow": + os.environ.pop("OPENAI_API_KEY", None) + os.environ.pop("OPENAI_API_BASE", None) except Exception as e: error_msg = str(e) @@ -469,23 +513,61 @@ def render_vision_llm_settings(tr): config.app["vision_llm_provider"] = "litellm" # 获取已保存的 LiteLLM 配置 - vision_model_name = config.app.get("vision_litellm_model_name", "gemini/gemini-2.0-flash-lite") + full_vision_model_name = config.app.get("vision_litellm_model_name", "gemini/gemini-2.0-flash-lite") vision_api_key = config.app.get("vision_litellm_api_key", "") vision_base_url = config.app.get("vision_litellm_base_url", "") + # 解析 provider 和 model + default_provider = "gemini" + default_model = "gemini-2.0-flash-lite" + + if "/" in full_vision_model_name: + parts = full_vision_model_name.split("/", 1) + current_provider = parts[0] + current_model = parts[1] + else: + current_provider = default_provider + current_model = full_vision_model_name + + # 定义支持的 provider 列表 + LITELLM_PROVIDERS = [ + "openai", "gemini", "deepseek", "qwen", "siliconflow", "moonshot", + "anthropic", "azure", "ollama", "vertex_ai", "mistral", "codestral", + "volcengine", "groq", "cohere", "together_ai", "fireworks_ai", + "openrouter", "replicate", "huggingface", "xai", "deepgram", "vllm", + "bedrock", "cloudflare" + ] + + # 如果当前 provider 不在列表中,添加到列表头部 + if current_provider not in LITELLM_PROVIDERS: + LITELLM_PROVIDERS.insert(0, current_provider) + # 渲染配置输入框 - st_vision_model_name = st.text_input( - tr("Vision Model Name"), - value=vision_model_name, - help="LiteLLM 模型格式: provider/model\n\n" - "常用示例:\n" - "• gemini/gemini-2.0-flash-lite (推荐,速度快)\n" - "• gemini/gemini-1.5-pro (高精度)\n" - "• openai/gpt-4o, openai/gpt-4o-mini\n" - "• qwen/qwen2.5-vl-32b-instruct\n" - "• siliconflow/Qwen/Qwen2.5-VL-32B-Instruct\n\n" - "支持 100+ providers,详见: https://docs.litellm.ai/docs/providers" - ) + col1, col2 = st.columns([1, 2]) + with col1: + selected_provider = st.selectbox( + tr("Vision Model Provider"), + options=LITELLM_PROVIDERS, + index=LITELLM_PROVIDERS.index(current_provider) if current_provider in LITELLM_PROVIDERS else 0, + key="vision_provider_select" + ) + + with col2: + model_name_input = st.text_input( + tr("Vision Model Name"), + value=current_model, + help="输入模型名称(不包含 provider 前缀)\n\n" + "常用示例:\n" + "• gemini-2.0-flash-lite\n" + "• gpt-4o\n" + "• qwen-vl-max\n" + "• Qwen/Qwen2.5-VL-32B-Instruct (SiliconFlow)\n\n" + "支持 100+ providers,详见: https://docs.litellm.ai/docs/providers", + key="vision_model_input" + ) + + # 组合完整的模型名称 + st_vision_model_name = f"{selected_provider}/{model_name_input}" if selected_provider and model_name_input else "" st_vision_api_key = st.text_input( tr("Vision API Key"), @@ -515,7 +597,7 @@ def render_vision_llm_settings(tr): test_errors = [] if not st_vision_api_key: test_errors.append("请先输入 API 密钥") - if not st_vision_model_name: + if not model_name_input: test_errors.append("请先输入模型名称") if test_errors: @@ -545,6 +627,7 @@ def render_vision_llm_settings(tr): # 验证模型名称 if st_vision_model_name: + # 这里的验证逻辑可能需要微调,因为我们现在是自动组合的 is_valid, error_msg = validate_litellm_model_name(st_vision_model_name, "视频分析") if is_valid: config.app["vision_litellm_model_name"] = st_vision_model_name @@ -698,24 +781,61 @@ def render_text_llm_settings(tr): config.app["text_llm_provider"] = "litellm" # 获取已保存的 LiteLLM 配置 - text_model_name = config.app.get("text_litellm_model_name", "deepseek/deepseek-chat") + full_text_model_name = config.app.get("text_litellm_model_name", "deepseek/deepseek-chat") text_api_key = config.app.get("text_litellm_api_key", "") text_base_url = config.app.get("text_litellm_base_url", "") + # 解析 provider 和 model + default_provider = "deepseek" + default_model = "deepseek-chat" + + if "/" in full_text_model_name: + parts = full_text_model_name.split("/", 1) + current_provider = parts[0] + current_model = parts[1] + else: + current_provider = default_provider + current_model = full_text_model_name + + # 定义支持的 provider 列表 + LITELLM_PROVIDERS = [ + "openai", "gemini", "deepseek", "qwen", "siliconflow", "moonshot", + "anthropic", "azure", "ollama", "vertex_ai", "mistral", "codestral", + "volcengine", "groq", "cohere", "together_ai", "fireworks_ai", + "openrouter", "replicate", "huggingface", "xai", "deepgram", "vllm", + "bedrock", "cloudflare" + ] + + # 如果当前 provider 不在列表中,添加到列表头部 + if current_provider not in LITELLM_PROVIDERS: + LITELLM_PROVIDERS.insert(0, current_provider) + # 渲染配置输入框 - st_text_model_name = st.text_input( - tr("Text Model Name"), - value=text_model_name, - help="LiteLLM 模型格式: provider/model\n\n" - "常用示例:\n" - "• deepseek/deepseek-chat (推荐,性价比高)\n" - "• gemini/gemini-2.0-flash (速度快)\n" - "• openai/gpt-4o, openai/gpt-4o-mini\n" - "• qwen/qwen-plus, qwen/qwen-turbo\n" - "• siliconflow/deepseek-ai/DeepSeek-R1\n" - "• moonshot/moonshot-v1-8k\n\n" - "支持 100+ providers,详见: https://docs.litellm.ai/docs/providers" - ) + col1, col2 = st.columns([1, 2]) + with col1: + selected_provider = st.selectbox( + tr("Text Model Provider"), + options=LITELLM_PROVIDERS, + index=LITELLM_PROVIDERS.index(current_provider) if current_provider in LITELLM_PROVIDERS else 0, + key="text_provider_select" + ) + + with col2: + model_name_input = st.text_input( + tr("Text Model Name"), + value=current_model, + help="输入模型名称(不包含 provider 前缀)\n\n" + "常用示例:\n" + "• deepseek-chat\n" + "• gpt-4o\n" + "• gemini-2.0-flash\n" + "• deepseek-ai/DeepSeek-R1 (SiliconFlow)\n\n" + "支持 100+ providers,详见: https://docs.litellm.ai/docs/providers", + key="text_model_input" + ) + + # 组合完整的模型名称 + st_text_model_name = f"{selected_provider}/{model_name_input}" if selected_provider and model_name_input else "" st_text_api_key = st.text_input( tr("Text API Key"), @@ -747,7 +867,7 @@ def render_text_llm_settings(tr): test_errors = [] if not st_text_api_key: test_errors.append("请先输入 API 密钥") - if not st_text_model_name: + if not model_name_input: test_errors.append("请先输入模型名称") if test_errors: