移除 opencv 和 sklearn 提取关键帧的代码

This commit is contained in:
linyq 2025-05-07 15:41:01 +08:00
parent c3ea0bcc69
commit f6c3f1640b
13 changed files with 478 additions and 696 deletions

View File

@ -3,10 +3,11 @@ import json
import time import time
import asyncio import asyncio
import requests import requests
from app.utils import video_processor
from loguru import logger from loguru import logger
from typing import List, Dict, Any, Callable from typing import List, Dict, Any, Callable
from app.utils import utils, gemini_analyzer, video_processor, video_processor_v2 from app.utils import utils, gemini_analyzer, video_processor
from app.utils.script_generator import ScriptProcessor from app.utils.script_generator import ScriptProcessor
from app.config import config from app.config import config
@ -105,19 +106,12 @@ class ScriptGenerator:
os.makedirs(video_keyframes_dir, exist_ok=True) os.makedirs(video_keyframes_dir, exist_ok=True)
try: try:
if config.frames.get("version") == "v2": processor = video_processor.VideoProcessor(video_path)
processor = video_processor_v2.VideoProcessor(video_path)
processor.process_video_pipeline( processor.process_video_pipeline(
output_dir=video_keyframes_dir, output_dir=video_keyframes_dir,
skip_seconds=skip_seconds, skip_seconds=skip_seconds,
threshold=threshold threshold=threshold
) )
else:
processor = video_processor.VideoProcessor(video_path)
processor.process_video(
output_dir=video_keyframes_dir,
skip_seconds=skip_seconds
)
for filename in sorted(os.listdir(video_keyframes_dir)): for filename in sorted(os.listdir(video_keyframes_dir)):
if filename.endswith('.jpg'): if filename.endswith('.jpg'):

View File

