Merge pull request #255 from hamizan-azman/codex/validate-openai-base-url

Validate OpenAI-compatible base URLs
This commit is contained in:
viccy 2026-07-02 11:58:37 +08:00 committed by GitHub
commit c288a76ff8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 241 additions and 6 deletions

View File

@ -17,6 +17,10 @@ import subprocess
from typing import Union, TextIO
from app.config import config
from app.utils.openai_base_url_security import (
openai_compatible_base_url_warning,
validate_openai_compatible_base_url,
)
from app.utils.utils import clean_model_output
_max_retries = 5
@ -330,6 +334,10 @@ def _generate_response(prompt: str, llm_provider: str = None) -> str:
azure_endpoint=base_url,
)
else:
base_url = validate_openai_compatible_base_url(base_url)
warning = openai_compatible_base_url_warning(base_url)
if warning:
logger.warning(warning)
client = OpenAI(
api_key=api_key,
base_url=base_url,

View File

@ -23,8 +23,13 @@ from openai import (
from app.config import config
from app.config.defaults import DEFAULT_LLM_GENERATION_CONFIG, normalize_openai_compatible_model_name
from app.utils.openai_base_url_security import (
is_trusted_openai_compatible_base_url,
openai_compatible_base_url_warning,
validate_openai_compatible_base_url as _validate_openai_compatible_base_url_value,
)
from .base import TextModelProvider, VisionModelProvider
from .exceptions import APICallError, AuthenticationError, ContentFilterError, RateLimitError
from .exceptions import APICallError, AuthenticationError, ConfigurationError, ContentFilterError, RateLimitError
def _normalize_model_name(model_name: str) -> str:
@ -41,6 +46,17 @@ def _is_content_filter_error(message: str) -> bool:
return "content_filter" in lowered or "safety" in lowered
def validate_openai_compatible_base_url(base_url: Optional[str]) -> Optional[str]:
try:
normalized = _validate_openai_compatible_base_url_value(base_url)
except ValueError as exc:
raise ConfigurationError(str(exc), "base_url") from exc
warning = openai_compatible_base_url_warning(normalized)
if warning:
logger.warning(warning)
return normalized
def _clean_json_output(output: str) -> str:
"""清理 JSON 输出中的 markdown 包裹。"""
output = re.sub(r"^```json\s*", "", output, flags=re.MULTILINE)
@ -118,6 +134,7 @@ class _OpenAICompatibleBase:
"""按请求构建 AsyncOpenAI 客户端,支持动态覆盖 api_key / base_url。"""
api_key = api_key_override or self.api_key
base_url = base_url_override or self.base_url or None
base_url = validate_openai_compatible_base_url(base_url)
timeout_seconds: float = timeout_override or config.app.get("llm_text_timeout", 180)
max_retries: int = max_retries_override or config.app.get("llm_max_retries", 3)

View File

@ -6,10 +6,17 @@ from unittest.mock import patch
from app.config import config
from app.services.llm.base import TextModelProvider
from app.services.llm.exceptions import ConfigurationError
from app.services.llm.manager import LLMServiceManager
from app.services.llm.migration_adapter import LegacyLLMAdapter, VisionAnalyzerAdapter
from app.services.llm.openai_compatible_provider import OpenAICompatibleTextProvider, OpenAICompatibleVisionProvider
from app.services.llm.openai_compatible_provider import (
OpenAICompatibleTextProvider,
OpenAICompatibleVisionProvider,
is_trusted_openai_compatible_base_url,
validate_openai_compatible_base_url,
)
from app.services.llm.providers import register_all_providers
from app.utils.openai_base_url_security import openai_compatible_base_url_warning
class DummyOpenAITextProvider(TextModelProvider):
@ -170,6 +177,77 @@ class OpenAICompatGenerationOptionTests(unittest.TestCase):
self.assertEqual(512, options["max_tokens"])
class OpenAICompatBaseURLValidationTests(unittest.TestCase):
def setUp(self):
self._original_app = dict(config.app)
def tearDown(self):
config.app.clear()
config.app.update(self._original_app)
def test_known_providers_and_local_ollama_are_trusted(self):
trusted_urls = [
"https://api.openai.com/v1",
"https://api.siliconflow.cn/v1",
"https://openrouter.ai/api/v1",
"https://dashscope.aliyuncs.com/compatible-mode/v1",
"https://example.openai.azure.com/openai/deployments/demo",
"http://localhost:11434/v1",
"http://127.0.0.1:11434/v1",
]
for url in trusted_urls:
with self.subTest(url=url):
self.assertTrue(is_trusted_openai_compatible_base_url(url))
def test_unrecognized_or_unsafe_base_urls_are_not_trusted(self):
untrusted_urls = [
"https://attacker.example/v1",
"http://api.openai.com/v1",
"https://user@api.openai.com/v1",
"https://127.0.0.1:9999/v1",
"not-a-url",
]
for url in untrusted_urls:
with self.subTest(url=url):
self.assertFalse(is_trusted_openai_compatible_base_url(url))
def test_build_client_allows_well_formed_custom_base_url_by_default(self):
provider = OpenAICompatibleTextProvider(
api_key="test-key",
model_name="test-model",
base_url="https://custom.example/v1",
)
with patch("app.services.llm.openai_compatible_provider.AsyncOpenAI") as async_openai:
provider._build_client()
self.assertEqual("https://custom.example/v1", async_openai.call_args.kwargs["base_url"])
def test_custom_base_url_validation_returns_normalized_url(self):
self.assertEqual(
"https://custom.example/v1",
validate_openai_compatible_base_url(" https://custom.example/v1 "),
)
def test_custom_base_url_warning_only_for_untrusted_well_formed_urls(self):
warning = openai_compatible_base_url_warning("https://custom.example/v1")
self.assertIn("custom.example", warning)
self.assertEqual("", openai_compatible_base_url_warning("https://api.openai.com/v1"))
self.assertEqual("", openai_compatible_base_url_warning(""))
def test_custom_base_url_validation_rejects_malformed_urls(self):
provider = OpenAICompatibleTextProvider(
api_key="test-key",
model_name="test-model",
base_url="https://user@custom.example/v1",
)
with self.assertRaises(ConfigurationError):
provider._build_client()
class ExplicitVisionAdapterSettingsTests(unittest.IsolatedAsyncioTestCase):
class _CapturingVisionProvider:
last_init: tuple[str, str, str | None] | None = None

View File

@ -0,0 +1,111 @@
import ipaddress
from typing import Optional
from urllib.parse import urlparse
TRUSTED_OPENAI_COMPATIBLE_BASE_HOSTS = {
"api.openai.com",
"openrouter.ai",
"api.siliconflow.cn",
"dashscope.aliyuncs.com",
"api.deepseek.com",
"api.moonshot.cn",
"api.together.xyz",
"api.cohere.ai",
"generativelanguage.googleapis.com",
"open.bigmodel.cn",
"api.z.ai",
"ark.cn-beijing.volces.com",
"ark.cn-shanghai.volces.com",
}
TRUSTED_OPENAI_COMPATIBLE_BASE_SUFFIXES = (
".openai.azure.com",
".services.ai.azure.com",
".cognitiveservices.azure.com",
)
OPENAI_COMPATIBLE_BASE_URL_ERROR = (
"OpenAI-compatible base_url must be a valid http(s) URL without embedded credentials."
)
OPENAI_COMPATIBLE_BASE_URL_WARNING = (
"OpenAI-compatible base_url host '{host}' is not in the trusted provider list. "
"Only continue if you trust this endpoint, because it will receive the configured API key."
)
def _is_loopback_ollama_url(scheme: str, host: str, port: Optional[int]) -> bool:
if scheme not in {"http", "https"} or port != 11434:
return False
if host == "localhost":
return True
try:
return ipaddress.ip_address(host).is_loopback
except ValueError:
return False
def _is_well_formed_http_base_url(base_url: str) -> bool:
parsed = urlparse(str(base_url).strip())
if parsed.scheme not in {"http", "https"} or not parsed.hostname:
return False
if parsed.username or parsed.password:
return False
try:
parsed.port
except ValueError:
return False
return True
def is_trusted_openai_compatible_base_url(base_url: Optional[str]) -> bool:
if not base_url:
return True
parsed = urlparse(str(base_url).strip())
host = (parsed.hostname or "").rstrip(".").lower()
if not parsed.scheme or not host or parsed.username or parsed.password:
return False
try:
port = parsed.port
except ValueError:
return False
if _is_loopback_ollama_url(parsed.scheme, host, port):
return True
if parsed.scheme != "https":
return False
if host in TRUSTED_OPENAI_COMPATIBLE_BASE_HOSTS:
return True
return any(host.endswith(suffix) for suffix in TRUSTED_OPENAI_COMPATIBLE_BASE_SUFFIXES)
def openai_compatible_base_url_warning(base_url: Optional[str]) -> str:
if not base_url:
return ""
normalized = str(base_url).strip()
if is_trusted_openai_compatible_base_url(normalized) or not _is_well_formed_http_base_url(normalized):
return ""
parsed = urlparse(normalized)
host = (parsed.hostname or "").rstrip(".").lower()
return OPENAI_COMPATIBLE_BASE_URL_WARNING.format(host=host)
def validate_openai_compatible_base_url(base_url: Optional[str]) -> Optional[str]:
if not base_url:
return None
normalized = str(base_url).strip()
if is_trusted_openai_compatible_base_url(normalized):
return normalized
if _is_well_formed_http_base_url(normalized):
return normalized
raise ValueError(OPENAI_COMPATIBLE_BASE_URL_ERROR)

View File

@ -26,7 +26,7 @@
# - SiliconFlow: siliconflow/Qwen/Qwen2.5-VL-32B-Instruct
vision_openai_model_name = "Qwen/Qwen3.5-122B-A10B"
vision_openai_api_key = "" # 填入对应 provider 的 API key
vision_openai_base_url = "https://api.siliconflow.cn/v1" # 可选:自定义 API base URL(官方 OpenAI 可留空)
vision_openai_base_url = "https://api.siliconflow.cn/v1" # 可选:自定义 API base URL;界面会提示 API key 将发送到对应端点
vision_openai_temperature = 1.0
vision_openai_top_p = 0.95
vision_openai_max_tokens = 65536
@ -55,7 +55,7 @@
# - Moonshot: moonshot/moonshot-v1-8k
text_openai_model_name = "Pro/zai-org/GLM-5"
text_openai_api_key = "" # 填入对应 provider 的 API key
text_openai_base_url = "https://api.siliconflow.cn/v1" # 可选:自定义 API base URL(官方 OpenAI 可留空)
text_openai_base_url = "https://api.siliconflow.cn/v1" # 可选:自定义 API base URL;界面会提示 API key 将发送到对应端点
text_openai_temperature = 1.0
text_openai_top_p = 0.95
text_openai_max_tokens = 65536

View File

@ -15,6 +15,10 @@ from app.config.defaults import (
get_openai_compatible_ui_values,
normalize_openai_compatible_model_name as normalize_openai_compatible_model_id,
)
from app.utils.openai_base_url_security import (
openai_compatible_base_url_warning,
validate_openai_compatible_base_url,
)
from app.utils import utils
from loguru import logger
from app.services.llm.unified_service import UnifiedLLMService
@ -73,9 +77,20 @@ def validate_base_url(base_url: str, provider: str) -> tuple[bool, str]:
if not (base_url.startswith('http://') or base_url.startswith('https://')):
return False, f"{provider} Base URL必须以http://或https://开头"
try:
validate_openai_compatible_base_url(base_url)
except ValueError as exc:
return False, f"{provider} Base URL格式无效: {exc}"
return True, ""
def show_base_url_security_warning(base_url: str) -> None:
warning = openai_compatible_base_url_warning(base_url)
if warning:
st.warning(warning)
def validate_model_name(model_name: str, provider: str) -> tuple[bool, str]:
"""验证模型名称"""
if not model_name or not model_name.strip():
@ -420,6 +435,7 @@ def test_vision_model_connection(api_key, base_url, model_name, provider, tr):
elif provider.lower() == 'gemini(openai)':
# OpenAI兼容的Gemini代理测试
try:
base_url = validate_openai_compatible_base_url(base_url)
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
@ -444,6 +460,7 @@ def test_vision_model_connection(api_key, base_url, model_name, provider, tr):
else:
from openai import OpenAI
try:
base_url = validate_openai_compatible_base_url(base_url)
client = OpenAI(
api_key=api_key,
base_url=base_url,
@ -493,9 +510,10 @@ def test_openai_compatible_vision_model(api_key: str, base_url: str, model_name:
from PIL import Image
logger.debug(
f"OpenAI 兼容视觉模型连通性测试: model={model_name}, api_key={api_key[:10]}..., base_url={base_url}"
f"OpenAI 兼容视觉模型连通性测试: model={model_name}, base_url={base_url}"
)
base_url = validate_openai_compatible_base_url(base_url)
client = OpenAI(
api_key=api_key,
base_url=base_url or None,
@ -548,9 +566,10 @@ def test_openai_compatible_text_model(api_key: str, base_url: str, model_name: s
from openai import OpenAI
logger.debug(
f"OpenAI 兼容文本模型连通性测试: model={model_name}, api_key={api_key[:10]}..., base_url={base_url}"
f"OpenAI 兼容文本模型连通性测试: model={model_name}, base_url={base_url}"
)
base_url = validate_openai_compatible_base_url(base_url)
client = OpenAI(
api_key=api_key,
base_url=base_url or None,
@ -653,6 +672,7 @@ def render_vision_llm_settings(tr):
if vision_base_required and not st_vision_base_url:
info_example = vision_placeholder or "https://your-openai-compatible-endpoint/v1"
st.info(tr("Please fill OpenAI compatible gateway").format(example=info_example))
show_base_url_security_warning(st_vision_base_url)
vision_generation_params = render_llm_generation_settings(tr, "vision")
@ -923,6 +943,7 @@ def render_text_llm_settings(tr):
if text_base_required and not st_text_base_url:
info_example = text_placeholder or "https://your-openai-compatible-endpoint/v1"
st.info(tr("Please fill OpenAI compatible gateway").format(example=info_example))
show_base_url_security_warning(st_text_base_url)
text_generation_params = render_llm_generation_settings(tr, "text")