NarratoAI/app/utils/video_processor.py
linyq ee52600ae2 feat(app): 优化关键帧提取功能
- 重构 VideoProcessor 类,优化内存使用和性能
- 添加分批处理逻辑,支持大视频文件的处理
- 使用 MiniBatchKMeans 替代 KMeans,减少内存消耗
- 优化镜头边界检测和关键帧提取算法
- 增加日志记录和错误处理,提高程序的健壮性
2024-11-11 15:53:33 +08:00

228 lines
7.7 KiB
Python

import cv2
import numpy as np
from sklearn.cluster import MiniBatchKMeans
import os
import re
from typing import List, Tuple, Generator
from loguru import logger
import gc
class VideoProcessor:
def __init__(self, video_path: str, batch_size: int = 100):
"""
初始化视频处理器
Args:
video_path: 视频文件路径
batch_size: 批处理大小,控制内存使用
"""
if not os.path.exists(video_path):
raise FileNotFoundError(f"视频文件不存在: {video_path}")
self.video_path = video_path
self.batch_size = batch_size
self.cap = cv2.VideoCapture(video_path)
if not self.cap.isOpened():
raise RuntimeError(f"无法打开视频文件: {video_path}")
self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
self.fps = int(self.cap.get(cv2.CAP_PROP_FPS))
def __del__(self):
"""析构函数,确保视频资源被释放"""
if hasattr(self, 'cap'):
self.cap.release()
gc.collect()
def preprocess_video(self) -> Generator[Tuple[int, np.ndarray], None, None]:
"""
使用生成器方式分批读取视频帧
Yields:
Tuple[int, np.ndarray]: (帧索引, 视频帧)
"""
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
frame_idx = 0
while self.cap.isOpened():
ret, frame = self.cap.read()
if not ret:
break
# 降低分辨率以减少内存使用
frame = cv2.resize(frame, (0, 0), fx=0.5, fy=0.5)
yield frame_idx, frame
frame_idx += 1
# 定期进行垃圾回收
if frame_idx % 1000 == 0:
gc.collect()
def detect_shot_boundaries(self, threshold: int = 30) -> List[int]:
"""
使用批处理方式检测镜头边界
Args:
threshold: 差异阈值
Returns:
List[int]: 镜头边界帧的索引列表
"""
shot_boundaries = []
prev_frame = None
prev_idx = -1
for frame_idx, curr_frame in self.preprocess_video():
if prev_frame is not None:
# 转换为灰度图并降低分辨率以提高性能
prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)
curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2GRAY)
diff = np.mean(np.abs(curr_gray.astype(float) - prev_gray.astype(float)))
if diff > threshold:
shot_boundaries.append(frame_idx)
prev_frame = curr_frame.copy()
prev_idx = frame_idx
# 释放不需要的内存
del curr_frame
if frame_idx % 100 == 0:
gc.collect()
return shot_boundaries
def process_shot(self, shot_frames: List[Tuple[int, np.ndarray]]) -> Tuple[np.ndarray, int]:
"""
处理单个镜头的帧
Args:
shot_frames: 镜头中的帧列表
Returns:
Tuple[np.ndarray, int]: (关键帧, 帧索引)
"""
if not shot_frames:
return None, -1
# 提取特征
frame_features = []
frame_indices = []
for idx, frame in shot_frames:
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
# 降低特征维度以节省内存
resized_gray = cv2.resize(gray, (32, 32))
frame_features.append(resized_gray.flatten())
frame_indices.append(idx)
frame_features = np.array(frame_features)
# 使用MiniBatchKMeans替代KMeans以减少内存使用
kmeans = MiniBatchKMeans(n_clusters=1, batch_size=min(len(frame_features), 100),
random_state=0).fit(frame_features)
center_idx = np.argmin(np.sum((frame_features - kmeans.cluster_centers_[0]) ** 2, axis=1))
return shot_frames[center_idx][1], frame_indices[center_idx]
def extract_keyframes(self, shot_boundaries: List[int]) -> Generator[Tuple[np.ndarray, int], None, None]:
"""
使用生成器方式提取关键帧
Args:
shot_boundaries: 镜头边界列表
Yields:
Tuple[np.ndarray, int]: (关键帧, 帧索引)
"""
shot_frames = []
current_shot_start = 0
for frame_idx, frame in self.preprocess_video():
if frame_idx in shot_boundaries:
if shot_frames:
keyframe, keyframe_idx = self.process_shot(shot_frames)
if keyframe is not None:
yield keyframe, keyframe_idx
# 清理内存
shot_frames.clear()
gc.collect()
current_shot_start = frame_idx
shot_frames.append((frame_idx, frame))
# 控制单个镜头的最大帧数
if len(shot_frames) > self.batch_size:
keyframe, keyframe_idx = self.process_shot(shot_frames)
if keyframe is not None:
yield keyframe, keyframe_idx
shot_frames.clear()
gc.collect()
# 处理最后一个镜头
if shot_frames:
keyframe, keyframe_idx = self.process_shot(shot_frames)
if keyframe is not None:
yield keyframe, keyframe_idx
def process_video(self, output_dir: str, skip_seconds: float = 0) -> None:
"""
处理视频并提取关键帧,使用分批处理方式
Args:
output_dir: 输出目录
skip_seconds: 跳过视频开头的秒数
"""
try:
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
# 计算要跳过的帧数
skip_frames = int(skip_seconds * self.fps)
self.cap.set(cv2.CAP_PROP_POS_FRAMES, skip_frames)
# 检测镜头边界
logger.info("开始检测镜头边界...")
shot_boundaries = self.detect_shot_boundaries()
# 提取关键帧
logger.info("开始提取关键帧...")
frame_count = 0
for keyframe, frame_idx in self.extract_keyframes(shot_boundaries):
if frame_idx < skip_frames:
continue
# 计算时间戳
timestamp = frame_idx / self.fps
hours = int(timestamp // 3600)
minutes = int((timestamp % 3600) // 60)
seconds = int(timestamp % 60)
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}"
# 保存关键帧
output_path = os.path.join(output_dir,
f'keyframe_{frame_idx:06d}_{time_str}.jpg')
cv2.imwrite(output_path, keyframe)
frame_count += 1
# 定期清理内存
if frame_count % 10 == 0:
gc.collect()
logger.info(f"关键帧提取完成,共保存 {frame_count} 帧到 {output_dir}")
except Exception as e:
logger.error(f"视频处理失败: {str(e)}")
raise
finally:
# 确保资源被释放
self.cap.release()
gc.collect()