mirror of
https://github.com/linyqh/NarratoAI.git
synced 2025-12-12 19:52:48 +00:00
feat: 增强 LiteLLM 提供商配置并更新基本设置界面
This commit is contained in:
parent
efa02d83ca
commit
77c0aa47f2
@ -187,8 +187,27 @@ class LiteLLMVisionProvider(VisionModelProvider):
|
||||
# 调用 LiteLLM
|
||||
try:
|
||||
# 准备参数
|
||||
effective_model_name = self.model_name
|
||||
|
||||
# SiliconFlow 特殊处理
|
||||
if self.model_name.lower().startswith("siliconflow/"):
|
||||
# 替换 provider 为 openai
|
||||
if "/" in self.model_name:
|
||||
effective_model_name = f"openai/{self.model_name.split('/', 1)[1]}"
|
||||
else:
|
||||
effective_model_name = f"openai/{self.model_name}"
|
||||
|
||||
# 确保设置了 OPENAI_API_KEY (如果尚未设置)
|
||||
import os
|
||||
if not os.environ.get("OPENAI_API_KEY") and os.environ.get("SILICONFLOW_API_KEY"):
|
||||
os.environ["OPENAI_API_KEY"] = os.environ.get("SILICONFLOW_API_KEY")
|
||||
|
||||
# 确保设置了 base_url (如果尚未设置)
|
||||
if not hasattr(self, '_api_base'):
|
||||
self._api_base = "https://api.siliconflow.cn/v1"
|
||||
|
||||
completion_kwargs = {
|
||||
"model": self.model_name,
|
||||
"model": effective_model_name,
|
||||
"messages": messages,
|
||||
"temperature": kwargs.get("temperature", 1.0),
|
||||
"max_tokens": kwargs.get("max_tokens", 4000)
|
||||
@ -346,8 +365,27 @@ class LiteLLMTextProvider(TextModelProvider):
|
||||
messages = self._build_messages(prompt, system_prompt)
|
||||
|
||||
# 准备参数
|
||||
effective_model_name = self.model_name
|
||||
|
||||
# SiliconFlow 特殊处理
|
||||
if self.model_name.lower().startswith("siliconflow/"):
|
||||
# 替换 provider 为 openai
|
||||
if "/" in self.model_name:
|
||||
effective_model_name = f"openai/{self.model_name.split('/', 1)[1]}"
|
||||
else:
|
||||
effective_model_name = f"openai/{self.model_name}"
|
||||
|
||||
# 确保设置了 OPENAI_API_KEY (如果尚未设置)
|
||||
import os
|
||||
if not os.environ.get("OPENAI_API_KEY") and os.environ.get("SILICONFLOW_API_KEY"):
|
||||
os.environ["OPENAI_API_KEY"] = os.environ.get("SILICONFLOW_API_KEY")
|
||||
|
||||
# 确保设置了 base_url (如果尚未设置)
|
||||
if not hasattr(self, '_api_base'):
|
||||
self._api_base = "https://api.siliconflow.cn/v1"
|
||||
|
||||
completion_kwargs = {
|
||||
"model": self.model_name,
|
||||
"model": effective_model_name,
|
||||
"messages": messages,
|
||||
"temperature": temperature
|
||||
}
|
||||
|
||||
@ -316,9 +316,26 @@ def test_litellm_vision_model(api_key: str, base_url: str, model_name: str, tr)
|
||||
old_key = os.environ.get(env_var)
|
||||
os.environ[env_var] = api_key
|
||||
|
||||
# SiliconFlow 特殊处理:使用 OpenAI 兼容模式
|
||||
test_model_name = model_name
|
||||
if provider.lower() == "siliconflow":
|
||||
# 替换 provider 为 openai
|
||||
if "/" in model_name:
|
||||
test_model_name = f"openai/{model_name.split('/', 1)[1]}"
|
||||
else:
|
||||
test_model_name = f"openai/{model_name}"
|
||||
|
||||
# 确保设置了 base_url
|
||||
if not base_url:
|
||||
base_url = "https://api.siliconflow.cn/v1"
|
||||
|
||||
# 设置 OPENAI_API_KEY (SiliconFlow 使用 OpenAI 协议)
|
||||
os.environ["OPENAI_API_KEY"] = api_key
|
||||
os.environ["OPENAI_API_BASE"] = base_url
|
||||
|
||||
try:
|
||||
# 创建测试图片(1x1 白色像素)
|
||||
test_image = Image.new('RGB', (1, 1), color='white')
|
||||
# 创建测试图片(64x64 白色像素,避免某些模型对极小图片的限制)
|
||||
test_image = Image.new('RGB', (64, 64), color='white')
|
||||
img_buffer = io.BytesIO()
|
||||
test_image.save(img_buffer, format='JPEG')
|
||||
img_bytes = img_buffer.getvalue()
|
||||
@ -340,7 +357,7 @@ def test_litellm_vision_model(api_key: str, base_url: str, model_name: str, tr)
|
||||
|
||||
# 准备参数
|
||||
completion_kwargs = {
|
||||
"model": model_name,
|
||||
"model": test_model_name,
|
||||
"messages": messages,
|
||||
"temperature": 0.1,
|
||||
"max_tokens": 50
|
||||
@ -363,6 +380,11 @@ def test_litellm_vision_model(api_key: str, base_url: str, model_name: str, tr)
|
||||
os.environ[env_var] = old_key
|
||||
else:
|
||||
os.environ.pop(env_var, None)
|
||||
|
||||
# 清理临时设置的 OpenAI 环境变量
|
||||
if provider.lower() == "siliconflow":
|
||||
os.environ.pop("OPENAI_API_KEY", None)
|
||||
os.environ.pop("OPENAI_API_BASE", None)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
@ -415,6 +437,23 @@ def test_litellm_text_model(api_key: str, base_url: str, model_name: str, tr) ->
|
||||
old_key = os.environ.get(env_var)
|
||||
os.environ[env_var] = api_key
|
||||
|
||||
# SiliconFlow 特殊处理:使用 OpenAI 兼容模式
|
||||
test_model_name = model_name
|
||||
if provider.lower() == "siliconflow":
|
||||
# 替换 provider 为 openai
|
||||
if "/" in model_name:
|
||||
test_model_name = f"openai/{model_name.split('/', 1)[1]}"
|
||||
else:
|
||||
test_model_name = f"openai/{model_name}"
|
||||
|
||||
# 确保设置了 base_url
|
||||
if not base_url:
|
||||
base_url = "https://api.siliconflow.cn/v1"
|
||||
|
||||
# 设置 OPENAI_API_KEY (SiliconFlow 使用 OpenAI 协议)
|
||||
os.environ["OPENAI_API_KEY"] = api_key
|
||||
os.environ["OPENAI_API_BASE"] = base_url
|
||||
|
||||
try:
|
||||
# 构建测试请求
|
||||
messages = [
|
||||
@ -423,7 +462,7 @@ def test_litellm_text_model(api_key: str, base_url: str, model_name: str, tr) ->
|
||||
|
||||
# 准备参数
|
||||
completion_kwargs = {
|
||||
"model": model_name,
|
||||
"model": test_model_name,
|
||||
"messages": messages,
|
||||
"temperature": 0.1,
|
||||
"max_tokens": 20
|
||||
@ -446,6 +485,11 @@ def test_litellm_text_model(api_key: str, base_url: str, model_name: str, tr) ->
|
||||
os.environ[env_var] = old_key
|
||||
else:
|
||||
os.environ.pop(env_var, None)
|
||||
|
||||
# 清理临时设置的 OpenAI 环境变量
|
||||
if provider.lower() == "siliconflow":
|
||||
os.environ.pop("OPENAI_API_KEY", None)
|
||||
os.environ.pop("OPENAI_API_BASE", None)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
@ -469,23 +513,61 @@ def render_vision_llm_settings(tr):
|
||||
config.app["vision_llm_provider"] = "litellm"
|
||||
|
||||
# 获取已保存的 LiteLLM 配置
|
||||
vision_model_name = config.app.get("vision_litellm_model_name", "gemini/gemini-2.0-flash-lite")
|
||||
full_vision_model_name = config.app.get("vision_litellm_model_name", "gemini/gemini-2.0-flash-lite")
|
||||
vision_api_key = config.app.get("vision_litellm_api_key", "")
|
||||
vision_base_url = config.app.get("vision_litellm_base_url", "")
|
||||
|
||||
# 解析 provider 和 model
|
||||
default_provider = "gemini"
|
||||
default_model = "gemini-2.0-flash-lite"
|
||||
|
||||
if "/" in full_vision_model_name:
|
||||
parts = full_vision_model_name.split("/", 1)
|
||||
current_provider = parts[0]
|
||||
current_model = parts[1]
|
||||
else:
|
||||
current_provider = default_provider
|
||||
current_model = full_vision_model_name
|
||||
|
||||
# 定义支持的 provider 列表
|
||||
LITELLM_PROVIDERS = [
|
||||
"openai", "gemini", "deepseek", "qwen", "siliconflow", "moonshot",
|
||||
"anthropic", "azure", "ollama", "vertex_ai", "mistral", "codestral",
|
||||
"volcengine", "groq", "cohere", "together_ai", "fireworks_ai",
|
||||
"openrouter", "replicate", "huggingface", "xai", "deepgram", "vllm",
|
||||
"bedrock", "cloudflare"
|
||||
]
|
||||
|
||||
# 如果当前 provider 不在列表中,添加到列表头部
|
||||
if current_provider not in LITELLM_PROVIDERS:
|
||||
LITELLM_PROVIDERS.insert(0, current_provider)
|
||||
|
||||
# 渲染配置输入框
|
||||
st_vision_model_name = st.text_input(
|
||||
tr("Vision Model Name"),
|
||||
value=vision_model_name,
|
||||
help="LiteLLM 模型格式: provider/model\n\n"
|
||||
"常用示例:\n"
|
||||
"• gemini/gemini-2.0-flash-lite (推荐,速度快)\n"
|
||||
"• gemini/gemini-1.5-pro (高精度)\n"
|
||||
"• openai/gpt-4o, openai/gpt-4o-mini\n"
|
||||
"• qwen/qwen2.5-vl-32b-instruct\n"
|
||||
"• siliconflow/Qwen/Qwen2.5-VL-32B-Instruct\n\n"
|
||||
"支持 100+ providers,详见: https://docs.litellm.ai/docs/providers"
|
||||
)
|
||||
col1, col2 = st.columns([1, 2])
|
||||
with col1:
|
||||
selected_provider = st.selectbox(
|
||||
tr("Vision Model Provider"),
|
||||
options=LITELLM_PROVIDERS,
|
||||
index=LITELLM_PROVIDERS.index(current_provider) if current_provider in LITELLM_PROVIDERS else 0,
|
||||
key="vision_provider_select"
|
||||
)
|
||||
|
||||
with col2:
|
||||
model_name_input = st.text_input(
|
||||
tr("Vision Model Name"),
|
||||
value=current_model,
|
||||
help="输入模型名称(不包含 provider 前缀)\n\n"
|
||||
"常用示例:\n"
|
||||
"• gemini-2.0-flash-lite\n"
|
||||
"• gpt-4o\n"
|
||||
"• qwen-vl-max\n"
|
||||
"• Qwen/Qwen2.5-VL-32B-Instruct (SiliconFlow)\n\n"
|
||||
"支持 100+ providers,详见: https://docs.litellm.ai/docs/providers",
|
||||
key="vision_model_input"
|
||||
)
|
||||
|
||||
# 组合完整的模型名称
|
||||
st_vision_model_name = f"{selected_provider}/{model_name_input}" if selected_provider and model_name_input else ""
|
||||
|
||||
st_vision_api_key = st.text_input(
|
||||
tr("Vision API Key"),
|
||||
@ -515,7 +597,7 @@ def render_vision_llm_settings(tr):
|
||||
test_errors = []
|
||||
if not st_vision_api_key:
|
||||
test_errors.append("请先输入 API 密钥")
|
||||
if not st_vision_model_name:
|
||||
if not model_name_input:
|
||||
test_errors.append("请先输入模型名称")
|
||||
|
||||
if test_errors:
|
||||
@ -545,6 +627,7 @@ def render_vision_llm_settings(tr):
|
||||
|
||||
# 验证模型名称
|
||||
if st_vision_model_name:
|
||||
# 这里的验证逻辑可能需要微调,因为我们现在是自动组合的
|
||||
is_valid, error_msg = validate_litellm_model_name(st_vision_model_name, "视频分析")
|
||||
if is_valid:
|
||||
config.app["vision_litellm_model_name"] = st_vision_model_name
|
||||
@ -698,24 +781,61 @@ def render_text_llm_settings(tr):
|
||||
config.app["text_llm_provider"] = "litellm"
|
||||
|
||||
# 获取已保存的 LiteLLM 配置
|
||||
text_model_name = config.app.get("text_litellm_model_name", "deepseek/deepseek-chat")
|
||||
full_text_model_name = config.app.get("text_litellm_model_name", "deepseek/deepseek-chat")
|
||||
text_api_key = config.app.get("text_litellm_api_key", "")
|
||||
text_base_url = config.app.get("text_litellm_base_url", "")
|
||||
|
||||
# 解析 provider 和 model
|
||||
default_provider = "deepseek"
|
||||
default_model = "deepseek-chat"
|
||||
|
||||
if "/" in full_text_model_name:
|
||||
parts = full_text_model_name.split("/", 1)
|
||||
current_provider = parts[0]
|
||||
current_model = parts[1]
|
||||
else:
|
||||
current_provider = default_provider
|
||||
current_model = full_text_model_name
|
||||
|
||||
# 定义支持的 provider 列表
|
||||
LITELLM_PROVIDERS = [
|
||||
"openai", "gemini", "deepseek", "qwen", "siliconflow", "moonshot",
|
||||
"anthropic", "azure", "ollama", "vertex_ai", "mistral", "codestral",
|
||||
"volcengine", "groq", "cohere", "together_ai", "fireworks_ai",
|
||||
"openrouter", "replicate", "huggingface", "xai", "deepgram", "vllm",
|
||||
"bedrock", "cloudflare"
|
||||
]
|
||||
|
||||
# 如果当前 provider 不在列表中,添加到列表头部
|
||||
if current_provider not in LITELLM_PROVIDERS:
|
||||
LITELLM_PROVIDERS.insert(0, current_provider)
|
||||
|
||||
# 渲染配置输入框
|
||||
st_text_model_name = st.text_input(
|
||||
tr("Text Model Name"),
|
||||
value=text_model_name,
|
||||
help="LiteLLM 模型格式: provider/model\n\n"
|
||||
"常用示例:\n"
|
||||
"• deepseek/deepseek-chat (推荐,性价比高)\n"
|
||||
"• gemini/gemini-2.0-flash (速度快)\n"
|
||||
"• openai/gpt-4o, openai/gpt-4o-mini\n"
|
||||
"• qwen/qwen-plus, qwen/qwen-turbo\n"
|
||||
"• siliconflow/deepseek-ai/DeepSeek-R1\n"
|
||||
"• moonshot/moonshot-v1-8k\n\n"
|
||||
"支持 100+ providers,详见: https://docs.litellm.ai/docs/providers"
|
||||
)
|
||||
col1, col2 = st.columns([1, 2])
|
||||
with col1:
|
||||
selected_provider = st.selectbox(
|
||||
tr("Text Model Provider"),
|
||||
options=LITELLM_PROVIDERS,
|
||||
index=LITELLM_PROVIDERS.index(current_provider) if current_provider in LITELLM_PROVIDERS else 0,
|
||||
key="text_provider_select"
|
||||
)
|
||||
|
||||
with col2:
|
||||
model_name_input = st.text_input(
|
||||
tr("Text Model Name"),
|
||||
value=current_model,
|
||||
help="输入模型名称(不包含 provider 前缀)\n\n"
|
||||
"常用示例:\n"
|
||||
"• deepseek-chat\n"
|
||||
"• gpt-4o\n"
|
||||
"• gemini-2.0-flash\n"
|
||||
"• deepseek-ai/DeepSeek-R1 (SiliconFlow)\n\n"
|
||||
"支持 100+ providers,详见: https://docs.litellm.ai/docs/providers",
|
||||
key="text_model_input"
|
||||
)
|
||||
|
||||
# 组合完整的模型名称
|
||||
st_text_model_name = f"{selected_provider}/{model_name_input}" if selected_provider and model_name_input else ""
|
||||
|
||||
st_text_api_key = st.text_input(
|
||||
tr("Text API Key"),
|
||||
@ -747,7 +867,7 @@ def render_text_llm_settings(tr):
|
||||
test_errors = []
|
||||
if not st_text_api_key:
|
||||
test_errors.append("请先输入 API 密钥")
|
||||
if not st_text_model_name:
|
||||
if not model_name_input:
|
||||
test_errors.append("请先输入模型名称")
|
||||
|
||||
if test_errors:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user