mirror of
https://github.com/linyqh/NarratoAI.git
synced 2026-07-02 12:25:35 +00:00
Merge pull request #255 from hamizan-azman/codex/validate-openai-base-url
Validate OpenAI-compatible base URLs
This commit is contained in:
commit
c288a76ff8
@ -17,6 +17,10 @@ import subprocess
|
||||
from typing import Union, TextIO
|
||||
|
||||
from app.config import config
|
||||
from app.utils.openai_base_url_security import (
|
||||
openai_compatible_base_url_warning,
|
||||
validate_openai_compatible_base_url,
|
||||
)
|
||||
from app.utils.utils import clean_model_output
|
||||
|
||||
_max_retries = 5
|
||||
@ -330,6 +334,10 @@ def _generate_response(prompt: str, llm_provider: str = None) -> str:
|
||||
azure_endpoint=base_url,
|
||||
)
|
||||
else:
|
||||
base_url = validate_openai_compatible_base_url(base_url)
|
||||
warning = openai_compatible_base_url_warning(base_url)
|
||||
if warning:
|
||||
logger.warning(warning)
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
|
||||
@ -23,8 +23,13 @@ from openai import (
|
||||
|
||||
from app.config import config
|
||||
from app.config.defaults import DEFAULT_LLM_GENERATION_CONFIG, normalize_openai_compatible_model_name
|
||||
from app.utils.openai_base_url_security import (
|
||||
is_trusted_openai_compatible_base_url,
|
||||
openai_compatible_base_url_warning,
|
||||
validate_openai_compatible_base_url as _validate_openai_compatible_base_url_value,
|
||||
)
|
||||
from .base import TextModelProvider, VisionModelProvider
|
||||
from .exceptions import APICallError, AuthenticationError, ContentFilterError, RateLimitError
|
||||
from .exceptions import APICallError, AuthenticationError, ConfigurationError, ContentFilterError, RateLimitError
|
||||
|
||||
|
||||
def _normalize_model_name(model_name: str) -> str:
|
||||
@ -41,6 +46,17 @@ def _is_content_filter_error(message: str) -> bool:
|
||||
return "content_filter" in lowered or "safety" in lowered
|
||||
|
||||
|
||||
def validate_openai_compatible_base_url(base_url: Optional[str]) -> Optional[str]:
|
||||
try:
|
||||
normalized = _validate_openai_compatible_base_url_value(base_url)
|
||||
except ValueError as exc:
|
||||
raise ConfigurationError(str(exc), "base_url") from exc
|
||||
warning = openai_compatible_base_url_warning(normalized)
|
||||
if warning:
|
||||
logger.warning(warning)
|
||||
return normalized
|
||||
|
||||
|
||||
def _clean_json_output(output: str) -> str:
|
||||
"""清理 JSON 输出中的 markdown 包裹。"""
|
||||
output = re.sub(r"^```json\s*", "", output, flags=re.MULTILINE)
|
||||
@ -118,6 +134,7 @@ class _OpenAICompatibleBase:
|
||||
"""按请求构建 AsyncOpenAI 客户端,支持动态覆盖 api_key / base_url。"""
|
||||
api_key = api_key_override or self.api_key
|
||||
base_url = base_url_override or self.base_url or None
|
||||
base_url = validate_openai_compatible_base_url(base_url)
|
||||
|
||||
timeout_seconds: float = timeout_override or config.app.get("llm_text_timeout", 180)
|
||||
max_retries: int = max_retries_override or config.app.get("llm_max_retries", 3)
|
||||
|
||||
@ -6,10 +6,17 @@ from unittest.mock import patch
|
||||
|
||||
from app.config import config
|
||||
from app.services.llm.base import TextModelProvider
|
||||
from app.services.llm.exceptions import ConfigurationError
|
||||
from app.services.llm.manager import LLMServiceManager
|
||||
from app.services.llm.migration_adapter import LegacyLLMAdapter, VisionAnalyzerAdapter
|
||||
from app.services.llm.openai_compatible_provider import OpenAICompatibleTextProvider, OpenAICompatibleVisionProvider
|
||||
from app.services.llm.openai_compatible_provider import (
|
||||
OpenAICompatibleTextProvider,
|
||||
OpenAICompatibleVisionProvider,
|
||||
is_trusted_openai_compatible_base_url,
|
||||
validate_openai_compatible_base_url,
|
||||
)
|
||||
from app.services.llm.providers import register_all_providers
|
||||
from app.utils.openai_base_url_security import openai_compatible_base_url_warning
|
||||
|
||||
|
||||
class DummyOpenAITextProvider(TextModelProvider):
|
||||
@ -170,6 +177,77 @@ class OpenAICompatGenerationOptionTests(unittest.TestCase):
|
||||
self.assertEqual(512, options["max_tokens"])
|
||||
|
||||
|
||||
class OpenAICompatBaseURLValidationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self._original_app = dict(config.app)
|
||||
|
||||
def tearDown(self):
|
||||
config.app.clear()
|
||||
config.app.update(self._original_app)
|
||||
|
||||
def test_known_providers_and_local_ollama_are_trusted(self):
|
||||
trusted_urls = [
|
||||
"https://api.openai.com/v1",
|
||||
"https://api.siliconflow.cn/v1",
|
||||
"https://openrouter.ai/api/v1",
|
||||
"https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"https://example.openai.azure.com/openai/deployments/demo",
|
||||
"http://localhost:11434/v1",
|
||||
"http://127.0.0.1:11434/v1",
|
||||
]
|
||||
|
||||
for url in trusted_urls:
|
||||
with self.subTest(url=url):
|
||||
self.assertTrue(is_trusted_openai_compatible_base_url(url))
|
||||
|
||||
def test_unrecognized_or_unsafe_base_urls_are_not_trusted(self):
|
||||
untrusted_urls = [
|
||||
"https://attacker.example/v1",
|
||||
"http://api.openai.com/v1",
|
||||
"https://user@api.openai.com/v1",
|
||||
"https://127.0.0.1:9999/v1",
|
||||
"not-a-url",
|
||||
]
|
||||
|
||||
for url in untrusted_urls:
|
||||
with self.subTest(url=url):
|
||||
self.assertFalse(is_trusted_openai_compatible_base_url(url))
|
||||
|
||||
def test_build_client_allows_well_formed_custom_base_url_by_default(self):
|
||||
provider = OpenAICompatibleTextProvider(
|
||||
api_key="test-key",
|
||||
model_name="test-model",
|
||||
base_url="https://custom.example/v1",
|
||||
)
|
||||
|
||||
with patch("app.services.llm.openai_compatible_provider.AsyncOpenAI") as async_openai:
|
||||
provider._build_client()
|
||||
|
||||
self.assertEqual("https://custom.example/v1", async_openai.call_args.kwargs["base_url"])
|
||||
|
||||
def test_custom_base_url_validation_returns_normalized_url(self):
|
||||
self.assertEqual(
|
||||
"https://custom.example/v1",
|
||||
validate_openai_compatible_base_url(" https://custom.example/v1 "),
|
||||
)
|
||||
|
||||
def test_custom_base_url_warning_only_for_untrusted_well_formed_urls(self):
|
||||
warning = openai_compatible_base_url_warning("https://custom.example/v1")
|
||||
self.assertIn("custom.example", warning)
|
||||
self.assertEqual("", openai_compatible_base_url_warning("https://api.openai.com/v1"))
|
||||
self.assertEqual("", openai_compatible_base_url_warning(""))
|
||||
|
||||
def test_custom_base_url_validation_rejects_malformed_urls(self):
|
||||
provider = OpenAICompatibleTextProvider(
|
||||
api_key="test-key",
|
||||
model_name="test-model",
|
||||
base_url="https://user@custom.example/v1",
|
||||
)
|
||||
|
||||
with self.assertRaises(ConfigurationError):
|
||||
provider._build_client()
|
||||
|
||||
|
||||
class ExplicitVisionAdapterSettingsTests(unittest.IsolatedAsyncioTestCase):
|
||||
class _CapturingVisionProvider:
|
||||
last_init: tuple[str, str, str | None] | None = None
|
||||
|
||||
111
app/utils/openai_base_url_security.py
Normal file
111
app/utils/openai_base_url_security.py
Normal file
@ -0,0 +1,111 @@
|
||||
import ipaddress
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
TRUSTED_OPENAI_COMPATIBLE_BASE_HOSTS = {
|
||||
"api.openai.com",
|
||||
"openrouter.ai",
|
||||
"api.siliconflow.cn",
|
||||
"dashscope.aliyuncs.com",
|
||||
"api.deepseek.com",
|
||||
"api.moonshot.cn",
|
||||
"api.together.xyz",
|
||||
"api.cohere.ai",
|
||||
"generativelanguage.googleapis.com",
|
||||
"open.bigmodel.cn",
|
||||
"api.z.ai",
|
||||
"ark.cn-beijing.volces.com",
|
||||
"ark.cn-shanghai.volces.com",
|
||||
}
|
||||
|
||||
TRUSTED_OPENAI_COMPATIBLE_BASE_SUFFIXES = (
|
||||
".openai.azure.com",
|
||||
".services.ai.azure.com",
|
||||
".cognitiveservices.azure.com",
|
||||
)
|
||||
|
||||
OPENAI_COMPATIBLE_BASE_URL_ERROR = (
|
||||
"OpenAI-compatible base_url must be a valid http(s) URL without embedded credentials."
|
||||
)
|
||||
|
||||
OPENAI_COMPATIBLE_BASE_URL_WARNING = (
|
||||
"OpenAI-compatible base_url host '{host}' is not in the trusted provider list. "
|
||||
"Only continue if you trust this endpoint, because it will receive the configured API key."
|
||||
)
|
||||
|
||||
|
||||
def _is_loopback_ollama_url(scheme: str, host: str, port: Optional[int]) -> bool:
|
||||
if scheme not in {"http", "https"} or port != 11434:
|
||||
return False
|
||||
if host == "localhost":
|
||||
return True
|
||||
try:
|
||||
return ipaddress.ip_address(host).is_loopback
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _is_well_formed_http_base_url(base_url: str) -> bool:
|
||||
parsed = urlparse(str(base_url).strip())
|
||||
if parsed.scheme not in {"http", "https"} or not parsed.hostname:
|
||||
return False
|
||||
if parsed.username or parsed.password:
|
||||
return False
|
||||
try:
|
||||
parsed.port
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_trusted_openai_compatible_base_url(base_url: Optional[str]) -> bool:
|
||||
if not base_url:
|
||||
return True
|
||||
|
||||
parsed = urlparse(str(base_url).strip())
|
||||
host = (parsed.hostname or "").rstrip(".").lower()
|
||||
if not parsed.scheme or not host or parsed.username or parsed.password:
|
||||
return False
|
||||
|
||||
try:
|
||||
port = parsed.port
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
if _is_loopback_ollama_url(parsed.scheme, host, port):
|
||||
return True
|
||||
|
||||
if parsed.scheme != "https":
|
||||
return False
|
||||
|
||||
if host in TRUSTED_OPENAI_COMPATIBLE_BASE_HOSTS:
|
||||
return True
|
||||
|
||||
return any(host.endswith(suffix) for suffix in TRUSTED_OPENAI_COMPATIBLE_BASE_SUFFIXES)
|
||||
|
||||
|
||||
def openai_compatible_base_url_warning(base_url: Optional[str]) -> str:
|
||||
if not base_url:
|
||||
return ""
|
||||
|
||||
normalized = str(base_url).strip()
|
||||
if is_trusted_openai_compatible_base_url(normalized) or not _is_well_formed_http_base_url(normalized):
|
||||
return ""
|
||||
|
||||
parsed = urlparse(normalized)
|
||||
host = (parsed.hostname or "").rstrip(".").lower()
|
||||
return OPENAI_COMPATIBLE_BASE_URL_WARNING.format(host=host)
|
||||
|
||||
|
||||
def validate_openai_compatible_base_url(base_url: Optional[str]) -> Optional[str]:
|
||||
if not base_url:
|
||||
return None
|
||||
|
||||
normalized = str(base_url).strip()
|
||||
if is_trusted_openai_compatible_base_url(normalized):
|
||||
return normalized
|
||||
if _is_well_formed_http_base_url(normalized):
|
||||
return normalized
|
||||
|
||||
raise ValueError(OPENAI_COMPATIBLE_BASE_URL_ERROR)
|
||||
@ -26,7 +26,7 @@
|
||||
# - SiliconFlow: siliconflow/Qwen/Qwen2.5-VL-32B-Instruct
|
||||
vision_openai_model_name = "Qwen/Qwen3.5-122B-A10B"
|
||||
vision_openai_api_key = "" # 填入对应 provider 的 API key
|
||||
vision_openai_base_url = "https://api.siliconflow.cn/v1" # 可选:自定义 API base URL(官方 OpenAI 可留空)
|
||||
vision_openai_base_url = "https://api.siliconflow.cn/v1" # 可选:自定义 API base URL;界面会提示 API key 将发送到对应端点
|
||||
vision_openai_temperature = 1.0
|
||||
vision_openai_top_p = 0.95
|
||||
vision_openai_max_tokens = 65536
|
||||
@ -55,7 +55,7 @@
|
||||
# - Moonshot: moonshot/moonshot-v1-8k
|
||||
text_openai_model_name = "Pro/zai-org/GLM-5"
|
||||
text_openai_api_key = "" # 填入对应 provider 的 API key
|
||||
text_openai_base_url = "https://api.siliconflow.cn/v1" # 可选:自定义 API base URL(官方 OpenAI 可留空)
|
||||
text_openai_base_url = "https://api.siliconflow.cn/v1" # 可选:自定义 API base URL;界面会提示 API key 将发送到对应端点
|
||||
text_openai_temperature = 1.0
|
||||
text_openai_top_p = 0.95
|
||||
text_openai_max_tokens = 65536
|
||||
|
||||
@ -15,6 +15,10 @@ from app.config.defaults import (
|
||||
get_openai_compatible_ui_values,
|
||||
normalize_openai_compatible_model_name as normalize_openai_compatible_model_id,
|
||||
)
|
||||
from app.utils.openai_base_url_security import (
|
||||
openai_compatible_base_url_warning,
|
||||
validate_openai_compatible_base_url,
|
||||
)
|
||||
from app.utils import utils
|
||||
from loguru import logger
|
||||
from app.services.llm.unified_service import UnifiedLLMService
|
||||
@ -73,9 +77,20 @@ def validate_base_url(base_url: str, provider: str) -> tuple[bool, str]:
|
||||
if not (base_url.startswith('http://') or base_url.startswith('https://')):
|
||||
return False, f"{provider} Base URL必须以http://或https://开头"
|
||||
|
||||
try:
|
||||
validate_openai_compatible_base_url(base_url)
|
||||
except ValueError as exc:
|
||||
return False, f"{provider} Base URL格式无效: {exc}"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def show_base_url_security_warning(base_url: str) -> None:
|
||||
warning = openai_compatible_base_url_warning(base_url)
|
||||
if warning:
|
||||
st.warning(warning)
|
||||
|
||||
|
||||
def validate_model_name(model_name: str, provider: str) -> tuple[bool, str]:
|
||||
"""验证模型名称"""
|
||||
if not model_name or not model_name.strip():
|
||||
@ -420,6 +435,7 @@ def test_vision_model_connection(api_key, base_url, model_name, provider, tr):
|
||||
elif provider.lower() == 'gemini(openai)':
|
||||
# OpenAI兼容的Gemini代理测试
|
||||
try:
|
||||
base_url = validate_openai_compatible_base_url(base_url)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
@ -444,6 +460,7 @@ def test_vision_model_connection(api_key, base_url, model_name, provider, tr):
|
||||
else:
|
||||
from openai import OpenAI
|
||||
try:
|
||||
base_url = validate_openai_compatible_base_url(base_url)
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
@ -493,9 +510,10 @@ def test_openai_compatible_vision_model(api_key: str, base_url: str, model_name:
|
||||
from PIL import Image
|
||||
|
||||
logger.debug(
|
||||
f"OpenAI 兼容视觉模型连通性测试: model={model_name}, api_key={api_key[:10]}..., base_url={base_url}"
|
||||
f"OpenAI 兼容视觉模型连通性测试: model={model_name}, base_url={base_url}"
|
||||
)
|
||||
|
||||
base_url = validate_openai_compatible_base_url(base_url)
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url or None,
|
||||
@ -548,9 +566,10 @@ def test_openai_compatible_text_model(api_key: str, base_url: str, model_name: s
|
||||
from openai import OpenAI
|
||||
|
||||
logger.debug(
|
||||
f"OpenAI 兼容文本模型连通性测试: model={model_name}, api_key={api_key[:10]}..., base_url={base_url}"
|
||||
f"OpenAI 兼容文本模型连通性测试: model={model_name}, base_url={base_url}"
|
||||
)
|
||||
|
||||
base_url = validate_openai_compatible_base_url(base_url)
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url or None,
|
||||
@ -653,6 +672,7 @@ def render_vision_llm_settings(tr):
|
||||
if vision_base_required and not st_vision_base_url:
|
||||
info_example = vision_placeholder or "https://your-openai-compatible-endpoint/v1"
|
||||
st.info(tr("Please fill OpenAI compatible gateway").format(example=info_example))
|
||||
show_base_url_security_warning(st_vision_base_url)
|
||||
|
||||
vision_generation_params = render_llm_generation_settings(tr, "vision")
|
||||
|
||||
@ -923,6 +943,7 @@ def render_text_llm_settings(tr):
|
||||
if text_base_required and not st_text_base_url:
|
||||
info_example = text_placeholder or "https://your-openai-compatible-endpoint/v1"
|
||||
st.info(tr("Please fill OpenAI compatible gateway").format(example=info_example))
|
||||
show_base_url_security_warning(st_text_base_url)
|
||||
|
||||
text_generation_params = render_llm_generation_settings(tr, "text")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user