mirror of
https://github.com/linyqh/NarratoAI.git
synced 2025-12-10 18:02:51 +00:00
feat(app): 优化关键帧提取功能
- 重构 VideoProcessor 类,优化内存使用和性能 - 添加分批处理逻辑,支持大视频文件的处理 - 使用 MiniBatchKMeans 替代 KMeans,减少内存消耗 - 优化镜头边界检测和关键帧提取算法 - 增加日志记录和错误处理,提高程序的健壮性
This commit is contained in:
parent
f1603097fa
commit
ee52600ae2
@ -1,23 +1,27 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from sklearn.cluster import KMeans
|
||||
from sklearn.cluster import MiniBatchKMeans
|
||||
import os
|
||||
import re
|
||||
from typing import List, Tuple, Generator
|
||||
from loguru import logger
|
||||
import gc
|
||||
|
||||
|
||||
class VideoProcessor:
|
||||
def __init__(self, video_path: str):
|
||||
def __init__(self, video_path: str, batch_size: int = 100):
|
||||
"""
|
||||
初始化视频处理器
|
||||
|
||||
Args:
|
||||
video_path: 视频文件路径
|
||||
batch_size: 批处理大小,控制内存使用
|
||||
"""
|
||||
if not os.path.exists(video_path):
|
||||
raise FileNotFoundError(f"视频文件不存在: {video_path}")
|
||||
|
||||
self.video_path = video_path
|
||||
self.batch_size = batch_size
|
||||
self.cap = cv2.VideoCapture(video_path)
|
||||
|
||||
if not self.cap.isOpened():
|
||||
@ -30,211 +34,194 @@ class VideoProcessor:
|
||||
"""析构函数,确保视频资源被释放"""
|
||||
if hasattr(self, 'cap'):
|
||||
self.cap.release()
|
||||
gc.collect()
|
||||
|
||||
def preprocess_video(self) -> Generator[np.ndarray, None, None]:
|
||||
def preprocess_video(self) -> Generator[Tuple[int, np.ndarray], None, None]:
|
||||
"""
|
||||
使用生成器方式读取视频帧
|
||||
使用生成器方式分批读取视频帧
|
||||
|
||||
Yields:
|
||||
np.ndarray: 视频帧
|
||||
Tuple[int, np.ndarray]: (帧索引, 视频帧)
|
||||
"""
|
||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # 重置到视频开始
|
||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||||
frame_idx = 0
|
||||
|
||||
while self.cap.isOpened():
|
||||
ret, frame = self.cap.read()
|
||||
if not ret:
|
||||
break
|
||||
yield frame
|
||||
|
||||
# 降低分辨率以减少内存使用
|
||||
frame = cv2.resize(frame, (0, 0), fx=0.5, fy=0.5)
|
||||
yield frame_idx, frame
|
||||
|
||||
frame_idx += 1
|
||||
|
||||
# 定期进行垃圾回收
|
||||
if frame_idx % 1000 == 0:
|
||||
gc.collect()
|
||||
|
||||
def detect_shot_boundaries(self, frames: List[np.ndarray], threshold: int = 30) -> List[int]:
|
||||
def detect_shot_boundaries(self, threshold: int = 30) -> List[int]:
|
||||
"""
|
||||
使用帧差法检测镜头边界
|
||||
使用批处理方式检测镜头边界
|
||||
|
||||
Args:
|
||||
frames: 视频帧列表
|
||||
threshold: 差异阈值
|
||||
|
||||
Returns:
|
||||
List[int]: 镜头边界帧的索引列表
|
||||
"""
|
||||
shot_boundaries = []
|
||||
for i in range(1, len(frames)):
|
||||
prev_frame = cv2.cvtColor(frames[i - 1], cv2.COLOR_BGR2GRAY)
|
||||
curr_frame = cv2.cvtColor(frames[i], cv2.COLOR_BGR2GRAY)
|
||||
diff = np.mean(np.abs(curr_frame.astype(int) - prev_frame.astype(int)))
|
||||
if diff > threshold:
|
||||
shot_boundaries.append(i)
|
||||
prev_frame = None
|
||||
prev_idx = -1
|
||||
|
||||
for frame_idx, curr_frame in self.preprocess_video():
|
||||
if prev_frame is not None:
|
||||
# 转换为灰度图并降低分辨率以提高性能
|
||||
prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)
|
||||
curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
diff = np.mean(np.abs(curr_gray.astype(float) - prev_gray.astype(float)))
|
||||
if diff > threshold:
|
||||
shot_boundaries.append(frame_idx)
|
||||
|
||||
prev_frame = curr_frame.copy()
|
||||
prev_idx = frame_idx
|
||||
|
||||
# 释放不需要的内存
|
||||
del curr_frame
|
||||
if frame_idx % 100 == 0:
|
||||
gc.collect()
|
||||
|
||||
return shot_boundaries
|
||||
|
||||
def filter_keyframes_by_time(self, keyframes: List[np.ndarray],
|
||||
keyframe_indices: List[int]) -> Tuple[List[np.ndarray], List[int]]:
|
||||
def process_shot(self, shot_frames: List[Tuple[int, np.ndarray]]) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
过滤关键帧,确保每秒最多只有一个关键帧
|
||||
处理单个镜头的帧
|
||||
|
||||
Args:
|
||||
keyframes: 关键帧列表
|
||||
keyframe_indices: 关键帧索引列表
|
||||
shot_frames: 镜头中的帧列表
|
||||
|
||||
Returns:
|
||||
Tuple[List[np.ndarray], List[int]]: 过滤后的关键帧列表和对应的帧索引
|
||||
Tuple[np.ndarray, int]: (关键帧, 帧索引)
|
||||
"""
|
||||
if not keyframes or not keyframe_indices:
|
||||
return keyframes, keyframe_indices
|
||||
if not shot_frames:
|
||||
return None, -1
|
||||
|
||||
filtered_frames = []
|
||||
filtered_indices = []
|
||||
last_second = -1
|
||||
# 提取特征
|
||||
frame_features = []
|
||||
frame_indices = []
|
||||
|
||||
for frame, idx in zip(keyframes, keyframe_indices):
|
||||
current_second = idx // self.fps
|
||||
if current_second != last_second:
|
||||
filtered_frames.append(frame)
|
||||
filtered_indices.append(idx)
|
||||
last_second = current_second
|
||||
|
||||
return filtered_frames, filtered_indices
|
||||
for idx, frame in shot_frames:
|
||||
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
||||
# 降低特征维度以节省内存
|
||||
resized_gray = cv2.resize(gray, (32, 32))
|
||||
frame_features.append(resized_gray.flatten())
|
||||
frame_indices.append(idx)
|
||||
|
||||
frame_features = np.array(frame_features)
|
||||
|
||||
# 使用MiniBatchKMeans替代KMeans以减少内存使用
|
||||
kmeans = MiniBatchKMeans(n_clusters=1, batch_size=min(len(frame_features), 100),
|
||||
random_state=0).fit(frame_features)
|
||||
|
||||
center_idx = np.argmin(np.sum((frame_features - kmeans.cluster_centers_[0]) ** 2, axis=1))
|
||||
|
||||
return shot_frames[center_idx][1], frame_indices[center_idx]
|
||||
|
||||
def extract_keyframes(self, frames: List[np.ndarray], shot_boundaries: List[int]) -> Tuple[List[np.ndarray], List[int]]:
|
||||
def extract_keyframes(self, shot_boundaries: List[int]) -> Generator[Tuple[np.ndarray, int], None, None]:
|
||||
"""
|
||||
从每个镜头中提取关键帧,并确保每秒最多一个关键帧
|
||||
使用生成器方式提取关键帧
|
||||
|
||||
Args:
|
||||
frames: 视频帧列表
|
||||
shot_boundaries: 镜头边界列表
|
||||
|
||||
Returns:
|
||||
Tuple[List[np.ndarray], List[int]]: 关键帧列表和对应的帧索引
|
||||
Yields:
|
||||
Tuple[np.ndarray, int]: (关键帧, 帧索引)
|
||||
"""
|
||||
keyframes = []
|
||||
keyframe_indices = []
|
||||
shot_frames = []
|
||||
current_shot_start = 0
|
||||
|
||||
for i in range(len(shot_boundaries)):
|
||||
start = shot_boundaries[i - 1] if i > 0 else 0
|
||||
end = shot_boundaries[i]
|
||||
shot_frames = frames[start:end]
|
||||
|
||||
frame_features = np.array([cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).flatten()
|
||||
for frame in shot_frames])
|
||||
kmeans = KMeans(n_clusters=1, random_state=0).fit(frame_features)
|
||||
center_idx = np.argmin(np.sum((frame_features - kmeans.cluster_centers_[0]) ** 2, axis=1))
|
||||
|
||||
keyframes.append(shot_frames[center_idx])
|
||||
keyframe_indices.append(start + center_idx)
|
||||
|
||||
# 过滤每秒多余的关键帧
|
||||
filtered_keyframes, filtered_indices = self.filter_keyframes_by_time(keyframes, keyframe_indices)
|
||||
|
||||
return filtered_keyframes, filtered_indices
|
||||
|
||||
def save_keyframes(self, keyframes: List[np.ndarray], keyframe_indices: List[int],
|
||||
output_dir: str) -> None:
|
||||
"""
|
||||
保存关键帧到指定目录,文件名格式为:keyframe_帧序号_时间戳.jpg
|
||||
|
||||
Args:
|
||||
keyframes: 关键帧列表
|
||||
keyframe_indices: 关键帧索引列表
|
||||
output_dir: 输出目录
|
||||
"""
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
for keyframe, frame_idx in zip(keyframes, keyframe_indices):
|
||||
# 计算时间戳(秒)
|
||||
timestamp = frame_idx / self.fps
|
||||
# 将时间戳转换为 HH:MM:SS 格式
|
||||
hours = int(timestamp // 3600)
|
||||
minutes = int((timestamp % 3600) // 60)
|
||||
seconds = int(timestamp % 60)
|
||||
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}"
|
||||
|
||||
# 构建新的文件名格式:keyframe_帧序号_时间戳.jpg
|
||||
output_path = os.path.join(output_dir,
|
||||
f'keyframe_{frame_idx:06d}_{time_str}.jpg')
|
||||
cv2.imwrite(output_path, keyframe)
|
||||
|
||||
print(f"已保存 {len(keyframes)} 个关键帧到 {output_dir}")
|
||||
|
||||
def extract_frames_by_numbers(self, frame_numbers: List[int], output_folder: str) -> None:
|
||||
"""
|
||||
根据指定的帧号提取帧
|
||||
|
||||
Args:
|
||||
frame_numbers: 要提取的帧号列表
|
||||
output_folder: 输出文件夹路径
|
||||
"""
|
||||
if not frame_numbers:
|
||||
raise ValueError("未提供帧号列表")
|
||||
|
||||
if any(fn >= self.total_frames or fn < 0 for fn in frame_numbers):
|
||||
raise ValueError("存在无效的帧号")
|
||||
|
||||
if not os.path.exists(output_folder):
|
||||
os.makedirs(output_folder)
|
||||
|
||||
for frame_number in frame_numbers:
|
||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
|
||||
ret, frame = self.cap.read()
|
||||
|
||||
if ret:
|
||||
# 计算时间戳
|
||||
timestamp = frame_number / self.fps
|
||||
hours = int(timestamp // 3600)
|
||||
minutes = int((timestamp % 3600) // 60)
|
||||
seconds = int(timestamp % 60)
|
||||
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}"
|
||||
for frame_idx, frame in self.preprocess_video():
|
||||
if frame_idx in shot_boundaries:
|
||||
if shot_frames:
|
||||
keyframe, keyframe_idx = self.process_shot(shot_frames)
|
||||
if keyframe is not None:
|
||||
yield keyframe, keyframe_idx
|
||||
|
||||
# 清理内存
|
||||
shot_frames.clear()
|
||||
gc.collect()
|
||||
|
||||
# 使用与关键帧相同的命名格式
|
||||
output_path = os.path.join(output_folder,
|
||||
f"extracted_frame_{frame_number:06d}_{time_str}.jpg")
|
||||
cv2.imwrite(output_path, frame)
|
||||
print(f"已提取并保存帧 {frame_number}")
|
||||
else:
|
||||
print(f"无法读取帧 {frame_number}")
|
||||
|
||||
@staticmethod
|
||||
def extract_numbers_from_folder(folder_path: str) -> List[int]:
|
||||
"""
|
||||
从文件夹中提取帧号
|
||||
|
||||
Args:
|
||||
folder_path: 关键帧文件夹路径
|
||||
current_shot_start = frame_idx
|
||||
|
||||
Returns:
|
||||
List[int]: 排序后的帧号列表
|
||||
"""
|
||||
files = [f for f in os.listdir(folder_path) if f.endswith('.jpg')]
|
||||
# 更新正则表达式以匹配新的文件名格式:keyframe_000123_010534.jpg
|
||||
pattern = re.compile(r'keyframe_(\d+)_\d+\.jpg$')
|
||||
numbers = []
|
||||
for f in files:
|
||||
match = pattern.search(f)
|
||||
if match:
|
||||
numbers.append(int(match.group(1)))
|
||||
return sorted(numbers)
|
||||
shot_frames.append((frame_idx, frame))
|
||||
|
||||
# 控制单个镜头的最大帧数
|
||||
if len(shot_frames) > self.batch_size:
|
||||
keyframe, keyframe_idx = self.process_shot(shot_frames)
|
||||
if keyframe is not None:
|
||||
yield keyframe, keyframe_idx
|
||||
shot_frames.clear()
|
||||
gc.collect()
|
||||
|
||||
# 处理最后一个镜头
|
||||
if shot_frames:
|
||||
keyframe, keyframe_idx = self.process_shot(shot_frames)
|
||||
if keyframe is not None:
|
||||
yield keyframe, keyframe_idx
|
||||
|
||||
def process_video(self, output_dir: str, skip_seconds: float = 0) -> None:
|
||||
"""
|
||||
处理视频并提取关键帧
|
||||
处理视频并提取关键帧,使用分批处理方式
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录
|
||||
skip_seconds: 跳过视频开头的秒数
|
||||
"""
|
||||
# 计算要跳过的帧数
|
||||
skip_frames = int(skip_seconds * self.fps)
|
||||
|
||||
# 获取所有帧
|
||||
frames = list(self.preprocess_video())
|
||||
|
||||
# 跳过指定秒数的帧
|
||||
frames = frames[skip_frames:]
|
||||
|
||||
if not frames:
|
||||
raise ValueError(f"跳过 {skip_seconds} 秒后没有剩余帧可以处理")
|
||||
|
||||
shot_boundaries = self.detect_shot_boundaries(frames)
|
||||
keyframes, keyframe_indices = self.extract_keyframes(frames, shot_boundaries)
|
||||
|
||||
# 调整关键帧索引,加上跳过的帧数
|
||||
adjusted_indices = [idx + skip_frames for idx in keyframe_indices]
|
||||
self.save_keyframes(keyframes, adjusted_indices, output_dir)
|
||||
try:
|
||||
# 创建输出目录
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 计算要跳过的帧数
|
||||
skip_frames = int(skip_seconds * self.fps)
|
||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, skip_frames)
|
||||
|
||||
# 检测镜头边界
|
||||
logger.info("开始检测镜头边界...")
|
||||
shot_boundaries = self.detect_shot_boundaries()
|
||||
|
||||
# 提取关键帧
|
||||
logger.info("开始提取关键帧...")
|
||||
frame_count = 0
|
||||
|
||||
for keyframe, frame_idx in self.extract_keyframes(shot_boundaries):
|
||||
if frame_idx < skip_frames:
|
||||
continue
|
||||
|
||||
# 计算时间戳
|
||||
timestamp = frame_idx / self.fps
|
||||
hours = int(timestamp // 3600)
|
||||
minutes = int((timestamp % 3600) // 60)
|
||||
seconds = int(timestamp % 60)
|
||||
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}"
|
||||
|
||||
# 保存关键帧
|
||||
output_path = os.path.join(output_dir,
|
||||
f'keyframe_{frame_idx:06d}_{time_str}.jpg')
|
||||
cv2.imwrite(output_path, keyframe)
|
||||
frame_count += 1
|
||||
|
||||
# 定期清理内存
|
||||
if frame_count % 10 == 0:
|
||||
gc.collect()
|
||||
|
||||
logger.info(f"关键帧提取完成,共保存 {frame_count} 帧到 {output_dir}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"视频处理失败: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
# 确保资源被释放
|
||||
self.cap.release()
|
||||
gc.collect()
|
||||
|
||||
@ -7,7 +7,16 @@ x-common: &common
|
||||
- ./:/NarratoAI
|
||||
environment:
|
||||
- VPN_PROXY_URL=http://host.docker.internal:7890
|
||||
- PYTHONUNBUFFERED=1
|
||||
- PYTHONMALLOC=malloc
|
||||
- OPENCV_OPENCL_RUNTIME=disabled
|
||||
- OPENCV_CPU_DISABLE=0
|
||||
restart: always
|
||||
mem_limit: 4g
|
||||
mem_reservation: 2g
|
||||
memswap_limit: 6g
|
||||
cpus: 2.0
|
||||
cpu_shares: 1024
|
||||
|
||||
services:
|
||||
webui:
|
||||
@ -16,3 +25,14 @@ services:
|
||||
ports:
|
||||
- "8501:8501"
|
||||
command: ["webui"]
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "200m"
|
||||
max-file: "3"
|
||||
tmpfs:
|
||||
- /tmp:size=1G
|
||||
ulimits:
|
||||
nofile:
|
||||
soft: 65536
|
||||
hard: 65536
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user