NarratoAI/webui/tools/generate_script_docu.py

128 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 纪录片脚本生成
import asyncio
import json
import time
import traceback
import streamlit as st
from loguru import logger
from app.config import config
from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService
from app.services.generate_narration_script import generate_narration, parse_frame_analysis_to_markdown
from webui.tools.generate_short_summary import parse_and_fix_json
def generate_script_docu(params):
"""
生成纪录片视频脚本。
要求: 原视频无字幕无配音
适合场景: 纪录片、动物搞笑解说、荒野建造等
"""
progress_bar = st.progress(0)
status_text = st.empty()
def update_progress(progress: float, message: str = ""):
progress_bar.progress(progress)
if message:
status_text.text(f"🎬 {message}")
else:
status_text.text(f"📊 进度: {progress}%")
try:
with st.spinner("正在生成脚本..."):
if not params.video_origin_path:
st.error("请先选择视频文件")
return
vision_llm_provider = (
st.session_state.get("vision_llm_provider") or config.app.get("vision_llm_provider", "openai")
).lower()
vision_api_key = (
st.session_state.get(f"vision_{vision_llm_provider}_api_key")
or config.app.get(f"vision_{vision_llm_provider}_api_key")
)
vision_model = (
st.session_state.get(f"vision_{vision_llm_provider}_model_name")
or config.app.get(f"vision_{vision_llm_provider}_model_name")
)
vision_base_url = (
st.session_state.get(f"vision_{vision_llm_provider}_base_url")
or config.app.get(f"vision_{vision_llm_provider}_base_url", "")
)
if not vision_api_key or not vision_model:
raise ValueError(
f"未配置 {vision_llm_provider} 的 API Key 或模型名称。"
f"请在设置页面配置 vision_{vision_llm_provider}_api_key 和 vision_{vision_llm_provider}_model_name"
)
frame_interval_input = st.session_state.get("frame_interval_input") or config.frames.get(
"frame_interval_input", 3
)
vision_batch_size = st.session_state.get("vision_batch_size") or config.frames.get("vision_batch_size", 10)
vision_max_concurrency = st.session_state.get("vision_max_concurrency") or config.frames.get(
"vision_max_concurrency", 2
)
update_progress(10, "正在提取关键帧...")
service = DocumentaryFrameAnalysisService()
analysis_result = asyncio.run(
service.analyze_video(
video_path=params.video_origin_path,
video_theme=st.session_state.get("video_theme", ""),
custom_prompt=st.session_state.get("custom_prompt", ""),
frame_interval_input=frame_interval_input,
vision_batch_size=vision_batch_size,
vision_llm_provider=vision_llm_provider,
progress_callback=update_progress,
vision_api_key=vision_api_key,
vision_model_name=vision_model,
vision_base_url=vision_base_url,
max_concurrency=vision_max_concurrency,
)
)
analysis_json_path = analysis_result["analysis_json_path"]
update_progress(80, "正在生成解说文案...")
text_provider = config.app.get("text_llm_provider", "gemini").lower()
text_api_key = config.app.get(f"text_{text_provider}_api_key")
text_model = config.app.get(f"text_{text_provider}_model_name")
text_base_url = config.app.get(f"text_{text_provider}_base_url")
markdown_output = parse_frame_analysis_to_markdown(analysis_json_path)
narration = generate_narration(
markdown_output,
text_api_key,
base_url=text_base_url,
model=text_model,
)
narration_data = parse_and_fix_json(narration)
if not narration_data or "items" not in narration_data:
logger.error(f"解说文案JSON解析失败原始内容: {narration[:200]}...")
raise Exception("解说文案格式错误无法解析JSON或缺少items字段")
narration_dict = [{**item, "OST": 2} for item in narration_data["items"]]
script = json.dumps(narration_dict, ensure_ascii=False, indent=2)
logger.info(f"纪录片解说脚本生成完成,共 {len(narration_dict)} 个片段")
if isinstance(script, list):
st.session_state["video_clip_json"] = script
elif isinstance(script, str):
st.session_state["video_clip_json"] = json.loads(script)
update_progress(100, "脚本生成完成")
time.sleep(0.1)
progress_bar.progress(100)
status_text.text("🎉 脚本生成完成!")
st.success("✅ 视频脚本生成成功!")
except Exception as err:
st.error(f"❌ 生成过程中发生错误: {str(err)}")
logger.exception(f"生成脚本时发生错误\n{traceback.format_exc()}")
finally:
time.sleep(2)
progress_bar.empty()
status_text.empty()