mirror of
https://github.com/linyqh/NarratoAI.git
synced 2025-12-12 03:02:48 +00:00
移除 opencv 和 sklearn 提取关键帧的代码
This commit is contained in:
parent
c3ea0bcc69
commit
f6c3f1640b
@ -3,10 +3,11 @@ import json
|
|||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
import requests
|
import requests
|
||||||
|
from app.utils import video_processor
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import List, Dict, Any, Callable
|
from typing import List, Dict, Any, Callable
|
||||||
|
|
||||||
from app.utils import utils, gemini_analyzer, video_processor, video_processor_v2
|
from app.utils import utils, gemini_analyzer, video_processor
|
||||||
from app.utils.script_generator import ScriptProcessor
|
from app.utils.script_generator import ScriptProcessor
|
||||||
from app.config import config
|
from app.config import config
|
||||||
|
|
||||||
@ -105,19 +106,12 @@ class ScriptGenerator:
|
|||||||
os.makedirs(video_keyframes_dir, exist_ok=True)
|
os.makedirs(video_keyframes_dir, exist_ok=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if config.frames.get("version") == "v2":
|
processor = video_processor.VideoProcessor(video_path)
|
||||||
processor = video_processor_v2.VideoProcessor(video_path)
|
|
||||||
processor.process_video_pipeline(
|
processor.process_video_pipeline(
|
||||||
output_dir=video_keyframes_dir,
|
output_dir=video_keyframes_dir,
|
||||||
skip_seconds=skip_seconds,
|
skip_seconds=skip_seconds,
|
||||||
threshold=threshold
|
threshold=threshold
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
processor = video_processor.VideoProcessor(video_path)
|
|
||||||
processor.process_video(
|
|
||||||
output_dir=video_keyframes_dir,
|
|
||||||
skip_seconds=skip_seconds
|
|
||||||
)
|
|
||||||
|
|
||||||
for filename in sorted(os.listdir(video_keyframes_dir)):
|
for filename in sorted(os.listdir(video_keyframes_dir)):
|
||||||
if filename.endswith('.jpg'):
|
if filename.endswith('.jpg'):
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import re
|
|||||||
import traceback
|
import traceback
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from faster_whisper import WhisperModel
|
# from faster_whisper import WhisperModel
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
@ -45,12 +45,25 @@ def create(audio_file, subtitle_file: str = ""):
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 尝试使用 CUDA,如果失败则回退到 CPU
|
# 首先使用CPU模式,不触发CUDA检查
|
||||||
|
use_cuda = False
|
||||||
|
try:
|
||||||
|
# 在函数中延迟导入torch,而不是在全局范围内
|
||||||
|
# 使用安全的方式检查CUDA可用性
|
||||||
|
def check_cuda_available():
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
if torch.cuda.is_available():
|
return torch.cuda.is_available()
|
||||||
try:
|
except (ImportError, RuntimeError) as e:
|
||||||
|
logger.warning(f"检查CUDA可用性时出错: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 仅当明确需要时才检查CUDA
|
||||||
|
use_cuda = check_cuda_available()
|
||||||
|
|
||||||
|
if use_cuda:
|
||||||
logger.info(f"尝试使用 CUDA 加载模型: {model_path}")
|
logger.info(f"尝试使用 CUDA 加载模型: {model_path}")
|
||||||
|
try:
|
||||||
model = WhisperModel(
|
model = WhisperModel(
|
||||||
model_size_or_path=model_path,
|
model_size_or_path=model_path,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
@ -63,18 +76,18 @@ def create(audio_file, subtitle_file: str = ""):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"CUDA 加载失败,错误信息: {str(e)}")
|
logger.warning(f"CUDA 加载失败,错误信息: {str(e)}")
|
||||||
logger.warning("回退到 CPU 模式")
|
logger.warning("回退到 CPU 模式")
|
||||||
device = "cpu"
|
use_cuda = False
|
||||||
compute_type = "int8"
|
|
||||||
else:
|
else:
|
||||||
logger.info("未检测到 CUDA,使用 CPU 模式")
|
logger.info("使用 CPU 模式")
|
||||||
device = "cpu"
|
except Exception as e:
|
||||||
compute_type = "int8"
|
logger.warning(f"CUDA检查过程出错: {e}")
|
||||||
except ImportError:
|
logger.warning("默认使用CPU模式")
|
||||||
logger.warning("未安装 torch,使用 CPU 模式")
|
use_cuda = False
|
||||||
device = "cpu"
|
|
||||||
compute_type = "int8"
|
|
||||||
|
|
||||||
if device == "cpu":
|
# 如果CUDA不可用或加载失败,使用CPU
|
||||||
|
if not use_cuda:
|
||||||
|
device = "cpu"
|
||||||
|
compute_type = "int8"
|
||||||
logger.info(f"使用 CPU 加载模型: {model_path}")
|
logger.info(f"使用 CPU 加载模型: {model_path}")
|
||||||
model = WhisperModel(
|
model = WhisperModel(
|
||||||
model_size_or_path=model_path,
|
model_size_or_path=model_path,
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import pysrt
|
# import pysrt
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import List
|
from typing import List
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import os
|
|||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import tiktoken
|
# import tiktoken
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
@ -94,12 +94,12 @@ class OpenAIGenerator(BaseGenerator):
|
|||||||
"user": "script_generator"
|
"user": "script_generator"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 初始化token计数器
|
# # 初始化token计数器
|
||||||
try:
|
# try:
|
||||||
self.encoding = tiktoken.encoding_for_model(self.model_name)
|
# self.encoding = tiktoken.encoding_for_model(self.model_name)
|
||||||
except KeyError:
|
# except KeyError:
|
||||||
logger.warning(f"未找到模型 {self.model_name} 的专用编码器,使用默认编码器")
|
# logger.warning(f"未找到模型 {self.model_name} 的专用编码器,使用默认编码器")
|
||||||
self.encoding = tiktoken.get_encoding("cl100k_base")
|
# self.encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
def _generate(self, messages: list, params: dict) -> any:
|
def _generate(self, messages: list, params: dict) -> any:
|
||||||
"""实现OpenAI特定的生成逻辑"""
|
"""实现OpenAI特定的生成逻辑"""
|
||||||
|
|||||||
@ -1,237 +1,349 @@
|
|||||||
import cv2
|
"""
|
||||||
import numpy as np
|
视频帧提取工具
|
||||||
from sklearn.cluster import MiniBatchKMeans
|
|
||||||
|
这个模块提供了简单高效的视频帧提取功能。主要特点:
|
||||||
|
1. 使用ffmpeg进行视频处理,支持硬件加速
|
||||||
|
2. 按指定时间间隔提取视频关键帧
|
||||||
|
3. 支持多种视频格式
|
||||||
|
4. 支持高清视频帧输出
|
||||||
|
5. 直接从原视频提取高质量关键帧
|
||||||
|
|
||||||
|
不依赖OpenCV和sklearn等库,只使用ffmpeg作为外部依赖,降低了安装和使用的复杂度。
|
||||||
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import List, Tuple, Generator
|
import time
|
||||||
|
import subprocess
|
||||||
|
from typing import List, Dict
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import gc
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
class VideoProcessor:
|
class VideoProcessor:
|
||||||
def __init__(self, video_path: str, batch_size: int = 100):
|
def __init__(self, video_path: str):
|
||||||
"""
|
"""
|
||||||
初始化视频处理器
|
初始化视频处理器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
video_path: 视频文件路径
|
video_path: 视频文件路径
|
||||||
batch_size: 批处理大小,控制内存使用
|
|
||||||
"""
|
"""
|
||||||
if not os.path.exists(video_path):
|
if not os.path.exists(video_path):
|
||||||
raise FileNotFoundError(f"视频文件不存在: {video_path}")
|
raise FileNotFoundError(f"视频文件不存在: {video_path}")
|
||||||
|
|
||||||
self.video_path = video_path
|
self.video_path = video_path
|
||||||
self.batch_size = batch_size
|
self.video_info = self._get_video_info()
|
||||||
self.cap = cv2.VideoCapture(video_path)
|
self.fps = float(self.video_info.get('fps', 25))
|
||||||
|
self.duration = float(self.video_info.get('duration', 0))
|
||||||
|
self.width = int(self.video_info.get('width', 0))
|
||||||
|
self.height = int(self.video_info.get('height', 0))
|
||||||
|
self.total_frames = int(self.fps * self.duration)
|
||||||
|
|
||||||
if not self.cap.isOpened():
|
def _get_video_info(self) -> Dict[str, str]:
|
||||||
raise RuntimeError(f"无法打开视频文件: {video_path}")
|
|
||||||
|
|
||||||
self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
||||||
self.fps = int(self.cap.get(cv2.CAP_PROP_FPS))
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
"""析构函数,确保视频资源被释放"""
|
|
||||||
if hasattr(self, 'cap'):
|
|
||||||
self.cap.release()
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
def preprocess_video(self) -> Generator[Tuple[int, np.ndarray], None, None]:
|
|
||||||
"""
|
"""
|
||||||
使用生成器方式分批读取视频帧
|
使用ffprobe获取视频信息
|
||||||
|
|
||||||
Yields:
|
|
||||||
Tuple[int, np.ndarray]: (帧索引, 视频帧)
|
|
||||||
"""
|
|
||||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
|
||||||
frame_idx = 0
|
|
||||||
|
|
||||||
while self.cap.isOpened():
|
|
||||||
ret, frame = self.cap.read()
|
|
||||||
if not ret:
|
|
||||||
break
|
|
||||||
|
|
||||||
# 降低分辨率以减少内存使用
|
|
||||||
frame = cv2.resize(frame, (0, 0), fx=0.5, fy=0.5)
|
|
||||||
yield frame_idx, frame
|
|
||||||
|
|
||||||
frame_idx += 1
|
|
||||||
|
|
||||||
# 定期进行垃圾回收
|
|
||||||
if frame_idx % 1000 == 0:
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
def detect_shot_boundaries(self, threshold: int = 70) -> List[int]:
|
|
||||||
"""
|
|
||||||
使用批处理方式检测镜头边界
|
|
||||||
|
|
||||||
Args:
|
|
||||||
threshold: 差异阈值
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[int]: 镜头边界帧的索引列表
|
Dict[str, str]: 包含视频基本信息的字典
|
||||||
"""
|
"""
|
||||||
shot_boundaries = []
|
cmd = [
|
||||||
prev_frame = None
|
"ffprobe",
|
||||||
prev_idx = -1
|
"-v", "error",
|
||||||
|
"-select_streams", "v:0",
|
||||||
|
"-show_entries", "stream=width,height,r_frame_rate,duration",
|
||||||
|
"-of", "default=noprint_wrappers=1:nokey=0",
|
||||||
|
self.video_path
|
||||||
|
]
|
||||||
|
|
||||||
pbar = tqdm(self.preprocess_video(),
|
try:
|
||||||
total=self.total_frames,
|
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
||||||
desc="检测镜头边界",
|
lines = result.stdout.strip().split('\n')
|
||||||
unit="帧")
|
info = {}
|
||||||
|
for line in lines:
|
||||||
|
if '=' in line:
|
||||||
|
key, value = line.split('=', 1)
|
||||||
|
info[key] = value
|
||||||
|
|
||||||
for frame_idx, curr_frame in pbar:
|
# 处理帧率(可能是分数形式)
|
||||||
if prev_frame is not None:
|
if 'r_frame_rate' in info:
|
||||||
prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)
|
try:
|
||||||
curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2GRAY)
|
num, den = map(int, info['r_frame_rate'].split('/'))
|
||||||
|
info['fps'] = str(num / den)
|
||||||
|
except ValueError:
|
||||||
|
info['fps'] = info.get('r_frame_rate', '25')
|
||||||
|
|
||||||
diff = np.mean(np.abs(curr_gray.astype(float) - prev_gray.astype(float)))
|
return info
|
||||||
if diff > threshold:
|
|
||||||
shot_boundaries.append(frame_idx)
|
|
||||||
pbar.set_postfix({"检测到边界": len(shot_boundaries)})
|
|
||||||
|
|
||||||
prev_frame = curr_frame.copy()
|
except subprocess.CalledProcessError as e:
|
||||||
prev_idx = frame_idx
|
logger.error(f"获取视频信息失败: {e.stderr}")
|
||||||
|
return {
|
||||||
|
'width': '1280',
|
||||||
|
'height': '720',
|
||||||
|
'fps': '25',
|
||||||
|
'duration': '0'
|
||||||
|
}
|
||||||
|
|
||||||
del curr_frame
|
def extract_frames_by_interval(self, output_dir: str, interval_seconds: float = 5.0,
|
||||||
if frame_idx % 100 == 0:
|
use_hw_accel: bool = True, skip_seconds: float = 0.0) -> List[int]:
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
return shot_boundaries
|
|
||||||
|
|
||||||
def process_shot(self, shot_frames: List[Tuple[int, np.ndarray]]) -> Tuple[np.ndarray, int]:
|
|
||||||
"""
|
"""
|
||||||
处理单个镜头的帧
|
按指定时间间隔提取视频帧
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
shot_frames: 镜头中的帧列表
|
output_dir: 输出目录
|
||||||
|
interval_seconds: 帧提取间隔(秒)
|
||||||
|
use_hw_accel: 是否使用硬件加速
|
||||||
|
skip_seconds: 跳过视频开头的秒数
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[np.ndarray, int]: (关键帧, 帧索引)
|
List[int]: 提取的帧号列表
|
||||||
"""
|
"""
|
||||||
if not shot_frames:
|
if not os.path.exists(output_dir):
|
||||||
return None, -1
|
os.makedirs(output_dir)
|
||||||
|
|
||||||
frame_features = []
|
# 计算起始时间和帧提取点
|
||||||
frame_indices = []
|
start_time = skip_seconds
|
||||||
|
end_time = self.duration
|
||||||
|
extraction_times = []
|
||||||
|
|
||||||
for idx, frame in tqdm(shot_frames,
|
current_time = start_time
|
||||||
desc="处理镜头帧",
|
while current_time < end_time:
|
||||||
unit="帧",
|
extraction_times.append(current_time)
|
||||||
leave=False):
|
current_time += interval_seconds
|
||||||
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
|
||||||
resized_gray = cv2.resize(gray, (32, 32))
|
|
||||||
frame_features.append(resized_gray.flatten())
|
|
||||||
frame_indices.append(idx)
|
|
||||||
|
|
||||||
frame_features = np.array(frame_features)
|
if not extraction_times:
|
||||||
|
logger.warning("未找到需要提取的帧")
|
||||||
|
return []
|
||||||
|
|
||||||
kmeans = MiniBatchKMeans(n_clusters=1, batch_size=min(len(frame_features), 100),
|
# 确定硬件加速器选项
|
||||||
random_state=0).fit(frame_features)
|
hw_accel = []
|
||||||
|
if use_hw_accel:
|
||||||
|
# 尝试检测可用的硬件加速器
|
||||||
|
hw_accel_options = self._detect_hw_accelerator()
|
||||||
|
if hw_accel_options:
|
||||||
|
hw_accel = hw_accel_options
|
||||||
|
logger.info(f"使用硬件加速: {' '.join(hw_accel)}")
|
||||||
|
else:
|
||||||
|
logger.warning("未检测到可用的硬件加速器,使用软件解码")
|
||||||
|
|
||||||
center_idx = np.argmin(np.sum((frame_features - kmeans.cluster_centers_[0]) ** 2, axis=1))
|
# 提取帧
|
||||||
|
frame_numbers = []
|
||||||
|
for i, timestamp in enumerate(tqdm(extraction_times, desc="提取视频帧")):
|
||||||
|
frame_number = int(timestamp * self.fps)
|
||||||
|
frame_numbers.append(frame_number)
|
||||||
|
|
||||||
return shot_frames[center_idx][1], frame_indices[center_idx]
|
# 格式化时间戳字符串 (HHMMSSmmm)
|
||||||
|
hours = int(timestamp // 3600)
|
||||||
|
minutes = int((timestamp % 3600) // 60)
|
||||||
|
seconds = int(timestamp % 60)
|
||||||
|
milliseconds = int((timestamp % 1) * 1000)
|
||||||
|
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}{milliseconds:03d}"
|
||||||
|
|
||||||
def extract_keyframes(self, shot_boundaries: List[int]) -> Generator[Tuple[np.ndarray, int], None, None]:
|
output_path = os.path.join(output_dir, f"keyframe_{frame_number:06d}_{time_str}.jpg")
|
||||||
|
|
||||||
|
# 使用ffmpeg提取单帧
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-hide_banner",
|
||||||
|
"-loglevel", "error",
|
||||||
|
]
|
||||||
|
|
||||||
|
# 添加硬件加速参数
|
||||||
|
cmd.extend(hw_accel)
|
||||||
|
|
||||||
|
cmd.extend([
|
||||||
|
"-ss", str(timestamp),
|
||||||
|
"-i", self.video_path,
|
||||||
|
"-vframes", "1",
|
||||||
|
"-q:v", "1", # 最高质量
|
||||||
|
"-y",
|
||||||
|
output_path
|
||||||
|
])
|
||||||
|
|
||||||
|
try:
|
||||||
|
subprocess.run(cmd, check=True, capture_output=True)
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
logger.warning(f"提取帧 {frame_number} 失败: {e.stderr}")
|
||||||
|
|
||||||
|
logger.info(f"成功提取了 {len(frame_numbers)} 个视频帧")
|
||||||
|
return frame_numbers
|
||||||
|
|
||||||
|
def _detect_hw_accelerator(self) -> List[str]:
|
||||||
"""
|
"""
|
||||||
使用生成器方式提取关键帧
|
检测系统可用的硬件加速器
|
||||||
|
|
||||||
Args:
|
Returns:
|
||||||
shot_boundaries: 镜头边界列表
|
List[str]: 硬件加速器ffmpeg命令参数
|
||||||
|
|
||||||
Yields:
|
|
||||||
Tuple[np.ndarray, int]: (关键帧, 帧索引)
|
|
||||||
"""
|
"""
|
||||||
shot_frames = []
|
# 检测操作系统
|
||||||
current_shot_start = 0
|
import platform
|
||||||
|
system = platform.system().lower()
|
||||||
|
|
||||||
for frame_idx, frame in self.preprocess_video():
|
# 测试不同的硬件加速器
|
||||||
if frame_idx in shot_boundaries:
|
accelerators = []
|
||||||
if shot_frames:
|
|
||||||
keyframe, keyframe_idx = self.process_shot(shot_frames)
|
|
||||||
if keyframe is not None:
|
|
||||||
yield keyframe, keyframe_idx
|
|
||||||
|
|
||||||
# 清理内存
|
if system == 'darwin': # macOS
|
||||||
shot_frames.clear()
|
# 测试 videotoolbox (Apple 硬件加速)
|
||||||
gc.collect()
|
test_cmd = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-hide_banner",
|
||||||
|
"-loglevel", "error",
|
||||||
|
"-hwaccel", "videotoolbox",
|
||||||
|
"-i", self.video_path,
|
||||||
|
"-t", "0.1",
|
||||||
|
"-f", "null",
|
||||||
|
"-"
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
subprocess.run(test_cmd, capture_output=True, check=True)
|
||||||
|
return ["-hwaccel", "videotoolbox"]
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
pass
|
||||||
|
|
||||||
current_shot_start = frame_idx
|
elif system == 'linux':
|
||||||
|
# 测试 VAAPI
|
||||||
|
test_cmd = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-hide_banner",
|
||||||
|
"-loglevel", "error",
|
||||||
|
"-hwaccel", "vaapi",
|
||||||
|
"-i", self.video_path,
|
||||||
|
"-t", "0.1",
|
||||||
|
"-f", "null",
|
||||||
|
"-"
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
subprocess.run(test_cmd, capture_output=True, check=True)
|
||||||
|
return ["-hwaccel", "vaapi"]
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
pass
|
||||||
|
|
||||||
shot_frames.append((frame_idx, frame))
|
# 尝试 CUDA
|
||||||
|
test_cmd = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-hide_banner",
|
||||||
|
"-loglevel", "error",
|
||||||
|
"-hwaccel", "cuda",
|
||||||
|
"-i", self.video_path,
|
||||||
|
"-t", "0.1",
|
||||||
|
"-f", "null",
|
||||||
|
"-"
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
subprocess.run(test_cmd, capture_output=True, check=True)
|
||||||
|
return ["-hwaccel", "cuda"]
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
pass
|
||||||
|
|
||||||
# 控制单个镜头的最大帧数
|
elif system == 'windows':
|
||||||
if len(shot_frames) > self.batch_size:
|
# 测试 CUDA
|
||||||
keyframe, keyframe_idx = self.process_shot(shot_frames)
|
test_cmd = [
|
||||||
if keyframe is not None:
|
"ffmpeg",
|
||||||
yield keyframe, keyframe_idx
|
"-hide_banner",
|
||||||
shot_frames.clear()
|
"-loglevel", "error",
|
||||||
gc.collect()
|
"-hwaccel", "cuda",
|
||||||
|
"-i", self.video_path,
|
||||||
|
"-t", "0.1",
|
||||||
|
"-f", "null",
|
||||||
|
"-"
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
subprocess.run(test_cmd, capture_output=True, check=True)
|
||||||
|
return ["-hwaccel", "cuda"]
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
pass
|
||||||
|
|
||||||
# 处理最后一个镜头
|
# 测试 D3D11VA
|
||||||
if shot_frames:
|
test_cmd = [
|
||||||
keyframe, keyframe_idx = self.process_shot(shot_frames)
|
"ffmpeg",
|
||||||
if keyframe is not None:
|
"-hide_banner",
|
||||||
yield keyframe, keyframe_idx
|
"-loglevel", "error",
|
||||||
|
"-hwaccel", "d3d11va",
|
||||||
|
"-i", self.video_path,
|
||||||
|
"-t", "0.1",
|
||||||
|
"-f", "null",
|
||||||
|
"-"
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
subprocess.run(test_cmd, capture_output=True, check=True)
|
||||||
|
return ["-hwaccel", "d3d11va"]
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
pass
|
||||||
|
|
||||||
def process_video(self, output_dir: str, skip_seconds: float = 0) -> None:
|
# 测试 DXVA2
|
||||||
|
test_cmd = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-hide_banner",
|
||||||
|
"-loglevel", "error",
|
||||||
|
"-hwaccel", "dxva2",
|
||||||
|
"-i", self.video_path,
|
||||||
|
"-t", "0.1",
|
||||||
|
"-f", "null",
|
||||||
|
"-"
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
subprocess.run(test_cmd, capture_output=True, check=True)
|
||||||
|
return ["-hwaccel", "dxva2"]
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 如果没有找到可用的硬件加速器
|
||||||
|
return []
|
||||||
|
|
||||||
|
def process_video_pipeline(self,
|
||||||
|
output_dir: str,
|
||||||
|
skip_seconds: float = 0.0,
|
||||||
|
threshold: int = 20, # 此参数保留但不使用
|
||||||
|
compressed_width: int = 320, # 此参数保留但不使用
|
||||||
|
keep_temp: bool = False, # 此参数保留但不使用
|
||||||
|
interval_seconds: float = 5.0,
|
||||||
|
use_hw_accel: bool = True) -> None:
|
||||||
"""
|
"""
|
||||||
处理视频并提取关键帧,使用分批处理方式
|
执行简化的视频处理流程,直接从原视频按固定时间间隔提取帧
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
output_dir: 输出目录
|
output_dir: 输出目录
|
||||||
skip_seconds: 跳过视频开头的秒数
|
skip_seconds: 跳过视频开头的秒数
|
||||||
|
threshold: 保留参数,不使用
|
||||||
|
compressed_width: 保留参数,不使用
|
||||||
|
keep_temp: 保留参数,不使用
|
||||||
|
interval_seconds: 帧提取间隔(秒)
|
||||||
|
use_hw_accel: 是否使用硬件加速
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
# 创建输出目录
|
# 创建输出目录
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
# 计算要跳过的帧数
|
try:
|
||||||
skip_frames = int(skip_seconds * self.fps)
|
# 直接从原视频提取关键帧
|
||||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, skip_frames)
|
logger.info("从视频直接提取关键帧...")
|
||||||
|
self.extract_frames_by_interval(
|
||||||
|
output_dir,
|
||||||
|
interval_seconds=interval_seconds,
|
||||||
|
use_hw_accel=use_hw_accel,
|
||||||
|
skip_seconds=skip_seconds
|
||||||
|
)
|
||||||
|
|
||||||
# 检测镜头边界
|
logger.info(f"处理完成!视频帧已保存在: {output_dir}")
|
||||||
logger.info("开始检测镜头边界...")
|
|
||||||
shot_boundaries = self.detect_shot_boundaries()
|
|
||||||
|
|
||||||
# 提取关键帧
|
|
||||||
logger.info("开始提取关键帧...")
|
|
||||||
frame_count = 0
|
|
||||||
|
|
||||||
pbar = tqdm(self.extract_keyframes(shot_boundaries),
|
|
||||||
desc="提取关键帧",
|
|
||||||
unit="帧")
|
|
||||||
|
|
||||||
for keyframe, frame_idx in pbar:
|
|
||||||
if frame_idx < skip_frames:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 计算时间戳
|
|
||||||
timestamp = frame_idx / self.fps
|
|
||||||
hours = int(timestamp // 3600)
|
|
||||||
minutes = int((timestamp % 3600) // 60)
|
|
||||||
seconds = int(timestamp % 60)
|
|
||||||
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}"
|
|
||||||
|
|
||||||
# 保存关键帧
|
|
||||||
output_path = os.path.join(output_dir,
|
|
||||||
f'keyframe_{frame_idx:06d}_{time_str}.jpg')
|
|
||||||
cv2.imwrite(output_path, keyframe)
|
|
||||||
frame_count += 1
|
|
||||||
|
|
||||||
pbar.set_postfix({"已保存": frame_count})
|
|
||||||
|
|
||||||
if frame_count % 10 == 0:
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
logger.info(f"关键帧提取完成,共保存 {frame_count} 帧到 {output_dir}")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"视频处理失败: {str(e)}")
|
import traceback
|
||||||
|
logger.error(f"视频处理失败: \n{traceback.format_exc()}")
|
||||||
raise
|
raise
|
||||||
finally:
|
|
||||||
# 确保资源被释放
|
|
||||||
self.cap.release()
|
if __name__ == "__main__":
|
||||||
gc.collect()
|
import time
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
processor = VideoProcessor("./resource/videos/test.mp4")
|
||||||
|
|
||||||
|
# 设置间隔为3秒提取帧
|
||||||
|
processor.process_video_pipeline(
|
||||||
|
output_dir="output",
|
||||||
|
interval_seconds=3.0,
|
||||||
|
use_hw_accel=True
|
||||||
|
)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"处理完成!总耗时: {end_time - start_time:.2f} 秒")
|
||||||
|
|||||||
@ -1,382 +0,0 @@
|
|||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
from sklearn.cluster import KMeans
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
from typing import List, Tuple, Generator
|
|
||||||
from loguru import logger
|
|
||||||
import subprocess
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
class VideoProcessor:
|
|
||||||
def __init__(self, video_path: str):
|
|
||||||
"""
|
|
||||||
初始化视频处理器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
video_path: 视频文件路径
|
|
||||||
"""
|
|
||||||
if not os.path.exists(video_path):
|
|
||||||
raise FileNotFoundError(f"视频文件不存在: {video_path}")
|
|
||||||
|
|
||||||
self.video_path = video_path
|
|
||||||
self.cap = cv2.VideoCapture(video_path)
|
|
||||||
|
|
||||||
if not self.cap.isOpened():
|
|
||||||
raise RuntimeError(f"无法打开视频文件: {video_path}")
|
|
||||||
|
|
||||||
self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
||||||
self.fps = int(self.cap.get(cv2.CAP_PROP_FPS))
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
"""析构函数,确保视频资源被释放"""
|
|
||||||
if hasattr(self, 'cap'):
|
|
||||||
self.cap.release()
|
|
||||||
|
|
||||||
def preprocess_video(self) -> Generator[np.ndarray, None, None]:
|
|
||||||
"""
|
|
||||||
使用生成器方式读取视频帧
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
np.ndarray: 视频帧
|
|
||||||
"""
|
|
||||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # 重置到视频开始
|
|
||||||
while self.cap.isOpened():
|
|
||||||
ret, frame = self.cap.read()
|
|
||||||
if not ret:
|
|
||||||
break
|
|
||||||
yield frame
|
|
||||||
|
|
||||||
def detect_shot_boundaries(self, frames: List[np.ndarray], threshold: int = 30) -> List[int]:
|
|
||||||
"""
|
|
||||||
使用帧差法检测镜头边界
|
|
||||||
|
|
||||||
Args:
|
|
||||||
frames: 视频帧列表
|
|
||||||
threshold: 差异阈值,默认值调低为30
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[int]: 镜头边界帧的索引列表
|
|
||||||
"""
|
|
||||||
shot_boundaries = []
|
|
||||||
if len(frames) < 2: # 添加帧数检查
|
|
||||||
logger.warning("视频帧数过少,无法检测场景边界")
|
|
||||||
return [len(frames) - 1] # 返回最后一帧作为边界
|
|
||||||
|
|
||||||
for i in range(1, len(frames)):
|
|
||||||
prev_frame = cv2.cvtColor(frames[i - 1], cv2.COLOR_BGR2GRAY)
|
|
||||||
curr_frame = cv2.cvtColor(frames[i], cv2.COLOR_BGR2GRAY)
|
|
||||||
|
|
||||||
# 计算帧差
|
|
||||||
diff = np.mean(np.abs(curr_frame.astype(float) - prev_frame.astype(float)))
|
|
||||||
|
|
||||||
if diff > threshold:
|
|
||||||
shot_boundaries.append(i)
|
|
||||||
|
|
||||||
# 如果没有检测到任何边界,至少返回最后一帧
|
|
||||||
if not shot_boundaries:
|
|
||||||
logger.warning("未检测到场景边界,将视频作为单个场景处理")
|
|
||||||
shot_boundaries.append(len(frames) - 1)
|
|
||||||
|
|
||||||
return shot_boundaries
|
|
||||||
|
|
||||||
def extract_keyframes(self, frames: List[np.ndarray], shot_boundaries: List[int]) -> Tuple[
|
|
||||||
List[np.ndarray], List[int]]:
|
|
||||||
"""
|
|
||||||
从每个镜头中提取关键帧
|
|
||||||
|
|
||||||
Args:
|
|
||||||
frames: 视频帧列表
|
|
||||||
shot_boundaries: 镜头边界列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[List[np.ndarray], List[int]]: 关键帧列表和对应的帧索引
|
|
||||||
"""
|
|
||||||
keyframes = []
|
|
||||||
keyframe_indices = []
|
|
||||||
|
|
||||||
for i in tqdm(range(len(shot_boundaries)), desc="提取关键帧"):
|
|
||||||
start = shot_boundaries[i - 1] if i > 0 else 0
|
|
||||||
end = shot_boundaries[i]
|
|
||||||
shot_frames = frames[start:end]
|
|
||||||
|
|
||||||
if not shot_frames:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 将每一帧转换为灰度图并展平为一维数组
|
|
||||||
frame_features = np.array([cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).flatten()
|
|
||||||
for frame in shot_frames])
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 尝试使用 KMeans
|
|
||||||
kmeans = KMeans(n_clusters=1, random_state=0).fit(frame_features)
|
|
||||||
center_idx = np.argmin(np.sum((frame_features - kmeans.cluster_centers_[0]) ** 2, axis=1))
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"KMeans 聚类失败,使用备选方案: {str(e)}")
|
|
||||||
# 备选方案:选择镜头中间的帧作为关键帧
|
|
||||||
center_idx = len(shot_frames) // 2
|
|
||||||
|
|
||||||
keyframes.append(shot_frames[center_idx])
|
|
||||||
keyframe_indices.append(start + center_idx)
|
|
||||||
|
|
||||||
return keyframes, keyframe_indices
|
|
||||||
|
|
||||||
def save_keyframes(self, keyframes: List[np.ndarray], keyframe_indices: List[int],
|
|
||||||
output_dir: str, desc: str = "保存关键帧") -> None:
|
|
||||||
"""
|
|
||||||
保存关键帧到指定目录,文件名格式为:keyframe_帧序号_时间戳.jpg
|
|
||||||
时间戳精确到毫秒,格式为:HHMMSSmmm
|
|
||||||
"""
|
|
||||||
if not os.path.exists(output_dir):
|
|
||||||
os.makedirs(output_dir)
|
|
||||||
|
|
||||||
for keyframe, frame_idx in tqdm(zip(keyframes, keyframe_indices),
|
|
||||||
total=len(keyframes),
|
|
||||||
desc=desc):
|
|
||||||
# 计算精确到毫秒的时间戳
|
|
||||||
timestamp = frame_idx / self.fps
|
|
||||||
hours = int(timestamp // 3600)
|
|
||||||
minutes = int((timestamp % 3600) // 60)
|
|
||||||
seconds = int(timestamp % 60)
|
|
||||||
milliseconds = int((timestamp % 1) * 1000) # 计算毫秒部分
|
|
||||||
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}{milliseconds:03d}"
|
|
||||||
|
|
||||||
output_path = os.path.join(output_dir,
|
|
||||||
f'keyframe_{frame_idx:06d}_{time_str}.jpg')
|
|
||||||
cv2.imwrite(output_path, keyframe)
|
|
||||||
|
|
||||||
def extract_frames_by_numbers(self, frame_numbers: List[int], output_folder: str) -> None:
|
|
||||||
"""
|
|
||||||
根据指定的帧号提取帧,如果多个帧在同一毫秒内,只保留一个
|
|
||||||
"""
|
|
||||||
if not frame_numbers:
|
|
||||||
raise ValueError("未提供帧号列表")
|
|
||||||
|
|
||||||
if any(fn >= self.total_frames or fn < 0 for fn in frame_numbers):
|
|
||||||
raise ValueError("存在无效的帧号")
|
|
||||||
|
|
||||||
if not os.path.exists(output_folder):
|
|
||||||
os.makedirs(output_folder)
|
|
||||||
|
|
||||||
# 用于记录已处理的时间戳(毫秒)
|
|
||||||
processed_timestamps = set()
|
|
||||||
|
|
||||||
for frame_number in tqdm(frame_numbers, desc="提取高清帧"):
|
|
||||||
# 计算精确到毫秒的时间戳
|
|
||||||
timestamp = frame_number / self.fps
|
|
||||||
timestamp_ms = int(timestamp * 1000) # 转换为毫秒
|
|
||||||
|
|
||||||
# 如果这一毫秒已经处理过,跳过
|
|
||||||
if timestamp_ms in processed_timestamps:
|
|
||||||
continue
|
|
||||||
|
|
||||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
|
|
||||||
ret, frame = self.cap.read()
|
|
||||||
|
|
||||||
if ret:
|
|
||||||
# 记录这一毫秒已经处理
|
|
||||||
processed_timestamps.add(timestamp_ms)
|
|
||||||
|
|
||||||
# 计算时间戳字符串
|
|
||||||
hours = int(timestamp // 3600)
|
|
||||||
minutes = int((timestamp % 3600) // 60)
|
|
||||||
seconds = int(timestamp % 60)
|
|
||||||
milliseconds = int((timestamp % 1) * 1000) # 计算毫秒部分
|
|
||||||
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}{milliseconds:03d}"
|
|
||||||
|
|
||||||
output_path = os.path.join(output_folder,
|
|
||||||
f"keyframe_{frame_number:06d}_{time_str}.jpg")
|
|
||||||
cv2.imwrite(output_path, frame)
|
|
||||||
else:
|
|
||||||
logger.info(f"无法读取帧 {frame_number}")
|
|
||||||
|
|
||||||
logger.info(f"共提取了 {len(processed_timestamps)} 个不同时间戳的帧")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def extract_numbers_from_folder(folder_path: str) -> List[int]:
|
|
||||||
"""
|
|
||||||
从文件夹中提取帧号
|
|
||||||
|
|
||||||
Args:
|
|
||||||
folder_path: 关键帧文件夹路径
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[int]: 排序后的帧号列表
|
|
||||||
"""
|
|
||||||
files = [f for f in os.listdir(folder_path) if f.endswith('.jpg')]
|
|
||||||
# 更新正则表达式以匹配新的文件名格式:keyframe_000123_010534123.jpg
|
|
||||||
pattern = re.compile(r'keyframe_(\d+)_\d{9}\.jpg$')
|
|
||||||
numbers = []
|
|
||||||
|
|
||||||
for f in files:
|
|
||||||
match = pattern.search(f)
|
|
||||||
if match:
|
|
||||||
numbers.append(int(match.group(1)))
|
|
||||||
else:
|
|
||||||
logger.warning(f"文件名格式不匹配: {f}")
|
|
||||||
|
|
||||||
if not numbers:
|
|
||||||
logger.error(f"在目录 {folder_path} 中未找到有效的关键帧文件")
|
|
||||||
|
|
||||||
return sorted(numbers)
|
|
||||||
|
|
||||||
def process_video(self, output_dir: str, skip_seconds: float = 0, threshold: int = 30) -> None:
|
|
||||||
"""
|
|
||||||
处理视频并提取关键帧
|
|
||||||
|
|
||||||
Args:
|
|
||||||
output_dir: 输出目录
|
|
||||||
skip_seconds: 跳过视频开头的秒数
|
|
||||||
"""
|
|
||||||
skip_frames = int(skip_seconds * self.fps)
|
|
||||||
|
|
||||||
logger.info("读取视频帧...")
|
|
||||||
frames = []
|
|
||||||
for frame in tqdm(self.preprocess_video(),
|
|
||||||
total=self.total_frames,
|
|
||||||
desc="读取视频"):
|
|
||||||
frames.append(frame)
|
|
||||||
|
|
||||||
frames = frames[skip_frames:]
|
|
||||||
|
|
||||||
if not frames:
|
|
||||||
raise ValueError(f"跳过 {skip_seconds} 秒后没有剩余帧可以处理")
|
|
||||||
|
|
||||||
logger.info("检测场景边界...")
|
|
||||||
shot_boundaries = self.detect_shot_boundaries(frames, threshold)
|
|
||||||
logger.info(f"检测到 {len(shot_boundaries)} 个场景边界")
|
|
||||||
|
|
||||||
keyframes, keyframe_indices = self.extract_keyframes(frames, shot_boundaries)
|
|
||||||
|
|
||||||
adjusted_indices = [idx + skip_frames for idx in keyframe_indices]
|
|
||||||
self.save_keyframes(keyframes, adjusted_indices, output_dir, desc="保存压缩关键帧")
|
|
||||||
|
|
||||||
def process_video_pipeline(self,
|
|
||||||
output_dir: str,
|
|
||||||
skip_seconds: float = 0,
|
|
||||||
threshold: int = 20, # 降低默认阈值
|
|
||||||
compressed_width: int = 320,
|
|
||||||
keep_temp: bool = False) -> None:
|
|
||||||
"""
|
|
||||||
执行完整的视频处理流程
|
|
||||||
|
|
||||||
Args:
|
|
||||||
threshold: 降低默认阈值为20,使场景检测更敏感
|
|
||||||
"""
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
temp_dir = os.path.join(output_dir, 'temp')
|
|
||||||
compressed_dir = os.path.join(temp_dir, 'compressed')
|
|
||||||
mini_frames_dir = os.path.join(temp_dir, 'mini_frames')
|
|
||||||
hd_frames_dir = output_dir
|
|
||||||
|
|
||||||
os.makedirs(temp_dir, exist_ok=True)
|
|
||||||
os.makedirs(compressed_dir, exist_ok=True)
|
|
||||||
os.makedirs(mini_frames_dir, exist_ok=True)
|
|
||||||
os.makedirs(hd_frames_dir, exist_ok=True)
|
|
||||||
|
|
||||||
mini_processor = None
|
|
||||||
compressed_video = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 1. 压缩视频
|
|
||||||
video_name = os.path.splitext(os.path.basename(self.video_path))[0]
|
|
||||||
compressed_video = os.path.join(compressed_dir, f"{video_name}_compressed.mp4")
|
|
||||||
|
|
||||||
# 获取原始视频的宽度和高度
|
|
||||||
original_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
||||||
original_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
||||||
|
|
||||||
logger.info("步骤1: 压缩视频...")
|
|
||||||
if original_width > original_height:
|
|
||||||
# 横版视频
|
|
||||||
scale_filter = f'scale={compressed_width}:-1'
|
|
||||||
else:
|
|
||||||
# 竖版视频
|
|
||||||
scale_filter = f'scale=-1:{compressed_width}'
|
|
||||||
|
|
||||||
ffmpeg_cmd = [
|
|
||||||
'ffmpeg', '-i', self.video_path,
|
|
||||||
'-vf', scale_filter,
|
|
||||||
'-y',
|
|
||||||
compressed_video
|
|
||||||
]
|
|
||||||
|
|
||||||
try:
|
|
||||||
subprocess.run(ffmpeg_cmd, check=True, capture_output=True, text=True)
|
|
||||||
except subprocess.CalledProcessError as e:
|
|
||||||
logger.error(f"FFmpeg 错误输出: {e.stderr}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# 2. 从压缩视频中提取关键帧
|
|
||||||
logger.info("\n步骤2: 从压缩视频提取关键帧...")
|
|
||||||
mini_processor = VideoProcessor(compressed_video)
|
|
||||||
mini_processor.process_video(mini_frames_dir, skip_seconds, threshold)
|
|
||||||
|
|
||||||
# 3. 从原始视频提取高清关键帧
|
|
||||||
logger.info("\n步骤3: 提取高清关键帧...")
|
|
||||||
frame_numbers = self.extract_numbers_from_folder(mini_frames_dir)
|
|
||||||
|
|
||||||
if not frame_numbers:
|
|
||||||
raise ValueError("未能从压缩视频中提取到有效的关键帧")
|
|
||||||
|
|
||||||
self.extract_frames_by_numbers(frame_numbers, hd_frames_dir)
|
|
||||||
|
|
||||||
logger.info(f"处理完成!高清关键帧保存在: {hd_frames_dir}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
import traceback
|
|
||||||
logger.error(f"视频处理失败: \n{traceback.format_exc()}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# 释放资源
|
|
||||||
if mini_processor:
|
|
||||||
mini_processor.cap.release()
|
|
||||||
del mini_processor
|
|
||||||
|
|
||||||
# 确保视频文件句柄被释放
|
|
||||||
if hasattr(self, 'cap'):
|
|
||||||
self.cap.release()
|
|
||||||
|
|
||||||
# 等待资源释放
|
|
||||||
import time
|
|
||||||
time.sleep(0.5)
|
|
||||||
|
|
||||||
if not keep_temp:
|
|
||||||
try:
|
|
||||||
# 先删除压缩视频文件
|
|
||||||
if compressed_video and os.path.exists(compressed_video):
|
|
||||||
try:
|
|
||||||
os.remove(compressed_video)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"删除压缩视频失败: {e}")
|
|
||||||
|
|
||||||
# 再删除临时目录
|
|
||||||
import shutil
|
|
||||||
if os.path.exists(temp_dir):
|
|
||||||
max_retries = 3
|
|
||||||
for i in range(max_retries):
|
|
||||||
try:
|
|
||||||
shutil.rmtree(temp_dir)
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
if i == max_retries - 1:
|
|
||||||
logger.warning(f"清理临时文件失败: {e}")
|
|
||||||
else:
|
|
||||||
time.sleep(1) # 等待1秒后重试
|
|
||||||
continue
|
|
||||||
|
|
||||||
logger.info("临时文件已清理")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"清理临时文件时出错: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import time
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
processor = VideoProcessor("E:\\projects\\NarratoAI\\resource\\videos\\test.mp4")
|
|
||||||
processor.process_video_pipeline(output_dir="output")
|
|
||||||
end_time = time.time()
|
|
||||||
print(f"处理完成!总耗时: {end_time - start_time:.2f} 秒")
|
|
||||||
@ -1,5 +1,5 @@
|
|||||||
[app]
|
[app]
|
||||||
project_version="0.5.3"
|
project_version="0.6.0"
|
||||||
# 支持视频理解的大模型提供商
|
# 支持视频理解的大模型提供商
|
||||||
# gemini
|
# gemini
|
||||||
# qwenvl
|
# qwenvl
|
||||||
|
|||||||
@ -1,38 +1,45 @@
|
|||||||
|
# 必须项
|
||||||
requests~=2.32.0
|
requests~=2.32.0
|
||||||
moviepy==2.1.1
|
moviepy==2.1.1
|
||||||
edge-tts==6.1.19
|
edge-tts==6.1.19
|
||||||
streamlit~=1.45.0
|
streamlit~=1.45.0
|
||||||
|
watchdog==6.0.0
|
||||||
|
loguru~=0.7.3
|
||||||
|
tomli~=2.2.1
|
||||||
|
|
||||||
openai~=1.77.0
|
openai~=1.77.0
|
||||||
google-generativeai>=0.8.5
|
google-generativeai>=0.8.5
|
||||||
|
|
||||||
loguru~=0.7.2
|
# 待优化项
|
||||||
fastapi~=0.115.4
|
# opencv-python==4.11.0.86
|
||||||
uvicorn~=0.27.1
|
# scikit-learn==1.6.1
|
||||||
pydantic~=2.11.4
|
|
||||||
|
|
||||||
faster-whisper~=1.0.1
|
# fastapi~=0.115.4
|
||||||
tomli~=2.0.1
|
# uvicorn~=0.27.1
|
||||||
aiohttp~=3.10.10
|
# pydantic~=2.11.4
|
||||||
httpx==0.27.2
|
|
||||||
urllib3~=2.2.1
|
|
||||||
|
|
||||||
python-multipart~=0.0.9
|
# faster-whisper~=1.0.1
|
||||||
redis==5.0.3
|
# tomli~=2.0.1
|
||||||
opencv-python~=4.10.0.84
|
# aiohttp~=3.10.10
|
||||||
azure-cognitiveservices-speech~=1.37.0
|
# httpx==0.27.2
|
||||||
git-changelog~=2.5.2
|
# urllib3~=2.2.1
|
||||||
watchdog==5.0.2
|
|
||||||
pydub==0.25.1
|
|
||||||
psutil>=5.9.0
|
|
||||||
scikit-learn~=1.5.2
|
|
||||||
pillow==10.3.0
|
|
||||||
python-dotenv~=1.0.1
|
|
||||||
|
|
||||||
tqdm>=4.66.6
|
# python-multipart~=0.0.9
|
||||||
tenacity>=9.0.0
|
# redis==5.0.3
|
||||||
tiktoken==0.8.0
|
# opencv-python~=4.10.0.84
|
||||||
pysrt==1.1.2
|
# azure-cognitiveservices-speech~=1.37.0
|
||||||
transformers==4.50.0
|
# git-changelog~=2.5.2
|
||||||
|
# watchdog==5.0.2
|
||||||
|
# pydub==0.25.1
|
||||||
|
# psutil>=5.9.0
|
||||||
|
# scikit-learn~=1.5.2
|
||||||
|
# pillow==10.3.0
|
||||||
|
# python-dotenv~=1.0.1
|
||||||
|
|
||||||
|
# tqdm>=4.66.6
|
||||||
|
# tenacity>=9.0.0
|
||||||
|
# tiktoken==0.8.0
|
||||||
|
# pysrt==1.1.2
|
||||||
|
# transformers==4.50.0
|
||||||
|
|
||||||
# yt-dlp==2025.4.30
|
# yt-dlp==2025.4.30
|
||||||
72
webui.py
72
webui.py
@ -1,7 +1,7 @@
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from uuid import uuid4
|
from loguru import logger
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from webui.components import basic_settings, video_settings, audio_settings, subtitle_settings, script_settings, \
|
from webui.components import basic_settings, video_settings, audio_settings, subtitle_settings, script_settings, \
|
||||||
review_settings, merge_settings, system_settings
|
review_settings, merge_settings, system_settings
|
||||||
@ -18,7 +18,7 @@ st.set_page_config(
|
|||||||
initial_sidebar_state="auto",
|
initial_sidebar_state="auto",
|
||||||
menu_items={
|
menu_items={
|
||||||
"Report a bug": "https://github.com/linyqh/NarratoAI/issues",
|
"Report a bug": "https://github.com/linyqh/NarratoAI/issues",
|
||||||
'About': f"# NarratoAI:sunglasses: 📽️ \n #### Version: v{config.project_version} \n "
|
'About': f"# Narrato:blue[AI] :sunglasses: 📽️ \n #### Version: v{config.project_version} \n "
|
||||||
f"自动化影视解说视频详情请移步:https://github.com/linyqh/NarratoAI"
|
f"自动化影视解说视频详情请移步:https://github.com/linyqh/NarratoAI"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -37,17 +37,7 @@ def init_log():
|
|||||||
_lvl = "DEBUG"
|
_lvl = "DEBUG"
|
||||||
|
|
||||||
def format_record(record):
|
def format_record(record):
|
||||||
# 增加更多需要过滤的警告消息
|
# 简化日志格式化处理,不尝试按特定字符串过滤torch相关内容
|
||||||
ignore_messages = [
|
|
||||||
"Examining the path of torch.classes raised",
|
|
||||||
"torch.cuda.is_available()",
|
|
||||||
"CUDA initialization"
|
|
||||||
]
|
|
||||||
|
|
||||||
for msg in ignore_messages:
|
|
||||||
if msg in record["message"]:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
file_path = record["file"].path
|
file_path = record["file"].path
|
||||||
relative_path = os.path.relpath(file_path, config.root_dir)
|
relative_path = os.path.relpath(file_path, config.root_dir)
|
||||||
record["file"].path = f"./{relative_path}"
|
record["file"].path = f"./{relative_path}"
|
||||||
@ -59,8 +49,25 @@ def init_log():
|
|||||||
'- <level>{message}</>' + "\n"
|
'- <level>{message}</>' + "\n"
|
||||||
return _format
|
return _format
|
||||||
|
|
||||||
# 优化日志过滤器
|
# 替换为更简单的过滤方式,避免在过滤时访问message内容
|
||||||
def log_filter(record):
|
# 此处先不设置复杂的过滤器,等应用启动后再动态添加
|
||||||
|
logger.add(
|
||||||
|
sys.stdout,
|
||||||
|
level=_lvl,
|
||||||
|
format=format_record,
|
||||||
|
colorize=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 应用启动后,可以再添加更复杂的过滤器
|
||||||
|
def setup_advanced_filters():
|
||||||
|
"""在应用完全启动后设置高级过滤器"""
|
||||||
|
try:
|
||||||
|
for handler_id in logger._core.handlers:
|
||||||
|
logger.remove(handler_id)
|
||||||
|
|
||||||
|
# 重新添加带有高级过滤的处理器
|
||||||
|
def advanced_filter(record):
|
||||||
|
"""更复杂的过滤器,在应用启动后安全使用"""
|
||||||
ignore_messages = [
|
ignore_messages = [
|
||||||
"Examining the path of torch.classes raised",
|
"Examining the path of torch.classes raised",
|
||||||
"torch.cuda.is_available()",
|
"torch.cuda.is_available()",
|
||||||
@ -73,8 +80,21 @@ def init_log():
|
|||||||
level=_lvl,
|
level=_lvl,
|
||||||
format=format_record,
|
format=format_record,
|
||||||
colorize=True,
|
colorize=True,
|
||||||
filter=log_filter
|
filter=advanced_filter
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# 如果过滤器设置失败,确保日志仍然可用
|
||||||
|
logger.add(
|
||||||
|
sys.stdout,
|
||||||
|
level=_lvl,
|
||||||
|
format=format_record,
|
||||||
|
colorize=True
|
||||||
|
)
|
||||||
|
logger.error(f"设置高级日志过滤器失败: {e}")
|
||||||
|
|
||||||
|
# 将高级过滤器设置放到启动主逻辑后
|
||||||
|
import threading
|
||||||
|
threading.Timer(5.0, setup_advanced_filters).start()
|
||||||
|
|
||||||
|
|
||||||
def init_global_state():
|
def init_global_state():
|
||||||
@ -177,11 +197,18 @@ def main():
|
|||||||
"""主函数"""
|
"""主函数"""
|
||||||
init_log()
|
init_log()
|
||||||
init_global_state()
|
init_global_state()
|
||||||
utils.init_resources()
|
|
||||||
|
|
||||||
st.title(f"NarratoAI :sunglasses:📽️")
|
# 仅初始化基本资源,避免过早地加载依赖PyTorch的资源
|
||||||
|
# 检查是否能分解utils.init_resources()为基本资源和高级资源(如依赖PyTorch的资源)
|
||||||
|
try:
|
||||||
|
utils.init_resources()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"资源初始化时出现警告: {e}")
|
||||||
|
|
||||||
|
st.title(f"Narrato:blue[AI]:sunglasses: 📽️")
|
||||||
st.write(tr("Get Help"))
|
st.write(tr("Get Help"))
|
||||||
|
|
||||||
|
# 首先渲染不依赖PyTorch的UI部分
|
||||||
# 渲染基础设置面板
|
# 渲染基础设置面板
|
||||||
basic_settings.render_basic_settings(tr)
|
basic_settings.render_basic_settings(tr)
|
||||||
# 渲染合并设置
|
# 渲染合并设置
|
||||||
@ -196,13 +223,16 @@ def main():
|
|||||||
audio_settings.render_audio_panel(tr)
|
audio_settings.render_audio_panel(tr)
|
||||||
with panel[2]:
|
with panel[2]:
|
||||||
subtitle_settings.render_subtitle_panel(tr)
|
subtitle_settings.render_subtitle_panel(tr)
|
||||||
# 渲染系统设置面板
|
|
||||||
system_settings.render_system_panel(tr)
|
|
||||||
|
|
||||||
# 渲染视频审查面板
|
# 渲染视频审查面板
|
||||||
review_settings.render_review_panel(tr)
|
review_settings.render_review_panel(tr)
|
||||||
|
|
||||||
# 渲染生成按钮和处理逻辑
|
# 放到最后渲染可能使用PyTorch的部分
|
||||||
|
# 渲染系统设置面板
|
||||||
|
with panel[2]:
|
||||||
|
system_settings.render_system_panel(tr)
|
||||||
|
|
||||||
|
# 放到最后渲染生成按钮和处理逻辑
|
||||||
render_generate_button()
|
render_generate_button()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -285,8 +285,8 @@ def render_merge_settings(tr):
|
|||||||
error_message = str(e)
|
error_message = str(e)
|
||||||
if "moviepy" in error_message.lower():
|
if "moviepy" in error_message.lower():
|
||||||
st.error(tr("Error processing video files. Please check if the videos are valid MP4 files."))
|
st.error(tr("Error processing video files. Please check if the videos are valid MP4 files."))
|
||||||
elif "pysrt" in error_message.lower():
|
# elif "pysrt" in error_message.lower():
|
||||||
st.error(tr("Error processing subtitle files. Please check if the subtitles are valid SRT files."))
|
# st.error(tr("Error processing subtitle files. Please check if the subtitles are valid SRT files."))
|
||||||
else:
|
else:
|
||||||
st.error(f"{tr('Error during merge')}: {error_message}")
|
st.error(f"{tr('Error during merge')}: {error_message}")
|
||||||
|
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import time
|
|||||||
import asyncio
|
import asyncio
|
||||||
import traceback
|
import traceback
|
||||||
import requests
|
import requests
|
||||||
|
from app.utils import video_processor
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from requests.adapters import HTTPAdapter
|
from requests.adapters import HTTPAdapter
|
||||||
@ -12,7 +13,7 @@ from urllib3.util.retry import Retry
|
|||||||
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.utils.script_generator import ScriptProcessor
|
from app.utils.script_generator import ScriptProcessor
|
||||||
from app.utils import utils, video_processor, video_processor_v2, qwenvl_analyzer
|
from app.utils import utils, video_processor, qwenvl_analyzer
|
||||||
from webui.tools.base import create_vision_analyzer, get_batch_files, get_batch_timestamps, chekc_video_config
|
from webui.tools.base import create_vision_analyzer, get_batch_files, get_batch_timestamps, chekc_video_config
|
||||||
|
|
||||||
|
|
||||||
@ -64,21 +65,13 @@ def generate_script_docu(tr, params):
|
|||||||
os.makedirs(video_keyframes_dir, exist_ok=True)
|
os.makedirs(video_keyframes_dir, exist_ok=True)
|
||||||
|
|
||||||
# 初始化视频处理器
|
# 初始化视频处理器
|
||||||
if config.frames.get("version") == "v2":
|
processor = video_processor.VideoProcessor(params.video_origin_path)
|
||||||
processor = video_processor_v2.VideoProcessor(params.video_origin_path)
|
|
||||||
# 处理视频并提取关键帧
|
# 处理视频并提取关键帧
|
||||||
processor.process_video_pipeline(
|
processor.process_video_pipeline(
|
||||||
output_dir=video_keyframes_dir,
|
output_dir=video_keyframes_dir,
|
||||||
skip_seconds=st.session_state.get('skip_seconds'),
|
skip_seconds=st.session_state.get('skip_seconds'),
|
||||||
threshold=st.session_state.get('threshold')
|
threshold=st.session_state.get('threshold')
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
processor = video_processor.VideoProcessor(params.video_origin_path)
|
|
||||||
# 处理视频并提取关键帧
|
|
||||||
processor.process_video(
|
|
||||||
output_dir=video_keyframes_dir,
|
|
||||||
skip_seconds=0
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取所有关键文件路径
|
# 获取所有关键文件路径
|
||||||
for filename in sorted(os.listdir(video_keyframes_dir)):
|
for filename in sorted(os.listdir(video_keyframes_dir)):
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
合并视频和字幕文件
|
合并视频和字幕文件
|
||||||
"""
|
"""
|
||||||
from moviepy import VideoFileClip, concatenate_videoclips
|
from moviepy import VideoFileClip, concatenate_videoclips
|
||||||
import pysrt
|
# import pysrt
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
import psutil
|
# import psutil
|
||||||
import os
|
import os
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import torch
|
|
||||||
|
|
||||||
class PerformanceMonitor:
|
class PerformanceMonitor:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -11,19 +10,35 @@ class PerformanceMonitor:
|
|||||||
|
|
||||||
logger.debug(f"Memory usage: {memory_info.rss / 1024 / 1024:.2f} MB")
|
logger.debug(f"Memory usage: {memory_info.rss / 1024 / 1024:.2f} MB")
|
||||||
|
|
||||||
|
# 延迟导入torch并检查CUDA
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
gpu_memory = torch.cuda.memory_allocated() / 1024 / 1024
|
gpu_memory = torch.cuda.memory_allocated() / 1024 / 1024
|
||||||
logger.debug(f"GPU Memory usage: {gpu_memory:.2f} MB")
|
logger.debug(f"GPU Memory usage: {gpu_memory:.2f} MB")
|
||||||
|
except (ImportError, RuntimeError) as e:
|
||||||
|
# 无法导入torch或触发CUDA相关错误时,静默处理
|
||||||
|
logger.debug(f"无法获取GPU内存信息: {e}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def cleanup_resources():
|
def cleanup_resources():
|
||||||
|
# 延迟导入torch并清理CUDA
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
logger.debug("CUDA缓存已清理")
|
||||||
|
except (ImportError, RuntimeError) as e:
|
||||||
|
# 无法导入torch或触发CUDA相关错误时,静默处理
|
||||||
|
logger.debug(f"无法清理CUDA资源: {e}")
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
PerformanceMonitor.monitor_memory()
|
# 仅报告进程内存,不尝试获取GPU内存
|
||||||
|
process = psutil.Process(os.getpid())
|
||||||
|
memory_info = process.memory_info()
|
||||||
|
logger.debug(f"Memory usage after cleanup: {memory_info.rss / 1024 / 1024:.2f} MB")
|
||||||
|
|
||||||
def monitor_performance(func):
|
def monitor_performance(func):
|
||||||
"""性能监控装饰器"""
|
"""性能监控装饰器"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user