mirror of
https://github.com/linyqh/NarratoAI.git
synced 2025-12-12 03:02:48 +00:00
0.3.4 修改各种bug
This commit is contained in:
parent
4bafd696a1
commit
d0462ce91b
@ -251,16 +251,41 @@ def download_videos(
|
||||
return video_paths
|
||||
|
||||
|
||||
def time_to_seconds(time_str: str) -> float:
|
||||
"""
|
||||
将时间字符串转换为秒数
|
||||
支持格式:
|
||||
1. "MM:SS" (分:秒)
|
||||
2. "SS" (纯秒数)
|
||||
"""
|
||||
parts = time_str.split(':')
|
||||
if len(parts) == 2:
|
||||
minutes, seconds = map(float, parts)
|
||||
return minutes * 60 + seconds
|
||||
return float(time_str)
|
||||
|
||||
|
||||
def format_timestamp(seconds: float) -> str:
|
||||
"""
|
||||
将秒数转换为 "MM:SS" 格式的时间字符串
|
||||
"""
|
||||
minutes = int(seconds) // 60
|
||||
secs = int(seconds) % 60
|
||||
return f"{minutes:02d}:{secs:02d}"
|
||||
|
||||
|
||||
def save_clip_video(timestamp: str, origin_video: str, save_dir: str = "") -> dict:
|
||||
"""
|
||||
保存剪辑后的视频
|
||||
Args:
|
||||
timestamp: 需要裁剪的单个时间戳,如:'00:36-00:40'
|
||||
timestamp: 需要裁剪的单个时间戳,支持两种格式:
|
||||
1. '00:36-00:40' (分:秒-分:秒)
|
||||
2. 'SS-SS' (秒-秒)
|
||||
origin_video: 原视频路径
|
||||
save_dir: 存储目录
|
||||
|
||||
Returns:
|
||||
裁剪后的视频路径
|
||||
裁剪后的视频路径,格式为 {timestamp: video_path}
|
||||
"""
|
||||
if not save_dir:
|
||||
save_dir = utils.storage_dir("cache_videos")
|
||||
@ -276,35 +301,64 @@ def save_clip_video(timestamp: str, origin_video: str, save_dir: str = "") -> di
|
||||
return {timestamp: video_path}
|
||||
|
||||
try:
|
||||
# 先加载视频获取总时长
|
||||
video = VideoFileClip(origin_video)
|
||||
total_duration = video.duration
|
||||
|
||||
# 获取目标时间段
|
||||
start_str, end_str = timestamp.split('-')
|
||||
start = time_to_seconds(start_str)
|
||||
end = time_to_seconds(end_str)
|
||||
|
||||
# 验证时间段是否有效
|
||||
if start >= total_duration:
|
||||
logger.warning(f"起始时间 {format_timestamp(start)} ({start:.2f}秒) 超出视频总时长 {format_timestamp(total_duration)} ({total_duration:.2f}秒)")
|
||||
video.close()
|
||||
return {}
|
||||
|
||||
if end > total_duration:
|
||||
logger.warning(f"结束时间 {format_timestamp(end)} ({end:.2f}秒) 超出视频总时长 {format_timestamp(total_duration)} ({total_duration:.2f}秒),将自动调整为视频结尾")
|
||||
end = total_duration
|
||||
|
||||
if end <= start:
|
||||
logger.warning(f"结束时间 {format_timestamp(end)} 必须大于起始时间 {format_timestamp(start)}")
|
||||
video.close()
|
||||
return {}
|
||||
|
||||
# 剪辑视频
|
||||
start, end = utils.split_timestamp(timestamp)
|
||||
video = VideoFileClip(origin_video).subclip(start, end)
|
||||
duration = end - start
|
||||
logger.info(f"开始剪辑视频: {format_timestamp(start)} - {format_timestamp(end)},时长 {format_timestamp(duration)}")
|
||||
subclip = video.subclip(start, end)
|
||||
|
||||
# 检查视频是否有音频轨道
|
||||
if video.audio is not None:
|
||||
video.write_videofile(video_path, logger=None) # 有音频时正常处理
|
||||
else:
|
||||
# 没有音频时使用不同的写入方式
|
||||
video.write_videofile(video_path, audio=False, logger=None)
|
||||
|
||||
video.close() # 确保关闭视频文件
|
||||
|
||||
if os.path.getsize(video_path) > 0 and os.path.exists(video_path):
|
||||
try:
|
||||
clip = VideoFileClip(video_path)
|
||||
duration = clip.duration
|
||||
fps = clip.fps
|
||||
clip.close()
|
||||
if duration > 0 and fps > 0:
|
||||
# 检查视频是否有音频轨道并写入文件
|
||||
subclip.write_videofile(video_path, audio=(subclip.audio is not None), logger=None)
|
||||
|
||||
# 验证生成的视频文件
|
||||
if os.path.exists(video_path) and os.path.getsize(video_path) > 0:
|
||||
with VideoFileClip(video_path) as clip:
|
||||
if clip.duration > 0 and clip.fps > 0:
|
||||
return {timestamp: video_path}
|
||||
|
||||
raise ValueError("视频文件验证失败")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"视频文件验证失败: {video_path} => {str(e)}")
|
||||
logger.warning(f"视频文件处理失败: {video_path} => {str(e)}")
|
||||
if os.path.exists(video_path):
|
||||
os.remove(video_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"视频剪辑失败: {str(e)}")
|
||||
logger.warning(f"视频剪辑失败: \n{str(traceback.format_exc())}")
|
||||
if os.path.exists(video_path):
|
||||
os.remove(video_path)
|
||||
finally:
|
||||
# 确保视频对象被正确关闭
|
||||
try:
|
||||
video.close()
|
||||
if 'subclip' in locals():
|
||||
subclip.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
return {}
|
||||
|
||||
@ -426,6 +480,4 @@ def merge_videos(video_paths, ost_list):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
download_videos(
|
||||
"test123", ["Money Exchange Medium"], audio_duration=100, source="pixabay"
|
||||
)
|
||||
save_clip_video('00:50-01:41', 'E:\\projects\\NarratoAI\\resource\\videos\\WeChat_20241110144511.mp4')
|
||||
|
||||
@ -326,7 +326,7 @@ class ScriptProcessor:
|
||||
start_seconds = time_to_seconds(start_str)
|
||||
end_seconds = time_to_seconds(end_str)
|
||||
duration = end_seconds - start_seconds
|
||||
word_count = int(duration / 0.2)
|
||||
word_count = int(duration / 0.35)
|
||||
|
||||
return word_count
|
||||
except Exception as e:
|
||||
|
||||
@ -396,6 +396,8 @@ def cut_video(params, progress_callback=None):
|
||||
|
||||
total_clips = len(time_list)
|
||||
|
||||
print("time_list", time_list)
|
||||
|
||||
def clip_progress(current, total):
|
||||
progress = int((current / total) * 100)
|
||||
if progress_callback:
|
||||
@ -413,12 +415,16 @@ def cut_video(params, progress_callback=None):
|
||||
|
||||
st.session_state['subclip_videos'] = subclip_videos
|
||||
|
||||
print("list:", subclip_videos)
|
||||
|
||||
for i, video_script in enumerate(video_script_list):
|
||||
print(i)
|
||||
print(video_script)
|
||||
try:
|
||||
video_script['path'] = subclip_videos[video_script['timestamp']]
|
||||
except KeyError as err:
|
||||
logger.error(f"裁剪视频失败: {err}")
|
||||
raise ValueError(f"裁剪视频失败: {err}")
|
||||
# raise ValueError(f"裁剪视频失败: {err}")
|
||||
|
||||
return task_id, subclip_videos
|
||||
|
||||
|
||||
@ -65,16 +65,44 @@ class VideoProcessor:
|
||||
shot_boundaries.append(i)
|
||||
return shot_boundaries
|
||||
|
||||
def filter_keyframes_by_time(self, keyframes: List[np.ndarray],
|
||||
keyframe_indices: List[int]) -> Tuple[List[np.ndarray], List[int]]:
|
||||
"""
|
||||
过滤关键帧,确保每秒最多只有一个关键帧
|
||||
|
||||
Args:
|
||||
keyframes: 关键帧列表
|
||||
keyframe_indices: 关键帧索引列表
|
||||
|
||||
Returns:
|
||||
Tuple[List[np.ndarray], List[int]]: 过滤后的关键帧列表和对应的帧索引
|
||||
"""
|
||||
if not keyframes or not keyframe_indices:
|
||||
return keyframes, keyframe_indices
|
||||
|
||||
filtered_frames = []
|
||||
filtered_indices = []
|
||||
last_second = -1
|
||||
|
||||
for frame, idx in zip(keyframes, keyframe_indices):
|
||||
current_second = idx // self.fps
|
||||
if current_second != last_second:
|
||||
filtered_frames.append(frame)
|
||||
filtered_indices.append(idx)
|
||||
last_second = current_second
|
||||
|
||||
return filtered_frames, filtered_indices
|
||||
|
||||
def extract_keyframes(self, frames: List[np.ndarray], shot_boundaries: List[int]) -> Tuple[List[np.ndarray], List[int]]:
|
||||
"""
|
||||
从每个镜头中提取关键帧
|
||||
从每个镜头中提取关键帧,并确保每秒最多一个关键帧
|
||||
|
||||
Args:
|
||||
frames: 视频帧列表
|
||||
shot_boundaries: 镜头边界列表
|
||||
|
||||
Returns:
|
||||
Tuple[List[np.ndarray], List[int]]: 关<EFBFBD><EFBFBD><EFBFBD>帧列表和对应的帧索引
|
||||
Tuple[List[np.ndarray], List[int]]: 关键帧列表和对应的帧索引
|
||||
"""
|
||||
keyframes = []
|
||||
keyframe_indices = []
|
||||
@ -92,7 +120,10 @@ class VideoProcessor:
|
||||
keyframes.append(shot_frames[center_idx])
|
||||
keyframe_indices.append(start + center_idx)
|
||||
|
||||
return keyframes, keyframe_indices
|
||||
# 过滤每秒多余的关键帧
|
||||
filtered_keyframes, filtered_indices = self.filter_keyframes_by_time(keyframes, keyframe_indices)
|
||||
|
||||
return filtered_keyframes, filtered_indices
|
||||
|
||||
def save_keyframes(self, keyframes: List[np.ndarray], keyframe_indices: List[int],
|
||||
output_dir: str) -> None:
|
||||
|
||||
@ -100,7 +100,7 @@ class VisionAnalyzer:
|
||||
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
error_msg = f"批次 {i // batch_size} 处理出错: {str(e)}\n{traceback.format_exc()}"
|
||||
error_msg = f"批次 {i // batch_size} 处理出错: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
|
||||
if retry_count >= 3:
|
||||
@ -111,7 +111,7 @@ class VisionAnalyzer:
|
||||
'model_used': self.model_name
|
||||
})
|
||||
else:
|
||||
logger.info(f"批次 {i // batch_size} 处理失败,等待60秒后重试...")
|
||||
logger.info(f"批次 {i // batch_size} 处理失败,等待60秒后重试当前批次...")
|
||||
await asyncio.sleep(60)
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
[app]
|
||||
project_version="0.3.2"
|
||||
project_version="0.3.4"
|
||||
# 支持视频理解的大模型提供商
|
||||
# gemini
|
||||
# NarratoAPI
|
||||
# qwen2-vl (待增加)
|
||||
vision_llm_provider="gemini"
|
||||
vision_batch_size = 5
|
||||
vision_batch_size = 7
|
||||
vision_analysis_prompt = "你是资深视频内容分析专家,擅长分析视频画面信息,分析下面视频画面内容,只输出客观的画面描述不要给任何总结或评价"
|
||||
|
||||
########## Vision Gemini API Key
|
||||
|
||||
@ -18,6 +18,69 @@ from app.utils import utils, check_script, vision_analyzer, video_processor
|
||||
from webui.utils import file_utils
|
||||
|
||||
|
||||
def get_batch_timestamps(batch_files, prev_batch_files=None):
|
||||
"""
|
||||
获取一批文件的时间戳范围
|
||||
返回: (first_timestamp, last_timestamp, timestamp_range)
|
||||
|
||||
文件名格式: keyframe_001253_000050.jpg
|
||||
其中 000050 表示 00:00:50 (50秒)
|
||||
000101 表示 00:01:01 (1分1秒)
|
||||
|
||||
Args:
|
||||
batch_files: 当前批次的文件列表
|
||||
prev_batch_files: 上一个批次的文件列表,用于处理单张图片的情况
|
||||
"""
|
||||
if not batch_files:
|
||||
logger.warning("Empty batch files")
|
||||
return "00:00", "00:00", "00:00-00:00"
|
||||
|
||||
# 如果当前批次只有一张图片,且有上一个批次的文件,则使用上一批次的最后一张作为首帧
|
||||
if len(batch_files) == 1 and prev_batch_files and len(prev_batch_files) > 0:
|
||||
first_frame = os.path.basename(prev_batch_files[-1])
|
||||
last_frame = os.path.basename(batch_files[0])
|
||||
logger.debug(f"单张图片批次,使用上一批次最后一帧作为首帧: {first_frame}")
|
||||
else:
|
||||
# 提取首帧和尾帧的时间戳
|
||||
first_frame = os.path.basename(batch_files[0])
|
||||
last_frame = os.path.basename(batch_files[-1])
|
||||
|
||||
# 从文件名中提取时间信息
|
||||
first_time = first_frame.split('_')[2].replace('.jpg', '') # 000050
|
||||
last_time = last_frame.split('_')[2].replace('.jpg', '') # 000101
|
||||
|
||||
# 转换为分:秒格式
|
||||
def format_timestamp(time_str):
|
||||
# 时间格式为 MMSS,如 0050 表示 00:50, 0101 表示 01:01
|
||||
if len(time_str) < 4:
|
||||
logger.warning(f"Invalid timestamp format: {time_str}")
|
||||
return "00:00"
|
||||
|
||||
minutes = int(time_str[-4:-2]) # 取后4位的前2位作为分钟
|
||||
seconds = int(time_str[-2:]) # 取后2位作为秒数
|
||||
|
||||
# 处理进位
|
||||
if seconds >= 60:
|
||||
minutes += seconds // 60
|
||||
seconds = seconds % 60
|
||||
|
||||
return f"{minutes:02d}:{seconds:02d}"
|
||||
|
||||
first_timestamp = format_timestamp(first_time)
|
||||
last_timestamp = format_timestamp(last_time)
|
||||
timestamp_range = f"{first_timestamp}-{last_timestamp}"
|
||||
|
||||
logger.debug(f"解析时间戳: {first_frame} -> {first_timestamp}, {last_frame} -> {last_timestamp}")
|
||||
return first_timestamp, last_timestamp, timestamp_range
|
||||
|
||||
def get_batch_files(keyframe_files, result, batch_size=5):
|
||||
"""
|
||||
获取当前批次的图片文件
|
||||
"""
|
||||
batch_start = result['batch_index'] * batch_size
|
||||
batch_end = min(batch_start + batch_size, len(keyframe_files))
|
||||
return keyframe_files[batch_start:batch_end]
|
||||
|
||||
def render_script_panel(tr):
|
||||
"""渲染脚本配置面板"""
|
||||
with st.container(border=True):
|
||||
@ -238,7 +301,7 @@ def generate_script(tr, params):
|
||||
# 检查是否已经提取过关键帧
|
||||
keyframe_files = []
|
||||
if os.path.exists(video_keyframes_dir):
|
||||
# 获取已有的关键帧文件
|
||||
# 取已有的关键帧文件
|
||||
for filename in sorted(os.listdir(video_keyframes_dir)):
|
||||
if filename.endswith('.jpg'):
|
||||
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
|
||||
@ -282,7 +345,7 @@ def generate_script(tr, params):
|
||||
except Exception as cleanup_err:
|
||||
logger.error(f"清理失败的关键帧目录时出错: {cleanup_err}")
|
||||
|
||||
raise Exception(f"关键帧提取失败: {str(e)}")
|
||||
raise Exception(f"关键帧提取<EFBFBD><EFBFBD>败: {str(e)}")
|
||||
|
||||
# 根据不同的 LLM 提供商处理
|
||||
vision_llm_provider = st.session_state.get('vision_llm_providers').lower()
|
||||
@ -327,39 +390,28 @@ def generate_script(tr, params):
|
||||
|
||||
# 合并所有批次的析结果
|
||||
frame_analysis = ""
|
||||
prev_batch_files = None
|
||||
|
||||
for result in results:
|
||||
if 'error' in result:
|
||||
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
|
||||
continue
|
||||
|
||||
# 获取当前批次的图片文件
|
||||
batch_start = result['batch_index'] * config.app.get("vision_batch_size", 5)
|
||||
batch_end = min(batch_start + config.app.get("vision_batch_size", 5), len(keyframe_files))
|
||||
batch_files = keyframe_files[batch_start:batch_end]
|
||||
batch_files = get_batch_files(keyframe_files, result, config.app.get("vision_batch_size", 5))
|
||||
logger.debug(f"批次 {result['batch_index']} 处理完成,共 {len(batch_files)} 张图片")
|
||||
logger.debug(batch_files)
|
||||
|
||||
# 提取首帧和尾帧的时间戳
|
||||
first_frame = os.path.basename(batch_files[0])
|
||||
last_frame = os.path.basename(batch_files[-1])
|
||||
|
||||
# 从文件名中提取时间信息
|
||||
first_time = first_frame.split('_')[2].replace('.jpg', '') # 000002
|
||||
last_time = last_frame.split('_')[2].replace('.jpg', '') # 000005
|
||||
|
||||
# 转换为分:秒格式
|
||||
def format_timestamp(time_str):
|
||||
seconds = int(time_str)
|
||||
minutes = seconds // 60
|
||||
seconds = seconds % 60
|
||||
return f"{minutes:02d}:{seconds:02d}"
|
||||
|
||||
first_timestamp = format_timestamp(first_time)
|
||||
last_timestamp = format_timestamp(last_time)
|
||||
first_timestamp, last_timestamp, _ = get_batch_timestamps(batch_files, prev_batch_files)
|
||||
logger.debug(f"处理时间戳: {first_timestamp}-{last_timestamp}")
|
||||
|
||||
# 添加带时间戳的分析结果
|
||||
frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n"
|
||||
frame_analysis += result['response']
|
||||
frame_analysis += "\n"
|
||||
|
||||
# 更新上一个批次的文件
|
||||
prev_batch_files = batch_files
|
||||
|
||||
if not frame_analysis.strip():
|
||||
raise Exception("未能生成有效的帧分析结果")
|
||||
|
||||
@ -378,46 +430,28 @@ def generate_script(tr, params):
|
||||
|
||||
# 构建帧内容列表
|
||||
frame_content_list = []
|
||||
prev_batch_files = None
|
||||
|
||||
for i, result in enumerate(results):
|
||||
if 'error' in result:
|
||||
continue
|
||||
|
||||
# 获取当前批次的图片文件
|
||||
batch_start = result['batch_index'] * config.app.get("vision_batch_size", 5)
|
||||
batch_end = min(batch_start + config.app.get("vision_batch_size", 5), len(keyframe_files))
|
||||
batch_files = keyframe_files[batch_start:batch_end]
|
||||
|
||||
# 提取首帧和尾帧的时间戳
|
||||
first_frame = os.path.basename(batch_files[0])
|
||||
last_frame = os.path.basename(batch_files[-1])
|
||||
|
||||
# 从文件名中提取时间信息
|
||||
first_time = first_frame.split('_')[2].replace('.jpg', '') # 000002
|
||||
last_time = last_frame.split('_')[2].replace('.jpg', '') # 000005
|
||||
|
||||
# 转换为分:秒格式
|
||||
def format_timestamp(time_str):
|
||||
seconds = int(time_str)
|
||||
minutes = seconds // 60
|
||||
seconds = seconds % 60
|
||||
return f"{minutes:02d}:{seconds:02d}"
|
||||
|
||||
first_timestamp = format_timestamp(first_time)
|
||||
last_timestamp = format_timestamp(last_time)
|
||||
|
||||
# 构建时间戳范围字符串 (MM:SS-MM:SS 格式)
|
||||
timestamp_range = f"{first_timestamp}-{last_timestamp}"
|
||||
batch_files = get_batch_files(keyframe_files, result, config.app.get("vision_batch_size", 5))
|
||||
_, _, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files)
|
||||
|
||||
frame_content = {
|
||||
"timestamp": timestamp_range, # 使用时间范围格式 "MM:SS-MM:SS"
|
||||
"picture": result['response'], # 图片分析结果
|
||||
"narration": "", # 将由 ScriptProcessor 生成
|
||||
"OST": 2 # 默认值
|
||||
"timestamp": timestamp_range,
|
||||
"picture": result['response'],
|
||||
"narration": "",
|
||||
"OST": 2
|
||||
}
|
||||
frame_content_list.append(frame_content)
|
||||
|
||||
logger.debug(f"添加帧内容: 时间范围={timestamp_range}, 分析结果长度={len(result['response'])}")
|
||||
|
||||
# 更新上一个批次的文件
|
||||
prev_batch_files = batch_files
|
||||
|
||||
if not frame_content_list:
|
||||
raise Exception("没有有效的帧内容可以处理")
|
||||
|
||||
@ -442,13 +476,15 @@ def generate_script(tr, params):
|
||||
)
|
||||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||
session.mount("https://", adapter)
|
||||
|
||||
try:
|
||||
response = session.post(
|
||||
f"{config.app.get('narrato_api_url')}/video/config",
|
||||
params=api_params,
|
||||
timeout=30,
|
||||
verify=True # 启用证书验证
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
custom_prompt = st.session_state.get('custom_prompt', '')
|
||||
processor = ScriptProcessor(
|
||||
@ -621,7 +657,7 @@ def save_script(tr, video_clip_json_details):
|
||||
# 显示成功消息
|
||||
st.success(tr("Script saved successfully"))
|
||||
|
||||
# 强制重新加载页面<EFBFBD><EFBFBD>更新选择框
|
||||
# 强制重新加载页面更新选择框
|
||||
time.sleep(0.5) # 给一点时间让用户看到成功消息
|
||||
st.rerun()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user