feat: 增强 LiteLLM 提供商配置并更新基本设置界面

This commit is contained in:
linyq 2025-11-19 19:10:07 +08:00
parent efa02d83ca
commit 77c0aa47f2
2 changed files with 193 additions and 35 deletions

View File

@ -187,8 +187,27 @@ class LiteLLMVisionProvider(VisionModelProvider):
# 调用 LiteLLM
try:
# 准备参数
effective_model_name = self.model_name
# SiliconFlow 特殊处理
if self.model_name.lower().startswith("siliconflow/"):
# 替换 provider 为 openai
if "/" in self.model_name:
effective_model_name = f"openai/{self.model_name.split('/', 1)[1]}"
else:
effective_model_name = f"openai/{self.model_name}"
# 确保设置了 OPENAI_API_KEY (如果尚未设置)
import os
if not os.environ.get("OPENAI_API_KEY") and os.environ.get("SILICONFLOW_API_KEY"):
os.environ["OPENAI_API_KEY"] = os.environ.get("SILICONFLOW_API_KEY")
# 确保设置了 base_url (如果尚未设置)
if not hasattr(self, '_api_base'):
self._api_base = "https://api.siliconflow.cn/v1"
completion_kwargs = {
"model": self.model_name,
"model": effective_model_name,
"messages": messages,
"temperature": kwargs.get("temperature", 1.0),
"max_tokens": kwargs.get("max_tokens", 4000)
@ -346,8 +365,27 @@ class LiteLLMTextProvider(TextModelProvider):
messages = self._build_messages(prompt, system_prompt)
# 准备参数
effective_model_name = self.model_name
# SiliconFlow 特殊处理
if self.model_name.lower().startswith("siliconflow/"):
# 替换 provider 为 openai
if "/" in self.model_name:
effective_model_name = f"openai/{self.model_name.split('/', 1)[1]}"
else:
effective_model_name = f"openai/{self.model_name}"
# 确保设置了 OPENAI_API_KEY (如果尚未设置)
import os
if not os.environ.get("OPENAI_API_KEY") and os.environ.get("SILICONFLOW_API_KEY"):
os.environ["OPENAI_API_KEY"] = os.environ.get("SILICONFLOW_API_KEY")
# 确保设置了 base_url (如果尚未设置)
if not hasattr(self, '_api_base'):
self._api_base = "https://api.siliconflow.cn/v1"
completion_kwargs = {
"model": self.model_name,
"model": effective_model_name,
"messages": messages,
"temperature": temperature
}

View File

@ -316,9 +316,26 @@ def test_litellm_vision_model(api_key: str, base_url: str, model_name: str, tr)
old_key = os.environ.get(env_var)
os.environ[env_var] = api_key
# SiliconFlow 特殊处理:使用 OpenAI 兼容模式
test_model_name = model_name
if provider.lower() == "siliconflow":
# 替换 provider 为 openai
if "/" in model_name:
test_model_name = f"openai/{model_name.split('/', 1)[1]}"
else:
test_model_name = f"openai/{model_name}"
# 确保设置了 base_url
if not base_url:
base_url = "https://api.siliconflow.cn/v1"
# 设置 OPENAI_API_KEY (SiliconFlow 使用 OpenAI 协议)
os.environ["OPENAI_API_KEY"] = api_key
os.environ["OPENAI_API_BASE"] = base_url
try:
# 创建测试图片1x1 白色像素)
test_image = Image.new('RGB', (1, 1), color='white')
# 创建测试图片(64x64 白色像素,避免某些模型对极小图片的限制
test_image = Image.new('RGB', (64, 64), color='white')
img_buffer = io.BytesIO()
test_image.save(img_buffer, format='JPEG')
img_bytes = img_buffer.getvalue()
@ -340,7 +357,7 @@ def test_litellm_vision_model(api_key: str, base_url: str, model_name: str, tr)
# 准备参数
completion_kwargs = {
"model": model_name,
"model": test_model_name,
"messages": messages,
"temperature": 0.1,
"max_tokens": 50
@ -363,6 +380,11 @@ def test_litellm_vision_model(api_key: str, base_url: str, model_name: str, tr)
os.environ[env_var] = old_key
else:
os.environ.pop(env_var, None)
# 清理临时设置的 OpenAI 环境变量
if provider.lower() == "siliconflow":
os.environ.pop("OPENAI_API_KEY", None)
os.environ.pop("OPENAI_API_BASE", None)
except Exception as e:
error_msg = str(e)
@ -415,6 +437,23 @@ def test_litellm_text_model(api_key: str, base_url: str, model_name: str, tr) ->
old_key = os.environ.get(env_var)
os.environ[env_var] = api_key
# SiliconFlow 特殊处理:使用 OpenAI 兼容模式
test_model_name = model_name
if provider.lower() == "siliconflow":
# 替换 provider 为 openai
if "/" in model_name:
test_model_name = f"openai/{model_name.split('/', 1)[1]}"
else:
test_model_name = f"openai/{model_name}"
# 确保设置了 base_url
if not base_url:
base_url = "https://api.siliconflow.cn/v1"
# 设置 OPENAI_API_KEY (SiliconFlow 使用 OpenAI 协议)
os.environ["OPENAI_API_KEY"] = api_key
os.environ["OPENAI_API_BASE"] = base_url
try:
# 构建测试请求
messages = [
@ -423,7 +462,7 @@ def test_litellm_text_model(api_key: str, base_url: str, model_name: str, tr) ->
# 准备参数
completion_kwargs = {
"model": model_name,
"model": test_model_name,
"messages": messages,
"temperature": 0.1,
"max_tokens": 20
@ -446,6 +485,11 @@ def test_litellm_text_model(api_key: str, base_url: str, model_name: str, tr) ->
os.environ[env_var] = old_key
else:
os.environ.pop(env_var, None)
# 清理临时设置的 OpenAI 环境变量
if provider.lower() == "siliconflow":
os.environ.pop("OPENAI_API_KEY", None)
os.environ.pop("OPENAI_API_BASE", None)
except Exception as e:
error_msg = str(e)
@ -469,23 +513,61 @@ def render_vision_llm_settings(tr):
config.app["vision_llm_provider"] = "litellm"
# 获取已保存的 LiteLLM 配置
vision_model_name = config.app.get("vision_litellm_model_name", "gemini/gemini-2.0-flash-lite")
full_vision_model_name = config.app.get("vision_litellm_model_name", "gemini/gemini-2.0-flash-lite")
vision_api_key = config.app.get("vision_litellm_api_key", "")
vision_base_url = config.app.get("vision_litellm_base_url", "")
# 解析 provider 和 model
default_provider = "gemini"
default_model = "gemini-2.0-flash-lite"
if "/" in full_vision_model_name:
parts = full_vision_model_name.split("/", 1)
current_provider = parts[0]
current_model = parts[1]
else:
current_provider = default_provider
current_model = full_vision_model_name
# 定义支持的 provider 列表
LITELLM_PROVIDERS = [
"openai", "gemini", "deepseek", "qwen", "siliconflow", "moonshot",
"anthropic", "azure", "ollama", "vertex_ai", "mistral", "codestral",
"volcengine", "groq", "cohere", "together_ai", "fireworks_ai",
"openrouter", "replicate", "huggingface", "xai", "deepgram", "vllm",
"bedrock", "cloudflare"
]
# 如果当前 provider 不在列表中,添加到列表头部
if current_provider not in LITELLM_PROVIDERS:
LITELLM_PROVIDERS.insert(0, current_provider)
# 渲染配置输入框
st_vision_model_name = st.text_input(
tr("Vision Model Name"),
value=vision_model_name,
help="LiteLLM 模型格式: provider/model\n\n"
"常用示例:\n"
"• gemini/gemini-2.0-flash-lite (推荐,速度快)\n"
"• gemini/gemini-1.5-pro (高精度)\n"
"• openai/gpt-4o, openai/gpt-4o-mini\n"
"• qwen/qwen2.5-vl-32b-instruct\n"
"• siliconflow/Qwen/Qwen2.5-VL-32B-Instruct\n\n"
"支持 100+ providers详见: https://docs.litellm.ai/docs/providers"
)
col1, col2 = st.columns([1, 2])
with col1:
selected_provider = st.selectbox(
tr("Vision Model Provider"),
options=LITELLM_PROVIDERS,
index=LITELLM_PROVIDERS.index(current_provider) if current_provider in LITELLM_PROVIDERS else 0,
key="vision_provider_select"
)
with col2:
model_name_input = st.text_input(
tr("Vision Model Name"),
value=current_model,
help="输入模型名称(不包含 provider 前缀)\n\n"
"常用示例:\n"
"• gemini-2.0-flash-lite\n"
"• gpt-4o\n"
"• qwen-vl-max\n"
"• Qwen/Qwen2.5-VL-32B-Instruct (SiliconFlow)\n\n"
"支持 100+ providers详见: https://docs.litellm.ai/docs/providers",
key="vision_model_input"
)
# 组合完整的模型名称
st_vision_model_name = f"{selected_provider}/{model_name_input}" if selected_provider and model_name_input else ""
st_vision_api_key = st.text_input(
tr("Vision API Key"),
@ -515,7 +597,7 @@ def render_vision_llm_settings(tr):
test_errors = []
if not st_vision_api_key:
test_errors.append("请先输入 API 密钥")
if not st_vision_model_name:
if not model_name_input:
test_errors.append("请先输入模型名称")
if test_errors:
@ -545,6 +627,7 @@ def render_vision_llm_settings(tr):
# 验证模型名称
if st_vision_model_name:
# 这里的验证逻辑可能需要微调,因为我们现在是自动组合的
is_valid, error_msg = validate_litellm_model_name(st_vision_model_name, "视频分析")
if is_valid:
config.app["vision_litellm_model_name"] = st_vision_model_name
@ -698,24 +781,61 @@ def render_text_llm_settings(tr):
config.app["text_llm_provider"] = "litellm"
# 获取已保存的 LiteLLM 配置
text_model_name = config.app.get("text_litellm_model_name", "deepseek/deepseek-chat")
full_text_model_name = config.app.get("text_litellm_model_name", "deepseek/deepseek-chat")
text_api_key = config.app.get("text_litellm_api_key", "")
text_base_url = config.app.get("text_litellm_base_url", "")
# 解析 provider 和 model
default_provider = "deepseek"
default_model = "deepseek-chat"
if "/" in full_text_model_name:
parts = full_text_model_name.split("/", 1)
current_provider = parts[0]
current_model = parts[1]
else:
current_provider = default_provider
current_model = full_text_model_name
# 定义支持的 provider 列表
LITELLM_PROVIDERS = [
"openai", "gemini", "deepseek", "qwen", "siliconflow", "moonshot",
"anthropic", "azure", "ollama", "vertex_ai", "mistral", "codestral",
"volcengine", "groq", "cohere", "together_ai", "fireworks_ai",
"openrouter", "replicate", "huggingface", "xai", "deepgram", "vllm",
"bedrock", "cloudflare"
]
# 如果当前 provider 不在列表中,添加到列表头部
if current_provider not in LITELLM_PROVIDERS:
LITELLM_PROVIDERS.insert(0, current_provider)
# 渲染配置输入框
st_text_model_name = st.text_input(
tr("Text Model Name"),
value=text_model_name,
help="LiteLLM 模型格式: provider/model\n\n"
"常用示例:\n"
"• deepseek/deepseek-chat (推荐,性价比高)\n"
"• gemini/gemini-2.0-flash (速度快)\n"
"• openai/gpt-4o, openai/gpt-4o-mini\n"
"• qwen/qwen-plus, qwen/qwen-turbo\n"
"• siliconflow/deepseek-ai/DeepSeek-R1\n"
"• moonshot/moonshot-v1-8k\n\n"
"支持 100+ providers详见: https://docs.litellm.ai/docs/providers"
)
col1, col2 = st.columns([1, 2])
with col1:
selected_provider = st.selectbox(
tr("Text Model Provider"),
options=LITELLM_PROVIDERS,
index=LITELLM_PROVIDERS.index(current_provider) if current_provider in LITELLM_PROVIDERS else 0,
key="text_provider_select"
)
with col2:
model_name_input = st.text_input(
tr("Text Model Name"),
value=current_model,
help="输入模型名称(不包含 provider 前缀)\n\n"
"常用示例:\n"
"• deepseek-chat\n"
"• gpt-4o\n"
"• gemini-2.0-flash\n"
"• deepseek-ai/DeepSeek-R1 (SiliconFlow)\n\n"
"支持 100+ providers详见: https://docs.litellm.ai/docs/providers",
key="text_model_input"
)
# 组合完整的模型名称
st_text_model_name = f"{selected_provider}/{model_name_input}" if selected_provider and model_name_input else ""
st_text_api_key = st.text_input(
tr("Text API Key"),
@ -747,7 +867,7 @@ def render_text_llm_settings(tr):
test_errors = []
if not st_text_api_key:
test_errors.append("请先输入 API 密钥")
if not st_text_model_name:
if not model_name_input:
test_errors.append("请先输入模型名称")
if test_errors: