mirror of
https://github.com/linyqh/NarratoAI.git
synced 2026-07-02 20:35:28 +00:00
Validate OpenAI-compatible base URLs
This commit is contained in:
parent
fede336592
commit
5f7eed9f85
@ -35,6 +35,7 @@ DEFAULT_LLM_APP_CONFIG = {
|
||||
"text_openai_model_name": DEFAULT_TEXT_OPENAI_MODEL_NAME,
|
||||
"text_openai_api_key": "",
|
||||
"text_openai_base_url": DEFAULT_OPENAI_COMPATIBLE_BASE_URL,
|
||||
"allow_custom_openai_base_url": False,
|
||||
"tavily_api_key": "",
|
||||
"tavily_search_depth": "basic",
|
||||
"tavily_max_results": 5,
|
||||
|
||||
@ -17,6 +17,7 @@ import subprocess
|
||||
from typing import Union, TextIO
|
||||
|
||||
from app.config import config
|
||||
from app.utils.openai_base_url_security import validate_openai_compatible_base_url
|
||||
from app.utils.utils import clean_model_output
|
||||
|
||||
_max_retries = 5
|
||||
@ -330,6 +331,7 @@ def _generate_response(prompt: str, llm_provider: str = None) -> str:
|
||||
azure_endpoint=base_url,
|
||||
)
|
||||
else:
|
||||
base_url = validate_openai_compatible_base_url(base_url)
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
|
||||
@ -23,8 +23,12 @@ from openai import (
|
||||
|
||||
from app.config import config
|
||||
from app.config.defaults import DEFAULT_LLM_GENERATION_CONFIG, normalize_openai_compatible_model_name
|
||||
from app.utils.openai_base_url_security import (
|
||||
is_trusted_openai_compatible_base_url,
|
||||
validate_openai_compatible_base_url as _validate_openai_compatible_base_url_value,
|
||||
)
|
||||
from .base import TextModelProvider, VisionModelProvider
|
||||
from .exceptions import APICallError, AuthenticationError, ContentFilterError, RateLimitError
|
||||
from .exceptions import APICallError, AuthenticationError, ConfigurationError, ContentFilterError, RateLimitError
|
||||
|
||||
|
||||
def _normalize_model_name(model_name: str) -> str:
|
||||
@ -41,6 +45,13 @@ def _is_content_filter_error(message: str) -> bool:
|
||||
return "content_filter" in lowered or "safety" in lowered
|
||||
|
||||
|
||||
def validate_openai_compatible_base_url(base_url: Optional[str]) -> Optional[str]:
|
||||
try:
|
||||
return _validate_openai_compatible_base_url_value(base_url)
|
||||
except ValueError as exc:
|
||||
raise ConfigurationError(str(exc), "base_url") from exc
|
||||
|
||||
|
||||
def _clean_json_output(output: str) -> str:
|
||||
"""清理 JSON 输出中的 markdown 包裹。"""
|
||||
output = re.sub(r"^```json\s*", "", output, flags=re.MULTILINE)
|
||||
@ -118,6 +129,7 @@ class _OpenAICompatibleBase:
|
||||
"""按请求构建 AsyncOpenAI 客户端,支持动态覆盖 api_key / base_url。"""
|
||||
api_key = api_key_override or self.api_key
|
||||
base_url = base_url_override or self.base_url or None
|
||||
base_url = validate_openai_compatible_base_url(base_url)
|
||||
|
||||
timeout_seconds: float = timeout_override or config.app.get("llm_text_timeout", 180)
|
||||
max_retries: int = max_retries_override or config.app.get("llm_max_retries", 3)
|
||||
|
||||
@ -6,9 +6,14 @@ from unittest.mock import patch
|
||||
|
||||
from app.config import config
|
||||
from app.services.llm.base import TextModelProvider
|
||||
from app.services.llm.exceptions import ConfigurationError
|
||||
from app.services.llm.manager import LLMServiceManager
|
||||
from app.services.llm.migration_adapter import LegacyLLMAdapter, VisionAnalyzerAdapter
|
||||
from app.services.llm.openai_compatible_provider import OpenAICompatibleTextProvider, OpenAICompatibleVisionProvider
|
||||
from app.services.llm.openai_compatible_provider import (
|
||||
OpenAICompatibleTextProvider,
|
||||
OpenAICompatibleVisionProvider,
|
||||
is_trusted_openai_compatible_base_url,
|
||||
)
|
||||
from app.services.llm.providers import register_all_providers
|
||||
|
||||
|
||||
@ -169,6 +174,77 @@ class OpenAICompatGenerationOptionTests(unittest.TestCase):
|
||||
self.assertEqual(512, options["max_tokens"])
|
||||
|
||||
|
||||
class OpenAICompatBaseURLValidationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self._original_app = dict(config.app)
|
||||
|
||||
def tearDown(self):
|
||||
config.app.clear()
|
||||
config.app.update(self._original_app)
|
||||
|
||||
def test_known_providers_and_local_ollama_are_trusted(self):
|
||||
trusted_urls = [
|
||||
"https://api.openai.com/v1",
|
||||
"https://api.siliconflow.cn/v1",
|
||||
"https://openrouter.ai/api/v1",
|
||||
"https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"https://example.openai.azure.com/openai/deployments/demo",
|
||||
"http://localhost:11434/v1",
|
||||
"http://127.0.0.1:11434/v1",
|
||||
]
|
||||
|
||||
for url in trusted_urls:
|
||||
with self.subTest(url=url):
|
||||
self.assertTrue(is_trusted_openai_compatible_base_url(url))
|
||||
|
||||
def test_unrecognized_or_unsafe_base_urls_are_not_trusted(self):
|
||||
untrusted_urls = [
|
||||
"https://attacker.example/v1",
|
||||
"http://api.openai.com/v1",
|
||||
"https://user@api.openai.com/v1",
|
||||
"https://127.0.0.1:9999/v1",
|
||||
"not-a-url",
|
||||
]
|
||||
|
||||
for url in untrusted_urls:
|
||||
with self.subTest(url=url):
|
||||
self.assertFalse(is_trusted_openai_compatible_base_url(url))
|
||||
|
||||
def test_build_client_rejects_untrusted_base_url_by_default(self):
|
||||
provider = OpenAICompatibleTextProvider(
|
||||
api_key="test-key",
|
||||
model_name="test-model",
|
||||
base_url="https://attacker.example/v1",
|
||||
)
|
||||
|
||||
with self.assertRaises(ConfigurationError):
|
||||
provider._build_client()
|
||||
|
||||
def test_build_client_allows_explicit_custom_base_url_opt_in(self):
|
||||
config.app["allow_custom_openai_base_url"] = True
|
||||
provider = OpenAICompatibleTextProvider(
|
||||
api_key="test-key",
|
||||
model_name="test-model",
|
||||
base_url="https://custom.example/v1",
|
||||
)
|
||||
|
||||
with patch("app.services.llm.openai_compatible_provider.AsyncOpenAI") as async_openai:
|
||||
provider._build_client()
|
||||
|
||||
self.assertEqual("https://custom.example/v1", async_openai.call_args.kwargs["base_url"])
|
||||
|
||||
def test_custom_base_url_opt_in_still_rejects_malformed_urls(self):
|
||||
config.app["allow_custom_openai_base_url"] = True
|
||||
provider = OpenAICompatibleTextProvider(
|
||||
api_key="test-key",
|
||||
model_name="test-model",
|
||||
base_url="https://user@custom.example/v1",
|
||||
)
|
||||
|
||||
with self.assertRaises(ConfigurationError):
|
||||
provider._build_client()
|
||||
|
||||
|
||||
class ExplicitVisionAdapterSettingsTests(unittest.IsolatedAsyncioTestCase):
|
||||
class _CapturingVisionProvider:
|
||||
last_init: tuple[str, str, str | None] | None = None
|
||||
|
||||
104
app/utils/openai_base_url_security.py
Normal file
104
app/utils/openai_base_url_security.py
Normal file
@ -0,0 +1,104 @@
|
||||
import ipaddress
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from app.config import config
|
||||
|
||||
|
||||
TRUSTED_OPENAI_COMPATIBLE_BASE_HOSTS = {
|
||||
"api.openai.com",
|
||||
"openrouter.ai",
|
||||
"api.siliconflow.cn",
|
||||
"dashscope.aliyuncs.com",
|
||||
"api.deepseek.com",
|
||||
"api.moonshot.cn",
|
||||
"api.together.xyz",
|
||||
"api.cohere.ai",
|
||||
"generativelanguage.googleapis.com",
|
||||
"open.bigmodel.cn",
|
||||
"api.z.ai",
|
||||
"ark.cn-beijing.volces.com",
|
||||
"ark.cn-shanghai.volces.com",
|
||||
}
|
||||
|
||||
TRUSTED_OPENAI_COMPATIBLE_BASE_SUFFIXES = (
|
||||
".openai.azure.com",
|
||||
".services.ai.azure.com",
|
||||
".cognitiveservices.azure.com",
|
||||
)
|
||||
|
||||
OPENAI_COMPATIBLE_BASE_URL_ERROR = (
|
||||
"OpenAI-compatible base_url is not in the trusted provider list. "
|
||||
"Use an official provider endpoint or set allow_custom_openai_base_url=true "
|
||||
"only if you understand that the configured endpoint receives the API key."
|
||||
)
|
||||
|
||||
|
||||
def custom_openai_base_url_allowed() -> bool:
|
||||
value = config.app.get("allow_custom_openai_base_url", False)
|
||||
if isinstance(value, str):
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
return bool(value)
|
||||
|
||||
|
||||
def _is_loopback_ollama_url(scheme: str, host: str, port: Optional[int]) -> bool:
|
||||
if scheme not in {"http", "https"} or port != 11434:
|
||||
return False
|
||||
if host == "localhost":
|
||||
return True
|
||||
try:
|
||||
return ipaddress.ip_address(host).is_loopback
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _is_well_formed_http_base_url(base_url: str) -> bool:
|
||||
parsed = urlparse(str(base_url).strip())
|
||||
if parsed.scheme not in {"http", "https"} or not parsed.hostname:
|
||||
return False
|
||||
if parsed.username or parsed.password:
|
||||
return False
|
||||
try:
|
||||
parsed.port
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_trusted_openai_compatible_base_url(base_url: Optional[str]) -> bool:
|
||||
if not base_url:
|
||||
return True
|
||||
|
||||
parsed = urlparse(str(base_url).strip())
|
||||
host = (parsed.hostname or "").rstrip(".").lower()
|
||||
if not parsed.scheme or not host or parsed.username or parsed.password:
|
||||
return False
|
||||
|
||||
try:
|
||||
port = parsed.port
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
if _is_loopback_ollama_url(parsed.scheme, host, port):
|
||||
return True
|
||||
|
||||
if parsed.scheme != "https":
|
||||
return False
|
||||
|
||||
if host in TRUSTED_OPENAI_COMPATIBLE_BASE_HOSTS:
|
||||
return True
|
||||
|
||||
return any(host.endswith(suffix) for suffix in TRUSTED_OPENAI_COMPATIBLE_BASE_SUFFIXES)
|
||||
|
||||
|
||||
def validate_openai_compatible_base_url(base_url: Optional[str]) -> Optional[str]:
|
||||
if not base_url:
|
||||
return None
|
||||
|
||||
normalized = str(base_url).strip()
|
||||
if is_trusted_openai_compatible_base_url(normalized):
|
||||
return normalized
|
||||
if custom_openai_base_url_allowed() and _is_well_formed_http_base_url(normalized):
|
||||
return normalized
|
||||
|
||||
raise ValueError(OPENAI_COMPATIBLE_BASE_URL_ERROR)
|
||||
@ -51,6 +51,9 @@
|
||||
text_openai_max_tokens = 65536
|
||||
text_openai_thinking_level = "auto" # auto/off/low/medium/high
|
||||
|
||||
# 默认只允许受信任的 OpenAI 兼容端点。确需自建网关时,确认 API key 会发送到该端点后再启用。
|
||||
allow_custom_openai_base_url = false
|
||||
|
||||
# ===== Tavily 联网搜索配置 =====
|
||||
# 用于短剧剧情理解前,按短剧名称检索公开剧情/人物/分集信息
|
||||
tavily_api_key = "" # 获取地址:https://app.tavily.com
|
||||
|
||||
@ -15,6 +15,7 @@ from app.config.defaults import (
|
||||
get_openai_compatible_ui_values,
|
||||
normalize_openai_compatible_model_name as normalize_openai_compatible_model_id,
|
||||
)
|
||||
from app.utils.openai_base_url_security import validate_openai_compatible_base_url
|
||||
from app.utils import utils
|
||||
from loguru import logger
|
||||
from app.services.llm.unified_service import UnifiedLLMService
|
||||
@ -73,6 +74,11 @@ def validate_base_url(base_url: str, provider: str) -> tuple[bool, str]:
|
||||
if not (base_url.startswith('http://') or base_url.startswith('https://')):
|
||||
return False, f"{provider} Base URL必须以http://或https://开头"
|
||||
|
||||
try:
|
||||
validate_openai_compatible_base_url(base_url)
|
||||
except ValueError as exc:
|
||||
return False, f"{provider} Base URL未通过安全校验: {exc}"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
@ -420,6 +426,7 @@ def test_vision_model_connection(api_key, base_url, model_name, provider, tr):
|
||||
elif provider.lower() == 'gemini(openai)':
|
||||
# OpenAI兼容的Gemini代理测试
|
||||
try:
|
||||
base_url = validate_openai_compatible_base_url(base_url)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
@ -444,6 +451,7 @@ def test_vision_model_connection(api_key, base_url, model_name, provider, tr):
|
||||
else:
|
||||
from openai import OpenAI
|
||||
try:
|
||||
base_url = validate_openai_compatible_base_url(base_url)
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
@ -493,9 +501,10 @@ def test_openai_compatible_vision_model(api_key: str, base_url: str, model_name:
|
||||
from PIL import Image
|
||||
|
||||
logger.debug(
|
||||
f"OpenAI 兼容视觉模型连通性测试: model={model_name}, api_key={api_key[:10]}..., base_url={base_url}"
|
||||
f"OpenAI 兼容视觉模型连通性测试: model={model_name}, base_url={base_url}"
|
||||
)
|
||||
|
||||
base_url = validate_openai_compatible_base_url(base_url)
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url or None,
|
||||
@ -548,9 +557,10 @@ def test_openai_compatible_text_model(api_key: str, base_url: str, model_name: s
|
||||
from openai import OpenAI
|
||||
|
||||
logger.debug(
|
||||
f"OpenAI 兼容文本模型连通性测试: model={model_name}, api_key={api_key[:10]}..., base_url={base_url}"
|
||||
f"OpenAI 兼容文本模型连通性测试: model={model_name}, base_url={base_url}"
|
||||
)
|
||||
|
||||
base_url = validate_openai_compatible_base_url(base_url)
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url or None,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user