feat(app): 优化关键帧提取功能

- 重构 VideoProcessor 类,优化内存使用和性能
- 添加分批处理逻辑,支持大视频文件的处理
- 使用 MiniBatchKMeans 替代 KMeans,减少内存消耗
- 优化镜头边界检测和关键帧提取算法
- 增加日志记录和错误处理,提高程序的健壮性
This commit is contained in:
linyq 2024-11-11 15:53:33 +08:00
parent f1603097fa
commit ee52600ae2
2 changed files with 170 additions and 163 deletions

View File

@ -1,23 +1,27 @@
import cv2 import cv2
import numpy as np import numpy as np
from sklearn.cluster import KMeans from sklearn.cluster import MiniBatchKMeans
import os import os
import re import re
from typing import List, Tuple, Generator from typing import List, Tuple, Generator
from loguru import logger
import gc
class VideoProcessor: class VideoProcessor:
def __init__(self, video_path: str): def __init__(self, video_path: str, batch_size: int = 100):
""" """
初始化视频处理器 初始化视频处理器
Args: Args:
video_path: 视频文件路径 video_path: 视频文件路径
batch_size: 批处理大小控制内存使用
""" """
if not os.path.exists(video_path): if not os.path.exists(video_path):
raise FileNotFoundError(f"视频文件不存在: {video_path}") raise FileNotFoundError(f"视频文件不存在: {video_path}")
self.video_path = video_path self.video_path = video_path
self.batch_size = batch_size
self.cap = cv2.VideoCapture(video_path) self.cap = cv2.VideoCapture(video_path)
if not self.cap.isOpened(): if not self.cap.isOpened():
@ -30,211 +34,194 @@ class VideoProcessor:
"""析构函数,确保视频资源被释放""" """析构函数,确保视频资源被释放"""
if hasattr(self, 'cap'): if hasattr(self, 'cap'):
self.cap.release() self.cap.release()
gc.collect()
def preprocess_video(self) -> Generator[np.ndarray, None, None]: def preprocess_video(self) -> Generator[Tuple[int, np.ndarray], None, None]:
""" """
使用生成器方式读取视频帧 使用生成器方式分批读取视频帧
Yields: Yields:
np.ndarray: 视频帧 Tuple[int, np.ndarray]: (帧索引, 视频帧)
""" """
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # 重置到视频开始 self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
frame_idx = 0
while self.cap.isOpened(): while self.cap.isOpened():
ret, frame = self.cap.read() ret, frame = self.cap.read()
if not ret: if not ret:
break break
yield frame
def detect_shot_boundaries(self, frames: List[np.ndarray], threshold: int = 30) -> List[int]: # 降低分辨率以减少内存使用
frame = cv2.resize(frame, (0, 0), fx=0.5, fy=0.5)
yield frame_idx, frame
frame_idx += 1
# 定期进行垃圾回收
if frame_idx % 1000 == 0:
gc.collect()
def detect_shot_boundaries(self, threshold: int = 30) -> List[int]:
""" """
使用帧差法检测镜头边界 使用批处理方式检测镜头边界
Args: Args:
frames: 视频帧列表
threshold: 差异阈值 threshold: 差异阈值
Returns: Returns:
List[int]: 镜头边界帧的索引列表 List[int]: 镜头边界帧的索引列表
""" """
shot_boundaries = [] shot_boundaries = []
for i in range(1, len(frames)): prev_frame = None
prev_frame = cv2.cvtColor(frames[i - 1], cv2.COLOR_BGR2GRAY) prev_idx = -1
curr_frame = cv2.cvtColor(frames[i], cv2.COLOR_BGR2GRAY)
diff = np.mean(np.abs(curr_frame.astype(int) - prev_frame.astype(int))) for frame_idx, curr_frame in self.preprocess_video():
if prev_frame is not None:
# 转换为灰度图并降低分辨率以提高性能
prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)
curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2GRAY)
diff = np.mean(np.abs(curr_gray.astype(float) - prev_gray.astype(float)))
if diff > threshold: if diff > threshold:
shot_boundaries.append(i) shot_boundaries.append(frame_idx)
prev_frame = curr_frame.copy()
prev_idx = frame_idx
# 释放不需要的内存
del curr_frame
if frame_idx % 100 == 0:
gc.collect()
return shot_boundaries return shot_boundaries
def filter_keyframes_by_time(self, keyframes: List[np.ndarray], def process_shot(self, shot_frames: List[Tuple[int, np.ndarray]]) -> Tuple[np.ndarray, int]:
keyframe_indices: List[int]) -> Tuple[List[np.ndarray], List[int]]:
""" """
过滤关键帧确保每秒最多只有一个关键帧 处理单个镜头的
Args: Args:
keyframes: 关键帧列表 shot_frames: 镜头中的帧列表
keyframe_indices: 关键帧索引列表
Returns: Returns:
Tuple[List[np.ndarray], List[int]]: 过滤后的关键帧列表和对应的帧索引 Tuple[np.ndarray, int]: (关键帧, 帧索引)
""" """
if not keyframes or not keyframe_indices: if not shot_frames:
return keyframes, keyframe_indices return None, -1
filtered_frames = [] # 提取特征
filtered_indices = [] frame_features = []
last_second = -1 frame_indices = []
for frame, idx in zip(keyframes, keyframe_indices): for idx, frame in shot_frames:
current_second = idx // self.fps gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
if current_second != last_second: # 降低特征维度以节省内存
filtered_frames.append(frame) resized_gray = cv2.resize(gray, (32, 32))
filtered_indices.append(idx) frame_features.append(resized_gray.flatten())
last_second = current_second frame_indices.append(idx)
return filtered_frames, filtered_indices frame_features = np.array(frame_features)
def extract_keyframes(self, frames: List[np.ndarray], shot_boundaries: List[int]) -> Tuple[List[np.ndarray], List[int]]: # 使用MiniBatchKMeans替代KMeans以减少内存使用
""" kmeans = MiniBatchKMeans(n_clusters=1, batch_size=min(len(frame_features), 100),
从每个镜头中提取关键帧并确保每秒最多一个关键帧 random_state=0).fit(frame_features)
Args:
frames: 视频帧列表
shot_boundaries: 镜头边界列表
Returns:
Tuple[List[np.ndarray], List[int]]: 关键帧列表和对应的帧索引
"""
keyframes = []
keyframe_indices = []
for i in range(len(shot_boundaries)):
start = shot_boundaries[i - 1] if i > 0 else 0
end = shot_boundaries[i]
shot_frames = frames[start:end]
frame_features = np.array([cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).flatten()
for frame in shot_frames])
kmeans = KMeans(n_clusters=1, random_state=0).fit(frame_features)
center_idx = np.argmin(np.sum((frame_features - kmeans.cluster_centers_[0]) ** 2, axis=1)) center_idx = np.argmin(np.sum((frame_features - kmeans.cluster_centers_[0]) ** 2, axis=1))
keyframes.append(shot_frames[center_idx]) return shot_frames[center_idx][1], frame_indices[center_idx]
keyframe_indices.append(start + center_idx)
# 过滤每秒多余的关键帧 def extract_keyframes(self, shot_boundaries: List[int]) -> Generator[Tuple[np.ndarray, int], None, None]:
filtered_keyframes, filtered_indices = self.filter_keyframes_by_time(keyframes, keyframe_indices)
return filtered_keyframes, filtered_indices
def save_keyframes(self, keyframes: List[np.ndarray], keyframe_indices: List[int],
output_dir: str) -> None:
""" """
保存关键帧到指定目录文件名格式为keyframe_帧序号_时间戳.jpg 使用生成器方式提取关键帧
Args: Args:
keyframes: 关键帧列表 shot_boundaries: 镜头边界列表
keyframe_indices: 关键帧索引列表
output_dir: 输出目录 Yields:
Tuple[np.ndarray, int]: (关键帧, 帧索引)
""" """
if not os.path.exists(output_dir): shot_frames = []
os.makedirs(output_dir) current_shot_start = 0
for keyframe, frame_idx in zip(keyframes, keyframe_indices): for frame_idx, frame in self.preprocess_video():
# 计算时间戳(秒) if frame_idx in shot_boundaries:
timestamp = frame_idx / self.fps if shot_frames:
# 将时间戳转换为 HH:MM:SS 格式 keyframe, keyframe_idx = self.process_shot(shot_frames)
hours = int(timestamp // 3600) if keyframe is not None:
minutes = int((timestamp % 3600) // 60) yield keyframe, keyframe_idx
seconds = int(timestamp % 60)
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}"
# 构建新的文件名格式keyframe_帧序号_时间戳.jpg # 清理内存
output_path = os.path.join(output_dir, shot_frames.clear()
f'keyframe_{frame_idx:06d}_{time_str}.jpg') gc.collect()
cv2.imwrite(output_path, keyframe)
print(f"已保存 {len(keyframes)} 个关键帧到 {output_dir}") current_shot_start = frame_idx
def extract_frames_by_numbers(self, frame_numbers: List[int], output_folder: str) -> None: shot_frames.append((frame_idx, frame))
"""
根据指定的帧号提取帧
Args: # 控制单个镜头的最大帧数
frame_numbers: 要提取的帧号列表 if len(shot_frames) > self.batch_size:
output_folder: 输出文件夹路径 keyframe, keyframe_idx = self.process_shot(shot_frames)
""" if keyframe is not None:
if not frame_numbers: yield keyframe, keyframe_idx
raise ValueError("未提供帧号列表") shot_frames.clear()
gc.collect()
if any(fn >= self.total_frames or fn < 0 for fn in frame_numbers): # 处理最后一个镜头
raise ValueError("存在无效的帧号") if shot_frames:
keyframe, keyframe_idx = self.process_shot(shot_frames)
if not os.path.exists(output_folder): if keyframe is not None:
os.makedirs(output_folder) yield keyframe, keyframe_idx
for frame_number in frame_numbers:
self.cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
ret, frame = self.cap.read()
if ret:
# 计算时间戳
timestamp = frame_number / self.fps
hours = int(timestamp // 3600)
minutes = int((timestamp % 3600) // 60)
seconds = int(timestamp % 60)
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}"
# 使用与关键帧相同的命名格式
output_path = os.path.join(output_folder,
f"extracted_frame_{frame_number:06d}_{time_str}.jpg")
cv2.imwrite(output_path, frame)
print(f"已提取并保存帧 {frame_number}")
else:
print(f"无法读取帧 {frame_number}")
@staticmethod
def extract_numbers_from_folder(folder_path: str) -> List[int]:
"""
从文件夹中提取帧号
Args:
folder_path: 关键帧文件夹路径
Returns:
List[int]: 排序后的帧号列表
"""
files = [f for f in os.listdir(folder_path) if f.endswith('.jpg')]
# 更新正则表达式以匹配新的文件名格式keyframe_000123_010534.jpg
pattern = re.compile(r'keyframe_(\d+)_\d+\.jpg$')
numbers = []
for f in files:
match = pattern.search(f)
if match:
numbers.append(int(match.group(1)))
return sorted(numbers)
def process_video(self, output_dir: str, skip_seconds: float = 0) -> None: def process_video(self, output_dir: str, skip_seconds: float = 0) -> None:
""" """
处理视频并提取关键帧 处理视频并提取关键帧使用分批处理方式
Args: Args:
output_dir: 输出目录 output_dir: 输出目录
skip_seconds: 跳过视频开头的秒数 skip_seconds: 跳过视频开头的秒数
""" """
try:
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
# 计算要跳过的帧数 # 计算要跳过的帧数
skip_frames = int(skip_seconds * self.fps) skip_frames = int(skip_seconds * self.fps)
self.cap.set(cv2.CAP_PROP_POS_FRAMES, skip_frames)
# 获取所有帧 # 检测镜头边界
frames = list(self.preprocess_video()) logger.info("开始检测镜头边界...")
shot_boundaries = self.detect_shot_boundaries()
# 跳过指定秒数的帧 # 提取关键帧
frames = frames[skip_frames:] logger.info("开始提取关键帧...")
frame_count = 0
if not frames: for keyframe, frame_idx in self.extract_keyframes(shot_boundaries):
raise ValueError(f"跳过 {skip_seconds} 秒后没有剩余帧可以处理") if frame_idx < skip_frames:
continue
shot_boundaries = self.detect_shot_boundaries(frames) # 计算时间戳
keyframes, keyframe_indices = self.extract_keyframes(frames, shot_boundaries) timestamp = frame_idx / self.fps
hours = int(timestamp // 3600)
minutes = int((timestamp % 3600) // 60)
seconds = int(timestamp % 60)
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}"
# 调整关键帧索引,加上跳过的帧数 # 保存关键帧
adjusted_indices = [idx + skip_frames for idx in keyframe_indices] output_path = os.path.join(output_dir,
self.save_keyframes(keyframes, adjusted_indices, output_dir) f'keyframe_{frame_idx:06d}_{time_str}.jpg')
cv2.imwrite(output_path, keyframe)
frame_count += 1
# 定期清理内存
if frame_count % 10 == 0:
gc.collect()
logger.info(f"关键帧提取完成,共保存 {frame_count} 帧到 {output_dir}")
except Exception as e:
logger.error(f"视频处理失败: {str(e)}")
raise
finally:
# 确保资源被释放
self.cap.release()
gc.collect()

View File

@ -7,7 +7,16 @@ x-common: &common
- ./:/NarratoAI - ./:/NarratoAI
environment: environment:
- VPN_PROXY_URL=http://host.docker.internal:7890 - VPN_PROXY_URL=http://host.docker.internal:7890
- PYTHONUNBUFFERED=1
- PYTHONMALLOC=malloc
- OPENCV_OPENCL_RUNTIME=disabled
- OPENCV_CPU_DISABLE=0
restart: always restart: always
mem_limit: 4g
mem_reservation: 2g
memswap_limit: 6g
cpus: 2.0
cpu_shares: 1024
services: services:
webui: webui:
@ -16,3 +25,14 @@ services:
ports: ports:
- "8501:8501" - "8501:8501"
command: ["webui"] command: ["webui"]
logging:
driver: "json-file"
options:
max-size: "200m"
max-file: "3"
tmpfs:
- /tmp:size=1G
ulimits:
nofile:
soft: 65536
hard: 65536