diff --git a/app/config/config.py b/app/config/config.py index a653ddc..f98a081 100644 --- a/app/config/config.py +++ b/app/config/config.py @@ -45,6 +45,7 @@ whisper = _cfg.get("whisper", {}) proxy = _cfg.get("proxy", {}) azure = _cfg.get("azure", {}) ui = _cfg.get("ui", {}) +frames = _cfg.get("frames", {}) hostname = socket.gethostname() diff --git a/app/utils/video_processor_v2.py b/app/utils/video_processor_v2.py new file mode 100644 index 0000000..038064a --- /dev/null +++ b/app/utils/video_processor_v2.py @@ -0,0 +1,294 @@ +import cv2 +import numpy as np +from sklearn.cluster import KMeans +import os +import re +from typing import List, Tuple, Generator +from loguru import logger +import subprocess +from tqdm import tqdm + + +class VideoProcessor: + def __init__(self, video_path: str): + """ + 初始化视频处理器 + + Args: + video_path: 视频文件路径 + """ + if not os.path.exists(video_path): + raise FileNotFoundError(f"视频文件不存在: {video_path}") + + self.video_path = video_path + self.cap = cv2.VideoCapture(video_path) + + if not self.cap.isOpened(): + raise RuntimeError(f"无法打开视频文件: {video_path}") + + self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) + self.fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + + def __del__(self): + """析构函数,确保视频资源被释放""" + if hasattr(self, 'cap'): + self.cap.release() + + def preprocess_video(self) -> Generator[np.ndarray, None, None]: + """ + 使用生成器方式读取视频帧 + + Yields: + np.ndarray: 视频帧 + """ + self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # 重置到视频开始 + while self.cap.isOpened(): + ret, frame = self.cap.read() + if not ret: + break + yield frame + + def detect_shot_boundaries(self, frames: List[np.ndarray], threshold: int = 30) -> List[int]: + """ + 使用帧差法检测镜头边界 + + Args: + frames: 视频帧列表 + threshold: 差异阈值 + + Returns: + List[int]: 镜头边界帧的索引列表 + """ + shot_boundaries = [] + for i in range(1, len(frames)): + prev_frame = cv2.cvtColor(frames[i - 1], cv2.COLOR_BGR2GRAY) + curr_frame = cv2.cvtColor(frames[i], cv2.COLOR_BGR2GRAY) + diff = np.mean(np.abs(curr_frame.astype(int) - prev_frame.astype(int))) + if diff > threshold: + shot_boundaries.append(i) + return shot_boundaries + + def extract_keyframes(self, frames: List[np.ndarray], shot_boundaries: List[int]) -> Tuple[ + List[np.ndarray], List[int]]: + """ + 从每个镜头中提取关键帧 + + Args: + frames: 视频帧列表 + shot_boundaries: 镜头边界列表 + + Returns: + Tuple[List[np.ndarray], List[int]]: 关键帧列表和对应的帧索引 + """ + keyframes = [] + keyframe_indices = [] + + for i in range(len(shot_boundaries)): + start = shot_boundaries[i - 1] if i > 0 else 0 + end = shot_boundaries[i] + shot_frames = frames[start:end] + + frame_features = np.array([cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).flatten() + for frame in shot_frames]) + kmeans = KMeans(n_clusters=1, random_state=0).fit(frame_features) + center_idx = np.argmin(np.sum((frame_features - kmeans.cluster_centers_[0]) ** 2, axis=1)) + + keyframes.append(shot_frames[center_idx]) + keyframe_indices.append(start + center_idx) + + return keyframes, keyframe_indices + + def save_keyframes(self, keyframes: List[np.ndarray], keyframe_indices: List[int], + output_dir: str, desc: str = "保存关键帧") -> None: + """ + 保存关键帧到指定目录,文件名格式为:keyframe_帧序号_时间戳.jpg + + Args: + keyframes: 关键帧列表 + keyframe_indices: 关键帧索引列表 + output_dir: 输出目录 + desc: 进度条描述 + """ + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + for keyframe, frame_idx in tqdm(zip(keyframes, keyframe_indices), + total=len(keyframes), + desc=desc): + timestamp = frame_idx / self.fps + hours = int(timestamp // 3600) + minutes = int((timestamp % 3600) // 60) + seconds = int(timestamp % 60) + time_str = f"{hours:02d}{minutes:02d}{seconds:02d}" + + output_path = os.path.join(output_dir, + f'keyframe_{frame_idx:06d}_{time_str}.jpg') + cv2.imwrite(output_path, keyframe) + + def extract_frames_by_numbers(self, frame_numbers: List[int], output_folder: str) -> None: + """ + 根据指定的帧号提取帧,如果多个帧在同一秒内,只保留一个 + + Args: + frame_numbers: 要提取的帧号列表 + output_folder: 输出文件夹路径 + """ + if not frame_numbers: + raise ValueError("未提供帧号列表") + + if any(fn >= self.total_frames or fn < 0 for fn in frame_numbers): + raise ValueError("存在无效的帧号") + + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + # 用于记录已处理的时间戳(秒) + processed_seconds = set() + + for frame_number in tqdm(frame_numbers, desc="提取高清帧"): + # 计算时间戳(秒) + timestamp_seconds = int(frame_number / self.fps) + + # 如果这一秒已经处理过,跳过 + if timestamp_seconds in processed_seconds: + continue + + self.cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) + ret, frame = self.cap.read() + + if ret: + # 记录这一秒已经处理 + processed_seconds.add(timestamp_seconds) + + # 计算时间戳字符串 + hours = int(timestamp_seconds // 3600) + minutes = int((timestamp_seconds % 3600) // 60) + seconds = int(timestamp_seconds % 60) + time_str = f"{hours:02d}{minutes:02d}{seconds:02d}" + + output_path = os.path.join(output_folder, + f"keyframe_{frame_number:06d}_{time_str}.jpg") + cv2.imwrite(output_path, frame) + else: + logger.info(f"无法读取帧 {frame_number}") + + logger.info(f"共提取了 {len(processed_seconds)} 个不同时间戳的帧") + + @staticmethod + def extract_numbers_from_folder(folder_path: str) -> List[int]: + """ + 从文件夹中提取帧号 + + Args: + folder_path: 关键帧文件夹路径 + + Returns: + List[int]: 排序后的帧号列表 + """ + files = [f for f in os.listdir(folder_path) if f.endswith('.jpg')] + # 更新正则表达式以匹配新的文件名格式:keyframe_000123_010534.jpg + pattern = re.compile(r'keyframe_(\d+)_\d+\.jpg$') + numbers = [] + for f in files: + match = pattern.search(f) + if match: + numbers.append(int(match.group(1))) + return sorted(numbers) + + def process_video(self, output_dir: str, skip_seconds: float = 0, threshold: int = 30) -> None: + """ + 处理视频并提取关键帧 + + Args: + output_dir: 输出目录 + skip_seconds: 跳过视频开头的秒数 + """ + skip_frames = int(skip_seconds * self.fps) + + logger.info("读取视频帧...") + frames = [] + for frame in tqdm(self.preprocess_video(), + total=self.total_frames, + desc="读取视频"): + frames.append(frame) + + frames = frames[skip_frames:] + + if not frames: + raise ValueError(f"跳过 {skip_seconds} 秒后没有剩余帧可以处理") + + logger.info("\n检测场景边界...") + shot_boundaries = self.detect_shot_boundaries(frames, threshold) + logger.info(f"检测到 {len(shot_boundaries)} 个场景边界") + + logger.info("\n提取关键帧...") + keyframes, keyframe_indices = self.extract_keyframes(frames, shot_boundaries) + + adjusted_indices = [idx + skip_frames for idx in keyframe_indices] + self.save_keyframes(keyframes, adjusted_indices, output_dir, desc="保存压缩关键帧") + + def process_video_pipeline(self, + output_dir: str, + skip_seconds: float = 0, + threshold: int = 30, + compressed_width: int = 320, + keep_temp: bool = False) -> None: + """ + 执行完整的视频处理流程:压缩、提取关键帧、导出高清帧 + """ + os.makedirs(output_dir, exist_ok=True) + temp_dir = os.path.join(output_dir, 'temp') + compressed_dir = os.path.join(temp_dir, 'compressed') + mini_frames_dir = os.path.join(temp_dir, 'mini_frames') + hd_frames_dir = output_dir + + os.makedirs(temp_dir, exist_ok=True) + os.makedirs(compressed_dir, exist_ok=True) + os.makedirs(mini_frames_dir, exist_ok=True) + os.makedirs(hd_frames_dir, exist_ok=True) + + try: + # 1. 压缩视频 + video_name = os.path.splitext(os.path.basename(self.video_path))[0] + compressed_video = os.path.join(compressed_dir, f"{video_name}_compressed.mp4") + + logger.info("步骤1: 压缩视频...") + ffmpeg_cmd = [ + 'ffmpeg', '-i', self.video_path, + '-vf', f'scale={compressed_width}:-1', + '-y', + compressed_video + ] + subprocess.run(ffmpeg_cmd, check=True) + + # 2. 从压缩视频中提取关键帧 + logger.info("\n步骤2: 从压缩视频提取关键帧...") + mini_processor = VideoProcessor(compressed_video) + mini_processor.process_video(mini_frames_dir, skip_seconds, threshold) + + # 3. 从原始视频提取高清关键帧 + logger.info("\n步骤3: 提取高清关键帧...") + frame_numbers = mini_processor.extract_numbers_from_folder(mini_frames_dir) + self.extract_frames_by_numbers(frame_numbers, hd_frames_dir) + + logger.info(f"\n处理完成!") + logger.info(f"高清关键帧保存在: {hd_frames_dir}") + + finally: + if not keep_temp: + import shutil + try: + shutil.rmtree(temp_dir) + logger.info("临时文件已清理") + except Exception as e: + logger.info(f"清理临时文件时出错: {e}") + + +if __name__ == "__main__": + import time + + start_time = time.time() + processor = VideoProcessor("best.mp4") + processor.process_video_pipeline(output_dir="output4") + end_time = time.time() + print(f"处理完成!总耗时: {end_time - start_time:.2f} 秒") diff --git a/config.example.toml b/config.example.toml index 1dc539d..46f8319 100644 --- a/config.example.toml +++ b/config.example.toml @@ -5,7 +5,6 @@ # NarratoAPI # qwen2-vl (待增加) vision_llm_provider="gemini" - vision_batch_size = 7 vision_analysis_prompt = "你是资深视频内容分析专家,擅长分析视频画面信息,分析下面视频画面内容,只输出客观的画面描述不要给任何总结或评价" ########## Vision Gemini API Key @@ -170,4 +169,15 @@ # Azure Speech API Key # Get your API key at https://portal.azure.com/#view/Microsoft_Azure_ProjectOxford/CognitiveServicesHub/~/SpeechServices speech_key="" - speech_region="" \ No newline at end of file + speech_region="" + +[frames] + skip_seconds = 0 + # threshold(差异阈值)用于判断两个连续帧之间是否发生了场景切换 + # 较小的阈值(如 20):更敏感,能捕捉到细微的场景变化,但可能会误判,关键帧图片更多 + # 较大的阈值(如 40):更保守,只捕捉明显的场景切换,但可能会漏掉渐变场景,关键帧图片更少 + # 默认值 30:在实践中是一个比较平衡的选择 + threshold = 30 + version = "v2" + # 大模型单次处理的关键帧数量 + vision_batch_size = 5 diff --git a/webui/components/script_settings.py b/webui/components/script_settings.py index bfbe297..4e4aea6 100644 --- a/webui/components/script_settings.py +++ b/webui/components/script_settings.py @@ -14,7 +14,7 @@ from loguru import logger from app.config import config from app.models.schema import VideoClipParams from app.utils.script_generator import ScriptProcessor -from app.utils import utils, check_script, vision_analyzer, video_processor +from app.utils import utils, check_script, vision_analyzer, video_processor, video_processor_v2 from webui.utils import file_utils @@ -318,13 +318,21 @@ def generate_script(tr, params): os.makedirs(video_keyframes_dir, exist_ok=True) # 初始化视频处理器 - processor = video_processor.VideoProcessor(params.video_origin_path) - - # 处理视频并提取关键帧 - processor.process_video( - output_dir=video_keyframes_dir, - skip_seconds=0 - ) + if config.frames.get("version") == "v2": + processor = video_processor_v2.VideoProcessor(params.video_origin_path) + # 处理视频并提取关键帧 + processor.process_video_pipeline( + output_dir=video_keyframes_dir, + skip_seconds=config.frames.get("skip_seconds", 0), + threshold=config.frames.get("threshold", 30) + ) + else: + processor = video_processor.VideoProcessor(params.video_origin_path) + # 处理视频并提取关键帧 + processor.process_video( + output_dir=video_keyframes_dir, + skip_seconds=0 + ) # 获取所有关键帧文件路径 for filename in sorted(os.listdir(video_keyframes_dir)): @@ -380,7 +388,7 @@ def generate_script(tr, params): analyzer.analyze_images( images=keyframe_files, prompt=config.app.get('vision_analysis_prompt'), - batch_size=config.app.get("vision_batch_size", 5) + batch_size=config.frames.get("vision_batch_size", 5) ) ) loop.close() @@ -397,7 +405,7 @@ def generate_script(tr, params): logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}") continue - batch_files = get_batch_files(keyframe_files, result, config.app.get("vision_batch_size", 5)) + batch_files = get_batch_files(keyframe_files, result, config.frames.get("vision_batch_size", 5)) logger.debug(f"批次 {result['batch_index']} 处理完成,共 {len(batch_files)} 张图片") logger.debug(batch_files) @@ -436,7 +444,7 @@ def generate_script(tr, params): if 'error' in result: continue - batch_files = get_batch_files(keyframe_files, result, config.app.get("vision_batch_size", 5)) + batch_files = get_batch_files(keyframe_files, result, config.frames.get("vision_batch_size", 5)) _, _, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files) frame_content = { @@ -612,14 +620,14 @@ def generate_script(tr, params): if script is None: st.error("生成脚本失败,请检查日志") st.stop() - logger.info(f"脚本生成完成\n{script} \n{type(script)}") + logger.info(f"脚本生成完成") if isinstance(script, list): st.session_state['video_clip_json'] = script elif isinstance(script, str): st.session_state['video_clip_json'] = json.loads(script) - update_progress(90, "脚本生成完成") + update_progress(80, "脚本生成完成") - time.sleep(0.5) + time.sleep(0.1) progress_bar.progress(100) status_text.text("脚本生成完成!") st.success("视频脚本生成成功!")