From abc9db22e5f1730aaafaa2b4e5d87a19bee17668 Mon Sep 17 00:00:00 2001 From: linyq Date: Fri, 3 Apr 2026 12:04:09 +0800 Subject: [PATCH] Fix documentary narration parsing and explicit vision overrides --- .../documentary/frame_analysis_service.py | 8 ++-- app/services/llm/migration_adapter.py | 37 ++------------ .../llm/test_openai_compat_unittest.py | 29 ++++++++++- ...est_script_service_documentary_unittest.py | 48 +++++++++++++++++++ 4 files changed, 85 insertions(+), 37 deletions(-) diff --git a/app/services/documentary/frame_analysis_service.py b/app/services/documentary/frame_analysis_service.py index e4278f7..cbb794a 100644 --- a/app/services/documentary/frame_analysis_service.py +++ b/app/services/documentary/frame_analysis_service.py @@ -119,9 +119,11 @@ JSON 必须包含以下键: concurrency = self._resolve_max_concurrency(max_concurrency) provider = (vision_llm_provider or config.app.get("vision_llm_provider", "openai")).lower() - api_key = vision_api_key or config.app.get(f"vision_{provider}_api_key") - model_name = vision_model_name or config.app.get(f"vision_{provider}_model_name") - base_url = vision_base_url or config.app.get(f"vision_{provider}_base_url", "") + api_key = vision_api_key if vision_api_key is not None else config.app.get(f"vision_{provider}_api_key") + model_name = ( + vision_model_name if vision_model_name is not None else config.app.get(f"vision_{provider}_model_name") + ) + base_url = vision_base_url if vision_base_url is not None else config.app.get(f"vision_{provider}_base_url", "") if not api_key or not model_name: raise ValueError( f"未配置 {provider} 的 API Key 或模型名称。" diff --git a/app/services/llm/migration_adapter.py b/app/services/llm/migration_adapter.py index 4107488..7bd5142 100644 --- a/app/services/llm/migration_adapter.py +++ b/app/services/llm/migration_adapter.py @@ -5,7 +5,6 @@ """ import asyncio -import json from typing import List, Dict, Any, Optional, Union from pathlib import Path import PIL.Image @@ -111,41 +110,11 @@ class LegacyLLMAdapter: temperature=1.5, response_format="json" ) - - # 使用增强的JSON解析器 - from webui.tools.generate_short_summary import parse_and_fix_json - parsed_result = parse_and_fix_json(result) - - if not parsed_result: - logger.error("无法解析LLM返回的JSON数据") - # 返回一个基本的JSON结构而不是错误字符串 - return json.dumps({ - "items": [ - { - "_id": 1, - "timestamp": "00:00:00-00:00:10", - "picture": "解析失败,请检查LLM输出", - "narration": "解说文案生成失败,请重试" - } - ] - }, ensure_ascii=False) - - # 确保返回的是JSON字符串 - return json.dumps(parsed_result, ensure_ascii=False) + return result if isinstance(result, str) else str(result) except Exception as e: logger.error(f"生成解说文案失败: {str(e)}") - # 返回一个基本的JSON结构而不是错误字符串 - return json.dumps({ - "items": [ - { - "_id": 1, - "timestamp": "00:00:00-00:00:10", - "picture": "生成失败", - "narration": f"解说文案生成失败: {str(e)}" - } - ] - }, ensure_ascii=False) + raise class VisionAnalyzerAdapter: @@ -198,6 +167,8 @@ class VisionAnalyzerAdapter: prompt=prompt, batch_size=batch_size, max_concurrency=max_concurrency, + api_key=self.api_key, + api_base=self.base_url, ) # 转换为旧格式以保持向后兼容性 diff --git a/app/services/llm/test_openai_compat_unittest.py b/app/services/llm/test_openai_compat_unittest.py index 68c0f47..acef31a 100644 --- a/app/services/llm/test_openai_compat_unittest.py +++ b/app/services/llm/test_openai_compat_unittest.py @@ -2,11 +2,12 @@ import asyncio import unittest +from unittest.mock import patch from app.config import config from app.services.llm.base import TextModelProvider from app.services.llm.manager import LLMServiceManager -from app.services.llm.migration_adapter import VisionAnalyzerAdapter +from app.services.llm.migration_adapter import LegacyLLMAdapter, VisionAnalyzerAdapter from app.services.llm.openai_compatible_provider import OpenAICompatibleVisionProvider from app.services.llm.providers import register_all_providers @@ -118,6 +119,7 @@ class OpenAICompatVisionConcurrencyTests(unittest.IsolatedAsyncioTestCase): class ExplicitVisionAdapterSettingsTests(unittest.IsolatedAsyncioTestCase): class _CapturingVisionProvider: last_init: tuple[str, str, str | None] | None = None + last_call_kwargs: dict | None = None def __init__(self, api_key: str, model_name: str, base_url: str | None = None): self.api_key = api_key @@ -126,6 +128,7 @@ class ExplicitVisionAdapterSettingsTests(unittest.IsolatedAsyncioTestCase): ExplicitVisionAdapterSettingsTests._CapturingVisionProvider.last_init = (api_key, model_name, base_url) async def analyze_images(self, images, prompt, batch_size=10, max_concurrency=1, **kwargs): + ExplicitVisionAdapterSettingsTests._CapturingVisionProvider.last_call_kwargs = dict(kwargs) return [f"{self.model_name}|{self.api_key}|{self.base_url}"] def setUp(self): @@ -160,8 +163,32 @@ class ExplicitVisionAdapterSettingsTests(unittest.IsolatedAsyncioTestCase): ("explicit-key", "explicit-model", "https://explicit.example/v1"), self._CapturingVisionProvider.last_init, ) + self.assertEqual("explicit-key", self._CapturingVisionProvider.last_call_kwargs["api_key"]) + self.assertEqual("https://explicit.example/v1", self._CapturingVisionProvider.last_call_kwargs["api_base"]) self.assertEqual("explicit-model|explicit-key|https://explicit.example/v1", result[0]["response"]) +class LegacyNarrationAdapterBehaviorTests(unittest.TestCase): + def test_generate_narration_returns_raw_unrecoverable_payload_without_fabrication(self): + raw_payload = "not-json-at-all ::: ???" + + with patch( + "app.services.llm.migration_adapter.PromptManager.get_prompt", + return_value="prompt", + ), patch( + "app.services.llm.migration_adapter._run_async_safely", + return_value=raw_payload, + ): + result = LegacyLLMAdapter.generate_narration( + markdown_content="markdown", + api_key="test-key", + base_url="https://example.com/v1", + model="test-model", + ) + + self.assertEqual(raw_payload, result) + self.assertNotIn('"items"', result) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_script_service_documentary_unittest.py b/tests/test_script_service_documentary_unittest.py index 2808cbe..d6789f1 100644 --- a/tests/test_script_service_documentary_unittest.py +++ b/tests/test_script_service_documentary_unittest.py @@ -260,9 +260,57 @@ class DocumentaryFrameAnalysisServiceScriptGenerationTests(unittest.IsolatedAsyn ) narration_input = mocked_generate.call_args.args[0] + self.assertIn("## 创作上下文", narration_input) self.assertIn("视频主题:野生动物纪录片", narration_input) self.assertIn("补充创作要求:重点描述危险信号", narration_input) + async def test_analyze_video_forwards_explicit_empty_base_url_without_config_fallback(self): + service = DocumentaryFrameAnalysisService() + + with patch.dict( + "app.services.documentary.frame_analysis_service.config.app", + { + "vision_llm_provider": "openai", + "vision_openai_api_key": "config-key", + "vision_openai_model_name": "config-model", + "vision_openai_base_url": "https://config.example/v1", + }, + ), patch( + "app.services.documentary.frame_analysis_service.os.path.exists", + return_value=True, + ), patch.object( + service, + "_load_or_extract_keyframes", + return_value=["/tmp/keyframe_000001_000000100.jpg"], + ), patch.object( + service, + "_analyze_batches", + AsyncMock(return_value=[]), + ), patch.object( + service, + "_save_analysis_artifact", + return_value="/tmp/frame_analysis_test.json", + ), patch.object( + service, + "_build_video_clip_json", + return_value=[], + ), patch( + "app.services.documentary.frame_analysis_service.create_vision_analyzer", + return_value=object(), + ) as mocked_create_analyzer: + await service.analyze_video( + video_path="/tmp/demo.mp4", + vision_api_key="explicit-key", + vision_model_name="explicit-model", + vision_base_url="", + ) + + called_kwargs = mocked_create_analyzer.call_args.kwargs + self.assertEqual("openai", called_kwargs["provider"]) + self.assertEqual("explicit-key", called_kwargs["api_key"]) + self.assertEqual("explicit-model", called_kwargs["model"]) + self.assertEqual("", called_kwargs["base_url"]) + if __name__ == "__main__": unittest.main()