mirror of
https://github.com/linyqh/NarratoAI.git
synced 2026-07-02 20:35:28 +00:00
Warn for custom OpenAI base URLs
This commit is contained in:
parent
5f7eed9f85
commit
0774ac5385
@ -35,7 +35,6 @@ DEFAULT_LLM_APP_CONFIG = {
|
||||
"text_openai_model_name": DEFAULT_TEXT_OPENAI_MODEL_NAME,
|
||||
"text_openai_api_key": "",
|
||||
"text_openai_base_url": DEFAULT_OPENAI_COMPATIBLE_BASE_URL,
|
||||
"allow_custom_openai_base_url": False,
|
||||
"tavily_api_key": "",
|
||||
"tavily_search_depth": "basic",
|
||||
"tavily_max_results": 5,
|
||||
|
||||
@ -17,7 +17,10 @@ import subprocess
|
||||
from typing import Union, TextIO
|
||||
|
||||
from app.config import config
|
||||
from app.utils.openai_base_url_security import validate_openai_compatible_base_url
|
||||
from app.utils.openai_base_url_security import (
|
||||
openai_compatible_base_url_warning,
|
||||
validate_openai_compatible_base_url,
|
||||
)
|
||||
from app.utils.utils import clean_model_output
|
||||
|
||||
_max_retries = 5
|
||||
@ -332,6 +335,9 @@ def _generate_response(prompt: str, llm_provider: str = None) -> str:
|
||||
)
|
||||
else:
|
||||
base_url = validate_openai_compatible_base_url(base_url)
|
||||
warning = openai_compatible_base_url_warning(base_url)
|
||||
if warning:
|
||||
logger.warning(warning)
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
|
||||
@ -25,6 +25,7 @@ from app.config import config
|
||||
from app.config.defaults import DEFAULT_LLM_GENERATION_CONFIG, normalize_openai_compatible_model_name
|
||||
from app.utils.openai_base_url_security import (
|
||||
is_trusted_openai_compatible_base_url,
|
||||
openai_compatible_base_url_warning,
|
||||
validate_openai_compatible_base_url as _validate_openai_compatible_base_url_value,
|
||||
)
|
||||
from .base import TextModelProvider, VisionModelProvider
|
||||
@ -47,9 +48,13 @@ def _is_content_filter_error(message: str) -> bool:
|
||||
|
||||
def validate_openai_compatible_base_url(base_url: Optional[str]) -> Optional[str]:
|
||||
try:
|
||||
return _validate_openai_compatible_base_url_value(base_url)
|
||||
normalized = _validate_openai_compatible_base_url_value(base_url)
|
||||
except ValueError as exc:
|
||||
raise ConfigurationError(str(exc), "base_url") from exc
|
||||
warning = openai_compatible_base_url_warning(normalized)
|
||||
if warning:
|
||||
logger.warning(warning)
|
||||
return normalized
|
||||
|
||||
|
||||
def _clean_json_output(output: str) -> str:
|
||||
|
||||
@ -13,8 +13,10 @@ from app.services.llm.openai_compatible_provider import (
|
||||
OpenAICompatibleTextProvider,
|
||||
OpenAICompatibleVisionProvider,
|
||||
is_trusted_openai_compatible_base_url,
|
||||
validate_openai_compatible_base_url,
|
||||
)
|
||||
from app.services.llm.providers import register_all_providers
|
||||
from app.utils.openai_base_url_security import openai_compatible_base_url_warning
|
||||
|
||||
|
||||
class DummyOpenAITextProvider(TextModelProvider):
|
||||
@ -210,18 +212,7 @@ class OpenAICompatBaseURLValidationTests(unittest.TestCase):
|
||||
with self.subTest(url=url):
|
||||
self.assertFalse(is_trusted_openai_compatible_base_url(url))
|
||||
|
||||
def test_build_client_rejects_untrusted_base_url_by_default(self):
|
||||
provider = OpenAICompatibleTextProvider(
|
||||
api_key="test-key",
|
||||
model_name="test-model",
|
||||
base_url="https://attacker.example/v1",
|
||||
)
|
||||
|
||||
with self.assertRaises(ConfigurationError):
|
||||
provider._build_client()
|
||||
|
||||
def test_build_client_allows_explicit_custom_base_url_opt_in(self):
|
||||
config.app["allow_custom_openai_base_url"] = True
|
||||
def test_build_client_allows_well_formed_custom_base_url_by_default(self):
|
||||
provider = OpenAICompatibleTextProvider(
|
||||
api_key="test-key",
|
||||
model_name="test-model",
|
||||
@ -233,8 +224,19 @@ class OpenAICompatBaseURLValidationTests(unittest.TestCase):
|
||||
|
||||
self.assertEqual("https://custom.example/v1", async_openai.call_args.kwargs["base_url"])
|
||||
|
||||
def test_custom_base_url_opt_in_still_rejects_malformed_urls(self):
|
||||
config.app["allow_custom_openai_base_url"] = True
|
||||
def test_custom_base_url_validation_returns_normalized_url(self):
|
||||
self.assertEqual(
|
||||
"https://custom.example/v1",
|
||||
validate_openai_compatible_base_url(" https://custom.example/v1 "),
|
||||
)
|
||||
|
||||
def test_custom_base_url_warning_only_for_untrusted_well_formed_urls(self):
|
||||
warning = openai_compatible_base_url_warning("https://custom.example/v1")
|
||||
self.assertIn("custom.example", warning)
|
||||
self.assertEqual("", openai_compatible_base_url_warning("https://api.openai.com/v1"))
|
||||
self.assertEqual("", openai_compatible_base_url_warning(""))
|
||||
|
||||
def test_custom_base_url_validation_rejects_malformed_urls(self):
|
||||
provider = OpenAICompatibleTextProvider(
|
||||
api_key="test-key",
|
||||
model_name="test-model",
|
||||
|
||||
@ -2,8 +2,6 @@ import ipaddress
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from app.config import config
|
||||
|
||||
|
||||
TRUSTED_OPENAI_COMPATIBLE_BASE_HOSTS = {
|
||||
"api.openai.com",
|
||||
@ -28,17 +26,13 @@ TRUSTED_OPENAI_COMPATIBLE_BASE_SUFFIXES = (
|
||||
)
|
||||
|
||||
OPENAI_COMPATIBLE_BASE_URL_ERROR = (
|
||||
"OpenAI-compatible base_url is not in the trusted provider list. "
|
||||
"Use an official provider endpoint or set allow_custom_openai_base_url=true "
|
||||
"only if you understand that the configured endpoint receives the API key."
|
||||
"OpenAI-compatible base_url must be a valid http(s) URL without embedded credentials."
|
||||
)
|
||||
|
||||
|
||||
def custom_openai_base_url_allowed() -> bool:
|
||||
value = config.app.get("allow_custom_openai_base_url", False)
|
||||
if isinstance(value, str):
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
return bool(value)
|
||||
OPENAI_COMPATIBLE_BASE_URL_WARNING = (
|
||||
"OpenAI-compatible base_url host '{host}' is not in the trusted provider list. "
|
||||
"Only continue if you trust this endpoint, because it will receive the configured API key."
|
||||
)
|
||||
|
||||
|
||||
def _is_loopback_ollama_url(scheme: str, host: str, port: Optional[int]) -> bool:
|
||||
@ -91,6 +85,19 @@ def is_trusted_openai_compatible_base_url(base_url: Optional[str]) -> bool:
|
||||
return any(host.endswith(suffix) for suffix in TRUSTED_OPENAI_COMPATIBLE_BASE_SUFFIXES)
|
||||
|
||||
|
||||
def openai_compatible_base_url_warning(base_url: Optional[str]) -> str:
|
||||
if not base_url:
|
||||
return ""
|
||||
|
||||
normalized = str(base_url).strip()
|
||||
if is_trusted_openai_compatible_base_url(normalized) or not _is_well_formed_http_base_url(normalized):
|
||||
return ""
|
||||
|
||||
parsed = urlparse(normalized)
|
||||
host = (parsed.hostname or "").rstrip(".").lower()
|
||||
return OPENAI_COMPATIBLE_BASE_URL_WARNING.format(host=host)
|
||||
|
||||
|
||||
def validate_openai_compatible_base_url(base_url: Optional[str]) -> Optional[str]:
|
||||
if not base_url:
|
||||
return None
|
||||
@ -98,7 +105,7 @@ def validate_openai_compatible_base_url(base_url: Optional[str]) -> Optional[str
|
||||
normalized = str(base_url).strip()
|
||||
if is_trusted_openai_compatible_base_url(normalized):
|
||||
return normalized
|
||||
if custom_openai_base_url_allowed() and _is_well_formed_http_base_url(normalized):
|
||||
if _is_well_formed_http_base_url(normalized):
|
||||
return normalized
|
||||
|
||||
raise ValueError(OPENAI_COMPATIBLE_BASE_URL_ERROR)
|
||||
|
||||
@ -26,7 +26,7 @@
|
||||
# - SiliconFlow: siliconflow/Qwen/Qwen2.5-VL-32B-Instruct
|
||||
vision_openai_model_name = "Qwen/Qwen3.5-122B-A10B"
|
||||
vision_openai_api_key = "" # 填入对应 provider 的 API key
|
||||
vision_openai_base_url = "https://api.siliconflow.cn/v1" # 可选:自定义 API base URL(官方 OpenAI 可留空)
|
||||
vision_openai_base_url = "https://api.siliconflow.cn/v1" # 可选:自定义 API base URL;界面会提示 API key 将发送到对应端点
|
||||
vision_openai_temperature = 1.0
|
||||
vision_openai_top_p = 0.95
|
||||
vision_openai_max_tokens = 65536
|
||||
@ -45,15 +45,12 @@
|
||||
# - Moonshot: moonshot/moonshot-v1-8k
|
||||
text_openai_model_name = "Pro/zai-org/GLM-5"
|
||||
text_openai_api_key = "" # 填入对应 provider 的 API key
|
||||
text_openai_base_url = "https://api.siliconflow.cn/v1" # 可选:自定义 API base URL(官方 OpenAI 可留空)
|
||||
text_openai_base_url = "https://api.siliconflow.cn/v1" # 可选:自定义 API base URL;界面会提示 API key 将发送到对应端点
|
||||
text_openai_temperature = 1.0
|
||||
text_openai_top_p = 0.95
|
||||
text_openai_max_tokens = 65536
|
||||
text_openai_thinking_level = "auto" # auto/off/low/medium/high
|
||||
|
||||
# 默认只允许受信任的 OpenAI 兼容端点。确需自建网关时,确认 API key 会发送到该端点后再启用。
|
||||
allow_custom_openai_base_url = false
|
||||
|
||||
# ===== Tavily 联网搜索配置 =====
|
||||
# 用于短剧剧情理解前,按短剧名称检索公开剧情/人物/分集信息
|
||||
tavily_api_key = "" # 获取地址:https://app.tavily.com
|
||||
|
||||
@ -15,7 +15,10 @@ from app.config.defaults import (
|
||||
get_openai_compatible_ui_values,
|
||||
normalize_openai_compatible_model_name as normalize_openai_compatible_model_id,
|
||||
)
|
||||
from app.utils.openai_base_url_security import validate_openai_compatible_base_url
|
||||
from app.utils.openai_base_url_security import (
|
||||
openai_compatible_base_url_warning,
|
||||
validate_openai_compatible_base_url,
|
||||
)
|
||||
from app.utils import utils
|
||||
from loguru import logger
|
||||
from app.services.llm.unified_service import UnifiedLLMService
|
||||
@ -77,11 +80,17 @@ def validate_base_url(base_url: str, provider: str) -> tuple[bool, str]:
|
||||
try:
|
||||
validate_openai_compatible_base_url(base_url)
|
||||
except ValueError as exc:
|
||||
return False, f"{provider} Base URL未通过安全校验: {exc}"
|
||||
return False, f"{provider} Base URL格式无效: {exc}"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def show_base_url_security_warning(base_url: str) -> None:
|
||||
warning = openai_compatible_base_url_warning(base_url)
|
||||
if warning:
|
||||
st.warning(warning)
|
||||
|
||||
|
||||
def validate_model_name(model_name: str, provider: str) -> tuple[bool, str]:
|
||||
"""验证模型名称"""
|
||||
if not model_name or not model_name.strip():
|
||||
@ -663,6 +672,7 @@ def render_vision_llm_settings(tr):
|
||||
if vision_base_required and not st_vision_base_url:
|
||||
info_example = vision_placeholder or "https://your-openai-compatible-endpoint/v1"
|
||||
st.info(tr("Please fill OpenAI compatible gateway").format(example=info_example))
|
||||
show_base_url_security_warning(st_vision_base_url)
|
||||
|
||||
vision_generation_params = render_llm_generation_settings(tr, "vision")
|
||||
|
||||
@ -933,6 +943,7 @@ def render_text_llm_settings(tr):
|
||||
if text_base_required and not st_text_base_url:
|
||||
info_example = text_placeholder or "https://your-openai-compatible-endpoint/v1"
|
||||
st.info(tr("Please fill OpenAI compatible gateway").format(example=info_example))
|
||||
show_base_url_security_warning(st_text_base_url)
|
||||
|
||||
text_generation_params = render_llm_generation_settings(tr, "text")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user