mirror of
https://github.com/linyqh/NarratoAI.git
synced 2026-05-01 14:18:19 +00:00
103 lines
4.2 KiB
Python
103 lines
4.2 KiB
Python
# 纪录片脚本生成
|
|
import asyncio
|
|
import json
|
|
import time
|
|
import traceback
|
|
|
|
import streamlit as st
|
|
from loguru import logger
|
|
|
|
from app.config import config
|
|
from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService
|
|
|
|
|
|
def generate_script_docu(params):
|
|
"""
|
|
生成纪录片视频脚本。
|
|
要求: 原视频无字幕无配音
|
|
适合场景: 纪录片、动物搞笑解说、荒野建造等
|
|
"""
|
|
progress_bar = st.progress(0)
|
|
status_text = st.empty()
|
|
|
|
def update_progress(progress: float, message: str = ""):
|
|
progress_bar.progress(progress)
|
|
if message:
|
|
status_text.text(f"🎬 {message}")
|
|
else:
|
|
status_text.text(f"📊 进度: {progress}%")
|
|
|
|
try:
|
|
with st.spinner("正在生成脚本..."):
|
|
if not params.video_origin_path:
|
|
st.error("请先选择视频文件")
|
|
return
|
|
|
|
vision_llm_provider = (
|
|
st.session_state.get("vision_llm_provider") or config.app.get("vision_llm_provider", "openai")
|
|
).lower()
|
|
vision_api_key = (
|
|
st.session_state.get(f"vision_{vision_llm_provider}_api_key")
|
|
or config.app.get(f"vision_{vision_llm_provider}_api_key")
|
|
)
|
|
vision_model = (
|
|
st.session_state.get(f"vision_{vision_llm_provider}_model_name")
|
|
or config.app.get(f"vision_{vision_llm_provider}_model_name")
|
|
)
|
|
vision_base_url = (
|
|
st.session_state.get(f"vision_{vision_llm_provider}_base_url")
|
|
or config.app.get(f"vision_{vision_llm_provider}_base_url", "")
|
|
)
|
|
if not vision_api_key or not vision_model:
|
|
raise ValueError(
|
|
f"未配置 {vision_llm_provider} 的 API Key 或模型名称。"
|
|
f"请在设置页面配置 vision_{vision_llm_provider}_api_key 和 vision_{vision_llm_provider}_model_name"
|
|
)
|
|
|
|
frame_interval_input = st.session_state.get("frame_interval_input") or config.frames.get(
|
|
"frame_interval_input", 3
|
|
)
|
|
vision_batch_size = st.session_state.get("vision_batch_size") or config.frames.get("vision_batch_size", 10)
|
|
vision_max_concurrency = st.session_state.get("vision_max_concurrency") or config.frames.get(
|
|
"vision_max_concurrency", 2
|
|
)
|
|
|
|
update_progress(10, "正在提取关键帧...")
|
|
service = DocumentaryFrameAnalysisService()
|
|
script_items = asyncio.run(
|
|
service.generate_documentary_script(
|
|
video_path=params.video_origin_path,
|
|
video_theme=st.session_state.get("video_theme", ""),
|
|
custom_prompt=st.session_state.get("custom_prompt", ""),
|
|
frame_interval_input=frame_interval_input,
|
|
vision_batch_size=vision_batch_size,
|
|
vision_llm_provider=vision_llm_provider,
|
|
progress_callback=update_progress,
|
|
vision_api_key=vision_api_key,
|
|
vision_model_name=vision_model,
|
|
vision_base_url=vision_base_url,
|
|
max_concurrency=vision_max_concurrency,
|
|
)
|
|
)
|
|
|
|
logger.info(f"纪录片解说脚本生成完成,共 {len(script_items)} 个片段")
|
|
script = json.dumps(script_items, ensure_ascii=False, indent=2)
|
|
if isinstance(script, list):
|
|
st.session_state["video_clip_json"] = script
|
|
elif isinstance(script, str):
|
|
st.session_state["video_clip_json"] = json.loads(script)
|
|
update_progress(100, "脚本生成完成")
|
|
|
|
time.sleep(0.1)
|
|
progress_bar.progress(100)
|
|
status_text.text("🎉 脚本生成完成!")
|
|
st.success("✅ 视频脚本生成成功!")
|
|
|
|
except Exception as err:
|
|
st.error(f"❌ 生成过程中发生错误: {str(err)}")
|
|
logger.exception(f"生成脚本时发生错误\n{traceback.format_exc()}")
|
|
finally:
|
|
time.sleep(2)
|
|
progress_bar.empty()
|
|
status_text.empty()
|