mirror of
https://github.com/linyqh/NarratoAI.git
synced 2025-12-11 18:42:49 +00:00
feat(app): 优化关键帧提取功能
- 重构 VideoProcessor 类,优化内存使用和性能 - 添加分批处理逻辑,支持大视频文件的处理 - 使用 MiniBatchKMeans 替代 KMeans,减少内存消耗 - 优化镜头边界检测和关键帧提取算法 - 增加日志记录和错误处理,提高程序的健壮性
This commit is contained in:
parent
f1603097fa
commit
ee52600ae2
@ -1,23 +1,27 @@
|
|||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.cluster import KMeans
|
from sklearn.cluster import MiniBatchKMeans
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import List, Tuple, Generator
|
from typing import List, Tuple, Generator
|
||||||
|
from loguru import logger
|
||||||
|
import gc
|
||||||
|
|
||||||
|
|
||||||
class VideoProcessor:
|
class VideoProcessor:
|
||||||
def __init__(self, video_path: str):
|
def __init__(self, video_path: str, batch_size: int = 100):
|
||||||
"""
|
"""
|
||||||
初始化视频处理器
|
初始化视频处理器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
video_path: 视频文件路径
|
video_path: 视频文件路径
|
||||||
|
batch_size: 批处理大小,控制内存使用
|
||||||
"""
|
"""
|
||||||
if not os.path.exists(video_path):
|
if not os.path.exists(video_path):
|
||||||
raise FileNotFoundError(f"视频文件不存在: {video_path}")
|
raise FileNotFoundError(f"视频文件不存在: {video_path}")
|
||||||
|
|
||||||
self.video_path = video_path
|
self.video_path = video_path
|
||||||
|
self.batch_size = batch_size
|
||||||
self.cap = cv2.VideoCapture(video_path)
|
self.cap = cv2.VideoCapture(video_path)
|
||||||
|
|
||||||
if not self.cap.isOpened():
|
if not self.cap.isOpened():
|
||||||
@ -30,211 +34,194 @@ class VideoProcessor:
|
|||||||
"""析构函数,确保视频资源被释放"""
|
"""析构函数,确保视频资源被释放"""
|
||||||
if hasattr(self, 'cap'):
|
if hasattr(self, 'cap'):
|
||||||
self.cap.release()
|
self.cap.release()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
def preprocess_video(self) -> Generator[np.ndarray, None, None]:
|
def preprocess_video(self) -> Generator[Tuple[int, np.ndarray], None, None]:
|
||||||
"""
|
"""
|
||||||
使用生成器方式读取视频帧
|
使用生成器方式分批读取视频帧
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
np.ndarray: 视频帧
|
Tuple[int, np.ndarray]: (帧索引, 视频帧)
|
||||||
"""
|
"""
|
||||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # 重置到视频开始
|
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||||||
|
frame_idx = 0
|
||||||
|
|
||||||
while self.cap.isOpened():
|
while self.cap.isOpened():
|
||||||
ret, frame = self.cap.read()
|
ret, frame = self.cap.read()
|
||||||
if not ret:
|
if not ret:
|
||||||
break
|
break
|
||||||
yield frame
|
|
||||||
|
# 降低分辨率以减少内存使用
|
||||||
|
frame = cv2.resize(frame, (0, 0), fx=0.5, fy=0.5)
|
||||||
|
yield frame_idx, frame
|
||||||
|
|
||||||
|
frame_idx += 1
|
||||||
|
|
||||||
|
# 定期进行垃圾回收
|
||||||
|
if frame_idx % 1000 == 0:
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
def detect_shot_boundaries(self, frames: List[np.ndarray], threshold: int = 30) -> List[int]:
|
def detect_shot_boundaries(self, threshold: int = 30) -> List[int]:
|
||||||
"""
|
"""
|
||||||
使用帧差法检测镜头边界
|
使用批处理方式检测镜头边界
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
frames: 视频帧列表
|
|
||||||
threshold: 差异阈值
|
threshold: 差异阈值
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[int]: 镜头边界帧的索引列表
|
List[int]: 镜头边界帧的索引列表
|
||||||
"""
|
"""
|
||||||
shot_boundaries = []
|
shot_boundaries = []
|
||||||
for i in range(1, len(frames)):
|
prev_frame = None
|
||||||
prev_frame = cv2.cvtColor(frames[i - 1], cv2.COLOR_BGR2GRAY)
|
prev_idx = -1
|
||||||
curr_frame = cv2.cvtColor(frames[i], cv2.COLOR_BGR2GRAY)
|
|
||||||
diff = np.mean(np.abs(curr_frame.astype(int) - prev_frame.astype(int)))
|
for frame_idx, curr_frame in self.preprocess_video():
|
||||||
if diff > threshold:
|
if prev_frame is not None:
|
||||||
shot_boundaries.append(i)
|
# 转换为灰度图并降低分辨率以提高性能
|
||||||
|
prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)
|
||||||
|
curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
||||||
|
diff = np.mean(np.abs(curr_gray.astype(float) - prev_gray.astype(float)))
|
||||||
|
if diff > threshold:
|
||||||
|
shot_boundaries.append(frame_idx)
|
||||||
|
|
||||||
|
prev_frame = curr_frame.copy()
|
||||||
|
prev_idx = frame_idx
|
||||||
|
|
||||||
|
# 释放不需要的内存
|
||||||
|
del curr_frame
|
||||||
|
if frame_idx % 100 == 0:
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
return shot_boundaries
|
return shot_boundaries
|
||||||
|
|
||||||
def filter_keyframes_by_time(self, keyframes: List[np.ndarray],
|
def process_shot(self, shot_frames: List[Tuple[int, np.ndarray]]) -> Tuple[np.ndarray, int]:
|
||||||
keyframe_indices: List[int]) -> Tuple[List[np.ndarray], List[int]]:
|
|
||||||
"""
|
"""
|
||||||
过滤关键帧,确保每秒最多只有一个关键帧
|
处理单个镜头的帧
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
keyframes: 关键帧列表
|
shot_frames: 镜头中的帧列表
|
||||||
keyframe_indices: 关键帧索引列表
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[List[np.ndarray], List[int]]: 过滤后的关键帧列表和对应的帧索引
|
Tuple[np.ndarray, int]: (关键帧, 帧索引)
|
||||||
"""
|
"""
|
||||||
if not keyframes or not keyframe_indices:
|
if not shot_frames:
|
||||||
return keyframes, keyframe_indices
|
return None, -1
|
||||||
|
|
||||||
filtered_frames = []
|
# 提取特征
|
||||||
filtered_indices = []
|
frame_features = []
|
||||||
last_second = -1
|
frame_indices = []
|
||||||
|
|
||||||
for frame, idx in zip(keyframes, keyframe_indices):
|
for idx, frame in shot_frames:
|
||||||
current_second = idx // self.fps
|
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
||||||
if current_second != last_second:
|
# 降低特征维度以节省内存
|
||||||
filtered_frames.append(frame)
|
resized_gray = cv2.resize(gray, (32, 32))
|
||||||
filtered_indices.append(idx)
|
frame_features.append(resized_gray.flatten())
|
||||||
last_second = current_second
|
frame_indices.append(idx)
|
||||||
|
|
||||||
return filtered_frames, filtered_indices
|
frame_features = np.array(frame_features)
|
||||||
|
|
||||||
|
# 使用MiniBatchKMeans替代KMeans以减少内存使用
|
||||||
|
kmeans = MiniBatchKMeans(n_clusters=1, batch_size=min(len(frame_features), 100),
|
||||||
|
random_state=0).fit(frame_features)
|
||||||
|
|
||||||
|
center_idx = np.argmin(np.sum((frame_features - kmeans.cluster_centers_[0]) ** 2, axis=1))
|
||||||
|
|
||||||
|
return shot_frames[center_idx][1], frame_indices[center_idx]
|
||||||
|
|
||||||
def extract_keyframes(self, frames: List[np.ndarray], shot_boundaries: List[int]) -> Tuple[List[np.ndarray], List[int]]:
|
def extract_keyframes(self, shot_boundaries: List[int]) -> Generator[Tuple[np.ndarray, int], None, None]:
|
||||||
"""
|
"""
|
||||||
从每个镜头中提取关键帧,并确保每秒最多一个关键帧
|
使用生成器方式提取关键帧
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
frames: 视频帧列表
|
|
||||||
shot_boundaries: 镜头边界列表
|
shot_boundaries: 镜头边界列表
|
||||||
|
|
||||||
Returns:
|
Yields:
|
||||||
Tuple[List[np.ndarray], List[int]]: 关键帧列表和对应的帧索引
|
Tuple[np.ndarray, int]: (关键帧, 帧索引)
|
||||||
"""
|
"""
|
||||||
keyframes = []
|
shot_frames = []
|
||||||
keyframe_indices = []
|
current_shot_start = 0
|
||||||
|
|
||||||
for i in range(len(shot_boundaries)):
|
for frame_idx, frame in self.preprocess_video():
|
||||||
start = shot_boundaries[i - 1] if i > 0 else 0
|
if frame_idx in shot_boundaries:
|
||||||
end = shot_boundaries[i]
|
if shot_frames:
|
||||||
shot_frames = frames[start:end]
|
keyframe, keyframe_idx = self.process_shot(shot_frames)
|
||||||
|
if keyframe is not None:
|
||||||
frame_features = np.array([cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).flatten()
|
yield keyframe, keyframe_idx
|
||||||
for frame in shot_frames])
|
|
||||||
kmeans = KMeans(n_clusters=1, random_state=0).fit(frame_features)
|
# 清理内存
|
||||||
center_idx = np.argmin(np.sum((frame_features - kmeans.cluster_centers_[0]) ** 2, axis=1))
|
shot_frames.clear()
|
||||||
|
gc.collect()
|
||||||
keyframes.append(shot_frames[center_idx])
|
|
||||||
keyframe_indices.append(start + center_idx)
|
|
||||||
|
|
||||||
# 过滤每秒多余的关键帧
|
|
||||||
filtered_keyframes, filtered_indices = self.filter_keyframes_by_time(keyframes, keyframe_indices)
|
|
||||||
|
|
||||||
return filtered_keyframes, filtered_indices
|
|
||||||
|
|
||||||
def save_keyframes(self, keyframes: List[np.ndarray], keyframe_indices: List[int],
|
|
||||||
output_dir: str) -> None:
|
|
||||||
"""
|
|
||||||
保存关键帧到指定目录,文件名格式为:keyframe_帧序号_时间戳.jpg
|
|
||||||
|
|
||||||
Args:
|
|
||||||
keyframes: 关键帧列表
|
|
||||||
keyframe_indices: 关键帧索引列表
|
|
||||||
output_dir: 输出目录
|
|
||||||
"""
|
|
||||||
if not os.path.exists(output_dir):
|
|
||||||
os.makedirs(output_dir)
|
|
||||||
|
|
||||||
for keyframe, frame_idx in zip(keyframes, keyframe_indices):
|
|
||||||
# 计算时间戳(秒)
|
|
||||||
timestamp = frame_idx / self.fps
|
|
||||||
# 将时间戳转换为 HH:MM:SS 格式
|
|
||||||
hours = int(timestamp // 3600)
|
|
||||||
minutes = int((timestamp % 3600) // 60)
|
|
||||||
seconds = int(timestamp % 60)
|
|
||||||
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}"
|
|
||||||
|
|
||||||
# 构建新的文件名格式:keyframe_帧序号_时间戳.jpg
|
|
||||||
output_path = os.path.join(output_dir,
|
|
||||||
f'keyframe_{frame_idx:06d}_{time_str}.jpg')
|
|
||||||
cv2.imwrite(output_path, keyframe)
|
|
||||||
|
|
||||||
print(f"已保存 {len(keyframes)} 个关键帧到 {output_dir}")
|
|
||||||
|
|
||||||
def extract_frames_by_numbers(self, frame_numbers: List[int], output_folder: str) -> None:
|
|
||||||
"""
|
|
||||||
根据指定的帧号提取帧
|
|
||||||
|
|
||||||
Args:
|
|
||||||
frame_numbers: 要提取的帧号列表
|
|
||||||
output_folder: 输出文件夹路径
|
|
||||||
"""
|
|
||||||
if not frame_numbers:
|
|
||||||
raise ValueError("未提供帧号列表")
|
|
||||||
|
|
||||||
if any(fn >= self.total_frames or fn < 0 for fn in frame_numbers):
|
|
||||||
raise ValueError("存在无效的帧号")
|
|
||||||
|
|
||||||
if not os.path.exists(output_folder):
|
|
||||||
os.makedirs(output_folder)
|
|
||||||
|
|
||||||
for frame_number in frame_numbers:
|
|
||||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
|
|
||||||
ret, frame = self.cap.read()
|
|
||||||
|
|
||||||
if ret:
|
|
||||||
# 计算时间戳
|
|
||||||
timestamp = frame_number / self.fps
|
|
||||||
hours = int(timestamp // 3600)
|
|
||||||
minutes = int((timestamp % 3600) // 60)
|
|
||||||
seconds = int(timestamp % 60)
|
|
||||||
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}"
|
|
||||||
|
|
||||||
# 使用与关键帧相同的命名格式
|
current_shot_start = frame_idx
|
||||||
output_path = os.path.join(output_folder,
|
|
||||||
f"extracted_frame_{frame_number:06d}_{time_str}.jpg")
|
|
||||||
cv2.imwrite(output_path, frame)
|
|
||||||
print(f"已提取并保存帧 {frame_number}")
|
|
||||||
else:
|
|
||||||
print(f"无法读取帧 {frame_number}")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def extract_numbers_from_folder(folder_path: str) -> List[int]:
|
|
||||||
"""
|
|
||||||
从文件夹中提取帧号
|
|
||||||
|
|
||||||
Args:
|
|
||||||
folder_path: 关键帧文件夹路径
|
|
||||||
|
|
||||||
Returns:
|
shot_frames.append((frame_idx, frame))
|
||||||
List[int]: 排序后的帧号列表
|
|
||||||
"""
|
# 控制单个镜头的最大帧数
|
||||||
files = [f for f in os.listdir(folder_path) if f.endswith('.jpg')]
|
if len(shot_frames) > self.batch_size:
|
||||||
# 更新正则表达式以匹配新的文件名格式:keyframe_000123_010534.jpg
|
keyframe, keyframe_idx = self.process_shot(shot_frames)
|
||||||
pattern = re.compile(r'keyframe_(\d+)_\d+\.jpg$')
|
if keyframe is not None:
|
||||||
numbers = []
|
yield keyframe, keyframe_idx
|
||||||
for f in files:
|
shot_frames.clear()
|
||||||
match = pattern.search(f)
|
gc.collect()
|
||||||
if match:
|
|
||||||
numbers.append(int(match.group(1)))
|
# 处理最后一个镜头
|
||||||
return sorted(numbers)
|
if shot_frames:
|
||||||
|
keyframe, keyframe_idx = self.process_shot(shot_frames)
|
||||||
|
if keyframe is not None:
|
||||||
|
yield keyframe, keyframe_idx
|
||||||
|
|
||||||
def process_video(self, output_dir: str, skip_seconds: float = 0) -> None:
|
def process_video(self, output_dir: str, skip_seconds: float = 0) -> None:
|
||||||
"""
|
"""
|
||||||
处理视频并提取关键帧
|
处理视频并提取关键帧,使用分批处理方式
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
output_dir: 输出目录
|
output_dir: 输出目录
|
||||||
skip_seconds: 跳过视频开头的秒数
|
skip_seconds: 跳过视频开头的秒数
|
||||||
"""
|
"""
|
||||||
# 计算要跳过的帧数
|
try:
|
||||||
skip_frames = int(skip_seconds * self.fps)
|
# 创建输出目录
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
# 获取所有帧
|
|
||||||
frames = list(self.preprocess_video())
|
# 计算要跳过的帧数
|
||||||
|
skip_frames = int(skip_seconds * self.fps)
|
||||||
# 跳过指定秒数的帧
|
self.cap.set(cv2.CAP_PROP_POS_FRAMES, skip_frames)
|
||||||
frames = frames[skip_frames:]
|
|
||||||
|
# 检测镜头边界
|
||||||
if not frames:
|
logger.info("开始检测镜头边界...")
|
||||||
raise ValueError(f"跳过 {skip_seconds} 秒后没有剩余帧可以处理")
|
shot_boundaries = self.detect_shot_boundaries()
|
||||||
|
|
||||||
shot_boundaries = self.detect_shot_boundaries(frames)
|
# 提取关键帧
|
||||||
keyframes, keyframe_indices = self.extract_keyframes(frames, shot_boundaries)
|
logger.info("开始提取关键帧...")
|
||||||
|
frame_count = 0
|
||||||
# 调整关键帧索引,加上跳过的帧数
|
|
||||||
adjusted_indices = [idx + skip_frames for idx in keyframe_indices]
|
for keyframe, frame_idx in self.extract_keyframes(shot_boundaries):
|
||||||
self.save_keyframes(keyframes, adjusted_indices, output_dir)
|
if frame_idx < skip_frames:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 计算时间戳
|
||||||
|
timestamp = frame_idx / self.fps
|
||||||
|
hours = int(timestamp // 3600)
|
||||||
|
minutes = int((timestamp % 3600) // 60)
|
||||||
|
seconds = int(timestamp % 60)
|
||||||
|
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}"
|
||||||
|
|
||||||
|
# 保存关键帧
|
||||||
|
output_path = os.path.join(output_dir,
|
||||||
|
f'keyframe_{frame_idx:06d}_{time_str}.jpg')
|
||||||
|
cv2.imwrite(output_path, keyframe)
|
||||||
|
frame_count += 1
|
||||||
|
|
||||||
|
# 定期清理内存
|
||||||
|
if frame_count % 10 == 0:
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
logger.info(f"关键帧提取完成,共保存 {frame_count} 帧到 {output_dir}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"视频处理失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# 确保资源被释放
|
||||||
|
self.cap.release()
|
||||||
|
gc.collect()
|
||||||
|
|||||||
@ -7,7 +7,16 @@ x-common: &common
|
|||||||
- ./:/NarratoAI
|
- ./:/NarratoAI
|
||||||
environment:
|
environment:
|
||||||
- VPN_PROXY_URL=http://host.docker.internal:7890
|
- VPN_PROXY_URL=http://host.docker.internal:7890
|
||||||
|
- PYTHONUNBUFFERED=1
|
||||||
|
- PYTHONMALLOC=malloc
|
||||||
|
- OPENCV_OPENCL_RUNTIME=disabled
|
||||||
|
- OPENCV_CPU_DISABLE=0
|
||||||
restart: always
|
restart: always
|
||||||
|
mem_limit: 4g
|
||||||
|
mem_reservation: 2g
|
||||||
|
memswap_limit: 6g
|
||||||
|
cpus: 2.0
|
||||||
|
cpu_shares: 1024
|
||||||
|
|
||||||
services:
|
services:
|
||||||
webui:
|
webui:
|
||||||
@ -16,3 +25,14 @@ services:
|
|||||||
ports:
|
ports:
|
||||||
- "8501:8501"
|
- "8501:8501"
|
||||||
command: ["webui"]
|
command: ["webui"]
|
||||||
|
logging:
|
||||||
|
driver: "json-file"
|
||||||
|
options:
|
||||||
|
max-size: "200m"
|
||||||
|
max-file: "3"
|
||||||
|
tmpfs:
|
||||||
|
- /tmp:size=1G
|
||||||
|
ulimits:
|
||||||
|
nofile:
|
||||||
|
soft: 65536
|
||||||
|
hard: 65536
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user