feat(video_processor): 优化镜头边界检测和关键帧提取功能

- 将镜头边界检测的阈值从 30 调整到 70,提高检测精度
- 添加 tqdm 进度条,增强处理过程的可视化
- 优化内存管理,提高程序运行效率
- 调整关键帧提取日志输出,增加处理进度信息
This commit is contained in:
linyqh 2024-11-13 20:19:29 +08:00
parent 2f41c13e19
commit d10a84caca

View File

@ -6,6 +6,7 @@ import re
from typing import List, Tuple, Generator
from loguru import logger
import gc
from tqdm import tqdm
class VideoProcessor:
@ -61,7 +62,7 @@ class VideoProcessor:
if frame_idx % 1000 == 0:
gc.collect()
def detect_shot_boundaries(self, threshold: int = 30) -> List[int]:
def detect_shot_boundaries(self, threshold: int = 70) -> List[int]:
"""
使用批处理方式检测镜头边界
@ -75,20 +76,24 @@ class VideoProcessor:
prev_frame = None
prev_idx = -1
for frame_idx, curr_frame in self.preprocess_video():
pbar = tqdm(self.preprocess_video(),
total=self.total_frames,
desc="检测镜头边界",
unit="")
for frame_idx, curr_frame in pbar:
if prev_frame is not None:
# 转换为灰度图并降低分辨率以提高性能
prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)
curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2GRAY)
diff = np.mean(np.abs(curr_gray.astype(float) - prev_gray.astype(float)))
if diff > threshold:
shot_boundaries.append(frame_idx)
pbar.set_postfix({"检测到边界": len(shot_boundaries)})
prev_frame = curr_frame.copy()
prev_idx = frame_idx
# 释放不需要的内存
del curr_frame
if frame_idx % 100 == 0:
gc.collect()
@ -108,20 +113,20 @@ class VideoProcessor:
if not shot_frames:
return None, -1
# 提取特征
frame_features = []
frame_indices = []
for idx, frame in shot_frames:
for idx, frame in tqdm(shot_frames,
desc="处理镜头帧",
unit="",
leave=False):
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
# 降低特征维度以节省内存
resized_gray = cv2.resize(gray, (32, 32))
frame_features.append(resized_gray.flatten())
frame_indices.append(idx)
frame_features = np.array(frame_features)
# 使用MiniBatchKMeans替代KMeans以减少内存使用
kmeans = MiniBatchKMeans(n_clusters=1, batch_size=min(len(frame_features), 100),
random_state=0).fit(frame_features)
@ -195,7 +200,11 @@ class VideoProcessor:
logger.info("开始提取关键帧...")
frame_count = 0
for keyframe, frame_idx in self.extract_keyframes(shot_boundaries):
pbar = tqdm(self.extract_keyframes(shot_boundaries),
desc="提取关键帧",
unit="")
for keyframe, frame_idx in pbar:
if frame_idx < skip_frames:
continue
@ -212,7 +221,8 @@ class VideoProcessor:
cv2.imwrite(output_path, keyframe)
frame_count += 1
# 定期清理内存
pbar.set_postfix({"已保存": frame_count})
if frame_count % 10 == 0:
gc.collect()