feat(video_processor): 优化镜头边界检测和关键帧提取功能

- 将镜头边界检测的阈值从 30 调整到 70,提高检测精度
- 添加 tqdm 进度条,增强处理过程的可视化
- 优化内存管理,提高程序运行效率
- 调整关键帧提取日志输出,增加处理进度信息
This commit is contained in:
linyqh 2024-11-13 20:19:29 +08:00
parent 2f41c13e19
commit d10a84caca

View File

@ -6,6 +6,7 @@ import re
from typing import List, Tuple, Generator from typing import List, Tuple, Generator
from loguru import logger from loguru import logger
import gc import gc
from tqdm import tqdm
class VideoProcessor: class VideoProcessor:
@ -61,7 +62,7 @@ class VideoProcessor:
if frame_idx % 1000 == 0: if frame_idx % 1000 == 0:
gc.collect() gc.collect()
def detect_shot_boundaries(self, threshold: int = 30) -> List[int]: def detect_shot_boundaries(self, threshold: int = 70) -> List[int]:
""" """
使用批处理方式检测镜头边界 使用批处理方式检测镜头边界
@ -75,20 +76,24 @@ class VideoProcessor:
prev_frame = None prev_frame = None
prev_idx = -1 prev_idx = -1
for frame_idx, curr_frame in self.preprocess_video(): pbar = tqdm(self.preprocess_video(),
total=self.total_frames,
desc="检测镜头边界",
unit="")
for frame_idx, curr_frame in pbar:
if prev_frame is not None: if prev_frame is not None:
# 转换为灰度图并降低分辨率以提高性能
prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY) prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)
curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2GRAY) curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2GRAY)
diff = np.mean(np.abs(curr_gray.astype(float) - prev_gray.astype(float))) diff = np.mean(np.abs(curr_gray.astype(float) - prev_gray.astype(float)))
if diff > threshold: if diff > threshold:
shot_boundaries.append(frame_idx) shot_boundaries.append(frame_idx)
pbar.set_postfix({"检测到边界": len(shot_boundaries)})
prev_frame = curr_frame.copy() prev_frame = curr_frame.copy()
prev_idx = frame_idx prev_idx = frame_idx
# 释放不需要的内存
del curr_frame del curr_frame
if frame_idx % 100 == 0: if frame_idx % 100 == 0:
gc.collect() gc.collect()
@ -108,20 +113,20 @@ class VideoProcessor:
if not shot_frames: if not shot_frames:
return None, -1 return None, -1
# 提取特征
frame_features = [] frame_features = []
frame_indices = [] frame_indices = []
for idx, frame in shot_frames: for idx, frame in tqdm(shot_frames,
desc="处理镜头帧",
unit="",
leave=False):
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
# 降低特征维度以节省内存
resized_gray = cv2.resize(gray, (32, 32)) resized_gray = cv2.resize(gray, (32, 32))
frame_features.append(resized_gray.flatten()) frame_features.append(resized_gray.flatten())
frame_indices.append(idx) frame_indices.append(idx)
frame_features = np.array(frame_features) frame_features = np.array(frame_features)
# 使用MiniBatchKMeans替代KMeans以减少内存使用
kmeans = MiniBatchKMeans(n_clusters=1, batch_size=min(len(frame_features), 100), kmeans = MiniBatchKMeans(n_clusters=1, batch_size=min(len(frame_features), 100),
random_state=0).fit(frame_features) random_state=0).fit(frame_features)
@ -195,7 +200,11 @@ class VideoProcessor:
logger.info("开始提取关键帧...") logger.info("开始提取关键帧...")
frame_count = 0 frame_count = 0
for keyframe, frame_idx in self.extract_keyframes(shot_boundaries): pbar = tqdm(self.extract_keyframes(shot_boundaries),
desc="提取关键帧",
unit="")
for keyframe, frame_idx in pbar:
if frame_idx < skip_frames: if frame_idx < skip_frames:
continue continue
@ -212,7 +221,8 @@ class VideoProcessor:
cv2.imwrite(output_path, keyframe) cv2.imwrite(output_path, keyframe)
frame_count += 1 frame_count += 1
# 定期清理内存 pbar.set_postfix({"已保存": frame_count})
if frame_count % 10 == 0: if frame_count % 10 == 0:
gc.collect() gc.collect()