mirror of
https://github.com/linyqh/NarratoAI.git
synced 2026-05-29 19:58:15 +00:00
119 lines
3.9 KiB
Python
119 lines
3.9 KiB
Python
"""OpenAI 兼容 provider 的最小回归测试。"""
|
|
|
|
import asyncio
|
|
import unittest
|
|
|
|
from app.config import config
|
|
from app.services.llm.base import TextModelProvider
|
|
from app.services.llm.manager import LLMServiceManager
|
|
from app.services.llm.openai_compatible_provider import OpenAICompatibleVisionProvider
|
|
from app.services.llm.providers import register_all_providers
|
|
|
|
|
|
class DummyOpenAITextProvider(TextModelProvider):
|
|
@property
|
|
def provider_name(self) -> str:
|
|
return "openai"
|
|
|
|
@property
|
|
def supported_models(self) -> list[str]:
|
|
return []
|
|
|
|
async def generate_text(self, prompt: str, **kwargs) -> str:
|
|
return prompt
|
|
|
|
async def _make_api_call(self, payload: dict) -> dict:
|
|
return payload
|
|
|
|
|
|
def _reset_manager_state():
|
|
LLMServiceManager._vision_providers.clear()
|
|
LLMServiceManager._text_providers.clear()
|
|
LLMServiceManager._vision_instance_cache.clear()
|
|
LLMServiceManager._text_instance_cache.clear()
|
|
|
|
|
|
class OpenAICompatManagerTests(unittest.TestCase):
|
|
def setUp(self):
|
|
_reset_manager_state()
|
|
self._original_app = dict(config.app)
|
|
|
|
def tearDown(self):
|
|
_reset_manager_state()
|
|
config.app.clear()
|
|
config.app.update(self._original_app)
|
|
|
|
def test_register_all_providers_only_registers_openai_provider(self):
|
|
register_all_providers()
|
|
|
|
self.assertEqual({"openai"}, set(LLMServiceManager.list_text_providers()))
|
|
self.assertEqual({"openai"}, set(LLMServiceManager.list_vision_providers()))
|
|
|
|
def test_get_text_provider_uses_openai_keys(self):
|
|
LLMServiceManager.register_text_provider("openai", DummyOpenAITextProvider)
|
|
|
|
config.app["text_llm_provider"] = "openai"
|
|
config.app["text_openai_api_key"] = "new-key"
|
|
config.app["text_openai_model_name"] = "new-model"
|
|
config.app["text_openai_base_url"] = "https://new.example/v1"
|
|
|
|
provider = LLMServiceManager.get_text_provider()
|
|
|
|
self.assertIsInstance(provider, DummyOpenAITextProvider)
|
|
self.assertEqual("new-key", provider.api_key)
|
|
self.assertEqual("new-model", provider.model_name)
|
|
self.assertEqual("https://new.example/v1", provider.base_url)
|
|
|
|
|
|
class OpenAICompatVisionConcurrencyTests(unittest.IsolatedAsyncioTestCase):
|
|
async def test_analyze_images_keeps_batch_order_when_running_concurrently(self):
|
|
provider = OpenAICompatibleVisionProvider(api_key="k", model_name="m")
|
|
provider._prepare_images = lambda images: list(images)
|
|
|
|
async def fake_analyze_batch(batch, prompt, **kwargs):
|
|
delays = {"a": 0.03, "c": 0.01, "e": 0.0}
|
|
await asyncio.sleep(delays[batch[0]])
|
|
return f"batch-{batch[0]}"
|
|
|
|
provider._analyze_batch = fake_analyze_batch
|
|
|
|
result = await provider.analyze_images(
|
|
images=["a", "b", "c", "d", "e", "f"],
|
|
prompt="prompt",
|
|
batch_size=2,
|
|
max_concurrency=2,
|
|
)
|
|
|
|
self.assertEqual(["batch-a", "batch-c", "batch-e"], result)
|
|
|
|
async def test_analyze_images_respects_max_concurrency_limit(self):
|
|
provider = OpenAICompatibleVisionProvider(api_key="k", model_name="m")
|
|
provider._prepare_images = lambda images: list(images)
|
|
|
|
in_flight = 0
|
|
max_in_flight = 0
|
|
|
|
async def fake_analyze_batch(batch, prompt, **kwargs):
|
|
nonlocal in_flight, max_in_flight
|
|
in_flight += 1
|
|
max_in_flight = max(max_in_flight, in_flight)
|
|
await asyncio.sleep(0.02)
|
|
in_flight -= 1
|
|
return f"batch-{batch[0]}"
|
|
|
|
provider._analyze_batch = fake_analyze_batch
|
|
|
|
result = await provider.analyze_images(
|
|
images=["a", "b", "c", "d", "e", "f"],
|
|
prompt="prompt",
|
|
batch_size=1,
|
|
max_concurrency=2,
|
|
)
|
|
|
|
self.assertEqual(6, len(result))
|
|
self.assertEqual(2, max_in_flight)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|