Validate OpenAI-compatible base URLs

This commit is contained in:
Hamizan Azman 2026-06-19 15:08:17 +08:00
parent fede336592
commit 5f7eed9f85
7 changed files with 212 additions and 4 deletions

View File

@ -35,6 +35,7 @@ DEFAULT_LLM_APP_CONFIG = {
"text_openai_model_name": DEFAULT_TEXT_OPENAI_MODEL_NAME,
"text_openai_api_key": "",
"text_openai_base_url": DEFAULT_OPENAI_COMPATIBLE_BASE_URL,
"allow_custom_openai_base_url": False,
"tavily_api_key": "",
"tavily_search_depth": "basic",
"tavily_max_results": 5,

View File

@ -17,6 +17,7 @@ import subprocess
from typing import Union, TextIO
from app.config import config
from app.utils.openai_base_url_security import validate_openai_compatible_base_url
from app.utils.utils import clean_model_output
_max_retries = 5
@ -330,6 +331,7 @@ def _generate_response(prompt: str, llm_provider: str = None) -> str:
azure_endpoint=base_url,
)
else:
base_url = validate_openai_compatible_base_url(base_url)
client = OpenAI(
api_key=api_key,
base_url=base_url,

View File

@ -23,8 +23,12 @@ from openai import (
from app.config import config
from app.config.defaults import DEFAULT_LLM_GENERATION_CONFIG, normalize_openai_compatible_model_name
from app.utils.openai_base_url_security import (
is_trusted_openai_compatible_base_url,
validate_openai_compatible_base_url as _validate_openai_compatible_base_url_value,
)
from .base import TextModelProvider, VisionModelProvider
from .exceptions import APICallError, AuthenticationError, ContentFilterError, RateLimitError
from .exceptions import APICallError, AuthenticationError, ConfigurationError, ContentFilterError, RateLimitError
def _normalize_model_name(model_name: str) -> str:
@ -41,6 +45,13 @@ def _is_content_filter_error(message: str) -> bool:
return "content_filter" in lowered or "safety" in lowered
def validate_openai_compatible_base_url(base_url: Optional[str]) -> Optional[str]:
try:
return _validate_openai_compatible_base_url_value(base_url)
except ValueError as exc:
raise ConfigurationError(str(exc), "base_url") from exc
def _clean_json_output(output: str) -> str:
"""清理 JSON 输出中的 markdown 包裹。"""
output = re.sub(r"^```json\s*", "", output, flags=re.MULTILINE)
@ -118,6 +129,7 @@ class _OpenAICompatibleBase:
"""按请求构建 AsyncOpenAI 客户端,支持动态覆盖 api_key / base_url。"""
api_key = api_key_override or self.api_key
base_url = base_url_override or self.base_url or None
base_url = validate_openai_compatible_base_url(base_url)
timeout_seconds: float = timeout_override or config.app.get("llm_text_timeout", 180)
max_retries: int = max_retries_override or config.app.get("llm_max_retries", 3)

View File

@ -6,9 +6,14 @@ from unittest.mock import patch
from app.config import config
from app.services.llm.base import TextModelProvider
from app.services.llm.exceptions import ConfigurationError
from app.services.llm.manager import LLMServiceManager
from app.services.llm.migration_adapter import LegacyLLMAdapter, VisionAnalyzerAdapter
from app.services.llm.openai_compatible_provider import OpenAICompatibleTextProvider, OpenAICompatibleVisionProvider
from app.services.llm.openai_compatible_provider import (
OpenAICompatibleTextProvider,
OpenAICompatibleVisionProvider,
is_trusted_openai_compatible_base_url,
)
from app.services.llm.providers import register_all_providers
@ -169,6 +174,77 @@ class OpenAICompatGenerationOptionTests(unittest.TestCase):
self.assertEqual(512, options["max_tokens"])
class OpenAICompatBaseURLValidationTests(unittest.TestCase):
def setUp(self):
self._original_app = dict(config.app)
def tearDown(self):
config.app.clear()
config.app.update(self._original_app)
def test_known_providers_and_local_ollama_are_trusted(self):
trusted_urls = [
"https://api.openai.com/v1",
"https://api.siliconflow.cn/v1",
"https://openrouter.ai/api/v1",
"https://dashscope.aliyuncs.com/compatible-mode/v1",
"https://example.openai.azure.com/openai/deployments/demo",
"http://localhost:11434/v1",
"http://127.0.0.1:11434/v1",
]
for url in trusted_urls:
with self.subTest(url=url):
self.assertTrue(is_trusted_openai_compatible_base_url(url))
def test_unrecognized_or_unsafe_base_urls_are_not_trusted(self):
untrusted_urls = [
"https://attacker.example/v1",
"http://api.openai.com/v1",
"https://user@api.openai.com/v1",
"https://127.0.0.1:9999/v1",
"not-a-url",
]
for url in untrusted_urls:
with self.subTest(url=url):
self.assertFalse(is_trusted_openai_compatible_base_url(url))
def test_build_client_rejects_untrusted_base_url_by_default(self):
provider = OpenAICompatibleTextProvider(
api_key="test-key",
model_name="test-model",
base_url="https://attacker.example/v1",
)
with self.assertRaises(ConfigurationError):
provider._build_client()
def test_build_client_allows_explicit_custom_base_url_opt_in(self):
config.app["allow_custom_openai_base_url"] = True
provider = OpenAICompatibleTextProvider(
api_key="test-key",
model_name="test-model",
base_url="https://custom.example/v1",
)
with patch("app.services.llm.openai_compatible_provider.AsyncOpenAI") as async_openai:
provider._build_client()
self.assertEqual("https://custom.example/v1", async_openai.call_args.kwargs["base_url"])
def test_custom_base_url_opt_in_still_rejects_malformed_urls(self):
config.app["allow_custom_openai_base_url"] = True
provider = OpenAICompatibleTextProvider(
api_key="test-key",
model_name="test-model",
base_url="https://user@custom.example/v1",
)
with self.assertRaises(ConfigurationError):
provider._build_client()
class ExplicitVisionAdapterSettingsTests(unittest.IsolatedAsyncioTestCase):
class _CapturingVisionProvider:
last_init: tuple[str, str, str | None] | None = None

View File

@ -0,0 +1,104 @@
import ipaddress
from typing import Optional
from urllib.parse import urlparse
from app.config import config
TRUSTED_OPENAI_COMPATIBLE_BASE_HOSTS = {
"api.openai.com",
"openrouter.ai",
"api.siliconflow.cn",
"dashscope.aliyuncs.com",
"api.deepseek.com",
"api.moonshot.cn",
"api.together.xyz",
"api.cohere.ai",
"generativelanguage.googleapis.com",
"open.bigmodel.cn",
"api.z.ai",
"ark.cn-beijing.volces.com",
"ark.cn-shanghai.volces.com",
}
TRUSTED_OPENAI_COMPATIBLE_BASE_SUFFIXES = (
".openai.azure.com",
".services.ai.azure.com",
".cognitiveservices.azure.com",
)
OPENAI_COMPATIBLE_BASE_URL_ERROR = (
"OpenAI-compatible base_url is not in the trusted provider list. "
"Use an official provider endpoint or set allow_custom_openai_base_url=true "
"only if you understand that the configured endpoint receives the API key."
)
def custom_openai_base_url_allowed() -> bool:
value = config.app.get("allow_custom_openai_base_url", False)
if isinstance(value, str):
return value.strip().lower() in {"1", "true", "yes", "on"}
return bool(value)
def _is_loopback_ollama_url(scheme: str, host: str, port: Optional[int]) -> bool:
if scheme not in {"http", "https"} or port != 11434:
return False
if host == "localhost":
return True
try:
return ipaddress.ip_address(host).is_loopback
except ValueError:
return False
def _is_well_formed_http_base_url(base_url: str) -> bool:
parsed = urlparse(str(base_url).strip())
if parsed.scheme not in {"http", "https"} or not parsed.hostname:
return False
if parsed.username or parsed.password:
return False
try:
parsed.port
except ValueError:
return False
return True
def is_trusted_openai_compatible_base_url(base_url: Optional[str]) -> bool:
if not base_url:
return True
parsed = urlparse(str(base_url).strip())
host = (parsed.hostname or "").rstrip(".").lower()
if not parsed.scheme or not host or parsed.username or parsed.password:
return False
try:
port = parsed.port
except ValueError:
return False
if _is_loopback_ollama_url(parsed.scheme, host, port):
return True
if parsed.scheme != "https":
return False
if host in TRUSTED_OPENAI_COMPATIBLE_BASE_HOSTS:
return True
return any(host.endswith(suffix) for suffix in TRUSTED_OPENAI_COMPATIBLE_BASE_SUFFIXES)
def validate_openai_compatible_base_url(base_url: Optional[str]) -> Optional[str]:
if not base_url:
return None
normalized = str(base_url).strip()
if is_trusted_openai_compatible_base_url(normalized):
return normalized
if custom_openai_base_url_allowed() and _is_well_formed_http_base_url(normalized):
return normalized
raise ValueError(OPENAI_COMPATIBLE_BASE_URL_ERROR)

View File

@ -51,6 +51,9 @@
text_openai_max_tokens = 65536
text_openai_thinking_level = "auto" # auto/off/low/medium/high
# 默认只允许受信任的 OpenAI 兼容端点。确需自建网关时,确认 API key 会发送到该端点后再启用。
allow_custom_openai_base_url = false
# ===== Tavily 联网搜索配置 =====
# 用于短剧剧情理解前,按短剧名称检索公开剧情/人物/分集信息
tavily_api_key = "" # 获取地址https://app.tavily.com

View File

@ -15,6 +15,7 @@ from app.config.defaults import (
get_openai_compatible_ui_values,
normalize_openai_compatible_model_name as normalize_openai_compatible_model_id,
)
from app.utils.openai_base_url_security import validate_openai_compatible_base_url
from app.utils import utils
from loguru import logger
from app.services.llm.unified_service import UnifiedLLMService
@ -73,6 +74,11 @@ def validate_base_url(base_url: str, provider: str) -> tuple[bool, str]:
if not (base_url.startswith('http://') or base_url.startswith('https://')):
return False, f"{provider} Base URL必须以http://或https://开头"
try:
validate_openai_compatible_base_url(base_url)
except ValueError as exc:
return False, f"{provider} Base URL未通过安全校验: {exc}"
return True, ""
@ -420,6 +426,7 @@ def test_vision_model_connection(api_key, base_url, model_name, provider, tr):
elif provider.lower() == 'gemini(openai)':
# OpenAI兼容的Gemini代理测试
try:
base_url = validate_openai_compatible_base_url(base_url)
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
@ -444,6 +451,7 @@ def test_vision_model_connection(api_key, base_url, model_name, provider, tr):
else:
from openai import OpenAI
try:
base_url = validate_openai_compatible_base_url(base_url)
client = OpenAI(
api_key=api_key,
base_url=base_url,
@ -493,9 +501,10 @@ def test_openai_compatible_vision_model(api_key: str, base_url: str, model_name:
from PIL import Image
logger.debug(
f"OpenAI 兼容视觉模型连通性测试: model={model_name}, api_key={api_key[:10]}..., base_url={base_url}"
f"OpenAI 兼容视觉模型连通性测试: model={model_name}, base_url={base_url}"
)
base_url = validate_openai_compatible_base_url(base_url)
client = OpenAI(
api_key=api_key,
base_url=base_url or None,
@ -548,9 +557,10 @@ def test_openai_compatible_text_model(api_key: str, base_url: str, model_name: s
from openai import OpenAI
logger.debug(
f"OpenAI 兼容文本模型连通性测试: model={model_name}, api_key={api_key[:10]}..., base_url={base_url}"
f"OpenAI 兼容文本模型连通性测试: model={model_name}, base_url={base_url}"
)
base_url = validate_openai_compatible_base_url(base_url)
client = OpenAI(
api_key=api_key,
base_url=base_url or None,