From d678bf62b15b0d384d0666266f89df4fa231d361 Mon Sep 17 00:00:00 2001 From: linyq Date: Fri, 3 Apr 2026 02:38:54 +0800 Subject: [PATCH] fix(documentary): centralize final script generation in shared service --- .../documentary/frame_analysis_service.py | 77 ++++++++++-- app/services/script_service.py | 18 ++- ...est_script_service_documentary_unittest.py | 113 ++++++++++++++++-- webui/tools/generate_script_docu.py | 33 +---- 4 files changed, 190 insertions(+), 51 deletions(-) diff --git a/app/services/documentary/frame_analysis_service.py b/app/services/documentary/frame_analysis_service.py index 3cc4cc3..a6dc92e 100644 --- a/app/services/documentary/frame_analysis_service.py +++ b/app/services/documentary/frame_analysis_service.py @@ -9,6 +9,7 @@ from loguru import logger from app.config import config from app.services.documentary.frame_analysis_models import FrameBatchResult +from app.services.generate_narration_script import generate_narration, parse_frame_analysis_to_markdown from app.services.llm.migration_adapter import create_vision_analyzer from app.utils import utils, video_processor @@ -39,15 +40,16 @@ JSON 必须包含以下键: video_path: str, video_theme: str = "", custom_prompt: str = "", - frame_interval_input: int | float = 3, - vision_batch_size: int = 10, - vision_llm_provider: str = "openai", + frame_interval_input: int | float | None = None, + vision_batch_size: int | None = None, + vision_llm_provider: str | None = None, progress_callback: Callable[[float, str], None] | None = None, vision_api_key: str | None = None, vision_model_name: str | None = None, vision_base_url: str | None = None, max_concurrency: int | None = None, ) -> list[dict]: + progress = progress_callback or (lambda _p, _m: None) analysis_result = await self.analyze_video( video_path=video_path, video_theme=video_theme, @@ -61,7 +63,31 @@ JSON 必须包含以下键: vision_base_url=vision_base_url, max_concurrency=max_concurrency, ) - return analysis_result["video_clip_json"] + analysis_json_path = analysis_result["analysis_json_path"] + + progress(80, "正在生成解说文案...") + text_provider = config.app.get("text_llm_provider", "openai").lower() + text_api_key = config.app.get(f"text_{text_provider}_api_key") + text_model = config.app.get(f"text_{text_provider}_model_name") + text_base_url = config.app.get(f"text_{text_provider}_base_url") + if not text_api_key or not text_model: + raise ValueError( + f"未配置 {text_provider} 的文本模型参数。" + f"请在设置中配置 text_{text_provider}_api_key 和 text_{text_provider}_model_name" + ) + + markdown_output = parse_frame_analysis_to_markdown(analysis_json_path) + narration_raw = generate_narration( + markdown_output, + text_api_key, + base_url=text_base_url, + model=text_model, + ) + narration_items = self._parse_narration_items(narration_raw) + + final_script = [{**item, "OST": 2} for item in narration_items] + progress(100, "脚本生成完成") + return final_script async def analyze_video( self, @@ -69,9 +95,9 @@ JSON 必须包含以下键: video_path: str, video_theme: str = "", custom_prompt: str = "", - frame_interval_input: int | float = 3, - vision_batch_size: int = 10, - vision_llm_provider: str = "openai", + frame_interval_input: int | float | None = None, + vision_batch_size: int | None = None, + vision_llm_provider: str | None = None, progress_callback: Callable[[float, str], None] | None = None, vision_api_key: str | None = None, vision_model_name: str | None = None, @@ -145,6 +171,43 @@ JSON 必须包含以下键: "keyframe_files": keyframe_files, } + def _parse_narration_items(self, narration_raw: str) -> list[dict[str, Any]]: + def load_json_candidate(payload: str) -> dict[str, Any] | None: + try: + return json.loads(payload) + except Exception: + return None + + cleaned = (narration_raw or "").strip() + parsed = load_json_candidate(cleaned) + if parsed is None: + parsed = load_json_candidate(cleaned.replace("```json", "").replace("```", "").strip()) + if parsed is None: + start = cleaned.find("{") + end = cleaned.rfind("}") + if start >= 0 and end > start: + parsed = load_json_candidate(cleaned[start : end + 1]) + + items = [] + if isinstance(parsed, dict): + raw_items = parsed.get("items") + if isinstance(raw_items, list): + items = [item for item in raw_items if isinstance(item, dict)] + + if items: + return items + + fallback_text = (cleaned[:200] + "...") if len(cleaned) > 200 else cleaned + if not fallback_text: + fallback_text = "解说文案解析失败,请重试。" + return [ + { + "timestamp": "00:00:00,000-00:00:10,000", + "picture": "解析失败,使用默认内容", + "narration": fallback_text, + } + ] + def _resolve_frame_interval(self, frame_interval_input: int | float | None) -> float: interval = frame_interval_input if interval in (None, ""): diff --git a/app/services/script_service.py b/app/services/script_service.py index 61c36a7..47c329c 100644 --- a/app/services/script_service.py +++ b/app/services/script_service.py @@ -1,5 +1,7 @@ from typing import Any, Callable +from loguru import logger + from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService @@ -12,14 +14,21 @@ class ScriptGenerator: video_path: str, video_theme: str = "", custom_prompt: str = "", - frame_interval_input: int = 5, + frame_interval_input: int | None = None, skip_seconds: int = 0, threshold: int = 30, - vision_batch_size: int = 5, - vision_llm_provider: str = "gemini", + vision_batch_size: int | None = None, + vision_llm_provider: str | None = None, progress_callback: Callable[[float, str], None] | None = None, ) -> list[dict[Any, Any]]: callback = progress_callback or (lambda _p, _m: None) + if skip_seconds != 0 or threshold != 30: + logger.warning( + "ScriptGenerator documentary path received " + f"skip_seconds={skip_seconds} threshold={threshold}; " + "the shared documentary frame pipeline does not currently apply these parameters." + ) + return await self.documentary_service.generate_documentary_script( video_path=video_path, video_theme=video_theme, @@ -28,7 +37,4 @@ class ScriptGenerator: vision_batch_size=vision_batch_size, vision_llm_provider=vision_llm_provider, progress_callback=callback, - # 历史参数保留在签名中以兼容调用方;共享逐帧分析当前不使用这两个参数。 - # skip_seconds=skip_seconds, - # threshold=threshold, ) diff --git a/tests/test_script_service_documentary_unittest.py b/tests/test_script_service_documentary_unittest.py index d1fcd70..d39b572 100644 --- a/tests/test_script_service_documentary_unittest.py +++ b/tests/test_script_service_documentary_unittest.py @@ -1,23 +1,24 @@ +import json import unittest +from pathlib import Path +from tempfile import TemporaryDirectory from unittest.mock import AsyncMock, patch +from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService from app.services.script_service import ScriptGenerator class ScriptGeneratorDocumentaryTests(unittest.IsolatedAsyncioTestCase): - async def test_generate_script_passes_frame_interval_to_shared_service(self): + async def test_generate_script_forwards_explicit_values_to_shared_service(self): expected_script = [ { "timestamp": "00:00:00,000-00:00:03,000", "picture": "批次描述", - "narration": "", + "narration": "这里是解说词", "OST": 2, } ] - progress = [] - - def progress_callback(percent, message): - progress.append((percent, message)) + callback = lambda _percent, _message: None with patch("app.services.script_service.DocumentaryFrameAnalysisService") as service_cls: service = service_cls.return_value @@ -31,10 +32,11 @@ class ScriptGeneratorDocumentaryTests(unittest.IsolatedAsyncioTestCase): frame_interval_input=3, vision_batch_size=6, vision_llm_provider="openai", - progress_callback=progress_callback, + progress_callback=callback, ) self.assertEqual(expected_script, result) + self.assertTrue(result[0]["narration"]) service.generate_documentary_script.assert_awaited_once() called_kwargs = service.generate_documentary_script.await_args.kwargs self.assertEqual("demo.mp4", called_kwargs["video_path"]) @@ -43,8 +45,101 @@ class ScriptGeneratorDocumentaryTests(unittest.IsolatedAsyncioTestCase): self.assertEqual("openai", called_kwargs["vision_llm_provider"]) self.assertEqual("荒野生存", called_kwargs["video_theme"]) self.assertEqual("请聚焦生存动作", called_kwargs["custom_prompt"]) - self.assertIs(called_kwargs["progress_callback"], progress_callback) - self.assertEqual([], progress) + self.assertIs(called_kwargs["progress_callback"], callback) + + async def test_generate_script_forwards_unset_values_as_none(self): + expected_script = [ + { + "timestamp": "00:00:00,000-00:00:03,000", + "picture": "批次描述", + "narration": "这里是解说词", + "OST": 2, + } + ] + with patch("app.services.script_service.DocumentaryFrameAnalysisService") as service_cls: + service = service_cls.return_value + service.generate_documentary_script = AsyncMock(return_value=expected_script) + generator = ScriptGenerator() + + await generator.generate_script(video_path="demo.mp4") + + called_kwargs = service.generate_documentary_script.await_args.kwargs + self.assertIsNone(called_kwargs["frame_interval_input"]) + self.assertIsNone(called_kwargs["vision_batch_size"]) + self.assertIsNone(called_kwargs["vision_llm_provider"]) + + async def test_generate_script_warns_when_skip_seconds_or_threshold_are_non_default(self): + expected_script = [ + { + "timestamp": "00:00:00,000-00:00:03,000", + "picture": "批次描述", + "narration": "这里是解说词", + "OST": 2, + } + ] + with patch("app.services.script_service.DocumentaryFrameAnalysisService") as service_cls, patch( + "app.services.script_service.logger.warning" + ) as warning: + service = service_cls.return_value + service.generate_documentary_script = AsyncMock(return_value=expected_script) + generator = ScriptGenerator() + await generator.generate_script( + video_path="demo.mp4", + skip_seconds=2, + threshold=20, + ) + + warning.assert_called_once() + warning_message = warning.call_args.args[0] + self.assertIn("skip_seconds", warning_message) + self.assertIn("threshold", warning_message) + self.assertIn("does not currently apply", warning_message) + + +class DocumentaryFrameAnalysisServiceScriptGenerationTests(unittest.IsolatedAsyncioTestCase): + async def test_generate_documentary_script_returns_final_narrated_items(self): + service = DocumentaryFrameAnalysisService() + analysis_payload = { + "batches": [ + { + "batch_index": 0, + "time_range": "00:00:00,000-00:00:03,000", + "overall_activity_summary": "", + "fallback_summary": "回退摘要", + "frame_observations": [ + {"timestamp": "00:00:00,000", "observation": "镜头里有一只猫"}, + ], + } + ] + } + + with TemporaryDirectory() as temp_dir: + analysis_path = Path(temp_dir) / "frame_analysis_test.json" + analysis_path.write_text(json.dumps(analysis_payload, ensure_ascii=False), encoding="utf-8") + + with patch.object( + DocumentaryFrameAnalysisService, + "analyze_video", + AsyncMock(return_value={"analysis_json_path": str(analysis_path)}), + ), patch.dict( + "app.services.documentary.frame_analysis_service.config.app", + { + "text_llm_provider": "openai", + "text_openai_api_key": "test-key", + "text_openai_model_name": "test-model", + "text_openai_base_url": "https://example.com/v1", + }, + ), patch( + "app.services.documentary.frame_analysis_service.generate_narration", + return_value='{"items":[{"timestamp":"00:00:00,000-00:00:03,000","picture":"镜头里有一只猫","narration":"一只猫警觉地望向镜头。"}]}', + ): + result = await service.generate_documentary_script(video_path="demo.mp4") + + self.assertEqual(1, len(result)) + self.assertEqual("00:00:00,000-00:00:03,000", result[0]["timestamp"]) + self.assertEqual("镜头里有一只猫", result[0]["picture"]) + self.assertEqual("一只猫警觉地望向镜头。", result[0]["narration"]) + self.assertEqual(2, result[0]["OST"]) if __name__ == "__main__": diff --git a/webui/tools/generate_script_docu.py b/webui/tools/generate_script_docu.py index 18fba78..77c322d 100644 --- a/webui/tools/generate_script_docu.py +++ b/webui/tools/generate_script_docu.py @@ -9,8 +9,6 @@ from loguru import logger from app.config import config from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService -from app.services.generate_narration_script import generate_narration, parse_frame_analysis_to_markdown -from webui.tools.generate_short_summary import parse_and_fix_json def generate_script_docu(params): @@ -66,8 +64,8 @@ def generate_script_docu(params): update_progress(10, "正在提取关键帧...") service = DocumentaryFrameAnalysisService() - analysis_result = asyncio.run( - service.analyze_video( + script_items = asyncio.run( + service.generate_documentary_script( video_path=params.video_origin_path, video_theme=st.session_state.get("video_theme", ""), custom_prompt=st.session_state.get("custom_prompt", ""), @@ -82,31 +80,8 @@ def generate_script_docu(params): ) ) - analysis_json_path = analysis_result["analysis_json_path"] - update_progress(80, "正在生成解说文案...") - - text_provider = config.app.get("text_llm_provider", "gemini").lower() - text_api_key = config.app.get(f"text_{text_provider}_api_key") - text_model = config.app.get(f"text_{text_provider}_model_name") - text_base_url = config.app.get(f"text_{text_provider}_base_url") - - markdown_output = parse_frame_analysis_to_markdown(analysis_json_path) - narration = generate_narration( - markdown_output, - text_api_key, - base_url=text_base_url, - model=text_model, - ) - narration_data = parse_and_fix_json(narration) - - if not narration_data or "items" not in narration_data: - logger.error(f"解说文案JSON解析失败,原始内容: {narration[:200]}...") - raise Exception("解说文案格式错误,无法解析JSON或缺少items字段") - - narration_dict = [{**item, "OST": 2} for item in narration_data["items"]] - script = json.dumps(narration_dict, ensure_ascii=False, indent=2) - - logger.info(f"纪录片解说脚本生成完成,共 {len(narration_dict)} 个片段") + logger.info(f"纪录片解说脚本生成完成,共 {len(script_items)} 个片段") + script = json.dumps(script_items, ensure_ascii=False, indent=2) if isinstance(script, list): st.session_state["video_clip_json"] = script elif isinstance(script, str):