mirror of
https://github.com/linyqh/NarratoAI.git
synced 2026-05-01 14:18:19 +00:00
fix(documentary): centralize final script generation in shared service
This commit is contained in:
parent
ac63fea953
commit
d678bf62b1
@ -9,6 +9,7 @@ from loguru import logger
|
||||
|
||||
from app.config import config
|
||||
from app.services.documentary.frame_analysis_models import FrameBatchResult
|
||||
from app.services.generate_narration_script import generate_narration, parse_frame_analysis_to_markdown
|
||||
from app.services.llm.migration_adapter import create_vision_analyzer
|
||||
from app.utils import utils, video_processor
|
||||
|
||||
@ -39,15 +40,16 @@ JSON 必须包含以下键:
|
||||
video_path: str,
|
||||
video_theme: str = "",
|
||||
custom_prompt: str = "",
|
||||
frame_interval_input: int | float = 3,
|
||||
vision_batch_size: int = 10,
|
||||
vision_llm_provider: str = "openai",
|
||||
frame_interval_input: int | float | None = None,
|
||||
vision_batch_size: int | None = None,
|
||||
vision_llm_provider: str | None = None,
|
||||
progress_callback: Callable[[float, str], None] | None = None,
|
||||
vision_api_key: str | None = None,
|
||||
vision_model_name: str | None = None,
|
||||
vision_base_url: str | None = None,
|
||||
max_concurrency: int | None = None,
|
||||
) -> list[dict]:
|
||||
progress = progress_callback or (lambda _p, _m: None)
|
||||
analysis_result = await self.analyze_video(
|
||||
video_path=video_path,
|
||||
video_theme=video_theme,
|
||||
@ -61,7 +63,31 @@ JSON 必须包含以下键:
|
||||
vision_base_url=vision_base_url,
|
||||
max_concurrency=max_concurrency,
|
||||
)
|
||||
return analysis_result["video_clip_json"]
|
||||
analysis_json_path = analysis_result["analysis_json_path"]
|
||||
|
||||
progress(80, "正在生成解说文案...")
|
||||
text_provider = config.app.get("text_llm_provider", "openai").lower()
|
||||
text_api_key = config.app.get(f"text_{text_provider}_api_key")
|
||||
text_model = config.app.get(f"text_{text_provider}_model_name")
|
||||
text_base_url = config.app.get(f"text_{text_provider}_base_url")
|
||||
if not text_api_key or not text_model:
|
||||
raise ValueError(
|
||||
f"未配置 {text_provider} 的文本模型参数。"
|
||||
f"请在设置中配置 text_{text_provider}_api_key 和 text_{text_provider}_model_name"
|
||||
)
|
||||
|
||||
markdown_output = parse_frame_analysis_to_markdown(analysis_json_path)
|
||||
narration_raw = generate_narration(
|
||||
markdown_output,
|
||||
text_api_key,
|
||||
base_url=text_base_url,
|
||||
model=text_model,
|
||||
)
|
||||
narration_items = self._parse_narration_items(narration_raw)
|
||||
|
||||
final_script = [{**item, "OST": 2} for item in narration_items]
|
||||
progress(100, "脚本生成完成")
|
||||
return final_script
|
||||
|
||||
async def analyze_video(
|
||||
self,
|
||||
@ -69,9 +95,9 @@ JSON 必须包含以下键:
|
||||
video_path: str,
|
||||
video_theme: str = "",
|
||||
custom_prompt: str = "",
|
||||
frame_interval_input: int | float = 3,
|
||||
vision_batch_size: int = 10,
|
||||
vision_llm_provider: str = "openai",
|
||||
frame_interval_input: int | float | None = None,
|
||||
vision_batch_size: int | None = None,
|
||||
vision_llm_provider: str | None = None,
|
||||
progress_callback: Callable[[float, str], None] | None = None,
|
||||
vision_api_key: str | None = None,
|
||||
vision_model_name: str | None = None,
|
||||
@ -145,6 +171,43 @@ JSON 必须包含以下键:
|
||||
"keyframe_files": keyframe_files,
|
||||
}
|
||||
|
||||
def _parse_narration_items(self, narration_raw: str) -> list[dict[str, Any]]:
|
||||
def load_json_candidate(payload: str) -> dict[str, Any] | None:
|
||||
try:
|
||||
return json.loads(payload)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
cleaned = (narration_raw or "").strip()
|
||||
parsed = load_json_candidate(cleaned)
|
||||
if parsed is None:
|
||||
parsed = load_json_candidate(cleaned.replace("```json", "").replace("```", "").strip())
|
||||
if parsed is None:
|
||||
start = cleaned.find("{")
|
||||
end = cleaned.rfind("}")
|
||||
if start >= 0 and end > start:
|
||||
parsed = load_json_candidate(cleaned[start : end + 1])
|
||||
|
||||
items = []
|
||||
if isinstance(parsed, dict):
|
||||
raw_items = parsed.get("items")
|
||||
if isinstance(raw_items, list):
|
||||
items = [item for item in raw_items if isinstance(item, dict)]
|
||||
|
||||
if items:
|
||||
return items
|
||||
|
||||
fallback_text = (cleaned[:200] + "...") if len(cleaned) > 200 else cleaned
|
||||
if not fallback_text:
|
||||
fallback_text = "解说文案解析失败,请重试。"
|
||||
return [
|
||||
{
|
||||
"timestamp": "00:00:00,000-00:00:10,000",
|
||||
"picture": "解析失败,使用默认内容",
|
||||
"narration": fallback_text,
|
||||
}
|
||||
]
|
||||
|
||||
def _resolve_frame_interval(self, frame_interval_input: int | float | None) -> float:
|
||||
interval = frame_interval_input
|
||||
if interval in (None, ""):
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from typing import Any, Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService
|
||||
|
||||
|
||||
@ -12,14 +14,21 @@ class ScriptGenerator:
|
||||
video_path: str,
|
||||
video_theme: str = "",
|
||||
custom_prompt: str = "",
|
||||
frame_interval_input: int = 5,
|
||||
frame_interval_input: int | None = None,
|
||||
skip_seconds: int = 0,
|
||||
threshold: int = 30,
|
||||
vision_batch_size: int = 5,
|
||||
vision_llm_provider: str = "gemini",
|
||||
vision_batch_size: int | None = None,
|
||||
vision_llm_provider: str | None = None,
|
||||
progress_callback: Callable[[float, str], None] | None = None,
|
||||
) -> list[dict[Any, Any]]:
|
||||
callback = progress_callback or (lambda _p, _m: None)
|
||||
if skip_seconds != 0 or threshold != 30:
|
||||
logger.warning(
|
||||
"ScriptGenerator documentary path received "
|
||||
f"skip_seconds={skip_seconds} threshold={threshold}; "
|
||||
"the shared documentary frame pipeline does not currently apply these parameters."
|
||||
)
|
||||
|
||||
return await self.documentary_service.generate_documentary_script(
|
||||
video_path=video_path,
|
||||
video_theme=video_theme,
|
||||
@ -28,7 +37,4 @@ class ScriptGenerator:
|
||||
vision_batch_size=vision_batch_size,
|
||||
vision_llm_provider=vision_llm_provider,
|
||||
progress_callback=callback,
|
||||
# 历史参数保留在签名中以兼容调用方;共享逐帧分析当前不使用这两个参数。
|
||||
# skip_seconds=skip_seconds,
|
||||
# threshold=threshold,
|
||||
)
|
||||
|
||||
@ -1,23 +1,24 @@
|
||||
import json
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService
|
||||
from app.services.script_service import ScriptGenerator
|
||||
|
||||
|
||||
class ScriptGeneratorDocumentaryTests(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_generate_script_passes_frame_interval_to_shared_service(self):
|
||||
async def test_generate_script_forwards_explicit_values_to_shared_service(self):
|
||||
expected_script = [
|
||||
{
|
||||
"timestamp": "00:00:00,000-00:00:03,000",
|
||||
"picture": "批次描述",
|
||||
"narration": "",
|
||||
"narration": "这里是解说词",
|
||||
"OST": 2,
|
||||
}
|
||||
]
|
||||
progress = []
|
||||
|
||||
def progress_callback(percent, message):
|
||||
progress.append((percent, message))
|
||||
callback = lambda _percent, _message: None
|
||||
|
||||
with patch("app.services.script_service.DocumentaryFrameAnalysisService") as service_cls:
|
||||
service = service_cls.return_value
|
||||
@ -31,10 +32,11 @@ class ScriptGeneratorDocumentaryTests(unittest.IsolatedAsyncioTestCase):
|
||||
frame_interval_input=3,
|
||||
vision_batch_size=6,
|
||||
vision_llm_provider="openai",
|
||||
progress_callback=progress_callback,
|
||||
progress_callback=callback,
|
||||
)
|
||||
|
||||
self.assertEqual(expected_script, result)
|
||||
self.assertTrue(result[0]["narration"])
|
||||
service.generate_documentary_script.assert_awaited_once()
|
||||
called_kwargs = service.generate_documentary_script.await_args.kwargs
|
||||
self.assertEqual("demo.mp4", called_kwargs["video_path"])
|
||||
@ -43,8 +45,101 @@ class ScriptGeneratorDocumentaryTests(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual("openai", called_kwargs["vision_llm_provider"])
|
||||
self.assertEqual("荒野生存", called_kwargs["video_theme"])
|
||||
self.assertEqual("请聚焦生存动作", called_kwargs["custom_prompt"])
|
||||
self.assertIs(called_kwargs["progress_callback"], progress_callback)
|
||||
self.assertEqual([], progress)
|
||||
self.assertIs(called_kwargs["progress_callback"], callback)
|
||||
|
||||
async def test_generate_script_forwards_unset_values_as_none(self):
|
||||
expected_script = [
|
||||
{
|
||||
"timestamp": "00:00:00,000-00:00:03,000",
|
||||
"picture": "批次描述",
|
||||
"narration": "这里是解说词",
|
||||
"OST": 2,
|
||||
}
|
||||
]
|
||||
with patch("app.services.script_service.DocumentaryFrameAnalysisService") as service_cls:
|
||||
service = service_cls.return_value
|
||||
service.generate_documentary_script = AsyncMock(return_value=expected_script)
|
||||
generator = ScriptGenerator()
|
||||
|
||||
await generator.generate_script(video_path="demo.mp4")
|
||||
|
||||
called_kwargs = service.generate_documentary_script.await_args.kwargs
|
||||
self.assertIsNone(called_kwargs["frame_interval_input"])
|
||||
self.assertIsNone(called_kwargs["vision_batch_size"])
|
||||
self.assertIsNone(called_kwargs["vision_llm_provider"])
|
||||
|
||||
async def test_generate_script_warns_when_skip_seconds_or_threshold_are_non_default(self):
|
||||
expected_script = [
|
||||
{
|
||||
"timestamp": "00:00:00,000-00:00:03,000",
|
||||
"picture": "批次描述",
|
||||
"narration": "这里是解说词",
|
||||
"OST": 2,
|
||||
}
|
||||
]
|
||||
with patch("app.services.script_service.DocumentaryFrameAnalysisService") as service_cls, patch(
|
||||
"app.services.script_service.logger.warning"
|
||||
) as warning:
|
||||
service = service_cls.return_value
|
||||
service.generate_documentary_script = AsyncMock(return_value=expected_script)
|
||||
generator = ScriptGenerator()
|
||||
await generator.generate_script(
|
||||
video_path="demo.mp4",
|
||||
skip_seconds=2,
|
||||
threshold=20,
|
||||
)
|
||||
|
||||
warning.assert_called_once()
|
||||
warning_message = warning.call_args.args[0]
|
||||
self.assertIn("skip_seconds", warning_message)
|
||||
self.assertIn("threshold", warning_message)
|
||||
self.assertIn("does not currently apply", warning_message)
|
||||
|
||||
|
||||
class DocumentaryFrameAnalysisServiceScriptGenerationTests(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_generate_documentary_script_returns_final_narrated_items(self):
|
||||
service = DocumentaryFrameAnalysisService()
|
||||
analysis_payload = {
|
||||
"batches": [
|
||||
{
|
||||
"batch_index": 0,
|
||||
"time_range": "00:00:00,000-00:00:03,000",
|
||||
"overall_activity_summary": "",
|
||||
"fallback_summary": "回退摘要",
|
||||
"frame_observations": [
|
||||
{"timestamp": "00:00:00,000", "observation": "镜头里有一只猫"},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
analysis_path = Path(temp_dir) / "frame_analysis_test.json"
|
||||
analysis_path.write_text(json.dumps(analysis_payload, ensure_ascii=False), encoding="utf-8")
|
||||
|
||||
with patch.object(
|
||||
DocumentaryFrameAnalysisService,
|
||||
"analyze_video",
|
||||
AsyncMock(return_value={"analysis_json_path": str(analysis_path)}),
|
||||
), patch.dict(
|
||||
"app.services.documentary.frame_analysis_service.config.app",
|
||||
{
|
||||
"text_llm_provider": "openai",
|
||||
"text_openai_api_key": "test-key",
|
||||
"text_openai_model_name": "test-model",
|
||||
"text_openai_base_url": "https://example.com/v1",
|
||||
},
|
||||
), patch(
|
||||
"app.services.documentary.frame_analysis_service.generate_narration",
|
||||
return_value='{"items":[{"timestamp":"00:00:00,000-00:00:03,000","picture":"镜头里有一只猫","narration":"一只猫警觉地望向镜头。"}]}',
|
||||
):
|
||||
result = await service.generate_documentary_script(video_path="demo.mp4")
|
||||
|
||||
self.assertEqual(1, len(result))
|
||||
self.assertEqual("00:00:00,000-00:00:03,000", result[0]["timestamp"])
|
||||
self.assertEqual("镜头里有一只猫", result[0]["picture"])
|
||||
self.assertEqual("一只猫警觉地望向镜头。", result[0]["narration"])
|
||||
self.assertEqual(2, result[0]["OST"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -9,8 +9,6 @@ from loguru import logger
|
||||
|
||||
from app.config import config
|
||||
from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService
|
||||
from app.services.generate_narration_script import generate_narration, parse_frame_analysis_to_markdown
|
||||
from webui.tools.generate_short_summary import parse_and_fix_json
|
||||
|
||||
|
||||
def generate_script_docu(params):
|
||||
@ -66,8 +64,8 @@ def generate_script_docu(params):
|
||||
|
||||
update_progress(10, "正在提取关键帧...")
|
||||
service = DocumentaryFrameAnalysisService()
|
||||
analysis_result = asyncio.run(
|
||||
service.analyze_video(
|
||||
script_items = asyncio.run(
|
||||
service.generate_documentary_script(
|
||||
video_path=params.video_origin_path,
|
||||
video_theme=st.session_state.get("video_theme", ""),
|
||||
custom_prompt=st.session_state.get("custom_prompt", ""),
|
||||
@ -82,31 +80,8 @@ def generate_script_docu(params):
|
||||
)
|
||||
)
|
||||
|
||||
analysis_json_path = analysis_result["analysis_json_path"]
|
||||
update_progress(80, "正在生成解说文案...")
|
||||
|
||||
text_provider = config.app.get("text_llm_provider", "gemini").lower()
|
||||
text_api_key = config.app.get(f"text_{text_provider}_api_key")
|
||||
text_model = config.app.get(f"text_{text_provider}_model_name")
|
||||
text_base_url = config.app.get(f"text_{text_provider}_base_url")
|
||||
|
||||
markdown_output = parse_frame_analysis_to_markdown(analysis_json_path)
|
||||
narration = generate_narration(
|
||||
markdown_output,
|
||||
text_api_key,
|
||||
base_url=text_base_url,
|
||||
model=text_model,
|
||||
)
|
||||
narration_data = parse_and_fix_json(narration)
|
||||
|
||||
if not narration_data or "items" not in narration_data:
|
||||
logger.error(f"解说文案JSON解析失败,原始内容: {narration[:200]}...")
|
||||
raise Exception("解说文案格式错误,无法解析JSON或缺少items字段")
|
||||
|
||||
narration_dict = [{**item, "OST": 2} for item in narration_data["items"]]
|
||||
script = json.dumps(narration_dict, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"纪录片解说脚本生成完成,共 {len(narration_dict)} 个片段")
|
||||
logger.info(f"纪录片解说脚本生成完成,共 {len(script_items)} 个片段")
|
||||
script = json.dumps(script_items, ensure_ascii=False, indent=2)
|
||||
if isinstance(script, list):
|
||||
st.session_state["video_clip_json"] = script
|
||||
elif isinstance(script, str):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user