@ -4,7 +4,7 @@ import re
import traceback import traceback
from typing import Optional from typing import Optional
from faster_whisper import WhisperModel # from faster_whisper import WhisperModel
from timeit import default_timer as timer from timeit import default_timer as timer
from loguru import logger from loguru import logger
import google.generativeai as genai import google.generativeai as genai
@ -45,12 +45,25 @@ def create(audio_file, subtitle_file: str = ""):
) )
return None return None
# 尝试使用 CUDA如果失败则回退到 CPU # 首先使用CPU模式不触发CUDA检查
use_cuda = False
try:
# 在函数中延迟导入torch而不是在全局范围内
# 使用安全的方式检查CUDA可用性
def check_cuda_available():
try: try:
import torch import torch
if torch.cuda.is_available(): return torch.cuda.is_available()
try: except (ImportError, RuntimeError) as e:
logger.warning(f"检查CUDA可用性时出错: {e}")
return False
# 仅当明确需要时才检查CUDA
use_cuda = check_cuda_available()
if use_cuda:
logger.info(f"尝试使用 CUDA 加载模型: {model_path}") logger.info(f"尝试使用 CUDA 加载模型: {model_path}")
try:
model = WhisperModel( model = WhisperModel(
model_size_or_path=model_path, model_size_or_path=model_path,
device="cuda", device="cuda",
@ -63,18 +76,18 @@ def create(audio_file, subtitle_file: str = ""):
except Exception as e: except Exception as e:
logger.warning(f"CUDA 加载失败,错误信息: {str(e)}") logger.warning(f"CUDA 加载失败,错误信息: {str(e)}")
logger.warning("回退到 CPU 模式") logger.warning("回退到 CPU 模式")
device = "cpu" use_cuda = False
compute_type = "int8"
else: else:
logger.info("未检测到 CUDA使用 CPU 模式") logger.info("使用 CPU 模式")
device = "cpu" except Exception as e:
compute_type = "int8" logger.warning(f"CUDA检查过程出错: {e}")
except ImportError: logger.warning("默认使用CPU模式")
logger.warning("未安装 torch使用 CPU 模式") use_cuda = False
device = "cpu"
compute_type = "int8"
if device == "cpu": # 如果CUDA不可用或加载失败使用CPU
if not use_cuda:
device = "cpu"
compute_type = "int8"
logger.info(f"使用 CPU 加载模型: {model_path}") logger.info(f"使用 CPU 加载模型: {model_path}")
model = WhisperModel( model = WhisperModel(
model_size_or_path=model_path, model_size_or_path=model_path,

View File

@ -1,6 +1,6 @@
import traceback import traceback
import pysrt # import pysrt
from typing import Optional from typing import Optional
from typing import List from typing import List
from loguru import logger from loguru import logger

View File

@ -2,7 +2,7 @@ import os
import json import json
import traceback import traceback
from loguru import logger from loguru import logger
import tiktoken # import tiktoken
from typing import List, Dict from typing import List, Dict
from datetime import datetime from datetime import datetime
from openai import OpenAI from openai import OpenAI
@ -94,12 +94,12 @@ class OpenAIGenerator(BaseGenerator):
"user": "script_generator" "user": "script_generator"
} }
# 初始化token计数器 # # 初始化token计数器
try: # try:
self.encoding = tiktoken.encoding_for_model(self.model_name) # self.encoding = tiktoken.encoding_for_model(self.model_name)
except KeyError: # except KeyError:
logger.warning(f"未找到模型 {self.model_name} 的专用编码器,使用默认编码器") # logger.warning(f"未找到模型 {self.model_name} 的专用编码器,使用默认编码器")
self.encoding = tiktoken.get_encoding("cl100k_base") # self.encoding = tiktoken.get_encoding("cl100k_base")
def _generate(self, messages: list, params: dict) -> any: def _generate(self, messages: list, params: dict) -> any:
"""实现OpenAI特定的生成逻辑""" """实现OpenAI特定的生成逻辑"""

View File

@ -1,237 +1,349 @@
import cv2 """
import numpy as np 视频帧提取工具
from sklearn.cluster import MiniBatchKMeans
这个模块提供了简单高效的视频帧提取功能主要特点
1. 使用ffmpeg进行视频处理支持硬件加速
2. 按指定时间间隔提取视频关键帧
3. 支持多种视频格式
4. 支持高清视频帧输出
5. 直接从原视频提取高质量关键帧
不依赖OpenCV和sklearn等库只使用ffmpeg作为外部依赖降低了安装和使用的复杂度
"""
import os import os
import re import re
from typing import List, Tuple, Generator import time
import subprocess
from typing import List, Dict
from loguru import logger from loguru import logger
import gc
from tqdm import tqdm from tqdm import tqdm
class VideoProcessor: class VideoProcessor:
def __init__(self, video_path: str, batch_size: int = 100): def __init__(self, video_path: str):
""" """
初始化视频处理器 初始化视频处理器
Args: Args:
video_path: 视频文件路径 video_path: 视频文件路径
batch_size: 批处理大小控制内存使用
""" """
if not os.path.exists(video_path): if not os.path.exists(video_path):
raise FileNotFoundError(f"视频文件不存在: {video_path}") raise FileNotFoundError(f"视频文件不存在: {video_path}")
self.video_path = video_path self.video_path = video_path
self.batch_size = batch_size self.video_info = self._get_video_info()
self.cap = cv2.VideoCapture(video_path) self.fps = float(self.video_info.get('fps', 25))
self.duration = float(self.video_info.get('duration', 0))
self.width = int(self.video_info.get('width', 0))
self.height = int(self.video_info.get('height', 0))
self.total_frames = int(self.fps * self.duration)
if not self.cap.isOpened(): def _get_video_info(self) -> Dict[str, str]:
raise RuntimeError(f"无法打开视频文件: {video_path}")
self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
self.fps = int(self.cap.get(cv2.CAP_PROP_FPS))
def __del__(self):
"""析构函数,确保视频资源被释放"""
if hasattr(self, 'cap'):
self.cap.release()
gc.collect()
def preprocess_video(self) -> Generator[Tuple[int, np.ndarray], None, None]:
""" """
使用生成器方式分批读取视频帧 使用ffprobe获取视频信息
Yields:
Tuple[int, np.ndarray]: (帧索引, 视频帧)
"""
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
frame_idx = 0
while self.cap.isOpened():
ret, frame = self.cap.read()
if not ret:
break
# 降低分辨率以减少内存使用
frame = cv2.resize(frame, (0, 0), fx=0.5, fy=0.5)
yield frame_idx, frame
frame_idx += 1
# 定期进行垃圾回收
if frame_idx % 1000 == 0:
gc.collect()
def detect_shot_boundaries(self, threshold: int = 70) -> List[int]:
"""
使用批处理方式检测镜头边界
Args:
threshold: 差异阈值
Returns: Returns:
List[int]: 镜头边界帧的索引列表 Dict[str, str]: 包含视频基本信息的字典
""" """
shot_boundaries = [] cmd = [
prev_frame = None "ffprobe",
prev_idx = -1 "-v", "error",
"-select_streams", "v:0",
"-show_entries", "stream=width,height,r_frame_rate,duration",
"-of", "default=noprint_wrappers=1:nokey=0",
self.video_path
]
pbar = tqdm(self.preprocess_video(), try:
total=self.total_frames, result = subprocess.run(cmd, capture_output=True, text=True, check=True)
desc="检测镜头边界", lines = result.stdout.strip().split('\n')
unit="") info = {}
for line in lines:
if '=' in line:
key, value = line.split('=', 1)
info[key] = value
for frame_idx, curr_frame in pbar: # 处理帧率(可能是分数形式)
if prev_frame is not None: if 'r_frame_rate' in info:
prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY) try:
curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2GRAY) num, den = map(int, info['r_frame_rate'].split('/'))
info['fps'] = str(num / den)
except ValueError:
info['fps'] = info.get('r_frame_rate', '25')
diff = np.mean(np.abs(curr_gray.astype(float) - prev_gray.astype(float))) return info
if diff > threshold:
shot_boundaries.append(frame_idx)
pbar.set_postfix({"检测到边界": len(shot_boundaries)})
prev_frame = curr_frame.copy() except subprocess.CalledProcessError as e:
prev_idx = frame_idx logger.error(f"获取视频信息失败: {e.stderr}")
return {
'width': '1280',
'height': '720',
'fps': '25',
'duration': '0'
}
del curr_frame def extract_frames_by_interval(self, output_dir: str, interval_seconds: float = 5.0,
if frame_idx % 100 == 0: use_hw_accel: bool = True, skip_seconds: float = 0.0) -> List[int]:
gc.collect()
return shot_boundaries
def process_shot(self, shot_frames: List[Tuple[int, np.ndarray]]) -> Tuple[np.ndarray, int]:
""" """
处理单个镜头的 按指定时间间隔提取视频帧
Args: Args:
shot_frames: 镜头中的帧列表 output_dir: 输出目录
interval_seconds: 帧提取间隔
use_hw_accel: 是否使用硬件加速
skip_seconds: 跳过视频开头的秒数
Returns: Returns:
Tuple[np.ndarray, int]: (关键帧, 帧索引) List[int]: 提取的帧号列表
""" """
if not shot_frames: if not os.path.exists(output_dir):
return None, -1 os.makedirs(output_dir)
frame_features = [] # 计算起始时间和帧提取点
frame_indices = [] start_time = skip_seconds
end_time = self.duration
extraction_times = []
for idx, frame in tqdm(shot_frames, current_time = start_time
desc="处理镜头帧", while current_time < end_time:
unit="", extraction_times.append(current_time)
leave=False): current_time += interval_seconds
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
resized_gray = cv2.resize(gray, (32, 32))
frame_features.append(resized_gray.flatten())
frame_indices.append(idx)
frame_features = np.array(frame_features) if not extraction_times:
logger.warning("未找到需要提取的帧")
return []
kmeans = MiniBatchKMeans(n_clusters=1, batch_size=min(len(frame_features), 100), # 确定硬件加速器选项
random_state=0).fit(frame_features) hw_accel = []
if use_hw_accel:
# 尝试检测可用的硬件加速器
hw_accel_options = self._detect_hw_accelerator()
if hw_accel_options:
hw_accel = hw_accel_options
logger.info(f"使用硬件加速: {' '.join(hw_accel)}")
else:
logger.warning("未检测到可用的硬件加速器,使用软件解码")
center_idx = np.argmin(np.sum((frame_features - kmeans.cluster_centers_[0]) ** 2, axis=1)) # 提取帧
frame_numbers = []
for i, timestamp in enumerate(tqdm(extraction_times, desc="提取视频帧")):
frame_number = int(timestamp * self.fps)
frame_numbers.append(frame_number)
return shot_frames[center_idx][1], frame_indices[center_idx] # 格式化时间戳字符串 (HHMMSSmmm)
hours = int(timestamp // 3600)
minutes = int((timestamp % 3600) // 60)
seconds = int(timestamp % 60)
milliseconds = int((timestamp % 1) * 1000)
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}{milliseconds:03d}"
def extract_keyframes(self, shot_boundaries: List[int]) -> Generator[Tuple[np.ndarray, int], None, None]: output_path = os.path.join(output_dir, f"keyframe_{frame_number:06d}_{time_str}.jpg")
# 使用ffmpeg提取单帧
cmd = [
"ffmpeg",
"-hide_banner",
"-loglevel", "error",
]
# 添加硬件加速参数
cmd.extend(hw_accel)
cmd.extend([
"-ss", str(timestamp),
"-i", self.video_path,
"-vframes", "1",
"-q:v", "1", # 最高质量
"-y",
output_path
])
try:
subprocess.run(cmd, check=True, capture_output=True)
except subprocess.CalledProcessError as e:
logger.warning(f"提取帧 {frame_number} 失败: {e.stderr}")
logger.info(f"成功提取了 {len(frame_numbers)} 个视频帧")
return frame_numbers
def _detect_hw_accelerator(self) -> List[str]:
""" """
使用生成器方式提取关键帧 检测系统可用的硬件加速器
Args: Returns:
shot_boundaries: 镜头边界列表 List[str]: 硬件加速器ffmpeg命令参数
Yields:
Tuple[np.ndarray, int]: (关键帧, 帧索引)
""" """
shot_frames = [] # 检测操作系统
current_shot_start = 0 import platform
system = platform.system().lower()
for frame_idx, frame in self.preprocess_video(): # 测试不同的硬件加速器
if frame_idx in shot_boundaries: accelerators = []
if shot_frames:
keyframe, keyframe_idx = self.process_shot(shot_frames)
if keyframe is not None:
yield keyframe, keyframe_idx
# 清理内存 if system == 'darwin': # macOS
shot_frames.clear() # 测试 videotoolbox (Apple 硬件加速)
gc.collect() test_cmd = [
"ffmpeg",
"-hide_banner",
"-loglevel", "error",
"-hwaccel", "videotoolbox",
"-i", self.video_path,
"-t", "0.1",
"-f", "null",
"-"
]
try:
subprocess.run(test_cmd, capture_output=True, check=True)
return ["-hwaccel", "videotoolbox"]
except subprocess.CalledProcessError:
pass
current_shot_start = frame_idx elif system == 'linux':
# 测试 VAAPI
test_cmd = [
"ffmpeg",
"-hide_banner",
"-loglevel", "error",
"-hwaccel", "vaapi",
"-i", self.video_path,
"-t", "0.1",
"-f", "null",
"-"
]
try:
subprocess.run(test_cmd, capture_output=True, check=True)
return ["-hwaccel", "vaapi"]
except subprocess.CalledProcessError:
pass
shot_frames.append((frame_idx, frame)) # 尝试 CUDA
test_cmd = [
"ffmpeg",
"-hide_banner",
"-loglevel", "error",
"-hwaccel", "cuda",
"-i", self.video_path,
"-t", "0.1",
"-f", "null",
"-"
]
try:
subprocess.run(test_cmd, capture_output=True, check=True)
return ["-hwaccel", "cuda"]
except subprocess.CalledProcessError:
pass
# 控制单个镜头的最大帧数 elif system == 'windows':
if len(shot_frames) > self.batch_size: # 测试 CUDA
keyframe, keyframe_idx = self.process_shot(shot_frames) test_cmd = [
if keyframe is not None: "ffmpeg",
yield keyframe, keyframe_idx "-hide_banner",
shot_frames.clear() "-loglevel", "error",
gc.collect() "-hwaccel", "cuda",
"-i", self.video_path,
"-t", "0.1",
"-f", "null",
"-"
]
try:
subprocess.run(test_cmd, capture_output=True, check=True)
return ["-hwaccel", "cuda"]
except subprocess.CalledProcessError:
pass
# 处理最后一个镜头 # 测试 D3D11VA
if shot_frames: test_cmd = [
keyframe, keyframe_idx = self.process_shot(shot_frames) "ffmpeg",
if keyframe is not None: "-hide_banner",
yield keyframe, keyframe_idx "-loglevel", "error",
"-hwaccel", "d3d11va",
"-i", self.video_path,
"-t", "0.1",
"-f", "null",
"-"
]
try:
subprocess.run(test_cmd, capture_output=True, check=True)
return ["-hwaccel", "d3d11va"]
except subprocess.CalledProcessError:
pass
def process_video(self, output_dir: str, skip_seconds: float = 0) -> None: # 测试 DXVA2
test_cmd = [
"ffmpeg",
"-hide_banner",
"-loglevel", "error",
"-hwaccel", "dxva2",
"-i", self.video_path,
"-t", "0.1",
"-f", "null",
"-"
]
try:
subprocess.run(test_cmd, capture_output=True, check=True)
return ["-hwaccel", "dxva2"]
except subprocess.CalledProcessError:
pass
# 如果没有找到可用的硬件加速器
return []
def process_video_pipeline(self,
output_dir: str,
skip_seconds: float = 0.0,
threshold: int = 20, # 此参数保留但不使用
compressed_width: int = 320, # 此参数保留但不使用
keep_temp: bool = False, # 此参数保留但不使用
interval_seconds: float = 5.0,
use_hw_accel: bool = True) -> None:
""" """
处理视频并提取关键帧使用分批处理方式 执行简化的视频处理流程直接从原视频按固定时间间隔提取帧
Args: Args:
output_dir: 输出目录 output_dir: 输出目录
skip_seconds: 跳过视频开头的秒数 skip_seconds: 跳过视频开头的秒数
threshold: 保留参数不使用
compressed_width: 保留参数不使用
keep_temp: 保留参数不使用
interval_seconds: 帧提取间隔
use_hw_accel: 是否使用硬件加速
""" """
try:
# 创建输出目录 # 创建输出目录
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
# 计算要跳过的帧数 try:
skip_frames = int(skip_seconds * self.fps) # 直接从原视频提取关键帧
self.cap.set(cv2.CAP_PROP_POS_FRAMES, skip_frames) logger.info("从视频直接提取关键帧...")
self.extract_frames_by_interval(
output_dir,
interval_seconds=interval_seconds,
use_hw_accel=use_hw_accel,
skip_seconds=skip_seconds
)
# 检测镜头边界 logger.info(f"处理完成!视频帧已保存在: {output_dir}")
logger.info("开始检测镜头边界...")
shot_boundaries = self.detect_shot_boundaries()
# 提取关键帧
logger.info("开始提取关键帧...")
frame_count = 0
pbar = tqdm(self.extract_keyframes(shot_boundaries),
desc="提取关键帧",
unit="")
for keyframe, frame_idx in pbar:
if frame_idx < skip_frames:
continue
# 计算时间戳
timestamp = frame_idx / self.fps
hours = int(timestamp // 3600)
minutes = int((timestamp % 3600) // 60)
seconds = int(timestamp % 60)
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}"
# 保存关键帧
output_path = os.path.join(output_dir,
f'keyframe_{frame_idx:06d}_{time_str}.jpg')
cv2.imwrite(output_path, keyframe)
frame_count += 1
pbar.set_postfix({"已保存": frame_count})
if frame_count % 10 == 0:
gc.collect()
logger.info(f"关键帧提取完成,共保存 {frame_count} 帧到 {output_dir}")
except Exception as e: except Exception as e:
logger.error(f"视频处理失败: {str(e)}") import traceback
logger.error(f"视频处理失败: \n{traceback.format_exc()}")
raise raise
finally:
# 确保资源被释放
self.cap.release() if __name__ == "__main__":
gc.collect() import time
start_time = time.time()
# 使用示例
processor = VideoProcessor("./resource/videos/test.mp4")
# 设置间隔为3秒提取帧
processor.process_video_pipeline(
output_dir="output",
interval_seconds=3.0,
use_hw_accel=True
)
end_time = time.time()
print(f"处理完成!总耗时: {end_time - start_time:.2f}")

View File

@ -1,382 +0,0 @@
import cv2
import numpy as np
from sklearn.cluster import KMeans
import os
import re
from typing import List, Tuple, Generator
from loguru import logger
import subprocess
from tqdm import tqdm
class VideoProcessor:
def __init__(self, video_path: str):
"""
初始化视频处理器
Args:
video_path: 视频文件路径
"""
if not os.path.exists(video_path):
raise FileNotFoundError(f"视频文件不存在: {video_path}")
self.video_path = video_path
self.cap = cv2.VideoCapture(video_path)
if not self.cap.isOpened():
raise RuntimeError(f"无法打开视频文件: {video_path}")
self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
self.fps = int(self.cap.get(cv2.CAP_PROP_FPS))
def __del__(self):
"""析构函数,确保视频资源被释放"""
if hasattr(self, 'cap'):
self.cap.release()
def preprocess_video(self) -> Generator[np.ndarray, None, None]:
"""
使用生成器方式读取视频帧
Yields:
np.ndarray: 视频帧
"""
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # 重置到视频开始
while self.cap.isOpened():
ret, frame = self.cap.read()
if not ret:
break
yield frame
def detect_shot_boundaries(self, frames: List[np.ndarray], threshold: int = 30) -> List[int]:
"""
使用帧差法检测镜头边界
Args:
frames: 视频帧列表
threshold: 差异阈值默认值调低为30
Returns:
List[int]: 镜头边界帧的索引列表
"""
shot_boundaries = []
if len(frames) < 2: # 添加帧数检查
logger.warning("视频帧数过少,无法检测场景边界")
return [len(frames) - 1] # 返回最后一帧作为边界
for i in range(1, len(frames)):
prev_frame = cv2.cvtColor(frames[i - 1], cv2.COLOR_BGR2GRAY)
curr_frame = cv2.cvtColor(frames[i], cv2.COLOR_BGR2GRAY)
# 计算帧差
diff = np.mean(np.abs(curr_frame.astype(float) - prev_frame.astype(float)))
if diff > threshold:
shot_boundaries.append(i)
# 如果没有检测到任何边界,至少返回最后一帧
if not shot_boundaries:
logger.warning("未检测到场景边界,将视频作为单个场景处理")
shot_boundaries.append(len(frames) - 1)
return shot_boundaries
def extract_keyframes(self, frames: List[np.ndarray], shot_boundaries: List[int]) -> Tuple[
List[np.ndarray], List[int]]:
"""
从每个镜头中提取关键帧
Args:
frames: 视频帧列表
shot_boundaries: 镜头边界列表
Returns:
Tuple[List[np.ndarray], List[int]]: 关键帧列表和对应的帧索引
"""
keyframes = []
keyframe_indices = []
for i in tqdm(range(len(shot_boundaries)), desc="提取关键帧"):
start = shot_boundaries[i - 1] if i > 0 else 0
end = shot_boundaries[i]
shot_frames = frames[start:end]
if not shot_frames:
continue
# 将每一帧转换为灰度图并展平为一维数组
frame_features = np.array([cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).flatten()
for frame in shot_frames])
try:
# 尝试使用 KMeans
kmeans = KMeans(n_clusters=1, random_state=0).fit(frame_features)
center_idx = np.argmin(np.sum((frame_features - kmeans.cluster_centers_[0]) ** 2, axis=1))
except Exception as e:
logger.warning(f"KMeans 聚类失败,使用备选方案: {str(e)}")
# 备选方案:选择镜头中间的帧作为关键帧
center_idx = len(shot_frames) // 2
keyframes.append(shot_frames[center_idx])
keyframe_indices.append(start + center_idx)
return keyframes, keyframe_indices
def save_keyframes(self, keyframes: List[np.ndarray], keyframe_indices: List[int],
output_dir: str, desc: str = "保存关键帧") -> None:
"""
保存关键帧到指定目录文件名格式为keyframe_帧序号_时间戳.jpg
时间戳精确到毫秒格式为HHMMSSmmm
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
for keyframe, frame_idx in tqdm(zip(keyframes, keyframe_indices),
total=len(keyframes),
desc=desc):
# 计算精确到毫秒的时间戳
timestamp = frame_idx / self.fps
hours = int(timestamp // 3600)
minutes = int((timestamp % 3600) // 60)
seconds = int(timestamp % 60)
milliseconds = int((timestamp % 1) * 1000) # 计算毫秒部分
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}{milliseconds:03d}"
output_path = os.path.join(output_dir,
f'keyframe_{frame_idx:06d}_{time_str}.jpg')
cv2.imwrite(output_path, keyframe)
def extract_frames_by_numbers(self, frame_numbers: List[int], output_folder: str) -> None:
"""
根据指定的帧号提取帧如果多个帧在同一毫秒内只保留一个
"""
if not frame_numbers:
raise ValueError("未提供帧号列表")
if any(fn >= self.total_frames or fn < 0 for fn in frame_numbers):
raise ValueError("存在无效的帧号")
if not os.path.exists(output_folder):
os.makedirs(output_folder)
# 用于记录已处理的时间戳(毫秒)
processed_timestamps = set()
for frame_number in tqdm(frame_numbers, desc="提取高清帧"):
# 计算精确到毫秒的时间戳
timestamp = frame_number / self.fps
timestamp_ms = int(timestamp * 1000) # 转换为毫秒
# 如果这一毫秒已经处理过,跳过
if timestamp_ms in processed_timestamps:
continue
self.cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
ret, frame = self.cap.read()
if ret:
# 记录这一毫秒已经处理
processed_timestamps.add(timestamp_ms)
# 计算时间戳字符串
hours = int(timestamp // 3600)
minutes = int((timestamp % 3600) // 60)
seconds = int(timestamp % 60)
milliseconds = int((timestamp % 1) * 1000) # 计算毫秒部分
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}{milliseconds:03d}"
output_path = os.path.join(output_folder,
f"keyframe_{frame_number:06d}_{time_str}.jpg")
cv2.imwrite(output_path, frame)
else:
logger.info(f"无法读取帧 {frame_number}")
logger.info(f"共提取了 {len(processed_timestamps)} 个不同时间戳的帧")
@staticmethod
def extract_numbers_from_folder(folder_path: str) -> List[int]:
"""
从文件夹中提取帧号
Args:
folder_path: 关键帧文件夹路径
Returns:
List[int]: 排序后的帧号列表
"""
files = [f for f in os.listdir(folder_path) if f.endswith('.jpg')]
# 更新正则表达式以匹配新的文件名格式keyframe_000123_010534123.jpg
pattern = re.compile(r'keyframe_(\d+)_\d{9}\.jpg$')
numbers = []
for f in files:
match = pattern.search(f)
if match:
numbers.append(int(match.group(1)))
else:
logger.warning(f"文件名格式不匹配: {f}")
if not numbers:
logger.error(f"在目录 {folder_path} 中未找到有效的关键帧文件")
return sorted(numbers)
def process_video(self, output_dir: str, skip_seconds: float = 0, threshold: int = 30) -> None:
"""
处理视频并提取关键帧
Args:
output_dir: 输出目录
skip_seconds: 跳过视频开头的秒数
"""
skip_frames = int(skip_seconds * self.fps)
logger.info("读取视频帧...")
frames = []
for frame in tqdm(self.preprocess_video(),
total=self.total_frames,
desc="读取视频"):
frames.append(frame)
frames = frames[skip_frames:]
if not frames:
raise ValueError(f"跳过 {skip_seconds} 秒后没有剩余帧可以处理")
logger.info("检测场景边界...")
shot_boundaries = self.detect_shot_boundaries(frames, threshold)
logger.info(f"检测到 {len(shot_boundaries)} 个场景边界")
keyframes, keyframe_indices = self.extract_keyframes(frames, shot_boundaries)
adjusted_indices = [idx + skip_frames for idx in keyframe_indices]
self.save_keyframes(keyframes, adjusted_indices, output_dir, desc="保存压缩关键帧")
def process_video_pipeline(self,
output_dir: str,
skip_seconds: float = 0,
threshold: int = 20, # 降低默认阈值
compressed_width: int = 320,
keep_temp: bool = False) -> None:
"""
执行完整的视频处理流程
Args:
threshold: 降低默认阈值为20使场景检测更敏感
"""
os.makedirs(output_dir, exist_ok=True)
temp_dir = os.path.join(output_dir, 'temp')
compressed_dir = os.path.join(temp_dir, 'compressed')
mini_frames_dir = os.path.join(temp_dir, 'mini_frames')
hd_frames_dir = output_dir
os.makedirs(temp_dir, exist_ok=True)
os.makedirs(compressed_dir, exist_ok=True)
os.makedirs(mini_frames_dir, exist_ok=True)
os.makedirs(hd_frames_dir, exist_ok=True)
mini_processor = None
compressed_video = None
try:
# 1. 压缩视频
video_name = os.path.splitext(os.path.basename(self.video_path))[0]
compressed_video = os.path.join(compressed_dir, f"{video_name}_compressed.mp4")
# 获取原始视频的宽度和高度
original_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
original_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
logger.info("步骤1: 压缩视频...")
if original_width > original_height:
# 横版视频
scale_filter = f'scale={compressed_width}:-1'
else:
# 竖版视频
scale_filter = f'scale=-1:{compressed_width}'
ffmpeg_cmd = [
'ffmpeg', '-i', self.video_path,
'-vf', scale_filter,
'-y',
compressed_video
]
try:
subprocess.run(ffmpeg_cmd, check=True, capture_output=True, text=True)
except subprocess.CalledProcessError as e:
logger.error(f"FFmpeg 错误输出: {e.stderr}")
raise
# 2. 从压缩视频中提取关键帧
logger.info("\n步骤2: 从压缩视频提取关键帧...")
mini_processor = VideoProcessor(compressed_video)
mini_processor.process_video(mini_frames_dir, skip_seconds, threshold)
# 3. 从原始视频提取高清关键帧
logger.info("\n步骤3: 提取高清关键帧...")
frame_numbers = self.extract_numbers_from_folder(mini_frames_dir)
if not frame_numbers:
raise ValueError("未能从压缩视频中提取到有效的关键帧")
self.extract_frames_by_numbers(frame_numbers, hd_frames_dir)
logger.info(f"处理完成!高清关键帧保存在: {hd_frames_dir}")
except Exception as e:
import traceback
logger.error(f"视频处理失败: \n{traceback.format_exc()}")
raise
finally:
# 释放资源
if mini_processor:
mini_processor.cap.release()
del mini_processor
# 确保视频文件句柄被释放
if hasattr(self, 'cap'):
self.cap.release()
# 等待资源释放
import time
time.sleep(0.5)
if not keep_temp:
try:
# 先删除压缩视频文件
if compressed_video and os.path.exists(compressed_video):
try:
os.remove(compressed_video)
except Exception as e:
logger.warning(f"删除压缩视频失败: {e}")
# 再删除临时目录
import shutil
if os.path.exists(temp_dir):
max_retries = 3
for i in range(max_retries):
try:
shutil.rmtree(temp_dir)
break
except Exception as e:
if i == max_retries - 1:
logger.warning(f"清理临时文件失败: {e}")
else:
time.sleep(1) # 等待1秒后重试
continue
logger.info("临时文件已清理")
except Exception as e:
logger.warning(f"清理临时文件时出错: {e}")
if __name__ == "__main__":
import time
start_time = time.time()
processor = VideoProcessor("E:\\projects\\NarratoAI\\resource\\videos\\test.mp4")
processor.process_video_pipeline(output_dir="output")
end_time = time.time()
print(f"处理完成!总耗时: {end_time - start_time:.2f}")

View File

@ -1,5 +1,5 @@
[app] [app]
project_version="0.5.3" project_version="0.6.0"
# 支持视频理解的大模型提供商 # 支持视频理解的大模型提供商
# gemini # gemini
# qwenvl # qwenvl

View File

@ -1,38 +1,45 @@
# 必须项
requests~=2.32.0 requests~=2.32.0
moviepy==2.1.1 moviepy==2.1.1
edge-tts==6.1.19 edge-tts==6.1.19
streamlit~=1.45.0 streamlit~=1.45.0
watchdog==6.0.0
loguru~=0.7.3
tomli~=2.2.1
openai~=1.77.0 openai~=1.77.0
google-generativeai>=0.8.5 google-generativeai>=0.8.5
loguru~=0.7.2 # 待优化项
fastapi~=0.115.4 # opencv-python==4.11.0.86
uvicorn~=0.27.1 # scikit-learn==1.6.1
pydantic~=2.11.4
faster-whisper~=1.0.1 # fastapi~=0.115.4
tomli~=2.0.1 # uvicorn~=0.27.1
aiohttp~=3.10.10 # pydantic~=2.11.4
httpx==0.27.2
urllib3~=2.2.1
python-multipart~=0.0.9 # faster-whisper~=1.0.1
redis==5.0.3 # tomli~=2.0.1
opencv-python~=4.10.0.84 # aiohttp~=3.10.10
azure-cognitiveservices-speech~=1.37.0 # httpx==0.27.2
git-changelog~=2.5.2 # urllib3~=2.2.1
watchdog==5.0.2
pydub==0.25.1
psutil>=5.9.0
scikit-learn~=1.5.2
pillow==10.3.0
python-dotenv~=1.0.1
tqdm>=4.66.6 # python-multipart~=0.0.9
tenacity>=9.0.0 # redis==5.0.3
tiktoken==0.8.0 # opencv-python~=4.10.0.84
pysrt==1.1.2 # azure-cognitiveservices-speech~=1.37.0
transformers==4.50.0 # git-changelog~=2.5.2
# watchdog==5.0.2
# pydub==0.25.1
# psutil>=5.9.0
# scikit-learn~=1.5.2
# pillow==10.3.0
# python-dotenv~=1.0.1
# tqdm>=4.66.6
# tenacity>=9.0.0
# tiktoken==0.8.0
# pysrt==1.1.2
# transformers==4.50.0
# yt-dlp==2025.4.30 # yt-dlp==2025.4.30

View File

@ -1,7 +1,7 @@
import streamlit as st import streamlit as st
import os import os
import sys import sys
from uuid import uuid4 from loguru import logger
from app.config import config from app.config import config
from webui.components import basic_settings, video_settings, audio_settings, subtitle_settings, script_settings, \ from webui.components import basic_settings, video_settings, audio_settings, subtitle_settings, script_settings, \
review_settings, merge_settings, system_settings review_settings, merge_settings, system_settings
@ -18,7 +18,7 @@ st.set_page_config(
initial_sidebar_state="auto", initial_sidebar_state="auto",
menu_items={ menu_items={
"Report a bug": "https://github.com/linyqh/NarratoAI/issues", "Report a bug": "https://github.com/linyqh/NarratoAI/issues",
'About': f"# NarratoAI:sunglasses: 📽️ \n #### Version: v{config.project_version} \n " 'About': f"# Narrato:blue[AI] :sunglasses: 📽️ \n #### Version: v{config.project_version} \n "
f"自动化影视解说视频详情请移步https://github.com/linyqh/NarratoAI" f"自动化影视解说视频详情请移步https://github.com/linyqh/NarratoAI"
}, },
) )
@ -37,17 +37,7 @@ def init_log():
_lvl = "DEBUG" _lvl = "DEBUG"
def format_record(record): def format_record(record):
# 增加更多需要过滤的警告消息 # 简化日志格式化处理不尝试按特定字符串过滤torch相关内容
ignore_messages = [
"Examining the path of torch.classes raised",
"torch.cuda.is_available()",
"CUDA initialization"
]
for msg in ignore_messages:
if msg in record["message"]:
return ""
file_path = record["file"].path file_path = record["file"].path
relative_path = os.path.relpath(file_path, config.root_dir) relative_path = os.path.relpath(file_path, config.root_dir)
record["file"].path = f"./{relative_path}" record["file"].path = f"./{relative_path}"
@ -59,8 +49,25 @@ def init_log():
'- <level>{message}</>' + "\n" '- <level>{message}</>' + "\n"
return _format return _format
# 优化日志过滤器 # 替换为更简单的过滤方式避免在过滤时访问message内容
def log_filter(record): # 此处先不设置复杂的过滤器,等应用启动后再动态添加
logger.add(
sys.stdout,
level=_lvl,
format=format_record,
colorize=True
)
# 应用启动后,可以再添加更复杂的过滤器
def setup_advanced_filters():
"""在应用完全启动后设置高级过滤器"""
try:
for handler_id in logger._core.handlers:
logger.remove(handler_id)
# 重新添加带有高级过滤的处理器
def advanced_filter(record):
"""更复杂的过滤器,在应用启动后安全使用"""
ignore_messages = [ ignore_messages = [
"Examining the path of torch.classes raised", "Examining the path of torch.classes raised",
"torch.cuda.is_available()", "torch.cuda.is_available()",
@ -73,8 +80,21 @@ def init_log():
level=_lvl, level=_lvl,
format=format_record, format=format_record,
colorize=True, colorize=True,
filter=log_filter filter=advanced_filter
) )
except Exception as e:
# 如果过滤器设置失败,确保日志仍然可用
logger.add(
sys.stdout,
level=_lvl,
format=format_record,
colorize=True
)
logger.error(f"设置高级日志过滤器失败: {e}")
# 将高级过滤器设置放到启动主逻辑后
import threading
threading.Timer(5.0, setup_advanced_filters).start()
def init_global_state(): def init_global_state():
@ -177,11 +197,18 @@ def main():
"""主函数""" """主函数"""
init_log() init_log()
init_global_state() init_global_state()
utils.init_resources()
st.title(f"NarratoAI :sunglasses:📽️") # 仅初始化基本资源避免过早地加载依赖PyTorch的资源
# 检查是否能分解utils.init_resources()为基本资源和高级资源(如依赖PyTorch的资源)
try:
utils.init_resources()
except Exception as e:
logger.warning(f"资源初始化时出现警告: {e}")
st.title(f"Narrato:blue[AI]:sunglasses: 📽️")
st.write(tr("Get Help")) st.write(tr("Get Help"))
# 首先渲染不依赖PyTorch的UI部分
# 渲染基础设置面板 # 渲染基础设置面板
basic_settings.render_basic_settings(tr) basic_settings.render_basic_settings(tr)
# 渲染合并设置 # 渲染合并设置
@ -196,13 +223,16 @@ def main():
audio_settings.render_audio_panel(tr) audio_settings.render_audio_panel(tr)
with panel[2]: with panel[2]:
subtitle_settings.render_subtitle_panel(tr) subtitle_settings.render_subtitle_panel(tr)
# 渲染系统设置面板
system_settings.render_system_panel(tr)
# 渲染视频审查面板 # 渲染视频审查面板
review_settings.render_review_panel(tr) review_settings.render_review_panel(tr)
# 渲染生成按钮和处理逻辑 # 放到最后渲染可能使用PyTorch的部分
# 渲染系统设置面板
with panel[2]:
system_settings.render_system_panel(tr)
# 放到最后渲染生成按钮和处理逻辑
render_generate_button() render_generate_button()

View File

@ -285,8 +285,8 @@ def render_merge_settings(tr):
error_message = str(e) error_message = str(e)
if "moviepy" in error_message.lower(): if "moviepy" in error_message.lower():
st.error(tr("Error processing video files. Please check if the videos are valid MP4 files.")) st.error(tr("Error processing video files. Please check if the videos are valid MP4 files."))
elif "pysrt" in error_message.lower(): # elif "pysrt" in error_message.lower():
st.error(tr("Error processing subtitle files. Please check if the subtitles are valid SRT files.")) # st.error(tr("Error processing subtitle files. Please check if the subtitles are valid SRT files."))
else: else:
st.error(f"{tr('Error during merge')}: {error_message}") st.error(f"{tr('Error during merge')}: {error_message}")

View File

@ -5,6 +5,7 @@ import time
import asyncio import asyncio
import traceback import traceback
import requests import requests
from app.utils import video_processor
import streamlit as st import streamlit as st
from loguru import logger from loguru import logger
from requests.adapters import HTTPAdapter from requests.adapters import HTTPAdapter
@ -12,7 +13,7 @@ from urllib3.util.retry import Retry
from app.config import config from app.config import config
from app.utils.script_generator import ScriptProcessor from app.utils.script_generator import ScriptProcessor
from app.utils import utils, video_processor, video_processor_v2, qwenvl_analyzer from app.utils import utils, video_processor, qwenvl_analyzer
from webui.tools.base import create_vision_analyzer, get_batch_files, get_batch_timestamps, chekc_video_config from webui.tools.base import create_vision_analyzer, get_batch_files, get_batch_timestamps, chekc_video_config
@ -64,21 +65,13 @@ def generate_script_docu(tr, params):
os.makedirs(video_keyframes_dir, exist_ok=True) os.makedirs(video_keyframes_dir, exist_ok=True)
# 初始化视频处理器 # 初始化视频处理器
if config.frames.get("version") == "v2": processor = video_processor.VideoProcessor(params.video_origin_path)
processor = video_processor_v2.VideoProcessor(params.video_origin_path)
# 处理视频并提取关键帧 # 处理视频并提取关键帧
processor.process_video_pipeline( processor.process_video_pipeline(
output_dir=video_keyframes_dir, output_dir=video_keyframes_dir,
skip_seconds=st.session_state.get('skip_seconds'), skip_seconds=st.session_state.get('skip_seconds'),
threshold=st.session_state.get('threshold') threshold=st.session_state.get('threshold')
) )
else:
processor = video_processor.VideoProcessor(params.video_origin_path)
# 处理视频并提取关键帧
processor.process_video(
output_dir=video_keyframes_dir,
skip_seconds=0
)
# 获取所有关键文件路径 # 获取所有关键文件路径
for filename in sorted(os.listdir(video_keyframes_dir)): for filename in sorted(os.listdir(video_keyframes_dir)):

View File

@ -2,7 +2,7 @@
合并视频和字幕文件 合并视频和字幕文件
""" """
from moviepy import VideoFileClip, concatenate_videoclips from moviepy import VideoFileClip, concatenate_videoclips
import pysrt # import pysrt
import os import os

View File

@ -1,7 +1,6 @@
import psutil # import psutil
import os import os
from loguru import logger from loguru import logger
import torch
class PerformanceMonitor: class PerformanceMonitor:
@staticmethod @staticmethod
@ -11,19 +10,35 @@ class PerformanceMonitor:
logger.debug(f"Memory usage: {memory_info.rss / 1024 / 1024:.2f} MB") logger.debug(f"Memory usage: {memory_info.rss / 1024 / 1024:.2f} MB")
# 延迟导入torch并检查CUDA
try:
import torch
if torch.cuda.is_available(): if torch.cuda.is_available():
gpu_memory = torch.cuda.memory_allocated() / 1024 / 1024 gpu_memory = torch.cuda.memory_allocated() / 1024 / 1024
logger.debug(f"GPU Memory usage: {gpu_memory:.2f} MB") logger.debug(f"GPU Memory usage: {gpu_memory:.2f} MB")
except (ImportError, RuntimeError) as e:
# 无法导入torch或触发CUDA相关错误时静默处理
logger.debug(f"无法获取GPU内存信息: {e}")
@staticmethod @staticmethod
def cleanup_resources(): def cleanup_resources():
# 延迟导入torch并清理CUDA
try:
import torch
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
logger.debug("CUDA缓存已清理")
except (ImportError, RuntimeError) as e:
# 无法导入torch或触发CUDA相关错误时静默处理
logger.debug(f"无法清理CUDA资源: {e}")
import gc import gc
gc.collect() gc.collect()
PerformanceMonitor.monitor_memory() # 仅报告进程内存不尝试获取GPU内存
process = psutil.Process(os.getpid())
memory_info = process.memory_info()
logger.debug(f"Memory usage after cleanup: {memory_info.rss / 1024 / 1024:.2f} MB")
def monitor_performance(func): def monitor_performance(func):
"""性能监控装饰器""" """性能监控装饰器"""