From 4e2560651f958a306f4ba1f184be7d500960d6de Mon Sep 17 00:00:00 2001 From: linyq Date: Fri, 3 Apr 2026 11:29:27 +0800 Subject: [PATCH] fix(documentary): restore narration repair and explicit vision overrides --- .../documentary/frame_analysis_service.py | 81 +++++++++++++++---- app/services/llm/migration_adapter.py | 23 +++++- .../llm/test_openai_compat_unittest.py | 49 +++++++++++ ...est_script_service_documentary_unittest.py | 80 ++++++++++++++++++ 4 files changed, 214 insertions(+), 19 deletions(-) diff --git a/app/services/documentary/frame_analysis_service.py b/app/services/documentary/frame_analysis_service.py index d4d36ee..e4278f7 100644 --- a/app/services/documentary/frame_analysis_service.py +++ b/app/services/documentary/frame_analysis_service.py @@ -77,8 +77,13 @@ JSON 必须包含以下键: ) markdown_output = parse_frame_analysis_to_markdown(analysis_json_path) + narration_input = self._build_narration_input( + markdown_output=markdown_output, + video_theme=video_theme, + custom_prompt=custom_prompt, + ) narration_raw = generate_narration( - markdown_output, + narration_input, text_api_key, base_url=text_base_url, model=text_model, @@ -172,21 +177,7 @@ JSON 必须包含以下键: } def _parse_narration_items(self, narration_raw: str) -> list[dict[str, Any]]: - def load_json_candidate(payload: str) -> dict[str, Any] | None: - try: - return json.loads(payload) - except Exception: - return None - - cleaned = (narration_raw or "").strip() - parsed = load_json_candidate(cleaned) - if parsed is None: - parsed = load_json_candidate(cleaned.replace("```json", "").replace("```", "").strip()) - if parsed is None: - start = cleaned.find("{") - end = cleaned.rfind("}") - if start >= 0 and end > start: - parsed = load_json_candidate(cleaned[start : end + 1]) + parsed = self._repair_narration_payload(narration_raw) items: list[dict[str, Any]] = [] if isinstance(parsed, dict): @@ -199,6 +190,64 @@ JSON 必须包含以下键: return items + def _build_narration_input(self, *, markdown_output: str, video_theme: str, custom_prompt: str) -> str: + context_lines: list[str] = [] + if (video_theme or "").strip(): + context_lines.append(f"视频主题:{video_theme.strip()}") + if (custom_prompt or "").strip(): + context_lines.append(f"补充创作要求:{custom_prompt.strip()}") + + if not context_lines: + return markdown_output + + context_block = "\n".join(f"- {line}" for line in context_lines) + return f"{markdown_output.rstrip()}\n\n## 创作上下文\n{context_block}\n" + + def _repair_narration_payload(self, narration_raw: str) -> dict[str, Any] | None: + def load_json_candidate(payload: str) -> dict[str, Any] | None: + try: + parsed = json.loads(payload) + return parsed if isinstance(parsed, dict) else None + except Exception: + return None + + cleaned = (narration_raw or "").strip() + if not cleaned: + return None + + candidates: list[str] = [cleaned] + candidates.append(cleaned.replace("{{", "{").replace("}}", "}")) + + json_block = re.search(r"```json\s*(.*?)\s*```", cleaned, re.DOTALL) + if json_block: + candidates.append(json_block.group(1).strip()) + + start = cleaned.find("{") + end = cleaned.rfind("}") + if start >= 0 and end > start: + candidates.append(cleaned[start : end + 1]) + + for candidate in candidates: + parsed = load_json_candidate(candidate) + if parsed is not None: + return parsed + + fixed = cleaned.replace("{{", "{").replace("}}", "}") + fixed_start = fixed.find("{") + fixed_end = fixed.rfind("}") + if fixed_start >= 0 and fixed_end > fixed_start: + fixed = fixed[fixed_start : fixed_end + 1] + + fixed = re.sub(r"^\s*#.*$", "", fixed, flags=re.MULTILINE) + fixed = re.sub(r"^\s*//.*$", "", fixed, flags=re.MULTILINE) + fixed = re.sub(r",\s*}", "}", fixed) + fixed = re.sub(r",\s*]", "]", fixed) + fixed = re.sub(r"'([^']*)'\s*:", r'"\1":', fixed) + fixed = re.sub(r'([{\[,]\s*)([A-Za-z_][\w\u4e00-\u9fff]*)(\s*:)', r'\1"\2"\3', fixed) + fixed = re.sub(r'""([^"]*?)""', r'"\1"', fixed) + + return load_json_candidate(fixed) + def _resolve_frame_interval(self, frame_interval_input: int | float | None) -> float: interval = frame_interval_input if interval in (None, ""): diff --git a/app/services/llm/migration_adapter.py b/app/services/llm/migration_adapter.py index 68615a8..4107488 100644 --- a/app/services/llm/migration_adapter.py +++ b/app/services/llm/migration_adapter.py @@ -13,6 +13,7 @@ from loguru import logger from .unified_service import UnifiedLLMService from .exceptions import LLMServiceError +from .manager import LLMServiceManager # 导入新的提示词管理系统 from app.services.prompts import PromptManager @@ -155,6 +156,23 @@ class VisionAnalyzerAdapter: self.api_key = api_key self.model = model self.base_url = base_url + + def _build_provider_with_explicit_settings(self): + provider_name = (self.provider or "").lower() + if not LLMServiceManager.is_registered(): + from .providers import register_all_providers + + register_all_providers() + + provider_class = LLMServiceManager._vision_providers.get(provider_name) + if provider_class is None: + raise LLMServiceError(f"视觉模型提供商未注册: {provider_name}") + + return provider_class( + api_key=self.api_key, + model_name=self.model, + base_url=self.base_url, + ) async def analyze_images(self, images: List[Union[str, Path, PIL.Image.Image]], @@ -174,11 +192,10 @@ class VisionAnalyzerAdapter: 分析结果列表,格式与旧实现兼容 """ try: - # 使用统一服务分析图片 - results = await UnifiedLLMService.analyze_images( + provider = self._build_provider_with_explicit_settings() + results = await provider.analyze_images( images=images, prompt=prompt, - provider=self.provider, batch_size=batch_size, max_concurrency=max_concurrency, ) diff --git a/app/services/llm/test_openai_compat_unittest.py b/app/services/llm/test_openai_compat_unittest.py index 8393ee3..68c0f47 100644 --- a/app/services/llm/test_openai_compat_unittest.py +++ b/app/services/llm/test_openai_compat_unittest.py @@ -6,6 +6,7 @@ import unittest from app.config import config from app.services.llm.base import TextModelProvider from app.services.llm.manager import LLMServiceManager +from app.services.llm.migration_adapter import VisionAnalyzerAdapter from app.services.llm.openai_compatible_provider import OpenAICompatibleVisionProvider from app.services.llm.providers import register_all_providers @@ -114,5 +115,53 @@ class OpenAICompatVisionConcurrencyTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(2, max_in_flight) +class ExplicitVisionAdapterSettingsTests(unittest.IsolatedAsyncioTestCase): + class _CapturingVisionProvider: + last_init: tuple[str, str, str | None] | None = None + + def __init__(self, api_key: str, model_name: str, base_url: str | None = None): + self.api_key = api_key + self.model_name = model_name + self.base_url = base_url + ExplicitVisionAdapterSettingsTests._CapturingVisionProvider.last_init = (api_key, model_name, base_url) + + async def analyze_images(self, images, prompt, batch_size=10, max_concurrency=1, **kwargs): + return [f"{self.model_name}|{self.api_key}|{self.base_url}"] + + def setUp(self): + _reset_manager_state() + self._original_app = dict(config.app) + + def tearDown(self): + _reset_manager_state() + config.app.clear() + config.app.update(self._original_app) + + async def test_adapter_uses_explicit_settings_instead_of_global_config(self): + LLMServiceManager.register_vision_provider("openai", self._CapturingVisionProvider) + config.app["vision_openai_api_key"] = "config-key" + config.app["vision_openai_model_name"] = "config-model" + config.app["vision_openai_base_url"] = "https://config.example/v1" + + adapter = VisionAnalyzerAdapter( + provider="openai", + api_key="explicit-key", + model="explicit-model", + base_url="https://explicit.example/v1", + ) + result = await adapter.analyze_images( + images=["/tmp/keyframe_000001_000000100.jpg"], + prompt="描述画面", + batch_size=1, + max_concurrency=1, + ) + + self.assertEqual( + ("explicit-key", "explicit-model", "https://explicit.example/v1"), + self._CapturingVisionProvider.last_init, + ) + self.assertEqual("explicit-model|explicit-key|https://explicit.example/v1", result[0]["response"]) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_script_service_documentary_unittest.py b/tests/test_script_service_documentary_unittest.py index 7b63b77..2808cbe 100644 --- a/tests/test_script_service_documentary_unittest.py +++ b/tests/test_script_service_documentary_unittest.py @@ -183,6 +183,86 @@ class DocumentaryFrameAnalysisServiceScriptGenerationTests(unittest.IsolatedAsyn self.assertIn("解说文案格式错误", str(ctx.exception)) self.assertIn("items", str(ctx.exception)) + def test_parse_narration_items_recovers_from_common_json_damage(self): + service = DocumentaryFrameAnalysisService() + damaged_payload = """ +解释文字 +```json +{{ + "items": [ + {{ + "timestamp": "00:00:00,000-00:00:03,000", + "picture": "镜头里有一只猫", + "narration": "一只猫警觉地望向镜头。", + }}, + ], +}} +``` +补充文字 +""".strip() + + parsed_items = service._parse_narration_items(damaged_payload) + + self.assertEqual(1, len(parsed_items)) + self.assertEqual("00:00:00,000-00:00:03,000", parsed_items[0]["timestamp"]) + self.assertEqual("镜头里有一只猫", parsed_items[0]["picture"]) + self.assertEqual("一只猫警觉地望向镜头。", parsed_items[0]["narration"]) + + def test_parse_narration_items_raises_for_unrecoverable_payload(self): + service = DocumentaryFrameAnalysisService() + + with self.assertRaises(ValueError) as ctx: + service._parse_narration_items("not-json-at-all ::: ???") + + self.assertIn("解说文案格式错误", str(ctx.exception)) + self.assertIn("items", str(ctx.exception)) + + async def test_generate_documentary_script_includes_theme_and_custom_prompt_for_narration(self): + service = DocumentaryFrameAnalysisService() + analysis_payload = { + "batches": [ + { + "batch_index": 0, + "time_range": "00:00:00,000-00:00:03,000", + "overall_activity_summary": "测试摘要", + "fallback_summary": "", + "frame_observations": [ + {"timestamp": "00:00:00,000", "observation": "镜头里有一只猫"}, + ], + } + ] + } + + with TemporaryDirectory() as temp_dir: + analysis_path = Path(temp_dir) / "frame_analysis_test.json" + analysis_path.write_text(json.dumps(analysis_payload, ensure_ascii=False), encoding="utf-8") + + with patch.object( + DocumentaryFrameAnalysisService, + "analyze_video", + AsyncMock(return_value={"analysis_json_path": str(analysis_path)}), + ), patch.dict( + "app.services.documentary.frame_analysis_service.config.app", + { + "text_llm_provider": "openai", + "text_openai_api_key": "test-key", + "text_openai_model_name": "test-model", + "text_openai_base_url": "https://example.com/v1", + }, + ), patch( + "app.services.documentary.frame_analysis_service.generate_narration", + return_value='{"items":[{"timestamp":"00:00:00,000-00:00:03,000","picture":"镜头里有一只猫","narration":"一只猫警觉地望向镜头。"}]}', + ) as mocked_generate: + await service.generate_documentary_script( + video_path="demo.mp4", + video_theme="野生动物纪录片", + custom_prompt="重点描述危险信号", + ) + + narration_input = mocked_generate.call_args.args[0] + self.assertIn("视频主题:野生动物纪录片", narration_input) + self.assertIn("补充创作要求:重点描述危险信号", narration_input) + if __name__ == "__main__": unittest.main()