mirror of
https://github.com/linyqh/NarratoAI.git
synced 2025-12-10 18:02:51 +00:00
454 lines
24 KiB
Python
454 lines
24 KiB
Python
# 纪录片脚本生成
|
||
import os
|
||
import json
|
||
import time
|
||
import asyncio
|
||
import traceback
|
||
import streamlit as st
|
||
from loguru import logger
|
||
from datetime import datetime
|
||
|
||
from app.config import config
|
||
from app.utils import utils, video_processor
|
||
from webui.tools.base import create_vision_analyzer, get_batch_files, get_batch_timestamps, chekc_video_config
|
||
|
||
|
||
def generate_script_docu(params):
|
||
"""
|
||
生成 纪录片 视频脚本
|
||
要求: 原视频无字幕无配音
|
||
适合场景: 纪录片、动物搞笑解说、荒野建造等
|
||
"""
|
||
progress_bar = st.progress(0)
|
||
status_text = st.empty()
|
||
|
||
def update_progress(progress: float, message: str = ""):
|
||
progress_bar.progress(progress)
|
||
if message:
|
||
status_text.text(f"🎬 {message}")
|
||
else:
|
||
status_text.text(f"📊 进度: {progress}%")
|
||
|
||
try:
|
||
with st.spinner("正在生成脚本..."):
|
||
if not params.video_origin_path:
|
||
st.error("请先选择视频文件")
|
||
return
|
||
"""
|
||
1. 提取键帧
|
||
"""
|
||
update_progress(10, "正在提取关键帧...")
|
||
|
||
# 创建临时目录用于存储关键帧
|
||
keyframes_dir = os.path.join(utils.temp_dir(), "keyframes")
|
||
video_hash = utils.md5(params.video_origin_path + str(os.path.getmtime(params.video_origin_path)))
|
||
video_keyframes_dir = os.path.join(keyframes_dir, video_hash)
|
||
|
||
# 检查是否已经提取过关键帧
|
||
keyframe_files = []
|
||
if os.path.exists(video_keyframes_dir):
|
||
# 取已有的关键帧文件
|
||
for filename in sorted(os.listdir(video_keyframes_dir)):
|
||
if filename.endswith('.jpg'):
|
||
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
|
||
|
||
if keyframe_files:
|
||
logger.info(f"使用已缓存的关键帧: {video_keyframes_dir}")
|
||
st.info(f"✅ 使用已缓存关键帧,共 {len(keyframe_files)} 帧")
|
||
update_progress(20, f"使用已缓存关键帧,共 {len(keyframe_files)} 帧")
|
||
|
||
# 如果没有缓存的关键帧,则进行提取
|
||
if not keyframe_files:
|
||
try:
|
||
# 确保目录存在
|
||
os.makedirs(video_keyframes_dir, exist_ok=True)
|
||
|
||
# 初始化视频处理器
|
||
processor = video_processor.VideoProcessor(params.video_origin_path)
|
||
|
||
# 显示视频信息
|
||
st.info(f"📹 视频信息: {processor.width}x{processor.height}, {processor.fps:.1f}fps, {processor.duration:.1f}秒")
|
||
|
||
# 处理视频并提取关键帧 - 直接使用超级兼容性方案
|
||
update_progress(15, "正在提取关键帧(使用超级兼容性方案)...")
|
||
|
||
try:
|
||
# 使用优化的关键帧提取方法
|
||
processor.extract_frames_by_interval_ultra_compatible(
|
||
output_dir=video_keyframes_dir,
|
||
interval_seconds=st.session_state.get('frame_interval_input'),
|
||
)
|
||
except Exception as extract_error:
|
||
logger.error(f"关键帧提取失败: {extract_error}")
|
||
|
||
# 提供详细的错误信息和解决建议
|
||
error_msg = str(extract_error)
|
||
if "权限" in error_msg or "permission" in error_msg.lower():
|
||
suggestion = "建议:检查输出目录权限,或更换输出位置"
|
||
elif "空间" in error_msg or "space" in error_msg.lower():
|
||
suggestion = "建议:检查磁盘空间是否足够"
|
||
else:
|
||
suggestion = "建议:检查视频文件是否损坏,或尝试转换为标准格式"
|
||
|
||
raise Exception(f"关键帧提取失败: {error_msg}\n{suggestion}")
|
||
|
||
# 获取所有关键文件路径
|
||
for filename in sorted(os.listdir(video_keyframes_dir)):
|
||
if filename.endswith('.jpg'):
|
||
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
|
||
|
||
if not keyframe_files:
|
||
# 检查目录中是否有其他文件
|
||
all_files = os.listdir(video_keyframes_dir)
|
||
logger.error(f"关键帧目录内容: {all_files}")
|
||
raise Exception("未提取到任何关键帧文件,请检查视频文件格式")
|
||
|
||
update_progress(20, f"关键帧提取完成,共 {len(keyframe_files)} 帧")
|
||
st.success(f"✅ 成功提取 {len(keyframe_files)} 个关键帧")
|
||
|
||
except Exception as e:
|
||
# 如果提取失败,清理创建的目录
|
||
try:
|
||
if os.path.exists(video_keyframes_dir):
|
||
import shutil
|
||
shutil.rmtree(video_keyframes_dir)
|
||
except Exception as cleanup_err:
|
||
logger.error(f"清理失败的关键帧目录时出错: {cleanup_err}")
|
||
|
||
raise Exception(f"关键帧提取失败: {str(e)}")
|
||
|
||
"""
|
||
2. 视觉分析(批量分析每一帧)
|
||
"""
|
||
# 最佳实践:使用 get() 的默认值参数 + 从 config 获取备用值
|
||
vision_llm_provider = (
|
||
st.session_state.get('vision_llm_provider') or
|
||
config.app.get('vision_llm_provider', 'litellm')
|
||
).lower()
|
||
|
||
logger.info(f"使用 {vision_llm_provider.upper()} 进行视觉分析")
|
||
|
||
try:
|
||
# ===================初始化视觉分析器===================
|
||
update_progress(30, "正在初始化视觉分析器...")
|
||
|
||
# 使用统一的配置键格式获取配置(支持所有 provider)
|
||
vision_api_key = (
|
||
st.session_state.get(f'vision_{vision_llm_provider}_api_key') or
|
||
config.app.get(f'vision_{vision_llm_provider}_api_key')
|
||
)
|
||
vision_model = (
|
||
st.session_state.get(f'vision_{vision_llm_provider}_model_name') or
|
||
config.app.get(f'vision_{vision_llm_provider}_model_name')
|
||
)
|
||
vision_base_url = (
|
||
st.session_state.get(f'vision_{vision_llm_provider}_base_url') or
|
||
config.app.get(f'vision_{vision_llm_provider}_base_url', '')
|
||
)
|
||
|
||
# 验证必需配置
|
||
if not vision_api_key or not vision_model:
|
||
raise ValueError(
|
||
f"未配置 {vision_llm_provider} 的 API Key 或模型名称。"
|
||
f"请在设置页面配置 vision_{vision_llm_provider}_api_key 和 vision_{vision_llm_provider}_model_name"
|
||
)
|
||
|
||
# 创建视觉分析器实例(使用统一接口)
|
||
llm_params = {
|
||
"vision_provider": vision_llm_provider,
|
||
"vision_api_key": vision_api_key,
|
||
"vision_model_name": vision_model,
|
||
"vision_base_url": vision_base_url,
|
||
}
|
||
|
||
logger.debug(f"视觉分析器配置: provider={vision_llm_provider}, model={vision_model}")
|
||
|
||
analyzer = create_vision_analyzer(
|
||
provider=vision_llm_provider,
|
||
api_key=vision_api_key,
|
||
model=vision_model,
|
||
base_url=vision_base_url
|
||
)
|
||
|
||
update_progress(40, "正在分析关键帧...")
|
||
|
||
# ===================创建异步事件循环===================
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
|
||
# 执行异步分析
|
||
vision_batch_size = st.session_state.get('vision_batch_size') or config.frames.get("vision_batch_size")
|
||
vision_analysis_prompt = """
|
||
我提供了 %s 张视频帧,它们按时间顺序排列,代表一个连续的视频片段。请仔细分析每一帧的内容,并关注帧与帧之间的变化,以理解整个片段的活动。
|
||
|
||
首先,请详细描述每一帧的关键视觉信息(包含:主要内容、人物、动作和场景)。
|
||
然后,基于所有帧的分析,请用**简洁的语言**总结整个视频片段中发生的主要活动或事件流程。
|
||
|
||
请务必使用 JSON 格式输出你的结果。JSON 结构应如下:
|
||
{
|
||
"frame_observations": [
|
||
{
|
||
"frame_number": 1, // 或其他标识帧的方式
|
||
"observation": "描述每张视频帧中的主要内容、人物、动作和场景。"
|
||
},
|
||
// ... 更多帧的观察 ...
|
||
],
|
||
"overall_activity_summary": "在这里填写你总结的整个片段的主要活动,保持简洁。"
|
||
}
|
||
|
||
请务必不要遗漏视频帧,我提供了 %s 张视频帧,frame_observations 必须包含 %s 个元素
|
||
|
||
请只返回 JSON 字符串,不要包含任何其他解释性文字。
|
||
"""
|
||
results = loop.run_until_complete(
|
||
analyzer.analyze_images(
|
||
images=keyframe_files,
|
||
prompt=vision_analysis_prompt,
|
||
batch_size=vision_batch_size
|
||
)
|
||
)
|
||
loop.close()
|
||
|
||
"""
|
||
3. 处理分析结果(格式化为 json 数据)
|
||
"""
|
||
# ===================处理分析结果===================
|
||
update_progress(60, "正在整理分析结果...")
|
||
|
||
# 合并所有批次的分析结果
|
||
frame_analysis = ""
|
||
merged_frame_observations = [] # 合并所有批次的帧观察
|
||
overall_activity_summaries = [] # 合并所有批次的整体总结
|
||
prev_batch_files = None
|
||
frame_counter = 1 # 初始化帧计数器,用于给所有帧分配连续的序号
|
||
|
||
# 确保分析目录存在
|
||
analysis_dir = os.path.join(utils.storage_dir(), "temp", "analysis")
|
||
os.makedirs(analysis_dir, exist_ok=True)
|
||
origin_res = os.path.join(analysis_dir, "frame_analysis.json")
|
||
with open(origin_res, 'w', encoding='utf-8') as f:
|
||
json.dump(results, f, ensure_ascii=False, indent=2)
|
||
|
||
# 开始处理
|
||
for result in results:
|
||
if 'error' in result:
|
||
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
|
||
continue
|
||
|
||
# 获取当前批次的文件列表
|
||
batch_files = get_batch_files(keyframe_files, result, vision_batch_size)
|
||
|
||
# 获取批次的时间戳范围
|
||
first_timestamp, last_timestamp, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files)
|
||
|
||
# 解析响应中的JSON数据
|
||
response_text = result['response']
|
||
try:
|
||
# 处理可能包含```json```格式的响应
|
||
if "```json" in response_text:
|
||
json_content = response_text.split("```json")[1].split("```")[0].strip()
|
||
elif "```" in response_text:
|
||
json_content = response_text.split("```")[1].split("```")[0].strip()
|
||
else:
|
||
json_content = response_text.strip()
|
||
|
||
response_data = json.loads(json_content)
|
||
|
||
# 提取frame_observations和overall_activity_summary
|
||
if "frame_observations" in response_data:
|
||
frame_obs = response_data["frame_observations"]
|
||
overall_summary = response_data.get("overall_activity_summary", "")
|
||
|
||
# 添加时间戳信息到每个帧观察
|
||
for i, obs in enumerate(frame_obs):
|
||
if i < len(batch_files):
|
||
# 从文件名中提取时间戳
|
||
file_path = batch_files[i]
|
||
file_name = os.path.basename(file_path)
|
||
# 提取时间戳字符串 (格式如: keyframe_000675_000027000.jpg)
|
||
# 格式解析: keyframe_帧序号_毫秒时间戳.jpg
|
||
timestamp_parts = file_name.split('_')
|
||
if len(timestamp_parts) >= 3:
|
||
timestamp_str = timestamp_parts[-1].split('.')[0]
|
||
try:
|
||
# 修正时间戳解析逻辑
|
||
# 格式为000100000,表示00:01:00,000,即1分钟
|
||
# 需要按照对应位数进行解析:
|
||
# 前两位是小时,中间两位是分钟,后面是秒和毫秒
|
||
if len(timestamp_str) >= 9: # 确保格式正确
|
||
hours = int(timestamp_str[0:2])
|
||
minutes = int(timestamp_str[2:4])
|
||
seconds = int(timestamp_str[4:6])
|
||
milliseconds = int(timestamp_str[6:9])
|
||
|
||
# 计算总秒数
|
||
timestamp_seconds = hours * 3600 + minutes * 60 + seconds + milliseconds / 1000
|
||
formatted_time = utils.format_time(timestamp_seconds) # 格式化时间戳
|
||
else:
|
||
# 兼容旧的解析方式
|
||
timestamp_seconds = int(timestamp_str) / 1000 # 转换为秒
|
||
formatted_time = utils.format_time(timestamp_seconds) # 格式化时间戳
|
||
except ValueError:
|
||
logger.warning(f"无法解析时间戳: {timestamp_str}")
|
||
timestamp_seconds = 0
|
||
formatted_time = "00:00:00,000"
|
||
else:
|
||
logger.warning(f"文件名格式不符合预期: {file_name}")
|
||
timestamp_seconds = 0
|
||
formatted_time = "00:00:00,000"
|
||
|
||
# 添加额外信息到帧观察
|
||
obs["frame_path"] = file_path
|
||
obs["timestamp"] = formatted_time
|
||
obs["timestamp_seconds"] = timestamp_seconds
|
||
obs["batch_index"] = result['batch_index']
|
||
|
||
# 使用全局递增的帧计数器替换原始的frame_number
|
||
if "frame_number" in obs:
|
||
obs["original_frame_number"] = obs["frame_number"] # 保留原始编号作为参考
|
||
obs["frame_number"] = frame_counter # 赋值连续的帧编号
|
||
frame_counter += 1 # 增加帧计数器
|
||
|
||
# 添加到合并列表
|
||
merged_frame_observations.append(obs)
|
||
|
||
# 添加批次整体总结信息
|
||
if overall_summary:
|
||
# 从文件名中提取时间戳数值
|
||
first_time_str = first_timestamp.split('_')[-1].split('.')[0]
|
||
last_time_str = last_timestamp.split('_')[-1].split('.')[0]
|
||
|
||
# 转换为毫秒并计算持续时间(秒)
|
||
try:
|
||
# 修正解析逻辑,与上面相同的方式解析时间戳
|
||
if len(first_time_str) >= 9 and len(last_time_str) >= 9:
|
||
# 解析第一个时间戳
|
||
first_hours = int(first_time_str[0:2])
|
||
first_minutes = int(first_time_str[2:4])
|
||
first_seconds = int(first_time_str[4:6])
|
||
first_ms = int(first_time_str[6:9])
|
||
first_time_seconds = first_hours * 3600 + first_minutes * 60 + first_seconds + first_ms / 1000
|
||
|
||
# 解析第二个时间戳
|
||
last_hours = int(last_time_str[0:2])
|
||
last_minutes = int(last_time_str[2:4])
|
||
last_seconds = int(last_time_str[4:6])
|
||
last_ms = int(last_time_str[6:9])
|
||
last_time_seconds = last_hours * 3600 + last_minutes * 60 + last_seconds + last_ms / 1000
|
||
|
||
batch_duration = last_time_seconds - first_time_seconds
|
||
else:
|
||
# 兼容旧的解析方式
|
||
first_time_ms = int(first_time_str)
|
||
last_time_ms = int(last_time_str)
|
||
batch_duration = (last_time_ms - first_time_ms) / 1000
|
||
except ValueError:
|
||
# 使用 utils.time_to_seconds 函数处理格式化的时间戳
|
||
first_time_seconds = utils.time_to_seconds(first_time_str.replace('_', ':').replace('-', ','))
|
||
last_time_seconds = utils.time_to_seconds(last_time_str.replace('_', ':').replace('-', ','))
|
||
batch_duration = last_time_seconds - first_time_seconds
|
||
|
||
overall_activity_summaries.append({
|
||
"batch_index": result['batch_index'],
|
||
"time_range": f"{first_timestamp}-{last_timestamp}",
|
||
"duration_seconds": batch_duration,
|
||
"summary": overall_summary
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"解析批次 {result['batch_index']} 的响应数据失败: {str(e)}")
|
||
# 添加原始响应作为回退
|
||
frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n"
|
||
frame_analysis += response_text
|
||
frame_analysis += "\n"
|
||
|
||
# 更新上一个批次的文件
|
||
prev_batch_files = batch_files
|
||
|
||
# 将合并后的结果转为JSON字符串
|
||
merged_results = {
|
||
"frame_observations": merged_frame_observations,
|
||
"overall_activity_summaries": overall_activity_summaries
|
||
}
|
||
|
||
# 使用当前时间创建文件名
|
||
now = datetime.now()
|
||
timestamp_str = now.strftime("%Y%m%d_%H%M")
|
||
|
||
# 保存完整的分析结果为JSON
|
||
analysis_filename = f"frame_analysis_{timestamp_str}.json"
|
||
analysis_json_path = os.path.join(analysis_dir, analysis_filename)
|
||
with open(analysis_json_path, 'w', encoding='utf-8') as f:
|
||
json.dump(merged_results, f, ensure_ascii=False, indent=2)
|
||
logger.info(f"分析结果已保存到: {analysis_json_path}")
|
||
|
||
"""
|
||
4. 生成文案
|
||
"""
|
||
logger.info("开始生成解说文案")
|
||
update_progress(80, "正在生成解说文案...")
|
||
from app.services.generate_narration_script import parse_frame_analysis_to_markdown, generate_narration
|
||
# 从配置中获取文本生成相关配置
|
||
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
|
||
text_api_key = config.app.get(f'text_{text_provider}_api_key')
|
||
text_model = config.app.get(f'text_{text_provider}_model_name')
|
||
text_base_url = config.app.get(f'text_{text_provider}_base_url')
|
||
llm_params.update({
|
||
"text_provider": text_provider,
|
||
"text_api_key": text_api_key,
|
||
"text_model_name": text_model,
|
||
"text_base_url": text_base_url
|
||
})
|
||
chekc_video_config(llm_params)
|
||
# 整理帧分析数据
|
||
markdown_output = parse_frame_analysis_to_markdown(analysis_json_path)
|
||
|
||
# 生成解说文案
|
||
narration = generate_narration(
|
||
markdown_output,
|
||
text_api_key,
|
||
base_url=text_base_url,
|
||
model=text_model
|
||
)
|
||
|
||
# 使用增强的JSON解析器
|
||
from webui.tools.generate_short_summary import parse_and_fix_json
|
||
narration_data = parse_and_fix_json(narration)
|
||
|
||
if not narration_data or 'items' not in narration_data:
|
||
logger.error(f"解说文案JSON解析失败,原始内容: {narration[:200]}...")
|
||
raise Exception("解说文案格式错误,无法解析JSON或缺少items字段")
|
||
|
||
narration_dict = narration_data['items']
|
||
# 为 narration_dict 中每个 item 新增一个 OST: 2 的字段, 代表保留原声和配音
|
||
narration_dict = [{**item, "OST": 2} for item in narration_dict]
|
||
logger.info(f"解说文案生成完成,共 {len(narration_dict)} 个片段")
|
||
# 结果转换为JSON字符串
|
||
script = json.dumps(narration_dict, ensure_ascii=False, indent=2)
|
||
|
||
except Exception as e:
|
||
logger.exception(f"大模型处理过程中发生错误\n{traceback.format_exc()}")
|
||
raise Exception(f"分析失败: {str(e)}")
|
||
|
||
if script is None:
|
||
st.error("生成脚本失败,请检查日志")
|
||
st.stop()
|
||
logger.info(f"纪录片解说脚本生成完成")
|
||
if isinstance(script, list):
|
||
st.session_state['video_clip_json'] = script
|
||
elif isinstance(script, str):
|
||
st.session_state['video_clip_json'] = json.loads(script)
|
||
update_progress(100, "脚本生成完成")
|
||
|
||
time.sleep(0.1)
|
||
progress_bar.progress(100)
|
||
status_text.text("🎉 脚本生成完成!")
|
||
st.success("✅ 视频脚本生成成功!")
|
||
|
||
except Exception as err:
|
||
st.error(f"❌ 生成过程中发生错误: {str(err)}")
|
||
logger.exception(f"生成脚本时发生错误\n{traceback.format_exc()}")
|
||
finally:
|
||
time.sleep(2)
|
||
progress_bar.empty()
|
||
status_text.empty()
|