From ee52600ae2ef0bd91092c65d585e95a6cfeabdcf Mon Sep 17 00:00:00 2001 From: linyq Date: Mon, 11 Nov 2024 15:53:33 +0800 Subject: [PATCH] =?UTF-8?q?feat(app):=20=E4=BC=98=E5=8C=96=E5=85=B3?= =?UTF-8?q?=E9=94=AE=E5=B8=A7=E6=8F=90=E5=8F=96=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 重构 VideoProcessor 类,优化内存使用和性能 - 添加分批处理逻辑,支持大视频文件的处理 - 使用 MiniBatchKMeans 替代 KMeans,减少内存消耗 - 优化镜头边界检测和关键帧提取算法 - 增加日志记录和错误处理,提高程序的健壮性 --- app/utils/video_processor.py | 313 +++++++++++++++++------------------ docker-compose.yml | 20 +++ 2 files changed, 170 insertions(+), 163 deletions(-) diff --git a/app/utils/video_processor.py b/app/utils/video_processor.py index eb0da75..46a8971 100644 --- a/app/utils/video_processor.py +++ b/app/utils/video_processor.py @@ -1,23 +1,27 @@ import cv2 import numpy as np -from sklearn.cluster import KMeans +from sklearn.cluster import MiniBatchKMeans import os import re from typing import List, Tuple, Generator +from loguru import logger +import gc class VideoProcessor: - def __init__(self, video_path: str): + def __init__(self, video_path: str, batch_size: int = 100): """ 初始化视频处理器 Args: video_path: 视频文件路径 + batch_size: 批处理大小,控制内存使用 """ if not os.path.exists(video_path): raise FileNotFoundError(f"视频文件不存在: {video_path}") self.video_path = video_path + self.batch_size = batch_size self.cap = cv2.VideoCapture(video_path) if not self.cap.isOpened(): @@ -30,211 +34,194 @@ class VideoProcessor: """析构函数,确保视频资源被释放""" if hasattr(self, 'cap'): self.cap.release() + gc.collect() - def preprocess_video(self) -> Generator[np.ndarray, None, None]: + def preprocess_video(self) -> Generator[Tuple[int, np.ndarray], None, None]: """ - 使用生成器方式读取视频帧 + 使用生成器方式分批读取视频帧 Yields: - np.ndarray: 视频帧 + Tuple[int, np.ndarray]: (帧索引, 视频帧) """ - self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # 重置到视频开始 + self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0) + frame_idx = 0 + while self.cap.isOpened(): ret, frame = self.cap.read() if not ret: break - yield frame + + # 降低分辨率以减少内存使用 + frame = cv2.resize(frame, (0, 0), fx=0.5, fy=0.5) + yield frame_idx, frame + + frame_idx += 1 + + # 定期进行垃圾回收 + if frame_idx % 1000 == 0: + gc.collect() - def detect_shot_boundaries(self, frames: List[np.ndarray], threshold: int = 30) -> List[int]: + def detect_shot_boundaries(self, threshold: int = 30) -> List[int]: """ - 使用帧差法检测镜头边界 + 使用批处理方式检测镜头边界 Args: - frames: 视频帧列表 threshold: 差异阈值 Returns: List[int]: 镜头边界帧的索引列表 """ shot_boundaries = [] - for i in range(1, len(frames)): - prev_frame = cv2.cvtColor(frames[i - 1], cv2.COLOR_BGR2GRAY) - curr_frame = cv2.cvtColor(frames[i], cv2.COLOR_BGR2GRAY) - diff = np.mean(np.abs(curr_frame.astype(int) - prev_frame.astype(int))) - if diff > threshold: - shot_boundaries.append(i) + prev_frame = None + prev_idx = -1 + + for frame_idx, curr_frame in self.preprocess_video(): + if prev_frame is not None: + # 转换为灰度图并降低分辨率以提高性能 + prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY) + curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2GRAY) + + diff = np.mean(np.abs(curr_gray.astype(float) - prev_gray.astype(float))) + if diff > threshold: + shot_boundaries.append(frame_idx) + + prev_frame = curr_frame.copy() + prev_idx = frame_idx + + # 释放不需要的内存 + del curr_frame + if frame_idx % 100 == 0: + gc.collect() + return shot_boundaries - def filter_keyframes_by_time(self, keyframes: List[np.ndarray], - keyframe_indices: List[int]) -> Tuple[List[np.ndarray], List[int]]: + def process_shot(self, shot_frames: List[Tuple[int, np.ndarray]]) -> Tuple[np.ndarray, int]: """ - 过滤关键帧,确保每秒最多只有一个关键帧 + 处理单个镜头的帧 Args: - keyframes: 关键帧列表 - keyframe_indices: 关键帧索引列表 + shot_frames: 镜头中的帧列表 Returns: - Tuple[List[np.ndarray], List[int]]: 过滤后的关键帧列表和对应的帧索引 + Tuple[np.ndarray, int]: (关键帧, 帧索引) """ - if not keyframes or not keyframe_indices: - return keyframes, keyframe_indices + if not shot_frames: + return None, -1 - filtered_frames = [] - filtered_indices = [] - last_second = -1 + # 提取特征 + frame_features = [] + frame_indices = [] - for frame, idx in zip(keyframes, keyframe_indices): - current_second = idx // self.fps - if current_second != last_second: - filtered_frames.append(frame) - filtered_indices.append(idx) - last_second = current_second - - return filtered_frames, filtered_indices + for idx, frame in shot_frames: + gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + # 降低特征维度以节省内存 + resized_gray = cv2.resize(gray, (32, 32)) + frame_features.append(resized_gray.flatten()) + frame_indices.append(idx) + + frame_features = np.array(frame_features) + + # 使用MiniBatchKMeans替代KMeans以减少内存使用 + kmeans = MiniBatchKMeans(n_clusters=1, batch_size=min(len(frame_features), 100), + random_state=0).fit(frame_features) + + center_idx = np.argmin(np.sum((frame_features - kmeans.cluster_centers_[0]) ** 2, axis=1)) + + return shot_frames[center_idx][1], frame_indices[center_idx] - def extract_keyframes(self, frames: List[np.ndarray], shot_boundaries: List[int]) -> Tuple[List[np.ndarray], List[int]]: + def extract_keyframes(self, shot_boundaries: List[int]) -> Generator[Tuple[np.ndarray, int], None, None]: """ - 从每个镜头中提取关键帧,并确保每秒最多一个关键帧 + 使用生成器方式提取关键帧 Args: - frames: 视频帧列表 shot_boundaries: 镜头边界列表 - Returns: - Tuple[List[np.ndarray], List[int]]: 关键帧列表和对应的帧索引 + Yields: + Tuple[np.ndarray, int]: (关键帧, 帧索引) """ - keyframes = [] - keyframe_indices = [] + shot_frames = [] + current_shot_start = 0 - for i in range(len(shot_boundaries)): - start = shot_boundaries[i - 1] if i > 0 else 0 - end = shot_boundaries[i] - shot_frames = frames[start:end] - - frame_features = np.array([cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).flatten() - for frame in shot_frames]) - kmeans = KMeans(n_clusters=1, random_state=0).fit(frame_features) - center_idx = np.argmin(np.sum((frame_features - kmeans.cluster_centers_[0]) ** 2, axis=1)) - - keyframes.append(shot_frames[center_idx]) - keyframe_indices.append(start + center_idx) - - # 过滤每秒多余的关键帧 - filtered_keyframes, filtered_indices = self.filter_keyframes_by_time(keyframes, keyframe_indices) - - return filtered_keyframes, filtered_indices - - def save_keyframes(self, keyframes: List[np.ndarray], keyframe_indices: List[int], - output_dir: str) -> None: - """ - 保存关键帧到指定目录,文件名格式为:keyframe_帧序号_时间戳.jpg - - Args: - keyframes: 关键帧列表 - keyframe_indices: 关键帧索引列表 - output_dir: 输出目录 - """ - if not os.path.exists(output_dir): - os.makedirs(output_dir) - - for keyframe, frame_idx in zip(keyframes, keyframe_indices): - # 计算时间戳(秒) - timestamp = frame_idx / self.fps - # 将时间戳转换为 HH:MM:SS 格式 - hours = int(timestamp // 3600) - minutes = int((timestamp % 3600) // 60) - seconds = int(timestamp % 60) - time_str = f"{hours:02d}{minutes:02d}{seconds:02d}" - - # 构建新的文件名格式:keyframe_帧序号_时间戳.jpg - output_path = os.path.join(output_dir, - f'keyframe_{frame_idx:06d}_{time_str}.jpg') - cv2.imwrite(output_path, keyframe) - - print(f"已保存 {len(keyframes)} 个关键帧到 {output_dir}") - - def extract_frames_by_numbers(self, frame_numbers: List[int], output_folder: str) -> None: - """ - 根据指定的帧号提取帧 - - Args: - frame_numbers: 要提取的帧号列表 - output_folder: 输出文件夹路径 - """ - if not frame_numbers: - raise ValueError("未提供帧号列表") - - if any(fn >= self.total_frames or fn < 0 for fn in frame_numbers): - raise ValueError("存在无效的帧号") - - if not os.path.exists(output_folder): - os.makedirs(output_folder) - - for frame_number in frame_numbers: - self.cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) - ret, frame = self.cap.read() - - if ret: - # 计算时间戳 - timestamp = frame_number / self.fps - hours = int(timestamp // 3600) - minutes = int((timestamp % 3600) // 60) - seconds = int(timestamp % 60) - time_str = f"{hours:02d}{minutes:02d}{seconds:02d}" + for frame_idx, frame in self.preprocess_video(): + if frame_idx in shot_boundaries: + if shot_frames: + keyframe, keyframe_idx = self.process_shot(shot_frames) + if keyframe is not None: + yield keyframe, keyframe_idx + + # 清理内存 + shot_frames.clear() + gc.collect() - # 使用与关键帧相同的命名格式 - output_path = os.path.join(output_folder, - f"extracted_frame_{frame_number:06d}_{time_str}.jpg") - cv2.imwrite(output_path, frame) - print(f"已提取并保存帧 {frame_number}") - else: - print(f"无法读取帧 {frame_number}") - - @staticmethod - def extract_numbers_from_folder(folder_path: str) -> List[int]: - """ - 从文件夹中提取帧号 - - Args: - folder_path: 关键帧文件夹路径 + current_shot_start = frame_idx - Returns: - List[int]: 排序后的帧号列表 - """ - files = [f for f in os.listdir(folder_path) if f.endswith('.jpg')] - # 更新正则表达式以匹配新的文件名格式:keyframe_000123_010534.jpg - pattern = re.compile(r'keyframe_(\d+)_\d+\.jpg$') - numbers = [] - for f in files: - match = pattern.search(f) - if match: - numbers.append(int(match.group(1))) - return sorted(numbers) + shot_frames.append((frame_idx, frame)) + + # 控制单个镜头的最大帧数 + if len(shot_frames) > self.batch_size: + keyframe, keyframe_idx = self.process_shot(shot_frames) + if keyframe is not None: + yield keyframe, keyframe_idx + shot_frames.clear() + gc.collect() + + # 处理最后一个镜头 + if shot_frames: + keyframe, keyframe_idx = self.process_shot(shot_frames) + if keyframe is not None: + yield keyframe, keyframe_idx def process_video(self, output_dir: str, skip_seconds: float = 0) -> None: """ - 处理视频并提取关键帧 + 处理视频并提取关键帧,使用分批处理方式 Args: output_dir: 输出目录 skip_seconds: 跳过视频开头的秒数 """ - # 计算要跳过的帧数 - skip_frames = int(skip_seconds * self.fps) - - # 获取所有帧 - frames = list(self.preprocess_video()) - - # 跳过指定秒数的帧 - frames = frames[skip_frames:] - - if not frames: - raise ValueError(f"跳过 {skip_seconds} 秒后没有剩余帧可以处理") - - shot_boundaries = self.detect_shot_boundaries(frames) - keyframes, keyframe_indices = self.extract_keyframes(frames, shot_boundaries) - - # 调整关键帧索引,加上跳过的帧数 - adjusted_indices = [idx + skip_frames for idx in keyframe_indices] - self.save_keyframes(keyframes, adjusted_indices, output_dir) + try: + # 创建输出目录 + os.makedirs(output_dir, exist_ok=True) + + # 计算要跳过的帧数 + skip_frames = int(skip_seconds * self.fps) + self.cap.set(cv2.CAP_PROP_POS_FRAMES, skip_frames) + + # 检测镜头边界 + logger.info("开始检测镜头边界...") + shot_boundaries = self.detect_shot_boundaries() + + # 提取关键帧 + logger.info("开始提取关键帧...") + frame_count = 0 + + for keyframe, frame_idx in self.extract_keyframes(shot_boundaries): + if frame_idx < skip_frames: + continue + + # 计算时间戳 + timestamp = frame_idx / self.fps + hours = int(timestamp // 3600) + minutes = int((timestamp % 3600) // 60) + seconds = int(timestamp % 60) + time_str = f"{hours:02d}{minutes:02d}{seconds:02d}" + + # 保存关键帧 + output_path = os.path.join(output_dir, + f'keyframe_{frame_idx:06d}_{time_str}.jpg') + cv2.imwrite(output_path, keyframe) + frame_count += 1 + + # 定期清理内存 + if frame_count % 10 == 0: + gc.collect() + + logger.info(f"关键帧提取完成,共保存 {frame_count} 帧到 {output_dir}") + + except Exception as e: + logger.error(f"视频处理失败: {str(e)}") + raise + finally: + # 确保资源被释放 + self.cap.release() + gc.collect() diff --git a/docker-compose.yml b/docker-compose.yml index a28a0ca..8b36f7c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -7,7 +7,16 @@ x-common: &common - ./:/NarratoAI environment: - VPN_PROXY_URL=http://host.docker.internal:7890 + - PYTHONUNBUFFERED=1 + - PYTHONMALLOC=malloc + - OPENCV_OPENCL_RUNTIME=disabled + - OPENCV_CPU_DISABLE=0 restart: always + mem_limit: 4g + mem_reservation: 2g + memswap_limit: 6g + cpus: 2.0 + cpu_shares: 1024 services: webui: @@ -16,3 +25,14 @@ services: ports: - "8501:8501" command: ["webui"] + logging: + driver: "json-file" + options: + max-size: "200m" + max-file: "3" + tmpfs: + - /tmp:size=1G + ulimits: + nofile: + soft: 65536 + hard: 65536