diff --git a/app/services/script_service.py b/app/services/script_service.py index 37644a7..f4a6f95 100644 --- a/app/services/script_service.py +++ b/app/services/script_service.py @@ -3,10 +3,11 @@ import json import time import asyncio import requests +from app.utils import video_processor from loguru import logger from typing import List, Dict, Any, Callable -from app.utils import utils, gemini_analyzer, video_processor, video_processor_v2 +from app.utils import utils, gemini_analyzer, video_processor from app.utils.script_generator import ScriptProcessor from app.config import config @@ -105,20 +106,13 @@ class ScriptGenerator: os.makedirs(video_keyframes_dir, exist_ok=True) try: - if config.frames.get("version") == "v2": - processor = video_processor_v2.VideoProcessor(video_path) - processor.process_video_pipeline( - output_dir=video_keyframes_dir, - skip_seconds=skip_seconds, - threshold=threshold - ) - else: - processor = video_processor.VideoProcessor(video_path) - processor.process_video( - output_dir=video_keyframes_dir, - skip_seconds=skip_seconds - ) - + processor = video_processor.VideoProcessor(video_path) + processor.process_video_pipeline( + output_dir=video_keyframes_dir, + skip_seconds=skip_seconds, + threshold=threshold + ) + for filename in sorted(os.listdir(video_keyframes_dir)): if filename.endswith('.jpg'): keyframe_files.append(os.path.join(video_keyframes_dir, filename)) diff --git a/app/services/subtitle.py b/app/services/subtitle.py index 34aa2cb..c443c3f 100644 --- a/app/services/subtitle.py +++ b/app/services/subtitle.py @@ -4,7 +4,7 @@ import re import traceback from typing import Optional -from faster_whisper import WhisperModel +# from faster_whisper import WhisperModel from timeit import default_timer as timer from loguru import logger import google.generativeai as genai @@ -45,12 +45,25 @@ def create(audio_file, subtitle_file: str = ""): ) return None - # 尝试使用 CUDA,如果失败则回退到 CPU + # 首先使用CPU模式,不触发CUDA检查 + use_cuda = False try: - import torch - if torch.cuda.is_available(): + # 在函数中延迟导入torch,而不是在全局范围内 + # 使用安全的方式检查CUDA可用性 + def check_cuda_available(): + try: + import torch + return torch.cuda.is_available() + except (ImportError, RuntimeError) as e: + logger.warning(f"检查CUDA可用性时出错: {e}") + return False + + # 仅当明确需要时才检查CUDA + use_cuda = check_cuda_available() + + if use_cuda: + logger.info(f"尝试使用 CUDA 加载模型: {model_path}") try: - logger.info(f"尝试使用 CUDA 加载模型: {model_path}") model = WhisperModel( model_size_or_path=model_path, device="cuda", @@ -63,18 +76,18 @@ def create(audio_file, subtitle_file: str = ""): except Exception as e: logger.warning(f"CUDA 加载失败,错误信息: {str(e)}") logger.warning("回退到 CPU 模式") - device = "cpu" - compute_type = "int8" + use_cuda = False else: - logger.info("未检测到 CUDA,使用 CPU 模式") - device = "cpu" - compute_type = "int8" - except ImportError: - logger.warning("未安装 torch,使用 CPU 模式") + logger.info("使用 CPU 模式") + except Exception as e: + logger.warning(f"CUDA检查过程出错: {e}") + logger.warning("默认使用CPU模式") + use_cuda = False + + # 如果CUDA不可用或加载失败,使用CPU + if not use_cuda: device = "cpu" compute_type = "int8" - - if device == "cpu": logger.info(f"使用 CPU 加载模型: {model_path}") model = WhisperModel( model_size_or_path=model_path, diff --git a/app/services/video.py b/app/services/video.py index 83214f9..087dbdf 100644 --- a/app/services/video.py +++ b/app/services/video.py @@ -1,6 +1,6 @@ import traceback -import pysrt +# import pysrt from typing import Optional from typing import List from loguru import logger diff --git a/app/utils/script_generator.py b/app/utils/script_generator.py index 6493e82..7020782 100644 --- a/app/utils/script_generator.py +++ b/app/utils/script_generator.py @@ -2,7 +2,7 @@ import os import json import traceback from loguru import logger -import tiktoken +# import tiktoken from typing import List, Dict from datetime import datetime from openai import OpenAI @@ -94,12 +94,12 @@ class OpenAIGenerator(BaseGenerator): "user": "script_generator" } - # 初始化token计数器 - try: - self.encoding = tiktoken.encoding_for_model(self.model_name) - except KeyError: - logger.warning(f"未找到模型 {self.model_name} 的专用编码器,使用默认编码器") - self.encoding = tiktoken.get_encoding("cl100k_base") + # # 初始化token计数器 + # try: + # self.encoding = tiktoken.encoding_for_model(self.model_name) + # except KeyError: + # logger.warning(f"未找到模型 {self.model_name} 的专用编码器,使用默认编码器") + # self.encoding = tiktoken.get_encoding("cl100k_base") def _generate(self, messages: list, params: dict) -> any: """实现OpenAI特定的生成逻辑""" diff --git a/app/utils/video_processor.py b/app/utils/video_processor.py index 5949e6b..a754d50 100644 --- a/app/utils/video_processor.py +++ b/app/utils/video_processor.py @@ -1,237 +1,349 @@ -import cv2 -import numpy as np -from sklearn.cluster import MiniBatchKMeans +""" +视频帧提取工具 + +这个模块提供了简单高效的视频帧提取功能。主要特点: +1. 使用ffmpeg进行视频处理,支持硬件加速 +2. 按指定时间间隔提取视频关键帧 +3. 支持多种视频格式 +4. 支持高清视频帧输出 +5. 直接从原视频提取高质量关键帧 + +不依赖OpenCV和sklearn等库,只使用ffmpeg作为外部依赖,降低了安装和使用的复杂度。 +""" + import os import re -from typing import List, Tuple, Generator +import time +import subprocess +from typing import List, Dict from loguru import logger -import gc from tqdm import tqdm class VideoProcessor: - def __init__(self, video_path: str, batch_size: int = 100): + def __init__(self, video_path: str): """ 初始化视频处理器 - + Args: video_path: 视频文件路径 - batch_size: 批处理大小,控制内存使用 """ if not os.path.exists(video_path): raise FileNotFoundError(f"视频文件不存在: {video_path}") - + self.video_path = video_path - self.batch_size = batch_size - self.cap = cv2.VideoCapture(video_path) - - if not self.cap.isOpened(): - raise RuntimeError(f"无法打开视频文件: {video_path}") - - self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) - self.fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + self.video_info = self._get_video_info() + self.fps = float(self.video_info.get('fps', 25)) + self.duration = float(self.video_info.get('duration', 0)) + self.width = int(self.video_info.get('width', 0)) + self.height = int(self.video_info.get('height', 0)) + self.total_frames = int(self.fps * self.duration) - def __del__(self): - """析构函数,确保视频资源被释放""" - if hasattr(self, 'cap'): - self.cap.release() - gc.collect() + def _get_video_info(self) -> Dict[str, str]: + """ + 使用ffprobe获取视频信息 - def preprocess_video(self) -> Generator[Tuple[int, np.ndarray], None, None]: - """ - 使用生成器方式分批读取视频帧 - - Yields: - Tuple[int, np.ndarray]: (帧索引, 视频帧) - """ - self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0) - frame_idx = 0 - - while self.cap.isOpened(): - ret, frame = self.cap.read() - if not ret: - break - - # 降低分辨率以减少内存使用 - frame = cv2.resize(frame, (0, 0), fx=0.5, fy=0.5) - yield frame_idx, frame - - frame_idx += 1 - - # 定期进行垃圾回收 - if frame_idx % 1000 == 0: - gc.collect() - - def detect_shot_boundaries(self, threshold: int = 70) -> List[int]: - """ - 使用批处理方式检测镜头边界 - - Args: - threshold: 差异阈值 - Returns: - List[int]: 镜头边界帧的索引列表 + Dict[str, str]: 包含视频基本信息的字典 """ - shot_boundaries = [] - prev_frame = None - prev_idx = -1 - - pbar = tqdm(self.preprocess_video(), - total=self.total_frames, - desc="检测镜头边界", - unit="帧") - - for frame_idx, curr_frame in pbar: - if prev_frame is not None: - prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY) - curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2GRAY) - - diff = np.mean(np.abs(curr_gray.astype(float) - prev_gray.astype(float))) - if diff > threshold: - shot_boundaries.append(frame_idx) - pbar.set_postfix({"检测到边界": len(shot_boundaries)}) - - prev_frame = curr_frame.copy() - prev_idx = frame_idx - - del curr_frame - if frame_idx % 100 == 0: - gc.collect() - - return shot_boundaries + cmd = [ + "ffprobe", + "-v", "error", + "-select_streams", "v:0", + "-show_entries", "stream=width,height,r_frame_rate,duration", + "-of", "default=noprint_wrappers=1:nokey=0", + self.video_path + ] - def process_shot(self, shot_frames: List[Tuple[int, np.ndarray]]) -> Tuple[np.ndarray, int]: - """ - 处理单个镜头的帧 - - Args: - shot_frames: 镜头中的帧列表 + try: + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + lines = result.stdout.strip().split('\n') + info = {} + for line in lines: + if '=' in line: + key, value = line.split('=', 1) + info[key] = value + # 处理帧率(可能是分数形式) + if 'r_frame_rate' in info: + try: + num, den = map(int, info['r_frame_rate'].split('/')) + info['fps'] = str(num / den) + except ValueError: + info['fps'] = info.get('r_frame_rate', '25') + + return info + + except subprocess.CalledProcessError as e: + logger.error(f"获取视频信息失败: {e.stderr}") + return { + 'width': '1280', + 'height': '720', + 'fps': '25', + 'duration': '0' + } + + def extract_frames_by_interval(self, output_dir: str, interval_seconds: float = 5.0, + use_hw_accel: bool = True, skip_seconds: float = 0.0) -> List[int]: + """ + 按指定时间间隔提取视频帧 + + Args: + output_dir: 输出目录 + interval_seconds: 帧提取间隔(秒) + use_hw_accel: 是否使用硬件加速 + skip_seconds: 跳过视频开头的秒数 + Returns: - Tuple[np.ndarray, int]: (关键帧, 帧索引) + List[int]: 提取的帧号列表 """ - if not shot_frames: - return None, -1 - - frame_features = [] - frame_indices = [] + if not os.path.exists(output_dir): + os.makedirs(output_dir) - for idx, frame in tqdm(shot_frames, - desc="处理镜头帧", - unit="帧", - leave=False): - gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) - resized_gray = cv2.resize(gray, (32, 32)) - frame_features.append(resized_gray.flatten()) - frame_indices.append(idx) - - frame_features = np.array(frame_features) + # 计算起始时间和帧提取点 + start_time = skip_seconds + end_time = self.duration + extraction_times = [] - kmeans = MiniBatchKMeans(n_clusters=1, batch_size=min(len(frame_features), 100), - random_state=0).fit(frame_features) + current_time = start_time + while current_time < end_time: + extraction_times.append(current_time) + current_time += interval_seconds - center_idx = np.argmin(np.sum((frame_features - kmeans.cluster_centers_[0]) ** 2, axis=1)) - - return shot_frames[center_idx][1], frame_indices[center_idx] + if not extraction_times: + logger.warning("未找到需要提取的帧") + return [] - def extract_keyframes(self, shot_boundaries: List[int]) -> Generator[Tuple[np.ndarray, int], None, None]: - """ - 使用生成器方式提取关键帧 + # 确定硬件加速器选项 + hw_accel = [] + if use_hw_accel: + # 尝试检测可用的硬件加速器 + hw_accel_options = self._detect_hw_accelerator() + if hw_accel_options: + hw_accel = hw_accel_options + logger.info(f"使用硬件加速: {' '.join(hw_accel)}") + else: + logger.warning("未检测到可用的硬件加速器,使用软件解码") - Args: - shot_boundaries: 镜头边界列表 + # 提取帧 + frame_numbers = [] + for i, timestamp in enumerate(tqdm(extraction_times, desc="提取视频帧")): + frame_number = int(timestamp * self.fps) + frame_numbers.append(frame_number) - Yields: - Tuple[np.ndarray, int]: (关键帧, 帧索引) - """ - shot_frames = [] - current_shot_start = 0 + # 格式化时间戳字符串 (HHMMSSmmm) + hours = int(timestamp // 3600) + minutes = int((timestamp % 3600) // 60) + seconds = int(timestamp % 60) + milliseconds = int((timestamp % 1) * 1000) + time_str = f"{hours:02d}{minutes:02d}{seconds:02d}{milliseconds:03d}" + + output_path = os.path.join(output_dir, f"keyframe_{frame_number:06d}_{time_str}.jpg") + + # 使用ffmpeg提取单帧 + cmd = [ + "ffmpeg", + "-hide_banner", + "-loglevel", "error", + ] + + # 添加硬件加速参数 + cmd.extend(hw_accel) + + cmd.extend([ + "-ss", str(timestamp), + "-i", self.video_path, + "-vframes", "1", + "-q:v", "1", # 最高质量 + "-y", + output_path + ]) + + try: + subprocess.run(cmd, check=True, capture_output=True) + except subprocess.CalledProcessError as e: + logger.warning(f"提取帧 {frame_number} 失败: {e.stderr}") - for frame_idx, frame in self.preprocess_video(): - if frame_idx in shot_boundaries: - if shot_frames: - keyframe, keyframe_idx = self.process_shot(shot_frames) - if keyframe is not None: - yield keyframe, keyframe_idx - - # 清理内存 - shot_frames.clear() - gc.collect() + logger.info(f"成功提取了 {len(frame_numbers)} 个视频帧") + return frame_numbers + + def _detect_hw_accelerator(self) -> List[str]: + """ + 检测系统可用的硬件加速器 + + Returns: + List[str]: 硬件加速器ffmpeg命令参数 + """ + # 检测操作系统 + import platform + system = platform.system().lower() + + # 测试不同的硬件加速器 + accelerators = [] + + if system == 'darwin': # macOS + # 测试 videotoolbox (Apple 硬件加速) + test_cmd = [ + "ffmpeg", + "-hide_banner", + "-loglevel", "error", + "-hwaccel", "videotoolbox", + "-i", self.video_path, + "-t", "0.1", + "-f", "null", + "-" + ] + try: + subprocess.run(test_cmd, capture_output=True, check=True) + return ["-hwaccel", "videotoolbox"] + except subprocess.CalledProcessError: + pass - current_shot_start = frame_idx + elif system == 'linux': + # 测试 VAAPI + test_cmd = [ + "ffmpeg", + "-hide_banner", + "-loglevel", "error", + "-hwaccel", "vaapi", + "-i", self.video_path, + "-t", "0.1", + "-f", "null", + "-" + ] + try: + subprocess.run(test_cmd, capture_output=True, check=True) + return ["-hwaccel", "vaapi"] + except subprocess.CalledProcessError: + pass - shot_frames.append((frame_idx, frame)) - - # 控制单个镜头的最大帧数 - if len(shot_frames) > self.batch_size: - keyframe, keyframe_idx = self.process_shot(shot_frames) - if keyframe is not None: - yield keyframe, keyframe_idx - shot_frames.clear() - gc.collect() + # 尝试 CUDA + test_cmd = [ + "ffmpeg", + "-hide_banner", + "-loglevel", "error", + "-hwaccel", "cuda", + "-i", self.video_path, + "-t", "0.1", + "-f", "null", + "-" + ] + try: + subprocess.run(test_cmd, capture_output=True, check=True) + return ["-hwaccel", "cuda"] + except subprocess.CalledProcessError: + pass + + elif system == 'windows': + # 测试 CUDA + test_cmd = [ + "ffmpeg", + "-hide_banner", + "-loglevel", "error", + "-hwaccel", "cuda", + "-i", self.video_path, + "-t", "0.1", + "-f", "null", + "-" + ] + try: + subprocess.run(test_cmd, capture_output=True, check=True) + return ["-hwaccel", "cuda"] + except subprocess.CalledProcessError: + pass + + # 测试 D3D11VA + test_cmd = [ + "ffmpeg", + "-hide_banner", + "-loglevel", "error", + "-hwaccel", "d3d11va", + "-i", self.video_path, + "-t", "0.1", + "-f", "null", + "-" + ] + try: + subprocess.run(test_cmd, capture_output=True, check=True) + return ["-hwaccel", "d3d11va"] + except subprocess.CalledProcessError: + pass + + # 测试 DXVA2 + test_cmd = [ + "ffmpeg", + "-hide_banner", + "-loglevel", "error", + "-hwaccel", "dxva2", + "-i", self.video_path, + "-t", "0.1", + "-f", "null", + "-" + ] + try: + subprocess.run(test_cmd, capture_output=True, check=True) + return ["-hwaccel", "dxva2"] + except subprocess.CalledProcessError: + pass - # 处理最后一个镜头 - if shot_frames: - keyframe, keyframe_idx = self.process_shot(shot_frames) - if keyframe is not None: - yield keyframe, keyframe_idx + # 如果没有找到可用的硬件加速器 + return [] - def process_video(self, output_dir: str, skip_seconds: float = 0) -> None: + def process_video_pipeline(self, + output_dir: str, + skip_seconds: float = 0.0, + threshold: int = 20, # 此参数保留但不使用 + compressed_width: int = 320, # 此参数保留但不使用 + keep_temp: bool = False, # 此参数保留但不使用 + interval_seconds: float = 5.0, + use_hw_accel: bool = True) -> None: """ - 处理视频并提取关键帧,使用分批处理方式 + 执行简化的视频处理流程,直接从原视频按固定时间间隔提取帧 Args: output_dir: 输出目录 skip_seconds: 跳过视频开头的秒数 + threshold: 保留参数,不使用 + compressed_width: 保留参数,不使用 + keep_temp: 保留参数,不使用 + interval_seconds: 帧提取间隔(秒) + use_hw_accel: 是否使用硬件加速 """ + # 创建输出目录 + os.makedirs(output_dir, exist_ok=True) + try: - # 创建输出目录 - os.makedirs(output_dir, exist_ok=True) - - # 计算要跳过的帧数 - skip_frames = int(skip_seconds * self.fps) - self.cap.set(cv2.CAP_PROP_POS_FRAMES, skip_frames) - - # 检测镜头边界 - logger.info("开始检测镜头边界...") - shot_boundaries = self.detect_shot_boundaries() - - # 提取关键帧 - logger.info("开始提取关键帧...") - frame_count = 0 - - pbar = tqdm(self.extract_keyframes(shot_boundaries), - desc="提取关键帧", - unit="帧") - - for keyframe, frame_idx in pbar: - if frame_idx < skip_frames: - continue - - # 计算时间戳 - timestamp = frame_idx / self.fps - hours = int(timestamp // 3600) - minutes = int((timestamp % 3600) // 60) - seconds = int(timestamp % 60) - time_str = f"{hours:02d}{minutes:02d}{seconds:02d}" - - # 保存关键帧 - output_path = os.path.join(output_dir, - f'keyframe_{frame_idx:06d}_{time_str}.jpg') - cv2.imwrite(output_path, keyframe) - frame_count += 1 - - pbar.set_postfix({"已保存": frame_count}) - - if frame_count % 10 == 0: - gc.collect() - - logger.info(f"关键帧提取完成,共保存 {frame_count} 帧到 {output_dir}") + # 直接从原视频提取关键帧 + logger.info("从视频直接提取关键帧...") + self.extract_frames_by_interval( + output_dir, + interval_seconds=interval_seconds, + use_hw_accel=use_hw_accel, + skip_seconds=skip_seconds + ) + logger.info(f"处理完成!视频帧已保存在: {output_dir}") + except Exception as e: - logger.error(f"视频处理失败: {str(e)}") + import traceback + logger.error(f"视频处理失败: \n{traceback.format_exc()}") raise - finally: - # 确保资源被释放 - self.cap.release() - gc.collect() + + +if __name__ == "__main__": + import time + + start_time = time.time() + + # 使用示例 + processor = VideoProcessor("./resource/videos/test.mp4") + + # 设置间隔为3秒提取帧 + processor.process_video_pipeline( + output_dir="output", + interval_seconds=3.0, + use_hw_accel=True + ) + + end_time = time.time() + print(f"处理完成!总耗时: {end_time - start_time:.2f} 秒") diff --git a/app/utils/video_processor_v2.py b/app/utils/video_processor_v2.py deleted file mode 100644 index 825306b..0000000 --- a/app/utils/video_processor_v2.py +++ /dev/null @@ -1,382 +0,0 @@ -import cv2 -import numpy as np -from sklearn.cluster import KMeans -import os -import re -from typing import List, Tuple, Generator -from loguru import logger -import subprocess -from tqdm import tqdm - - -class VideoProcessor: - def __init__(self, video_path: str): - """ - 初始化视频处理器 - - Args: - video_path: 视频文件路径 - """ - if not os.path.exists(video_path): - raise FileNotFoundError(f"视频文件不存在: {video_path}") - - self.video_path = video_path - self.cap = cv2.VideoCapture(video_path) - - if not self.cap.isOpened(): - raise RuntimeError(f"无法打开视频文件: {video_path}") - - self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) - self.fps = int(self.cap.get(cv2.CAP_PROP_FPS)) - - def __del__(self): - """析构函数,确保视频资源被释放""" - if hasattr(self, 'cap'): - self.cap.release() - - def preprocess_video(self) -> Generator[np.ndarray, None, None]: - """ - 使用生成器方式读取视频帧 - - Yields: - np.ndarray: 视频帧 - """ - self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # 重置到视频开始 - while self.cap.isOpened(): - ret, frame = self.cap.read() - if not ret: - break - yield frame - - def detect_shot_boundaries(self, frames: List[np.ndarray], threshold: int = 30) -> List[int]: - """ - 使用帧差法检测镜头边界 - - Args: - frames: 视频帧列表 - threshold: 差异阈值,默认值调低为30 - - Returns: - List[int]: 镜头边界帧的索引列表 - """ - shot_boundaries = [] - if len(frames) < 2: # 添加帧数检查 - logger.warning("视频帧数过少,无法检测场景边界") - return [len(frames) - 1] # 返回最后一帧作为边界 - - for i in range(1, len(frames)): - prev_frame = cv2.cvtColor(frames[i - 1], cv2.COLOR_BGR2GRAY) - curr_frame = cv2.cvtColor(frames[i], cv2.COLOR_BGR2GRAY) - - # 计算帧差 - diff = np.mean(np.abs(curr_frame.astype(float) - prev_frame.astype(float))) - - if diff > threshold: - shot_boundaries.append(i) - - # 如果没有检测到任何边界,至少返回最后一帧 - if not shot_boundaries: - logger.warning("未检测到场景边界,将视频作为单个场景处理") - shot_boundaries.append(len(frames) - 1) - - return shot_boundaries - - def extract_keyframes(self, frames: List[np.ndarray], shot_boundaries: List[int]) -> Tuple[ - List[np.ndarray], List[int]]: - """ - 从每个镜头中提取关键帧 - - Args: - frames: 视频帧列表 - shot_boundaries: 镜头边界列表 - - Returns: - Tuple[List[np.ndarray], List[int]]: 关键帧列表和对应的帧索引 - """ - keyframes = [] - keyframe_indices = [] - - for i in tqdm(range(len(shot_boundaries)), desc="提取关键帧"): - start = shot_boundaries[i - 1] if i > 0 else 0 - end = shot_boundaries[i] - shot_frames = frames[start:end] - - if not shot_frames: - continue - - # 将每一帧转换为灰度图并展平为一维数组 - frame_features = np.array([cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).flatten() - for frame in shot_frames]) - - try: - # 尝试使用 KMeans - kmeans = KMeans(n_clusters=1, random_state=0).fit(frame_features) - center_idx = np.argmin(np.sum((frame_features - kmeans.cluster_centers_[0]) ** 2, axis=1)) - except Exception as e: - logger.warning(f"KMeans 聚类失败,使用备选方案: {str(e)}") - # 备选方案:选择镜头中间的帧作为关键帧 - center_idx = len(shot_frames) // 2 - - keyframes.append(shot_frames[center_idx]) - keyframe_indices.append(start + center_idx) - - return keyframes, keyframe_indices - - def save_keyframes(self, keyframes: List[np.ndarray], keyframe_indices: List[int], - output_dir: str, desc: str = "保存关键帧") -> None: - """ - 保存关键帧到指定目录,文件名格式为:keyframe_帧序号_时间戳.jpg - 时间戳精确到毫秒,格式为:HHMMSSmmm - """ - if not os.path.exists(output_dir): - os.makedirs(output_dir) - - for keyframe, frame_idx in tqdm(zip(keyframes, keyframe_indices), - total=len(keyframes), - desc=desc): - # 计算精确到毫秒的时间戳 - timestamp = frame_idx / self.fps - hours = int(timestamp // 3600) - minutes = int((timestamp % 3600) // 60) - seconds = int(timestamp % 60) - milliseconds = int((timestamp % 1) * 1000) # 计算毫秒部分 - time_str = f"{hours:02d}{minutes:02d}{seconds:02d}{milliseconds:03d}" - - output_path = os.path.join(output_dir, - f'keyframe_{frame_idx:06d}_{time_str}.jpg') - cv2.imwrite(output_path, keyframe) - - def extract_frames_by_numbers(self, frame_numbers: List[int], output_folder: str) -> None: - """ - 根据指定的帧号提取帧,如果多个帧在同一毫秒内,只保留一个 - """ - if not frame_numbers: - raise ValueError("未提供帧号列表") - - if any(fn >= self.total_frames or fn < 0 for fn in frame_numbers): - raise ValueError("存在无效的帧号") - - if not os.path.exists(output_folder): - os.makedirs(output_folder) - - # 用于记录已处理的时间戳(毫秒) - processed_timestamps = set() - - for frame_number in tqdm(frame_numbers, desc="提取高清帧"): - # 计算精确到毫秒的时间戳 - timestamp = frame_number / self.fps - timestamp_ms = int(timestamp * 1000) # 转换为毫秒 - - # 如果这一毫秒已经处理过,跳过 - if timestamp_ms in processed_timestamps: - continue - - self.cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) - ret, frame = self.cap.read() - - if ret: - # 记录这一毫秒已经处理 - processed_timestamps.add(timestamp_ms) - - # 计算时间戳字符串 - hours = int(timestamp // 3600) - minutes = int((timestamp % 3600) // 60) - seconds = int(timestamp % 60) - milliseconds = int((timestamp % 1) * 1000) # 计算毫秒部分 - time_str = f"{hours:02d}{minutes:02d}{seconds:02d}{milliseconds:03d}" - - output_path = os.path.join(output_folder, - f"keyframe_{frame_number:06d}_{time_str}.jpg") - cv2.imwrite(output_path, frame) - else: - logger.info(f"无法读取帧 {frame_number}") - - logger.info(f"共提取了 {len(processed_timestamps)} 个不同时间戳的帧") - - @staticmethod - def extract_numbers_from_folder(folder_path: str) -> List[int]: - """ - 从文件夹中提取帧号 - - Args: - folder_path: 关键帧文件夹路径 - - Returns: - List[int]: 排序后的帧号列表 - """ - files = [f for f in os.listdir(folder_path) if f.endswith('.jpg')] - # 更新正则表达式以匹配新的文件名格式:keyframe_000123_010534123.jpg - pattern = re.compile(r'keyframe_(\d+)_\d{9}\.jpg$') - numbers = [] - - for f in files: - match = pattern.search(f) - if match: - numbers.append(int(match.group(1))) - else: - logger.warning(f"文件名格式不匹配: {f}") - - if not numbers: - logger.error(f"在目录 {folder_path} 中未找到有效的关键帧文件") - - return sorted(numbers) - - def process_video(self, output_dir: str, skip_seconds: float = 0, threshold: int = 30) -> None: - """ - 处理视频并提取关键帧 - - Args: - output_dir: 输出目录 - skip_seconds: 跳过视频开头的秒数 - """ - skip_frames = int(skip_seconds * self.fps) - - logger.info("读取视频帧...") - frames = [] - for frame in tqdm(self.preprocess_video(), - total=self.total_frames, - desc="读取视频"): - frames.append(frame) - - frames = frames[skip_frames:] - - if not frames: - raise ValueError(f"跳过 {skip_seconds} 秒后没有剩余帧可以处理") - - logger.info("检测场景边界...") - shot_boundaries = self.detect_shot_boundaries(frames, threshold) - logger.info(f"检测到 {len(shot_boundaries)} 个场景边界") - - keyframes, keyframe_indices = self.extract_keyframes(frames, shot_boundaries) - - adjusted_indices = [idx + skip_frames for idx in keyframe_indices] - self.save_keyframes(keyframes, adjusted_indices, output_dir, desc="保存压缩关键帧") - - def process_video_pipeline(self, - output_dir: str, - skip_seconds: float = 0, - threshold: int = 20, # 降低默认阈值 - compressed_width: int = 320, - keep_temp: bool = False) -> None: - """ - 执行完整的视频处理流程 - - Args: - threshold: 降低默认阈值为20,使场景检测更敏感 - """ - os.makedirs(output_dir, exist_ok=True) - temp_dir = os.path.join(output_dir, 'temp') - compressed_dir = os.path.join(temp_dir, 'compressed') - mini_frames_dir = os.path.join(temp_dir, 'mini_frames') - hd_frames_dir = output_dir - - os.makedirs(temp_dir, exist_ok=True) - os.makedirs(compressed_dir, exist_ok=True) - os.makedirs(mini_frames_dir, exist_ok=True) - os.makedirs(hd_frames_dir, exist_ok=True) - - mini_processor = None - compressed_video = None - - try: - # 1. 压缩视频 - video_name = os.path.splitext(os.path.basename(self.video_path))[0] - compressed_video = os.path.join(compressed_dir, f"{video_name}_compressed.mp4") - - # 获取原始视频的宽度和高度 - original_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - original_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - - logger.info("步骤1: 压缩视频...") - if original_width > original_height: - # 横版视频 - scale_filter = f'scale={compressed_width}:-1' - else: - # 竖版视频 - scale_filter = f'scale=-1:{compressed_width}' - - ffmpeg_cmd = [ - 'ffmpeg', '-i', self.video_path, - '-vf', scale_filter, - '-y', - compressed_video - ] - - try: - subprocess.run(ffmpeg_cmd, check=True, capture_output=True, text=True) - except subprocess.CalledProcessError as e: - logger.error(f"FFmpeg 错误输出: {e.stderr}") - raise - - # 2. 从压缩视频中提取关键帧 - logger.info("\n步骤2: 从压缩视频提取关键帧...") - mini_processor = VideoProcessor(compressed_video) - mini_processor.process_video(mini_frames_dir, skip_seconds, threshold) - - # 3. 从原始视频提取高清关键帧 - logger.info("\n步骤3: 提取高清关键帧...") - frame_numbers = self.extract_numbers_from_folder(mini_frames_dir) - - if not frame_numbers: - raise ValueError("未能从压缩视频中提取到有效的关键帧") - - self.extract_frames_by_numbers(frame_numbers, hd_frames_dir) - - logger.info(f"处理完成!高清关键帧保存在: {hd_frames_dir}") - - except Exception as e: - import traceback - logger.error(f"视频处理失败: \n{traceback.format_exc()}") - raise - - finally: - # 释放资源 - if mini_processor: - mini_processor.cap.release() - del mini_processor - - # 确保视频文件句柄被释放 - if hasattr(self, 'cap'): - self.cap.release() - - # 等待资源释放 - import time - time.sleep(0.5) - - if not keep_temp: - try: - # 先删除压缩视频文件 - if compressed_video and os.path.exists(compressed_video): - try: - os.remove(compressed_video) - except Exception as e: - logger.warning(f"删除压缩视频失败: {e}") - - # 再删除临时目录 - import shutil - if os.path.exists(temp_dir): - max_retries = 3 - for i in range(max_retries): - try: - shutil.rmtree(temp_dir) - break - except Exception as e: - if i == max_retries - 1: - logger.warning(f"清理临时文件失败: {e}") - else: - time.sleep(1) # 等待1秒后重试 - continue - - logger.info("临时文件已清理") - except Exception as e: - logger.warning(f"清理临时文件时出错: {e}") - - -if __name__ == "__main__": - import time - - start_time = time.time() - processor = VideoProcessor("E:\\projects\\NarratoAI\\resource\\videos\\test.mp4") - processor.process_video_pipeline(output_dir="output") - end_time = time.time() - print(f"处理完成!总耗时: {end_time - start_time:.2f} 秒") diff --git a/config.example.toml b/config.example.toml index 5620744..835a8e9 100644 --- a/config.example.toml +++ b/config.example.toml @@ -1,5 +1,5 @@ [app] - project_version="0.5.3" + project_version="0.6.0" # 支持视频理解的大模型提供商 # gemini # qwenvl diff --git a/requirements.txt b/requirements.txt index 723fa43..207865b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,38 +1,45 @@ +# 必须项 requests~=2.32.0 moviepy==2.1.1 edge-tts==6.1.19 streamlit~=1.45.0 +watchdog==6.0.0 +loguru~=0.7.3 +tomli~=2.2.1 openai~=1.77.0 google-generativeai>=0.8.5 -loguru~=0.7.2 -fastapi~=0.115.4 -uvicorn~=0.27.1 -pydantic~=2.11.4 +# 待优化项 +# opencv-python==4.11.0.86 +# scikit-learn==1.6.1 -faster-whisper~=1.0.1 -tomli~=2.0.1 -aiohttp~=3.10.10 -httpx==0.27.2 -urllib3~=2.2.1 +# fastapi~=0.115.4 +# uvicorn~=0.27.1 +# pydantic~=2.11.4 -python-multipart~=0.0.9 -redis==5.0.3 -opencv-python~=4.10.0.84 -azure-cognitiveservices-speech~=1.37.0 -git-changelog~=2.5.2 -watchdog==5.0.2 -pydub==0.25.1 -psutil>=5.9.0 -scikit-learn~=1.5.2 -pillow==10.3.0 -python-dotenv~=1.0.1 +# faster-whisper~=1.0.1 +# tomli~=2.0.1 +# aiohttp~=3.10.10 +# httpx==0.27.2 +# urllib3~=2.2.1 -tqdm>=4.66.6 -tenacity>=9.0.0 -tiktoken==0.8.0 -pysrt==1.1.2 -transformers==4.50.0 +# python-multipart~=0.0.9 +# redis==5.0.3 +# opencv-python~=4.10.0.84 +# azure-cognitiveservices-speech~=1.37.0 +# git-changelog~=2.5.2 +# watchdog==5.0.2 +# pydub==0.25.1 +# psutil>=5.9.0 +# scikit-learn~=1.5.2 +# pillow==10.3.0 +# python-dotenv~=1.0.1 + +# tqdm>=4.66.6 +# tenacity>=9.0.0 +# tiktoken==0.8.0 +# pysrt==1.1.2 +# transformers==4.50.0 # yt-dlp==2025.4.30 \ No newline at end of file diff --git a/webui.py b/webui.py index 1f605b5..5f296cd 100644 --- a/webui.py +++ b/webui.py @@ -1,7 +1,7 @@ import streamlit as st import os import sys -from uuid import uuid4 +from loguru import logger from app.config import config from webui.components import basic_settings, video_settings, audio_settings, subtitle_settings, script_settings, \ review_settings, merge_settings, system_settings @@ -18,7 +18,7 @@ st.set_page_config( initial_sidebar_state="auto", menu_items={ "Report a bug": "https://github.com/linyqh/NarratoAI/issues", - 'About': f"# NarratoAI:sunglasses: 📽️ \n #### Version: v{config.project_version} \n " + 'About': f"# Narrato:blue[AI] :sunglasses: 📽️ \n #### Version: v{config.project_version} \n " f"自动化影视解说视频详情请移步:https://github.com/linyqh/NarratoAI" }, ) @@ -37,17 +37,7 @@ def init_log(): _lvl = "DEBUG" def format_record(record): - # 增加更多需要过滤的警告消息 - ignore_messages = [ - "Examining the path of torch.classes raised", - "torch.cuda.is_available()", - "CUDA initialization" - ] - - for msg in ignore_messages: - if msg in record["message"]: - return "" - + # 简化日志格式化处理,不尝试按特定字符串过滤torch相关内容 file_path = record["file"].path relative_path = os.path.relpath(file_path, config.root_dir) record["file"].path = f"./{relative_path}" @@ -59,23 +49,53 @@ def init_log(): '- {message}' + "\n" return _format - # 优化日志过滤器 - def log_filter(record): - ignore_messages = [ - "Examining the path of torch.classes raised", - "torch.cuda.is_available()", - "CUDA initialization" - ] - return not any(msg in record["message"] for msg in ignore_messages) - + # 替换为更简单的过滤方式,避免在过滤时访问message内容 + # 此处先不设置复杂的过滤器,等应用启动后再动态添加 logger.add( sys.stdout, level=_lvl, format=format_record, - colorize=True, - filter=log_filter + colorize=True ) + # 应用启动后,可以再添加更复杂的过滤器 + def setup_advanced_filters(): + """在应用完全启动后设置高级过滤器""" + try: + for handler_id in logger._core.handlers: + logger.remove(handler_id) + + # 重新添加带有高级过滤的处理器 + def advanced_filter(record): + """更复杂的过滤器,在应用启动后安全使用""" + ignore_messages = [ + "Examining the path of torch.classes raised", + "torch.cuda.is_available()", + "CUDA initialization" + ] + return not any(msg in record["message"] for msg in ignore_messages) + + logger.add( + sys.stdout, + level=_lvl, + format=format_record, + colorize=True, + filter=advanced_filter + ) + except Exception as e: + # 如果过滤器设置失败,确保日志仍然可用 + logger.add( + sys.stdout, + level=_lvl, + format=format_record, + colorize=True + ) + logger.error(f"设置高级日志过滤器失败: {e}") + + # 将高级过滤器设置放到启动主逻辑后 + import threading + threading.Timer(5.0, setup_advanced_filters).start() + def init_global_state(): """初始化全局状态""" @@ -177,11 +197,18 @@ def main(): """主函数""" init_log() init_global_state() - utils.init_resources() + + # 仅初始化基本资源,避免过早地加载依赖PyTorch的资源 + # 检查是否能分解utils.init_resources()为基本资源和高级资源(如依赖PyTorch的资源) + try: + utils.init_resources() + except Exception as e: + logger.warning(f"资源初始化时出现警告: {e}") - st.title(f"NarratoAI :sunglasses:📽️") + st.title(f"Narrato:blue[AI]:sunglasses: 📽️") st.write(tr("Get Help")) + # 首先渲染不依赖PyTorch的UI部分 # 渲染基础设置面板 basic_settings.render_basic_settings(tr) # 渲染合并设置 @@ -196,13 +223,16 @@ def main(): audio_settings.render_audio_panel(tr) with panel[2]: subtitle_settings.render_subtitle_panel(tr) - # 渲染系统设置面板 - system_settings.render_system_panel(tr) - + # 渲染视频审查面板 review_settings.render_review_panel(tr) - - # 渲染生成按钮和处理逻辑 + + # 放到最后渲染可能使用PyTorch的部分 + # 渲染系统设置面板 + with panel[2]: + system_settings.render_system_panel(tr) + + # 放到最后渲染生成按钮和处理逻辑 render_generate_button() diff --git a/webui/components/merge_settings.py b/webui/components/merge_settings.py index 99b8b43..edaa183 100644 --- a/webui/components/merge_settings.py +++ b/webui/components/merge_settings.py @@ -285,8 +285,8 @@ def render_merge_settings(tr): error_message = str(e) if "moviepy" in error_message.lower(): st.error(tr("Error processing video files. Please check if the videos are valid MP4 files.")) - elif "pysrt" in error_message.lower(): - st.error(tr("Error processing subtitle files. Please check if the subtitles are valid SRT files.")) + # elif "pysrt" in error_message.lower(): + # st.error(tr("Error processing subtitle files. Please check if the subtitles are valid SRT files.")) else: st.error(f"{tr('Error during merge')}: {error_message}") diff --git a/webui/tools/generate_script_docu.py b/webui/tools/generate_script_docu.py index 6552ebf..a580f64 100644 --- a/webui/tools/generate_script_docu.py +++ b/webui/tools/generate_script_docu.py @@ -5,6 +5,7 @@ import time import asyncio import traceback import requests +from app.utils import video_processor import streamlit as st from loguru import logger from requests.adapters import HTTPAdapter @@ -12,7 +13,7 @@ from urllib3.util.retry import Retry from app.config import config from app.utils.script_generator import ScriptProcessor -from app.utils import utils, video_processor, video_processor_v2, qwenvl_analyzer +from app.utils import utils, video_processor, qwenvl_analyzer from webui.tools.base import create_vision_analyzer, get_batch_files, get_batch_timestamps, chekc_video_config @@ -64,21 +65,13 @@ def generate_script_docu(tr, params): os.makedirs(video_keyframes_dir, exist_ok=True) # 初始化视频处理器 - if config.frames.get("version") == "v2": - processor = video_processor_v2.VideoProcessor(params.video_origin_path) - # 处理视频并提取关键帧 - processor.process_video_pipeline( - output_dir=video_keyframes_dir, - skip_seconds=st.session_state.get('skip_seconds'), - threshold=st.session_state.get('threshold') - ) - else: - processor = video_processor.VideoProcessor(params.video_origin_path) - # 处理视频并提取关键帧 - processor.process_video( - output_dir=video_keyframes_dir, - skip_seconds=0 - ) + processor = video_processor.VideoProcessor(params.video_origin_path) + # 处理视频并提取关键帧 + processor.process_video_pipeline( + output_dir=video_keyframes_dir, + skip_seconds=st.session_state.get('skip_seconds'), + threshold=st.session_state.get('threshold') + ) # 获取所有关键文件路径 for filename in sorted(os.listdir(video_keyframes_dir)): diff --git a/webui/utils/merge_video.py b/webui/utils/merge_video.py index 6c8503c..9fa2b39 100644 --- a/webui/utils/merge_video.py +++ b/webui/utils/merge_video.py @@ -2,7 +2,7 @@ 合并视频和字幕文件 """ from moviepy import VideoFileClip, concatenate_videoclips -import pysrt +# import pysrt import os diff --git a/webui/utils/performance.py b/webui/utils/performance.py index 0eab5fa..d0af06c 100644 --- a/webui/utils/performance.py +++ b/webui/utils/performance.py @@ -1,7 +1,6 @@ -import psutil +# import psutil import os from loguru import logger -import torch class PerformanceMonitor: @staticmethod @@ -11,19 +10,35 @@ class PerformanceMonitor: logger.debug(f"Memory usage: {memory_info.rss / 1024 / 1024:.2f} MB") - if torch.cuda.is_available(): - gpu_memory = torch.cuda.memory_allocated() / 1024 / 1024 - logger.debug(f"GPU Memory usage: {gpu_memory:.2f} MB") + # 延迟导入torch并检查CUDA + try: + import torch + if torch.cuda.is_available(): + gpu_memory = torch.cuda.memory_allocated() / 1024 / 1024 + logger.debug(f"GPU Memory usage: {gpu_memory:.2f} MB") + except (ImportError, RuntimeError) as e: + # 无法导入torch或触发CUDA相关错误时,静默处理 + logger.debug(f"无法获取GPU内存信息: {e}") @staticmethod def cleanup_resources(): - if torch.cuda.is_available(): - torch.cuda.empty_cache() + # 延迟导入torch并清理CUDA + try: + import torch + if torch.cuda.is_available(): + torch.cuda.empty_cache() + logger.debug("CUDA缓存已清理") + except (ImportError, RuntimeError) as e: + # 无法导入torch或触发CUDA相关错误时,静默处理 + logger.debug(f"无法清理CUDA资源: {e}") import gc gc.collect() - PerformanceMonitor.monitor_memory() + # 仅报告进程内存,不尝试获取GPU内存 + process = psutil.Process(os.getpid()) + memory_info = process.memory_info() + logger.debug(f"Memory usage after cleanup: {memory_info.rss / 1024 / 1024:.2f} MB") def monitor_performance(func): """性能监控装饰器"""