mirror of
https://github.com/linyqh/NarratoAI.git
synced 2026-05-01 14:18:19 +00:00
128 lines
5.4 KiB
Python
128 lines
5.4 KiB
Python
# 纪录片脚本生成
|
||
import asyncio
|
||
import json
|
||
import time
|
||
import traceback
|
||
|
||
import streamlit as st
|
||
from loguru import logger
|
||
|
||
from app.config import config
|
||
from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService
|
||
from app.services.generate_narration_script import generate_narration, parse_frame_analysis_to_markdown
|
||
from webui.tools.generate_short_summary import parse_and_fix_json
|
||
|
||
|
||
def generate_script_docu(params):
|
||
"""
|
||
生成纪录片视频脚本。
|
||
要求: 原视频无字幕无配音
|
||
适合场景: 纪录片、动物搞笑解说、荒野建造等
|
||
"""
|
||
progress_bar = st.progress(0)
|
||
status_text = st.empty()
|
||
|
||
def update_progress(progress: float, message: str = ""):
|
||
progress_bar.progress(progress)
|
||
if message:
|
||
status_text.text(f"🎬 {message}")
|
||
else:
|
||
status_text.text(f"📊 进度: {progress}%")
|
||
|
||
try:
|
||
with st.spinner("正在生成脚本..."):
|
||
if not params.video_origin_path:
|
||
st.error("请先选择视频文件")
|
||
return
|
||
|
||
vision_llm_provider = (
|
||
st.session_state.get("vision_llm_provider") or config.app.get("vision_llm_provider", "openai")
|
||
).lower()
|
||
vision_api_key = (
|
||
st.session_state.get(f"vision_{vision_llm_provider}_api_key")
|
||
or config.app.get(f"vision_{vision_llm_provider}_api_key")
|
||
)
|
||
vision_model = (
|
||
st.session_state.get(f"vision_{vision_llm_provider}_model_name")
|
||
or config.app.get(f"vision_{vision_llm_provider}_model_name")
|
||
)
|
||
vision_base_url = (
|
||
st.session_state.get(f"vision_{vision_llm_provider}_base_url")
|
||
or config.app.get(f"vision_{vision_llm_provider}_base_url", "")
|
||
)
|
||
if not vision_api_key or not vision_model:
|
||
raise ValueError(
|
||
f"未配置 {vision_llm_provider} 的 API Key 或模型名称。"
|
||
f"请在设置页面配置 vision_{vision_llm_provider}_api_key 和 vision_{vision_llm_provider}_model_name"
|
||
)
|
||
|
||
frame_interval_input = st.session_state.get("frame_interval_input") or config.frames.get(
|
||
"frame_interval_input", 3
|
||
)
|
||
vision_batch_size = st.session_state.get("vision_batch_size") or config.frames.get("vision_batch_size", 10)
|
||
vision_max_concurrency = st.session_state.get("vision_max_concurrency") or config.frames.get(
|
||
"vision_max_concurrency", 2
|
||
)
|
||
|
||
update_progress(10, "正在提取关键帧...")
|
||
service = DocumentaryFrameAnalysisService()
|
||
analysis_result = asyncio.run(
|
||
service.analyze_video(
|
||
video_path=params.video_origin_path,
|
||
video_theme=st.session_state.get("video_theme", ""),
|
||
custom_prompt=st.session_state.get("custom_prompt", ""),
|
||
frame_interval_input=frame_interval_input,
|
||
vision_batch_size=vision_batch_size,
|
||
vision_llm_provider=vision_llm_provider,
|
||
progress_callback=update_progress,
|
||
vision_api_key=vision_api_key,
|
||
vision_model_name=vision_model,
|
||
vision_base_url=vision_base_url,
|
||
max_concurrency=vision_max_concurrency,
|
||
)
|
||
)
|
||
|
||
analysis_json_path = analysis_result["analysis_json_path"]
|
||
update_progress(80, "正在生成解说文案...")
|
||
|
||
text_provider = config.app.get("text_llm_provider", "gemini").lower()
|
||
text_api_key = config.app.get(f"text_{text_provider}_api_key")
|
||
text_model = config.app.get(f"text_{text_provider}_model_name")
|
||
text_base_url = config.app.get(f"text_{text_provider}_base_url")
|
||
|
||
markdown_output = parse_frame_analysis_to_markdown(analysis_json_path)
|
||
narration = generate_narration(
|
||
markdown_output,
|
||
text_api_key,
|
||
base_url=text_base_url,
|
||
model=text_model,
|
||
)
|
||
narration_data = parse_and_fix_json(narration)
|
||
|
||
if not narration_data or "items" not in narration_data:
|
||
logger.error(f"解说文案JSON解析失败,原始内容: {narration[:200]}...")
|
||
raise Exception("解说文案格式错误,无法解析JSON或缺少items字段")
|
||
|
||
narration_dict = [{**item, "OST": 2} for item in narration_data["items"]]
|
||
script = json.dumps(narration_dict, ensure_ascii=False, indent=2)
|
||
|
||
logger.info(f"纪录片解说脚本生成完成,共 {len(narration_dict)} 个片段")
|
||
if isinstance(script, list):
|
||
st.session_state["video_clip_json"] = script
|
||
elif isinstance(script, str):
|
||
st.session_state["video_clip_json"] = json.loads(script)
|
||
update_progress(100, "脚本生成完成")
|
||
|
||
time.sleep(0.1)
|
||
progress_bar.progress(100)
|
||
status_text.text("🎉 脚本生成完成!")
|
||
st.success("✅ 视频脚本生成成功!")
|
||
|
||
except Exception as err:
|
||
st.error(f"❌ 生成过程中发生错误: {str(err)}")
|
||
logger.exception(f"生成脚本时发生错误\n{traceback.format_exc()}")
|
||
finally:
|
||
time.sleep(2)
|
||
progress_bar.empty()
|
||
status_text.empty()
|