Fix documentary narration parsing and explicit vision overrides

This commit is contained in:
linyq 2026-04-03 12:04:09 +08:00
parent 4e2560651f
commit abc9db22e5
4 changed files with 85 additions and 37 deletions

View File

@ -119,9 +119,11 @@ JSON 必须包含以下键:
concurrency = self._resolve_max_concurrency(max_concurrency)
provider = (vision_llm_provider or config.app.get("vision_llm_provider", "openai")).lower()
api_key = vision_api_key or config.app.get(f"vision_{provider}_api_key")
model_name = vision_model_name or config.app.get(f"vision_{provider}_model_name")
base_url = vision_base_url or config.app.get(f"vision_{provider}_base_url", "")
api_key = vision_api_key if vision_api_key is not None else config.app.get(f"vision_{provider}_api_key")
model_name = (
vision_model_name if vision_model_name is not None else config.app.get(f"vision_{provider}_model_name")
)
base_url = vision_base_url if vision_base_url is not None else config.app.get(f"vision_{provider}_base_url", "")
if not api_key or not model_name:
raise ValueError(
f"未配置 {provider} 的 API Key 或模型名称。"

View File

@ -5,7 +5,6 @@
"""
import asyncio
import json
from typing import List, Dict, Any, Optional, Union
from pathlib import Path
import PIL.Image
@ -111,41 +110,11 @@ class LegacyLLMAdapter:
temperature=1.5,
response_format="json"
)
# 使用增强的JSON解析器
from webui.tools.generate_short_summary import parse_and_fix_json
parsed_result = parse_and_fix_json(result)
if not parsed_result:
logger.error("无法解析LLM返回的JSON数据")
# 返回一个基本的JSON结构而不是错误字符串
return json.dumps({
"items": [
{
"_id": 1,
"timestamp": "00:00:00-00:00:10",
"picture": "解析失败请检查LLM输出",
"narration": "解说文案生成失败,请重试"
}
]
}, ensure_ascii=False)
# 确保返回的是JSON字符串
return json.dumps(parsed_result, ensure_ascii=False)
return result if isinstance(result, str) else str(result)
except Exception as e:
logger.error(f"生成解说文案失败: {str(e)}")
# 返回一个基本的JSON结构而不是错误字符串
return json.dumps({
"items": [
{
"_id": 1,
"timestamp": "00:00:00-00:00:10",
"picture": "生成失败",
"narration": f"解说文案生成失败: {str(e)}"
}
]
}, ensure_ascii=False)
raise
class VisionAnalyzerAdapter:
@ -198,6 +167,8 @@ class VisionAnalyzerAdapter:
prompt=prompt,
batch_size=batch_size,
max_concurrency=max_concurrency,
api_key=self.api_key,
api_base=self.base_url,
)
# 转换为旧格式以保持向后兼容性

View File

@ -2,11 +2,12 @@
import asyncio
import unittest
from unittest.mock import patch
from app.config import config
from app.services.llm.base import TextModelProvider
from app.services.llm.manager import LLMServiceManager
from app.services.llm.migration_adapter import VisionAnalyzerAdapter
from app.services.llm.migration_adapter import LegacyLLMAdapter, VisionAnalyzerAdapter
from app.services.llm.openai_compatible_provider import OpenAICompatibleVisionProvider
from app.services.llm.providers import register_all_providers
@ -118,6 +119,7 @@ class OpenAICompatVisionConcurrencyTests(unittest.IsolatedAsyncioTestCase):
class ExplicitVisionAdapterSettingsTests(unittest.IsolatedAsyncioTestCase):
class _CapturingVisionProvider:
last_init: tuple[str, str, str | None] | None = None
last_call_kwargs: dict | None = None
def __init__(self, api_key: str, model_name: str, base_url: str | None = None):
self.api_key = api_key
@ -126,6 +128,7 @@ class ExplicitVisionAdapterSettingsTests(unittest.IsolatedAsyncioTestCase):
ExplicitVisionAdapterSettingsTests._CapturingVisionProvider.last_init = (api_key, model_name, base_url)
async def analyze_images(self, images, prompt, batch_size=10, max_concurrency=1, **kwargs):
ExplicitVisionAdapterSettingsTests._CapturingVisionProvider.last_call_kwargs = dict(kwargs)
return [f"{self.model_name}|{self.api_key}|{self.base_url}"]
def setUp(self):
@ -160,8 +163,32 @@ class ExplicitVisionAdapterSettingsTests(unittest.IsolatedAsyncioTestCase):
("explicit-key", "explicit-model", "https://explicit.example/v1"),
self._CapturingVisionProvider.last_init,
)
self.assertEqual("explicit-key", self._CapturingVisionProvider.last_call_kwargs["api_key"])
self.assertEqual("https://explicit.example/v1", self._CapturingVisionProvider.last_call_kwargs["api_base"])
self.assertEqual("explicit-model|explicit-key|https://explicit.example/v1", result[0]["response"])
class LegacyNarrationAdapterBehaviorTests(unittest.TestCase):
def test_generate_narration_returns_raw_unrecoverable_payload_without_fabrication(self):
raw_payload = "not-json-at-all ::: ???"
with patch(
"app.services.llm.migration_adapter.PromptManager.get_prompt",
return_value="prompt",
), patch(
"app.services.llm.migration_adapter._run_async_safely",
return_value=raw_payload,
):
result = LegacyLLMAdapter.generate_narration(
markdown_content="markdown",
api_key="test-key",
base_url="https://example.com/v1",
model="test-model",
)
self.assertEqual(raw_payload, result)
self.assertNotIn('"items"', result)
if __name__ == "__main__":
unittest.main()

View File

@ -260,9 +260,57 @@ class DocumentaryFrameAnalysisServiceScriptGenerationTests(unittest.IsolatedAsyn
)
narration_input = mocked_generate.call_args.args[0]
self.assertIn("## 创作上下文", narration_input)
self.assertIn("视频主题:野生动物纪录片", narration_input)
self.assertIn("补充创作要求:重点描述危险信号", narration_input)
async def test_analyze_video_forwards_explicit_empty_base_url_without_config_fallback(self):
service = DocumentaryFrameAnalysisService()
with patch.dict(
"app.services.documentary.frame_analysis_service.config.app",
{
"vision_llm_provider": "openai",
"vision_openai_api_key": "config-key",
"vision_openai_model_name": "config-model",
"vision_openai_base_url": "https://config.example/v1",
},
), patch(
"app.services.documentary.frame_analysis_service.os.path.exists",
return_value=True,
), patch.object(
service,
"_load_or_extract_keyframes",
return_value=["/tmp/keyframe_000001_000000100.jpg"],
), patch.object(
service,
"_analyze_batches",
AsyncMock(return_value=[]),
), patch.object(
service,
"_save_analysis_artifact",
return_value="/tmp/frame_analysis_test.json",
), patch.object(
service,
"_build_video_clip_json",
return_value=[],
), patch(
"app.services.documentary.frame_analysis_service.create_vision_analyzer",
return_value=object(),
) as mocked_create_analyzer:
await service.analyze_video(
video_path="/tmp/demo.mp4",
vision_api_key="explicit-key",
vision_model_name="explicit-model",
vision_base_url="",
)
called_kwargs = mocked_create_analyzer.call_args.kwargs
self.assertEqual("openai", called_kwargs["provider"])
self.assertEqual("explicit-key", called_kwargs["api_key"])
self.assertEqual("explicit-model", called_kwargs["model"])
self.assertEqual("", called_kwargs["base_url"])
if __name__ == "__main__":
unittest.main()