diff --git a/app/services/llm.py b/app/services/llm.py index 0db7920..488ce32 100644 --- a/app/services/llm.py +++ b/app/services/llm.py @@ -17,6 +17,10 @@ import subprocess from typing import Union, TextIO from app.config import config +from app.utils.openai_base_url_security import ( + openai_compatible_base_url_warning, + validate_openai_compatible_base_url, +) from app.utils.utils import clean_model_output _max_retries = 5 @@ -330,6 +334,10 @@ def _generate_response(prompt: str, llm_provider: str = None) -> str: azure_endpoint=base_url, ) else: + base_url = validate_openai_compatible_base_url(base_url) + warning = openai_compatible_base_url_warning(base_url) + if warning: + logger.warning(warning) client = OpenAI( api_key=api_key, base_url=base_url, diff --git a/app/services/llm/openai_compatible_provider.py b/app/services/llm/openai_compatible_provider.py index d469955..74b6440 100644 --- a/app/services/llm/openai_compatible_provider.py +++ b/app/services/llm/openai_compatible_provider.py @@ -23,8 +23,13 @@ from openai import ( from app.config import config from app.config.defaults import DEFAULT_LLM_GENERATION_CONFIG, normalize_openai_compatible_model_name +from app.utils.openai_base_url_security import ( + is_trusted_openai_compatible_base_url, + openai_compatible_base_url_warning, + validate_openai_compatible_base_url as _validate_openai_compatible_base_url_value, +) from .base import TextModelProvider, VisionModelProvider -from .exceptions import APICallError, AuthenticationError, ContentFilterError, RateLimitError +from .exceptions import APICallError, AuthenticationError, ConfigurationError, ContentFilterError, RateLimitError def _normalize_model_name(model_name: str) -> str: @@ -41,6 +46,17 @@ def _is_content_filter_error(message: str) -> bool: return "content_filter" in lowered or "safety" in lowered +def validate_openai_compatible_base_url(base_url: Optional[str]) -> Optional[str]: + try: + normalized = _validate_openai_compatible_base_url_value(base_url) + except ValueError as exc: + raise ConfigurationError(str(exc), "base_url") from exc + warning = openai_compatible_base_url_warning(normalized) + if warning: + logger.warning(warning) + return normalized + + def _clean_json_output(output: str) -> str: """清理 JSON 输出中的 markdown 包裹。""" output = re.sub(r"^```json\s*", "", output, flags=re.MULTILINE) @@ -118,6 +134,7 @@ class _OpenAICompatibleBase: """按请求构建 AsyncOpenAI 客户端,支持动态覆盖 api_key / base_url。""" api_key = api_key_override or self.api_key base_url = base_url_override or self.base_url or None + base_url = validate_openai_compatible_base_url(base_url) timeout_seconds: float = timeout_override or config.app.get("llm_text_timeout", 180) max_retries: int = max_retries_override or config.app.get("llm_max_retries", 3) diff --git a/app/services/llm/test_openai_compat_unittest.py b/app/services/llm/test_openai_compat_unittest.py index f6f12c5..def568e 100644 --- a/app/services/llm/test_openai_compat_unittest.py +++ b/app/services/llm/test_openai_compat_unittest.py @@ -6,10 +6,17 @@ from unittest.mock import patch from app.config import config from app.services.llm.base import TextModelProvider +from app.services.llm.exceptions import ConfigurationError from app.services.llm.manager import LLMServiceManager from app.services.llm.migration_adapter import LegacyLLMAdapter, VisionAnalyzerAdapter -from app.services.llm.openai_compatible_provider import OpenAICompatibleTextProvider, OpenAICompatibleVisionProvider +from app.services.llm.openai_compatible_provider import ( + OpenAICompatibleTextProvider, + OpenAICompatibleVisionProvider, + is_trusted_openai_compatible_base_url, + validate_openai_compatible_base_url, +) from app.services.llm.providers import register_all_providers +from app.utils.openai_base_url_security import openai_compatible_base_url_warning class DummyOpenAITextProvider(TextModelProvider): @@ -170,6 +177,77 @@ class OpenAICompatGenerationOptionTests(unittest.TestCase): self.assertEqual(512, options["max_tokens"]) +class OpenAICompatBaseURLValidationTests(unittest.TestCase): + def setUp(self): + self._original_app = dict(config.app) + + def tearDown(self): + config.app.clear() + config.app.update(self._original_app) + + def test_known_providers_and_local_ollama_are_trusted(self): + trusted_urls = [ + "https://api.openai.com/v1", + "https://api.siliconflow.cn/v1", + "https://openrouter.ai/api/v1", + "https://dashscope.aliyuncs.com/compatible-mode/v1", + "https://example.openai.azure.com/openai/deployments/demo", + "http://localhost:11434/v1", + "http://127.0.0.1:11434/v1", + ] + + for url in trusted_urls: + with self.subTest(url=url): + self.assertTrue(is_trusted_openai_compatible_base_url(url)) + + def test_unrecognized_or_unsafe_base_urls_are_not_trusted(self): + untrusted_urls = [ + "https://attacker.example/v1", + "http://api.openai.com/v1", + "https://user@api.openai.com/v1", + "https://127.0.0.1:9999/v1", + "not-a-url", + ] + + for url in untrusted_urls: + with self.subTest(url=url): + self.assertFalse(is_trusted_openai_compatible_base_url(url)) + + def test_build_client_allows_well_formed_custom_base_url_by_default(self): + provider = OpenAICompatibleTextProvider( + api_key="test-key", + model_name="test-model", + base_url="https://custom.example/v1", + ) + + with patch("app.services.llm.openai_compatible_provider.AsyncOpenAI") as async_openai: + provider._build_client() + + self.assertEqual("https://custom.example/v1", async_openai.call_args.kwargs["base_url"]) + + def test_custom_base_url_validation_returns_normalized_url(self): + self.assertEqual( + "https://custom.example/v1", + validate_openai_compatible_base_url(" https://custom.example/v1 "), + ) + + def test_custom_base_url_warning_only_for_untrusted_well_formed_urls(self): + warning = openai_compatible_base_url_warning("https://custom.example/v1") + self.assertIn("custom.example", warning) + self.assertEqual("", openai_compatible_base_url_warning("https://api.openai.com/v1")) + self.assertEqual("", openai_compatible_base_url_warning("")) + + def test_custom_base_url_validation_rejects_malformed_urls(self): + provider = OpenAICompatibleTextProvider( + api_key="test-key", + model_name="test-model", + base_url="https://user@custom.example/v1", + ) + + with self.assertRaises(ConfigurationError): + provider._build_client() + + class ExplicitVisionAdapterSettingsTests(unittest.IsolatedAsyncioTestCase): class _CapturingVisionProvider: last_init: tuple[str, str, str | None] | None = None diff --git a/app/utils/openai_base_url_security.py b/app/utils/openai_base_url_security.py new file mode 100644 index 0000000..12cb0fa --- /dev/null +++ b/app/utils/openai_base_url_security.py @@ -0,0 +1,111 @@ +import ipaddress +from typing import Optional +from urllib.parse import urlparse + + +TRUSTED_OPENAI_COMPATIBLE_BASE_HOSTS = { + "api.openai.com", + "openrouter.ai", + "api.siliconflow.cn", + "dashscope.aliyuncs.com", + "api.deepseek.com", + "api.moonshot.cn", + "api.together.xyz", + "api.cohere.ai", + "generativelanguage.googleapis.com", + "open.bigmodel.cn", + "api.z.ai", + "ark.cn-beijing.volces.com", + "ark.cn-shanghai.volces.com", +} + +TRUSTED_OPENAI_COMPATIBLE_BASE_SUFFIXES = ( + ".openai.azure.com", + ".services.ai.azure.com", + ".cognitiveservices.azure.com", +) + +OPENAI_COMPATIBLE_BASE_URL_ERROR = ( + "OpenAI-compatible base_url must be a valid http(s) URL without embedded credentials." +) + +OPENAI_COMPATIBLE_BASE_URL_WARNING = ( + "OpenAI-compatible base_url host '{host}' is not in the trusted provider list. " + "Only continue if you trust this endpoint, because it will receive the configured API key." +) + + +def _is_loopback_ollama_url(scheme: str, host: str, port: Optional[int]) -> bool: + if scheme not in {"http", "https"} or port != 11434: + return False + if host == "localhost": + return True + try: + return ipaddress.ip_address(host).is_loopback + except ValueError: + return False + + +def _is_well_formed_http_base_url(base_url: str) -> bool: + parsed = urlparse(str(base_url).strip()) + if parsed.scheme not in {"http", "https"} or not parsed.hostname: + return False + if parsed.username or parsed.password: + return False + try: + parsed.port + except ValueError: + return False + return True + + +def is_trusted_openai_compatible_base_url(base_url: Optional[str]) -> bool: + if not base_url: + return True + + parsed = urlparse(str(base_url).strip()) + host = (parsed.hostname or "").rstrip(".").lower() + if not parsed.scheme or not host or parsed.username or parsed.password: + return False + + try: + port = parsed.port + except ValueError: + return False + + if _is_loopback_ollama_url(parsed.scheme, host, port): + return True + + if parsed.scheme != "https": + return False + + if host in TRUSTED_OPENAI_COMPATIBLE_BASE_HOSTS: + return True + + return any(host.endswith(suffix) for suffix in TRUSTED_OPENAI_COMPATIBLE_BASE_SUFFIXES) + + +def openai_compatible_base_url_warning(base_url: Optional[str]) -> str: + if not base_url: + return "" + + normalized = str(base_url).strip() + if is_trusted_openai_compatible_base_url(normalized) or not _is_well_formed_http_base_url(normalized): + return "" + + parsed = urlparse(normalized) + host = (parsed.hostname or "").rstrip(".").lower() + return OPENAI_COMPATIBLE_BASE_URL_WARNING.format(host=host) + + +def validate_openai_compatible_base_url(base_url: Optional[str]) -> Optional[str]: + if not base_url: + return None + + normalized = str(base_url).strip() + if is_trusted_openai_compatible_base_url(normalized): + return normalized + if _is_well_formed_http_base_url(normalized): + return normalized + + raise ValueError(OPENAI_COMPATIBLE_BASE_URL_ERROR) diff --git a/config.example.toml b/config.example.toml index db6f34d..d0911c1 100644 --- a/config.example.toml +++ b/config.example.toml @@ -26,7 +26,7 @@ # - SiliconFlow: siliconflow/Qwen/Qwen2.5-VL-32B-Instruct vision_openai_model_name = "Qwen/Qwen3.5-122B-A10B" vision_openai_api_key = "" # 填入对应 provider 的 API key - vision_openai_base_url = "https://api.siliconflow.cn/v1" # 可选:自定义 API base URL(官方 OpenAI 可留空) + vision_openai_base_url = "https://api.siliconflow.cn/v1" # 可选:自定义 API base URL;界面会提示 API key 将发送到对应端点 vision_openai_temperature = 1.0 vision_openai_top_p = 0.95 vision_openai_max_tokens = 65536 @@ -55,7 +55,7 @@ # - Moonshot: moonshot/moonshot-v1-8k text_openai_model_name = "Pro/zai-org/GLM-5" text_openai_api_key = "" # 填入对应 provider 的 API key - text_openai_base_url = "https://api.siliconflow.cn/v1" # 可选:自定义 API base URL(官方 OpenAI 可留空) + text_openai_base_url = "https://api.siliconflow.cn/v1" # 可选:自定义 API base URL;界面会提示 API key 将发送到对应端点 text_openai_temperature = 1.0 text_openai_top_p = 0.95 text_openai_max_tokens = 65536 diff --git a/webui/components/basic_settings.py b/webui/components/basic_settings.py index 1ea746c..e569809 100644 --- a/webui/components/basic_settings.py +++ b/webui/components/basic_settings.py @@ -15,6 +15,10 @@ from app.config.defaults import ( get_openai_compatible_ui_values, normalize_openai_compatible_model_name as normalize_openai_compatible_model_id, ) +from app.utils.openai_base_url_security import ( + openai_compatible_base_url_warning, + validate_openai_compatible_base_url, +) from app.utils import utils from loguru import logger from app.services.llm.unified_service import UnifiedLLMService @@ -73,9 +77,20 @@ def validate_base_url(base_url: str, provider: str) -> tuple[bool, str]: if not (base_url.startswith('http://') or base_url.startswith('https://')): return False, f"{provider} Base URL必须以http://或https://开头" + try: + validate_openai_compatible_base_url(base_url) + except ValueError as exc: + return False, f"{provider} Base URL格式无效: {exc}" + return True, "" +def show_base_url_security_warning(base_url: str) -> None: + warning = openai_compatible_base_url_warning(base_url) + if warning: + st.warning(warning) + + def validate_model_name(model_name: str, provider: str) -> tuple[bool, str]: """验证模型名称""" if not model_name or not model_name.strip(): @@ -420,6 +435,7 @@ def test_vision_model_connection(api_key, base_url, model_name, provider, tr): elif provider.lower() == 'gemini(openai)': # OpenAI兼容的Gemini代理测试 try: + base_url = validate_openai_compatible_base_url(base_url) headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json" @@ -444,6 +460,7 @@ def test_vision_model_connection(api_key, base_url, model_name, provider, tr): else: from openai import OpenAI try: + base_url = validate_openai_compatible_base_url(base_url) client = OpenAI( api_key=api_key, base_url=base_url, @@ -493,9 +510,10 @@ def test_openai_compatible_vision_model(api_key: str, base_url: str, model_name: from PIL import Image logger.debug( - f"OpenAI 兼容视觉模型连通性测试: model={model_name}, api_key={api_key[:10]}..., base_url={base_url}" + f"OpenAI 兼容视觉模型连通性测试: model={model_name}, base_url={base_url}" ) + base_url = validate_openai_compatible_base_url(base_url) client = OpenAI( api_key=api_key, base_url=base_url or None, @@ -548,9 +566,10 @@ def test_openai_compatible_text_model(api_key: str, base_url: str, model_name: s from openai import OpenAI logger.debug( - f"OpenAI 兼容文本模型连通性测试: model={model_name}, api_key={api_key[:10]}..., base_url={base_url}" + f"OpenAI 兼容文本模型连通性测试: model={model_name}, base_url={base_url}" ) + base_url = validate_openai_compatible_base_url(base_url) client = OpenAI( api_key=api_key, base_url=base_url or None, @@ -653,6 +672,7 @@ def render_vision_llm_settings(tr): if vision_base_required and not st_vision_base_url: info_example = vision_placeholder or "https://your-openai-compatible-endpoint/v1" st.info(tr("Please fill OpenAI compatible gateway").format(example=info_example)) + show_base_url_security_warning(st_vision_base_url) vision_generation_params = render_llm_generation_settings(tr, "vision") @@ -923,6 +943,7 @@ def render_text_llm_settings(tr): if text_base_required and not st_text_base_url: info_example = text_placeholder or "https://your-openai-compatible-endpoint/v1" st.info(tr("Please fill OpenAI compatible gateway").format(example=info_example)) + show_base_url_security_warning(st_text_base_url) text_generation_params = render_llm_generation_settings(tr, "text")