diff --git a/app/config/defaults.py b/app/config/defaults.py index 04e480c..9f648fa 100644 --- a/app/config/defaults.py +++ b/app/config/defaults.py @@ -35,7 +35,6 @@ DEFAULT_LLM_APP_CONFIG = { "text_openai_model_name": DEFAULT_TEXT_OPENAI_MODEL_NAME, "text_openai_api_key": "", "text_openai_base_url": DEFAULT_OPENAI_COMPATIBLE_BASE_URL, - "allow_custom_openai_base_url": False, "tavily_api_key": "", "tavily_search_depth": "basic", "tavily_max_results": 5, diff --git a/app/services/llm.py b/app/services/llm.py index e056c62..488ce32 100644 --- a/app/services/llm.py +++ b/app/services/llm.py @@ -17,7 +17,10 @@ import subprocess from typing import Union, TextIO from app.config import config -from app.utils.openai_base_url_security import validate_openai_compatible_base_url +from app.utils.openai_base_url_security import ( + openai_compatible_base_url_warning, + validate_openai_compatible_base_url, +) from app.utils.utils import clean_model_output _max_retries = 5 @@ -332,6 +335,9 @@ def _generate_response(prompt: str, llm_provider: str = None) -> str: ) else: base_url = validate_openai_compatible_base_url(base_url) + warning = openai_compatible_base_url_warning(base_url) + if warning: + logger.warning(warning) client = OpenAI( api_key=api_key, base_url=base_url, diff --git a/app/services/llm/openai_compatible_provider.py b/app/services/llm/openai_compatible_provider.py index b295eee..74b6440 100644 --- a/app/services/llm/openai_compatible_provider.py +++ b/app/services/llm/openai_compatible_provider.py @@ -25,6 +25,7 @@ from app.config import config from app.config.defaults import DEFAULT_LLM_GENERATION_CONFIG, normalize_openai_compatible_model_name from app.utils.openai_base_url_security import ( is_trusted_openai_compatible_base_url, + openai_compatible_base_url_warning, validate_openai_compatible_base_url as _validate_openai_compatible_base_url_value, ) from .base import TextModelProvider, VisionModelProvider @@ -47,9 +48,13 @@ def _is_content_filter_error(message: str) -> bool: def validate_openai_compatible_base_url(base_url: Optional[str]) -> Optional[str]: try: - return _validate_openai_compatible_base_url_value(base_url) + normalized = _validate_openai_compatible_base_url_value(base_url) except ValueError as exc: raise ConfigurationError(str(exc), "base_url") from exc + warning = openai_compatible_base_url_warning(normalized) + if warning: + logger.warning(warning) + return normalized def _clean_json_output(output: str) -> str: diff --git a/app/services/llm/test_openai_compat_unittest.py b/app/services/llm/test_openai_compat_unittest.py index 3791782..4277343 100644 --- a/app/services/llm/test_openai_compat_unittest.py +++ b/app/services/llm/test_openai_compat_unittest.py @@ -13,8 +13,10 @@ from app.services.llm.openai_compatible_provider import ( OpenAICompatibleTextProvider, OpenAICompatibleVisionProvider, is_trusted_openai_compatible_base_url, + validate_openai_compatible_base_url, ) from app.services.llm.providers import register_all_providers +from app.utils.openai_base_url_security import openai_compatible_base_url_warning class DummyOpenAITextProvider(TextModelProvider): @@ -210,18 +212,7 @@ class OpenAICompatBaseURLValidationTests(unittest.TestCase): with self.subTest(url=url): self.assertFalse(is_trusted_openai_compatible_base_url(url)) - def test_build_client_rejects_untrusted_base_url_by_default(self): - provider = OpenAICompatibleTextProvider( - api_key="test-key", - model_name="test-model", - base_url="https://attacker.example/v1", - ) - - with self.assertRaises(ConfigurationError): - provider._build_client() - - def test_build_client_allows_explicit_custom_base_url_opt_in(self): - config.app["allow_custom_openai_base_url"] = True + def test_build_client_allows_well_formed_custom_base_url_by_default(self): provider = OpenAICompatibleTextProvider( api_key="test-key", model_name="test-model", @@ -233,8 +224,19 @@ class OpenAICompatBaseURLValidationTests(unittest.TestCase): self.assertEqual("https://custom.example/v1", async_openai.call_args.kwargs["base_url"]) - def test_custom_base_url_opt_in_still_rejects_malformed_urls(self): - config.app["allow_custom_openai_base_url"] = True + def test_custom_base_url_validation_returns_normalized_url(self): + self.assertEqual( + "https://custom.example/v1", + validate_openai_compatible_base_url(" https://custom.example/v1 "), + ) + + def test_custom_base_url_warning_only_for_untrusted_well_formed_urls(self): + warning = openai_compatible_base_url_warning("https://custom.example/v1") + self.assertIn("custom.example", warning) + self.assertEqual("", openai_compatible_base_url_warning("https://api.openai.com/v1")) + self.assertEqual("", openai_compatible_base_url_warning("")) + + def test_custom_base_url_validation_rejects_malformed_urls(self): provider = OpenAICompatibleTextProvider( api_key="test-key", model_name="test-model", diff --git a/app/utils/openai_base_url_security.py b/app/utils/openai_base_url_security.py index 860216c..12cb0fa 100644 --- a/app/utils/openai_base_url_security.py +++ b/app/utils/openai_base_url_security.py @@ -2,8 +2,6 @@ import ipaddress from typing import Optional from urllib.parse import urlparse -from app.config import config - TRUSTED_OPENAI_COMPATIBLE_BASE_HOSTS = { "api.openai.com", @@ -28,17 +26,13 @@ TRUSTED_OPENAI_COMPATIBLE_BASE_SUFFIXES = ( ) OPENAI_COMPATIBLE_BASE_URL_ERROR = ( - "OpenAI-compatible base_url is not in the trusted provider list. " - "Use an official provider endpoint or set allow_custom_openai_base_url=true " - "only if you understand that the configured endpoint receives the API key." + "OpenAI-compatible base_url must be a valid http(s) URL without embedded credentials." ) - -def custom_openai_base_url_allowed() -> bool: - value = config.app.get("allow_custom_openai_base_url", False) - if isinstance(value, str): - return value.strip().lower() in {"1", "true", "yes", "on"} - return bool(value) +OPENAI_COMPATIBLE_BASE_URL_WARNING = ( + "OpenAI-compatible base_url host '{host}' is not in the trusted provider list. " + "Only continue if you trust this endpoint, because it will receive the configured API key." +) def _is_loopback_ollama_url(scheme: str, host: str, port: Optional[int]) -> bool: @@ -91,6 +85,19 @@ def is_trusted_openai_compatible_base_url(base_url: Optional[str]) -> bool: return any(host.endswith(suffix) for suffix in TRUSTED_OPENAI_COMPATIBLE_BASE_SUFFIXES) +def openai_compatible_base_url_warning(base_url: Optional[str]) -> str: + if not base_url: + return "" + + normalized = str(base_url).strip() + if is_trusted_openai_compatible_base_url(normalized) or not _is_well_formed_http_base_url(normalized): + return "" + + parsed = urlparse(normalized) + host = (parsed.hostname or "").rstrip(".").lower() + return OPENAI_COMPATIBLE_BASE_URL_WARNING.format(host=host) + + def validate_openai_compatible_base_url(base_url: Optional[str]) -> Optional[str]: if not base_url: return None @@ -98,7 +105,7 @@ def validate_openai_compatible_base_url(base_url: Optional[str]) -> Optional[str normalized = str(base_url).strip() if is_trusted_openai_compatible_base_url(normalized): return normalized - if custom_openai_base_url_allowed() and _is_well_formed_http_base_url(normalized): + if _is_well_formed_http_base_url(normalized): return normalized raise ValueError(OPENAI_COMPATIBLE_BASE_URL_ERROR) diff --git a/config.example.toml b/config.example.toml index f058902..47ffe28 100644 --- a/config.example.toml +++ b/config.example.toml @@ -26,7 +26,7 @@ # - SiliconFlow: siliconflow/Qwen/Qwen2.5-VL-32B-Instruct vision_openai_model_name = "Qwen/Qwen3.5-122B-A10B" vision_openai_api_key = "" # 填入对应 provider 的 API key - vision_openai_base_url = "https://api.siliconflow.cn/v1" # 可选:自定义 API base URL(官方 OpenAI 可留空) + vision_openai_base_url = "https://api.siliconflow.cn/v1" # 可选:自定义 API base URL;界面会提示 API key 将发送到对应端点 vision_openai_temperature = 1.0 vision_openai_top_p = 0.95 vision_openai_max_tokens = 65536 @@ -45,15 +45,12 @@ # - Moonshot: moonshot/moonshot-v1-8k text_openai_model_name = "Pro/zai-org/GLM-5" text_openai_api_key = "" # 填入对应 provider 的 API key - text_openai_base_url = "https://api.siliconflow.cn/v1" # 可选:自定义 API base URL(官方 OpenAI 可留空) + text_openai_base_url = "https://api.siliconflow.cn/v1" # 可选:自定义 API base URL;界面会提示 API key 将发送到对应端点 text_openai_temperature = 1.0 text_openai_top_p = 0.95 text_openai_max_tokens = 65536 text_openai_thinking_level = "auto" # auto/off/low/medium/high - # 默认只允许受信任的 OpenAI 兼容端点。确需自建网关时,确认 API key 会发送到该端点后再启用。 - allow_custom_openai_base_url = false - # ===== Tavily 联网搜索配置 ===== # 用于短剧剧情理解前,按短剧名称检索公开剧情/人物/分集信息 tavily_api_key = "" # 获取地址:https://app.tavily.com diff --git a/webui/components/basic_settings.py b/webui/components/basic_settings.py index 66bd4d0..e569809 100644 --- a/webui/components/basic_settings.py +++ b/webui/components/basic_settings.py @@ -15,7 +15,10 @@ from app.config.defaults import ( get_openai_compatible_ui_values, normalize_openai_compatible_model_name as normalize_openai_compatible_model_id, ) -from app.utils.openai_base_url_security import validate_openai_compatible_base_url +from app.utils.openai_base_url_security import ( + openai_compatible_base_url_warning, + validate_openai_compatible_base_url, +) from app.utils import utils from loguru import logger from app.services.llm.unified_service import UnifiedLLMService @@ -77,11 +80,17 @@ def validate_base_url(base_url: str, provider: str) -> tuple[bool, str]: try: validate_openai_compatible_base_url(base_url) except ValueError as exc: - return False, f"{provider} Base URL未通过安全校验: {exc}" + return False, f"{provider} Base URL格式无效: {exc}" return True, "" +def show_base_url_security_warning(base_url: str) -> None: + warning = openai_compatible_base_url_warning(base_url) + if warning: + st.warning(warning) + + def validate_model_name(model_name: str, provider: str) -> tuple[bool, str]: """验证模型名称""" if not model_name or not model_name.strip(): @@ -663,6 +672,7 @@ def render_vision_llm_settings(tr): if vision_base_required and not st_vision_base_url: info_example = vision_placeholder or "https://your-openai-compatible-endpoint/v1" st.info(tr("Please fill OpenAI compatible gateway").format(example=info_example)) + show_base_url_security_warning(st_vision_base_url) vision_generation_params = render_llm_generation_settings(tr, "vision") @@ -933,6 +943,7 @@ def render_text_llm_settings(tr): if text_base_required and not st_text_base_url: info_example = text_placeholder or "https://your-openai-compatible-endpoint/v1" st.info(tr("Please fill OpenAI compatible gateway").format(example=info_example)) + show_base_url_security_warning(st_text_base_url) text_generation_params = render_llm_generation_settings(tr, "text")