mirror of
https://github.com/linyqh/NarratoAI.git
synced 2025-12-21 11:00:17 +00:00
feat(video_processor): 优化镜头边界检测和关键帧提取功能
- 将镜头边界检测的阈值从 30 调整到 70,提高检测精度 - 添加 tqdm 进度条,增强处理过程的可视化 - 优化内存管理,提高程序运行效率 - 调整关键帧提取日志输出,增加处理进度信息
This commit is contained in:
parent
2f41c13e19
commit
d10a84caca
@ -6,6 +6,7 @@ import re
|
||||
from typing import List, Tuple, Generator
|
||||
from loguru import logger
|
||||
import gc
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class VideoProcessor:
|
||||
@ -61,7 +62,7 @@ class VideoProcessor:
|
||||
if frame_idx % 1000 == 0:
|
||||
gc.collect()
|
||||
|
||||
def detect_shot_boundaries(self, threshold: int = 30) -> List[int]:
|
||||
def detect_shot_boundaries(self, threshold: int = 70) -> List[int]:
|
||||
"""
|
||||
使用批处理方式检测镜头边界
|
||||
|
||||
@ -75,20 +76,24 @@ class VideoProcessor:
|
||||
prev_frame = None
|
||||
prev_idx = -1
|
||||
|
||||
for frame_idx, curr_frame in self.preprocess_video():
|
||||
pbar = tqdm(self.preprocess_video(),
|
||||
total=self.total_frames,
|
||||
desc="检测镜头边界",
|
||||
unit="帧")
|
||||
|
||||
for frame_idx, curr_frame in pbar:
|
||||
if prev_frame is not None:
|
||||
# 转换为灰度图并降低分辨率以提高性能
|
||||
prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)
|
||||
curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
diff = np.mean(np.abs(curr_gray.astype(float) - prev_gray.astype(float)))
|
||||
if diff > threshold:
|
||||
shot_boundaries.append(frame_idx)
|
||||
pbar.set_postfix({"检测到边界": len(shot_boundaries)})
|
||||
|
||||
prev_frame = curr_frame.copy()
|
||||
prev_idx = frame_idx
|
||||
|
||||
# 释放不需要的内存
|
||||
del curr_frame
|
||||
if frame_idx % 100 == 0:
|
||||
gc.collect()
|
||||
@ -108,20 +113,20 @@ class VideoProcessor:
|
||||
if not shot_frames:
|
||||
return None, -1
|
||||
|
||||
# 提取特征
|
||||
frame_features = []
|
||||
frame_indices = []
|
||||
|
||||
for idx, frame in shot_frames:
|
||||
for idx, frame in tqdm(shot_frames,
|
||||
desc="处理镜头帧",
|
||||
unit="帧",
|
||||
leave=False):
|
||||
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
||||
# 降低特征维度以节省内存
|
||||
resized_gray = cv2.resize(gray, (32, 32))
|
||||
frame_features.append(resized_gray.flatten())
|
||||
frame_indices.append(idx)
|
||||
|
||||
frame_features = np.array(frame_features)
|
||||
|
||||
# 使用MiniBatchKMeans替代KMeans以减少内存使用
|
||||
kmeans = MiniBatchKMeans(n_clusters=1, batch_size=min(len(frame_features), 100),
|
||||
random_state=0).fit(frame_features)
|
||||
|
||||
@ -195,7 +200,11 @@ class VideoProcessor:
|
||||
logger.info("开始提取关键帧...")
|
||||
frame_count = 0
|
||||
|
||||
for keyframe, frame_idx in self.extract_keyframes(shot_boundaries):
|
||||
pbar = tqdm(self.extract_keyframes(shot_boundaries),
|
||||
desc="提取关键帧",
|
||||
unit="帧")
|
||||
|
||||
for keyframe, frame_idx in pbar:
|
||||
if frame_idx < skip_frames:
|
||||
continue
|
||||
|
||||
@ -212,7 +221,8 @@ class VideoProcessor:
|
||||
cv2.imwrite(output_path, keyframe)
|
||||
frame_count += 1
|
||||
|
||||
# 定期清理内存
|
||||
pbar.set_postfix({"已保存": frame_count})
|
||||
|
||||
if frame_count % 10 == 0:
|
||||
gc.collect()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user