NarratoAI/webui/tools/generate_script_docu.py

117 lines
4.6 KiB
Python

# 纪录片脚本生成
import asyncio
import json
import time
import traceback
import streamlit as st
from loguru import logger
from app.config import config
from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService
def _normalize_progress_value(progress: float | int) -> int:
"""Normalize mixed progress inputs to Streamlit's 0-100 integer range."""
try:
value = float(progress)
except (TypeError, ValueError):
return 0
if 0.0 <= value <= 1.0:
value *= 100
return max(0, min(100, int(round(value))))
def generate_script_docu(params):
"""
生成纪录片视频脚本。
要求: 原视频无字幕无配音
适合场景: 纪录片、动物搞笑解说、荒野建造等
"""
progress_bar = st.progress(0)
status_text = st.empty()
def update_progress(progress: float, message: str = ""):
normalized_progress = _normalize_progress_value(progress)
progress_bar.progress(normalized_progress)
if message:
status_text.text(f"🎬 {message}")
else:
status_text.text(f"📊 进度: {normalized_progress}%")
try:
with st.spinner("正在生成脚本..."):
if not params.video_origin_path:
st.error("请先选择视频文件")
return
vision_llm_provider = (
st.session_state.get("vision_llm_provider") or config.app.get("vision_llm_provider", "openai")
).lower()
vision_api_key = (
st.session_state.get(f"vision_{vision_llm_provider}_api_key")
or config.app.get(f"vision_{vision_llm_provider}_api_key")
)
vision_model = (
st.session_state.get(f"vision_{vision_llm_provider}_model_name")
or config.app.get(f"vision_{vision_llm_provider}_model_name")
)
vision_base_url = (
st.session_state.get(f"vision_{vision_llm_provider}_base_url")
or config.app.get(f"vision_{vision_llm_provider}_base_url", "")
)
if not vision_api_key or not vision_model:
raise ValueError(
f"未配置 {vision_llm_provider} 的 API Key 或模型名称。"
f"请在设置页面配置 vision_{vision_llm_provider}_api_key 和 vision_{vision_llm_provider}_model_name"
)
frame_interval_input = st.session_state.get("frame_interval_input") or config.frames.get(
"frame_interval_input", 3
)
vision_batch_size = st.session_state.get("vision_batch_size") or config.frames.get("vision_batch_size", 10)
vision_max_concurrency = st.session_state.get("vision_max_concurrency") or config.frames.get(
"vision_max_concurrency", 2
)
update_progress(10, "正在提取关键帧...")
service = DocumentaryFrameAnalysisService()
script_items = asyncio.run(
service.generate_documentary_script(
video_path=params.video_origin_path,
video_theme=st.session_state.get("video_theme", ""),
custom_prompt=st.session_state.get("custom_prompt", ""),
frame_interval_input=frame_interval_input,
vision_batch_size=vision_batch_size,
vision_llm_provider=vision_llm_provider,
progress_callback=update_progress,
vision_api_key=vision_api_key,
vision_model_name=vision_model,
vision_base_url=vision_base_url,
max_concurrency=vision_max_concurrency,
)
)
logger.info(f"纪录片解说脚本生成完成,共 {len(script_items)} 个片段")
script = json.dumps(script_items, ensure_ascii=False, indent=2)
if isinstance(script, list):
st.session_state["video_clip_json"] = script
elif isinstance(script, str):
st.session_state["video_clip_json"] = json.loads(script)
update_progress(100, "脚本生成完成")
time.sleep(0.1)
progress_bar.progress(100)
status_text.text("🎉 脚本生成完成!")
st.success("✅ 视频脚本生成成功!")
except Exception as err:
st.error(f"❌ 生成过程中发生错误: {str(err)}")
logger.exception(f"生成脚本时发生错误\n{traceback.format_exc()}")
finally:
time.sleep(2)
progress_bar.empty()
status_text.empty()