mirror of
https://github.com/linyqh/NarratoAI.git
synced 2025-12-12 03:02:48 +00:00
移除 opencv 和 sklearn 提取关键帧的代码
This commit is contained in:
parent
c3ea0bcc69
commit
f6c3f1640b
@ -3,10 +3,11 @@ import json
|
||||
import time
|
||||
import asyncio
|
||||
import requests
|
||||
from app.utils import video_processor
|
||||
from loguru import logger
|
||||
from typing import List, Dict, Any, Callable
|
||||
|
||||
from app.utils import utils, gemini_analyzer, video_processor, video_processor_v2
|
||||
from app.utils import utils, gemini_analyzer, video_processor
|
||||
from app.utils.script_generator import ScriptProcessor
|
||||
from app.config import config
|
||||
|
||||
@ -105,19 +106,12 @@ class ScriptGenerator:
|
||||
os.makedirs(video_keyframes_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
if config.frames.get("version") == "v2":
|
||||
processor = video_processor_v2.VideoProcessor(video_path)
|
||||
processor = video_processor.VideoProcessor(video_path)
|
||||
processor.process_video_pipeline(
|
||||
output_dir=video_keyframes_dir,
|
||||
skip_seconds=skip_seconds,
|
||||
threshold=threshold
|
||||
)
|
||||
else:
|
||||
processor = video_processor.VideoProcessor(video_path)
|
||||
processor.process_video(
|
||||
output_dir=video_keyframes_dir,
|
||||
skip_seconds=skip_seconds
|
||||
)
|
||||
|
||||
for filename in sorted(os.listdir(video_keyframes_dir)):
|
||||
if filename.endswith('.jpg'):
|
||||
|
||||
@ -4,7 +4,7 @@ import re
|
||||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
from faster_whisper import WhisperModel
|
||||
# from faster_whisper import WhisperModel
|
||||
from timeit import default_timer as timer
|
||||
from loguru import logger
|
||||
import google.generativeai as genai
|
||||
@ -45,12 +45,25 @@ def create(audio_file, subtitle_file: str = ""):
|
||||
)
|
||||
return None
|
||||
|
||||
# 尝试使用 CUDA,如果失败则回退到 CPU
|
||||
# 首先使用CPU模式,不触发CUDA检查
|
||||
use_cuda = False
|
||||
try:
|
||||
# 在函数中延迟导入torch,而不是在全局范围内
|
||||
# 使用安全的方式检查CUDA可用性
|
||||
def check_cuda_available():
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
return torch.cuda.is_available()
|
||||
except (ImportError, RuntimeError) as e:
|
||||
logger.warning(f"检查CUDA可用性时出错: {e}")
|
||||
return False
|
||||
|
||||
# 仅当明确需要时才检查CUDA
|
||||
use_cuda = check_cuda_available()
|
||||
|
||||
if use_cuda:
|
||||
logger.info(f"尝试使用 CUDA 加载模型: {model_path}")
|
||||
try:
|
||||
model = WhisperModel(
|
||||
model_size_or_path=model_path,
|
||||
device="cuda",
|
||||
@ -63,18 +76,18 @@ def create(audio_file, subtitle_file: str = ""):
|
||||
except Exception as e:
|
||||
logger.warning(f"CUDA 加载失败,错误信息: {str(e)}")
|
||||
logger.warning("回退到 CPU 模式")
|
||||
device = "cpu"
|
||||
compute_type = "int8"
|
||||
use_cuda = False
|
||||
else:
|
||||
logger.info("未检测到 CUDA,使用 CPU 模式")
|
||||
device = "cpu"
|
||||
compute_type = "int8"
|
||||
except ImportError:
|
||||
logger.warning("未安装 torch,使用 CPU 模式")
|
||||
device = "cpu"
|
||||
compute_type = "int8"
|
||||
logger.info("使用 CPU 模式")
|
||||
except Exception as e:
|
||||
logger.warning(f"CUDA检查过程出错: {e}")
|
||||
logger.warning("默认使用CPU模式")
|
||||
use_cuda = False
|
||||
|
||||
if device == "cpu":
|
||||
# 如果CUDA不可用或加载失败,使用CPU
|
||||
if not use_cuda:
|
||||
device = "cpu"
|
||||
compute_type = "int8"
|
||||
logger.info(f"使用 CPU 加载模型: {model_path}")
|
||||
model = WhisperModel(
|
||||
model_size_or_path=model_path,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import traceback
|
||||
|
||||
import pysrt
|
||||
# import pysrt
|
||||
from typing import Optional
|
||||
from typing import List
|
||||
from loguru import logger
|
||||
|
||||
@ -2,7 +2,7 @@ import os
|
||||
import json
|
||||
import traceback
|
||||
from loguru import logger
|
||||
import tiktoken
|
||||
# import tiktoken
|
||||
from typing import List, Dict
|
||||
from datetime import datetime
|
||||
from openai import OpenAI
|
||||
@ -94,12 +94,12 @@ class OpenAIGenerator(BaseGenerator):
|
||||
"user": "script_generator"
|
||||
}
|
||||
|
||||
# 初始化token计数器
|
||||
try:
|
||||
self.encoding = tiktoken.encoding_for_model(self.model_name)
|
||||
except KeyError:
|
||||
logger.warning(f"未找到模型 {self.model_name} 的专用编码器,使用默认编码器")
|
||||
self.encoding = tiktoken.get_encoding("cl100k_base")
|
||||
# # 初始化token计数器
|
||||
# try:
|
||||
# self.encoding = tiktoken.encoding_for_model(self.model_name)
|
||||
# except KeyError:
|
||||
# logger.warning(f"未找到模型 {self.model_name} 的专用编码器,使用默认编码器")
|
||||
# self.encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
def _generate(self, messages: list, params: dict) -> any:
|
||||
"""实现OpenAI特定的生成逻辑"""
|
||||
|
||||
@ -1,237 +1,349 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from sklearn.cluster import MiniBatchKMeans
|
||||
"""
|
||||
视频帧提取工具
|
||||
|
||||
这个模块提供了简单高效的视频帧提取功能。主要特点:
|
||||
1. 使用ffmpeg进行视频处理,支持硬件加速
|
||||
2. 按指定时间间隔提取视频关键帧
|
||||
3. 支持多种视频格式
|
||||
4. 支持高清视频帧输出
|
||||
5. 直接从原视频提取高质量关键帧
|
||||
|
||||
不依赖OpenCV和sklearn等库,只使用ffmpeg作为外部依赖,降低了安装和使用的复杂度。
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import List, Tuple, Generator
|
||||
import time
|
||||
import subprocess
|
||||
from typing import List, Dict
|
||||
from loguru import logger
|
||||
import gc
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class VideoProcessor:
|
||||
def __init__(self, video_path: str, batch_size: int = 100):
|
||||
def __init__(self, video_path: str):
|
||||
"""
|
||||
初始化视频处理器
|
||||
|
||||
Args:
|
||||
video_path: 视频文件路径
|
||||
batch_size: 批处理大小,控制内存使用
|
||||
"""
|
||||
if not os.path.exists(video_path):
|
||||
raise FileNotFoundError(f"视频文件不存在: {video_path}")
|
||||
|
||||
self.video_path = video_path
|
||||
self.batch_size = batch_size
|
||||
self.cap = cv2.VideoCapture(video_path)
|
||||
self.video_info = self._get_video_info()
|
||||
self.fps = float(self.video_info.get('fps', 25))
|
||||
self.duration = float(self.video_info.get('duration', 0))
|
||||
self.width = int(self.video_info.get('width', 0))
|
||||
self.height = int(self.video_info.get('height', 0))
|
||||
self.total_frames = int(self.fps * self.duration)
|
||||
|
||||
if not self.cap.isOpened():
|
||||
raise RuntimeError(f"无法打开视频文件: {video_path}")
|
||||
|
||||
self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
self.fps = int(self.cap.get(cv2.CAP_PROP_FPS))
|
||||
|
||||
def __del__(self):
|
||||
"""析构函数,确保视频资源被释放"""
|
||||
if hasattr(self, 'cap'):
|
||||
self.cap.release()
|
||||
gc.collect()
|
||||
|
||||
def preprocess_video(self) -> Generator[Tuple[int, np.ndarray], None, None]:
|
||||
def _get_video_info(self) -> Dict[str, str]:
|
||||
"""
|
||||
使用生成器方式分批读取视频帧
|
||||
|
||||
Yields:
|
||||
Tuple[int, np.ndarray]: (帧索引, 视频帧)
|
||||
"""
|
||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||||
frame_idx = 0
|
||||
|
||||
while self.cap.isOpened():
|
||||
ret, frame = self.cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
# 降低分辨率以减少内存使用
|
||||
frame = cv2.resize(frame, (0, 0), fx=0.5, fy=0.5)
|
||||
yield frame_idx, frame
|
||||
|
||||
frame_idx += 1
|
||||
|
||||
# 定期进行垃圾回收
|
||||
if frame_idx % 1000 == 0:
|
||||
gc.collect()
|
||||
|
||||
def detect_shot_boundaries(self, threshold: int = 70) -> List[int]:
|
||||
"""
|
||||
使用批处理方式检测镜头边界
|
||||
|
||||
Args:
|
||||
threshold: 差异阈值
|
||||
使用ffprobe获取视频信息
|
||||
|
||||
Returns:
|
||||
List[int]: 镜头边界帧的索引列表
|
||||
Dict[str, str]: 包含视频基本信息的字典
|
||||
"""
|
||||
shot_boundaries = []
|
||||
prev_frame = None
|
||||
prev_idx = -1
|
||||
cmd = [
|
||||
"ffprobe",
|
||||
"-v", "error",
|
||||
"-select_streams", "v:0",
|
||||
"-show_entries", "stream=width,height,r_frame_rate,duration",
|
||||
"-of", "default=noprint_wrappers=1:nokey=0",
|
||||
self.video_path
|
||||
]
|
||||
|
||||
pbar = tqdm(self.preprocess_video(),
|
||||
total=self.total_frames,
|
||||
desc="检测镜头边界",
|
||||
unit="帧")
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
||||
lines = result.stdout.strip().split('\n')
|
||||
info = {}
|
||||
for line in lines:
|
||||
if '=' in line:
|
||||
key, value = line.split('=', 1)
|
||||
info[key] = value
|
||||
|
||||
for frame_idx, curr_frame in pbar:
|
||||
if prev_frame is not None:
|
||||
prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)
|
||||
curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2GRAY)
|
||||
# 处理帧率(可能是分数形式)
|
||||
if 'r_frame_rate' in info:
|
||||
try:
|
||||
num, den = map(int, info['r_frame_rate'].split('/'))
|
||||
info['fps'] = str(num / den)
|
||||
except ValueError:
|
||||
info['fps'] = info.get('r_frame_rate', '25')
|
||||
|
||||
diff = np.mean(np.abs(curr_gray.astype(float) - prev_gray.astype(float)))
|
||||
if diff > threshold:
|
||||
shot_boundaries.append(frame_idx)
|
||||
pbar.set_postfix({"检测到边界": len(shot_boundaries)})
|
||||
return info
|
||||
|
||||
prev_frame = curr_frame.copy()
|
||||
prev_idx = frame_idx
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"获取视频信息失败: {e.stderr}")
|
||||
return {
|
||||
'width': '1280',
|
||||
'height': '720',
|
||||
'fps': '25',
|
||||
'duration': '0'
|
||||
}
|
||||
|
||||
del curr_frame
|
||||
if frame_idx % 100 == 0:
|
||||
gc.collect()
|
||||
|
||||
return shot_boundaries
|
||||
|
||||
def process_shot(self, shot_frames: List[Tuple[int, np.ndarray]]) -> Tuple[np.ndarray, int]:
|
||||
def extract_frames_by_interval(self, output_dir: str, interval_seconds: float = 5.0,
|
||||
use_hw_accel: bool = True, skip_seconds: float = 0.0) -> List[int]:
|
||||
"""
|
||||
处理单个镜头的帧
|
||||
按指定时间间隔提取视频帧
|
||||
|
||||
Args:
|
||||
shot_frames: 镜头中的帧列表
|
||||
output_dir: 输出目录
|
||||
interval_seconds: 帧提取间隔(秒)
|
||||
use_hw_accel: 是否使用硬件加速
|
||||
skip_seconds: 跳过视频开头的秒数
|
||||
|
||||
Returns:
|
||||
Tuple[np.ndarray, int]: (关键帧, 帧索引)
|
||||
List[int]: 提取的帧号列表
|
||||
"""
|
||||
if not shot_frames:
|
||||
return None, -1
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
frame_features = []
|
||||
frame_indices = []
|
||||
# 计算起始时间和帧提取点
|
||||
start_time = skip_seconds
|
||||
end_time = self.duration
|
||||
extraction_times = []
|
||||
|
||||
for idx, frame in tqdm(shot_frames,
|
||||
desc="处理镜头帧",
|
||||
unit="帧",
|
||||
leave=False):
|
||||
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
||||
resized_gray = cv2.resize(gray, (32, 32))
|
||||
frame_features.append(resized_gray.flatten())
|
||||
frame_indices.append(idx)
|
||||
current_time = start_time
|
||||
while current_time < end_time:
|
||||
extraction_times.append(current_time)
|
||||
current_time += interval_seconds
|
||||
|
||||
frame_features = np.array(frame_features)
|
||||
if not extraction_times:
|
||||
logger.warning("未找到需要提取的帧")
|
||||
return []
|
||||
|
||||
kmeans = MiniBatchKMeans(n_clusters=1, batch_size=min(len(frame_features), 100),
|
||||
random_state=0).fit(frame_features)
|
||||
# 确定硬件加速器选项
|
||||
hw_accel = []
|
||||
if use_hw_accel:
|
||||
# 尝试检测可用的硬件加速器
|
||||
hw_accel_options = self._detect_hw_accelerator()
|
||||
if hw_accel_options:
|
||||
hw_accel = hw_accel_options
|
||||
logger.info(f"使用硬件加速: {' '.join(hw_accel)}")
|
||||
else:
|
||||
logger.warning("未检测到可用的硬件加速器,使用软件解码")
|
||||
|
||||
center_idx = np.argmin(np.sum((frame_features - kmeans.cluster_centers_[0]) ** 2, axis=1))
|
||||
# 提取帧
|
||||
frame_numbers = []
|
||||
for i, timestamp in enumerate(tqdm(extraction_times, desc="提取视频帧")):
|
||||
frame_number = int(timestamp * self.fps)
|
||||
frame_numbers.append(frame_number)
|
||||
|
||||
return shot_frames[center_idx][1], frame_indices[center_idx]
|
||||
# 格式化时间戳字符串 (HHMMSSmmm)
|
||||
hours = int(timestamp // 3600)
|
||||
minutes = int((timestamp % 3600) // 60)
|
||||
seconds = int(timestamp % 60)
|
||||
milliseconds = int((timestamp % 1) * 1000)
|
||||
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}{milliseconds:03d}"
|
||||
|
||||
def extract_keyframes(self, shot_boundaries: List[int]) -> Generator[Tuple[np.ndarray, int], None, None]:
|
||||
output_path = os.path.join(output_dir, f"keyframe_{frame_number:06d}_{time_str}.jpg")
|
||||
|
||||
# 使用ffmpeg提取单帧
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-hide_banner",
|
||||
"-loglevel", "error",
|
||||
]
|
||||
|
||||
# 添加硬件加速参数
|
||||
cmd.extend(hw_accel)
|
||||
|
||||
cmd.extend([
|
||||
"-ss", str(timestamp),
|
||||
"-i", self.video_path,
|
||||
"-vframes", "1",
|
||||
"-q:v", "1", # 最高质量
|
||||
"-y",
|
||||
output_path
|
||||
])
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, check=True, capture_output=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.warning(f"提取帧 {frame_number} 失败: {e.stderr}")
|
||||
|
||||
logger.info(f"成功提取了 {len(frame_numbers)} 个视频帧")
|
||||
return frame_numbers
|
||||
|
||||
def _detect_hw_accelerator(self) -> List[str]:
|
||||
"""
|
||||
使用生成器方式提取关键帧
|
||||
检测系统可用的硬件加速器
|
||||
|
||||
Args:
|
||||
shot_boundaries: 镜头边界列表
|
||||
|
||||
Yields:
|
||||
Tuple[np.ndarray, int]: (关键帧, 帧索引)
|
||||
Returns:
|
||||
List[str]: 硬件加速器ffmpeg命令参数
|
||||
"""
|
||||
shot_frames = []
|
||||
current_shot_start = 0
|
||||
# 检测操作系统
|
||||
import platform
|
||||
system = platform.system().lower()
|
||||
|
||||
for frame_idx, frame in self.preprocess_video():
|
||||
if frame_idx in shot_boundaries:
|
||||
if shot_frames:
|
||||
keyframe, keyframe_idx = self.process_shot(shot_frames)
|
||||
if keyframe is not None:
|
||||
yield keyframe, keyframe_idx
|
||||
# 测试不同的硬件加速器
|
||||
accelerators = []
|
||||
|
||||
# 清理内存
|
||||
shot_frames.clear()
|
||||
gc.collect()
|
||||
if system == 'darwin': # macOS
|
||||
# 测试 videotoolbox (Apple 硬件加速)
|
||||
test_cmd = [
|
||||
"ffmpeg",
|
||||
"-hide_banner",
|
||||
"-loglevel", "error",
|
||||
"-hwaccel", "videotoolbox",
|
||||
"-i", self.video_path,
|
||||
"-t", "0.1",
|
||||
"-f", "null",
|
||||
"-"
|
||||
]
|
||||
try:
|
||||
subprocess.run(test_cmd, capture_output=True, check=True)
|
||||
return ["-hwaccel", "videotoolbox"]
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
|
||||
current_shot_start = frame_idx
|
||||
elif system == 'linux':
|
||||
# 测试 VAAPI
|
||||
test_cmd = [
|
||||
"ffmpeg",
|
||||
"-hide_banner",
|
||||
"-loglevel", "error",
|
||||
"-hwaccel", "vaapi",
|
||||
"-i", self.video_path,
|
||||
"-t", "0.1",
|
||||
"-f", "null",
|
||||
"-"
|
||||
]
|
||||
try:
|
||||
subprocess.run(test_cmd, capture_output=True, check=True)
|
||||
return ["-hwaccel", "vaapi"]
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
|
||||
shot_frames.append((frame_idx, frame))
|
||||
# 尝试 CUDA
|
||||
test_cmd = [
|
||||
"ffmpeg",
|
||||
"-hide_banner",
|
||||
"-loglevel", "error",
|
||||
"-hwaccel", "cuda",
|
||||
"-i", self.video_path,
|
||||
"-t", "0.1",
|
||||
"-f", "null",
|
||||
"-"
|
||||
]
|
||||
try:
|
||||
subprocess.run(test_cmd, capture_output=True, check=True)
|
||||
return ["-hwaccel", "cuda"]
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
|
||||
# 控制单个镜头的最大帧数
|
||||
if len(shot_frames) > self.batch_size:
|
||||
keyframe, keyframe_idx = self.process_shot(shot_frames)
|
||||
if keyframe is not None:
|
||||
yield keyframe, keyframe_idx
|
||||
shot_frames.clear()
|
||||
gc.collect()
|
||||
elif system == 'windows':
|
||||
# 测试 CUDA
|
||||
test_cmd = [
|
||||
"ffmpeg",
|
||||
"-hide_banner",
|
||||
"-loglevel", "error",
|
||||
"-hwaccel", "cuda",
|
||||
"-i", self.video_path,
|
||||
"-t", "0.1",
|
||||
"-f", "null",
|
||||
"-"
|
||||
]
|
||||
try:
|
||||
subprocess.run(test_cmd, capture_output=True, check=True)
|
||||
return ["-hwaccel", "cuda"]
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
|
||||
# 处理最后一个镜头
|
||||
if shot_frames:
|
||||
keyframe, keyframe_idx = self.process_shot(shot_frames)
|
||||
if keyframe is not None:
|
||||
yield keyframe, keyframe_idx
|
||||
# 测试 D3D11VA
|
||||
test_cmd = [
|
||||
"ffmpeg",
|
||||
"-hide_banner",
|
||||
"-loglevel", "error",
|
||||
"-hwaccel", "d3d11va",
|
||||
"-i", self.video_path,
|
||||
"-t", "0.1",
|
||||
"-f", "null",
|
||||
"-"
|
||||
]
|
||||
try:
|
||||
subprocess.run(test_cmd, capture_output=True, check=True)
|
||||
return ["-hwaccel", "d3d11va"]
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
|
||||
def process_video(self, output_dir: str, skip_seconds: float = 0) -> None:
|
||||
# 测试 DXVA2
|
||||
test_cmd = [
|
||||
"ffmpeg",
|
||||
"-hide_banner",
|
||||
"-loglevel", "error",
|
||||
"-hwaccel", "dxva2",
|
||||
"-i", self.video_path,
|
||||
"-t", "0.1",
|
||||
"-f", "null",
|
||||
"-"
|
||||
]
|
||||
try:
|
||||
subprocess.run(test_cmd, capture_output=True, check=True)
|
||||
return ["-hwaccel", "dxva2"]
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
|
||||
# 如果没有找到可用的硬件加速器
|
||||
return []
|
||||
|
||||
def process_video_pipeline(self,
|
||||
output_dir: str,
|
||||
skip_seconds: float = 0.0,
|
||||
threshold: int = 20, # 此参数保留但不使用
|
||||
compressed_width: int = 320, # 此参数保留但不使用
|
||||
keep_temp: bool = False, # 此参数保留但不使用
|
||||
interval_seconds: float = 5.0,
|
||||
use_hw_accel: bool = True) -> None:
|
||||
"""
|
||||
处理视频并提取关键帧,使用分批处理方式
|
||||
执行简化的视频处理流程,直接从原视频按固定时间间隔提取帧
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录
|
||||
skip_seconds: 跳过视频开头的秒数
|
||||
threshold: 保留参数,不使用
|
||||
compressed_width: 保留参数,不使用
|
||||
keep_temp: 保留参数,不使用
|
||||
interval_seconds: 帧提取间隔(秒)
|
||||
use_hw_accel: 是否使用硬件加速
|
||||
"""
|
||||
try:
|
||||
# 创建输出目录
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 计算要跳过的帧数
|
||||
skip_frames = int(skip_seconds * self.fps)
|
||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, skip_frames)
|
||||
try:
|
||||
# 直接从原视频提取关键帧
|
||||
logger.info("从视频直接提取关键帧...")
|
||||
self.extract_frames_by_interval(
|
||||
output_dir,
|
||||
interval_seconds=interval_seconds,
|
||||
use_hw_accel=use_hw_accel,
|
||||
skip_seconds=skip_seconds
|
||||
)
|
||||
|
||||
# 检测镜头边界
|
||||
logger.info("开始检测镜头边界...")
|
||||
shot_boundaries = self.detect_shot_boundaries()
|
||||
|
||||
# 提取关键帧
|
||||
logger.info("开始提取关键帧...")
|
||||
frame_count = 0
|
||||
|
||||
pbar = tqdm(self.extract_keyframes(shot_boundaries),
|
||||
desc="提取关键帧",
|
||||
unit="帧")
|
||||
|
||||
for keyframe, frame_idx in pbar:
|
||||
if frame_idx < skip_frames:
|
||||
continue
|
||||
|
||||
# 计算时间戳
|
||||
timestamp = frame_idx / self.fps
|
||||
hours = int(timestamp // 3600)
|
||||
minutes = int((timestamp % 3600) // 60)
|
||||
seconds = int(timestamp % 60)
|
||||
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}"
|
||||
|
||||
# 保存关键帧
|
||||
output_path = os.path.join(output_dir,
|
||||
f'keyframe_{frame_idx:06d}_{time_str}.jpg')
|
||||
cv2.imwrite(output_path, keyframe)
|
||||
frame_count += 1
|
||||
|
||||
pbar.set_postfix({"已保存": frame_count})
|
||||
|
||||
if frame_count % 10 == 0:
|
||||
gc.collect()
|
||||
|
||||
logger.info(f"关键帧提取完成,共保存 {frame_count} 帧到 {output_dir}")
|
||||
logger.info(f"处理完成!视频帧已保存在: {output_dir}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"视频处理失败: {str(e)}")
|
||||
import traceback
|
||||
logger.error(f"视频处理失败: \n{traceback.format_exc()}")
|
||||
raise
|
||||
finally:
|
||||
# 确保资源被释放
|
||||
self.cap.release()
|
||||
gc.collect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# 使用示例
|
||||
processor = VideoProcessor("./resource/videos/test.mp4")
|
||||
|
||||
# 设置间隔为3秒提取帧
|
||||
processor.process_video_pipeline(
|
||||
output_dir="output",
|
||||
interval_seconds=3.0,
|
||||
use_hw_accel=True
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
print(f"处理完成!总耗时: {end_time - start_time:.2f} 秒")
|
||||
|
||||
@ -1,382 +0,0 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from sklearn.cluster import KMeans
|
||||
import os
|
||||
import re
|
||||
from typing import List, Tuple, Generator
|
||||
from loguru import logger
|
||||
import subprocess
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class VideoProcessor:
|
||||
def __init__(self, video_path: str):
|
||||
"""
|
||||
初始化视频处理器
|
||||
|
||||
Args:
|
||||
video_path: 视频文件路径
|
||||
"""
|
||||
if not os.path.exists(video_path):
|
||||
raise FileNotFoundError(f"视频文件不存在: {video_path}")
|
||||
|
||||
self.video_path = video_path
|
||||
self.cap = cv2.VideoCapture(video_path)
|
||||
|
||||
if not self.cap.isOpened():
|
||||
raise RuntimeError(f"无法打开视频文件: {video_path}")
|
||||
|
||||
self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
self.fps = int(self.cap.get(cv2.CAP_PROP_FPS))
|
||||
|
||||
def __del__(self):
|
||||
"""析构函数,确保视频资源被释放"""
|
||||
if hasattr(self, 'cap'):
|
||||
self.cap.release()
|
||||
|
||||
def preprocess_video(self) -> Generator[np.ndarray, None, None]:
|
||||
"""
|
||||
使用生成器方式读取视频帧
|
||||
|
||||
Yields:
|
||||
np.ndarray: 视频帧
|
||||
"""
|
||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # 重置到视频开始
|
||||
while self.cap.isOpened():
|
||||
ret, frame = self.cap.read()
|
||||
if not ret:
|
||||
break
|
||||
yield frame
|
||||
|
||||
def detect_shot_boundaries(self, frames: List[np.ndarray], threshold: int = 30) -> List[int]:
|
||||
"""
|
||||
使用帧差法检测镜头边界
|
||||
|
||||
Args:
|
||||
frames: 视频帧列表
|
||||
threshold: 差异阈值,默认值调低为30
|
||||
|
||||
Returns:
|
||||
List[int]: 镜头边界帧的索引列表
|
||||
"""
|
||||
shot_boundaries = []
|
||||
if len(frames) < 2: # 添加帧数检查
|
||||
logger.warning("视频帧数过少,无法检测场景边界")
|
||||
return [len(frames) - 1] # 返回最后一帧作为边界
|
||||
|
||||
for i in range(1, len(frames)):
|
||||
prev_frame = cv2.cvtColor(frames[i - 1], cv2.COLOR_BGR2GRAY)
|
||||
curr_frame = cv2.cvtColor(frames[i], cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# 计算帧差
|
||||
diff = np.mean(np.abs(curr_frame.astype(float) - prev_frame.astype(float)))
|
||||
|
||||
if diff > threshold:
|
||||
shot_boundaries.append(i)
|
||||
|
||||
# 如果没有检测到任何边界,至少返回最后一帧
|
||||
if not shot_boundaries:
|
||||
logger.warning("未检测到场景边界,将视频作为单个场景处理")
|
||||
shot_boundaries.append(len(frames) - 1)
|
||||
|
||||
return shot_boundaries
|
||||
|
||||
def extract_keyframes(self, frames: List[np.ndarray], shot_boundaries: List[int]) -> Tuple[
|
||||
List[np.ndarray], List[int]]:
|
||||
"""
|
||||
从每个镜头中提取关键帧
|
||||
|
||||
Args:
|
||||
frames: 视频帧列表
|
||||
shot_boundaries: 镜头边界列表
|
||||
|
||||
Returns:
|
||||
Tuple[List[np.ndarray], List[int]]: 关键帧列表和对应的帧索引
|
||||
"""
|
||||
keyframes = []
|
||||
keyframe_indices = []
|
||||
|
||||
for i in tqdm(range(len(shot_boundaries)), desc="提取关键帧"):
|
||||
start = shot_boundaries[i - 1] if i > 0 else 0
|
||||
end = shot_boundaries[i]
|
||||
shot_frames = frames[start:end]
|
||||
|
||||
if not shot_frames:
|
||||
continue
|
||||
|
||||
# 将每一帧转换为灰度图并展平为一维数组
|
||||
frame_features = np.array([cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).flatten()
|
||||
for frame in shot_frames])
|
||||
|
||||
try:
|
||||
# 尝试使用 KMeans
|
||||
kmeans = KMeans(n_clusters=1, random_state=0).fit(frame_features)
|
||||
center_idx = np.argmin(np.sum((frame_features - kmeans.cluster_centers_[0]) ** 2, axis=1))
|
||||
except Exception as e:
|
||||
logger.warning(f"KMeans 聚类失败,使用备选方案: {str(e)}")
|
||||
# 备选方案:选择镜头中间的帧作为关键帧
|
||||
center_idx = len(shot_frames) // 2
|
||||
|
||||
keyframes.append(shot_frames[center_idx])
|
||||
keyframe_indices.append(start + center_idx)
|
||||
|
||||
return keyframes, keyframe_indices
|
||||
|
||||
def save_keyframes(self, keyframes: List[np.ndarray], keyframe_indices: List[int],
|
||||
output_dir: str, desc: str = "保存关键帧") -> None:
|
||||
"""
|
||||
保存关键帧到指定目录,文件名格式为:keyframe_帧序号_时间戳.jpg
|
||||
时间戳精确到毫秒,格式为:HHMMSSmmm
|
||||
"""
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
for keyframe, frame_idx in tqdm(zip(keyframes, keyframe_indices),
|
||||
total=len(keyframes),
|
||||
desc=desc):
|
||||
# 计算精确到毫秒的时间戳
|
||||
timestamp = frame_idx / self.fps
|
||||
hours = int(timestamp // 3600)
|
||||
minutes = int((timestamp % 3600) // 60)
|
||||
seconds = int(timestamp % 60)
|
||||
milliseconds = int((timestamp % 1) * 1000) # 计算毫秒部分
|
||||
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}{milliseconds:03d}"
|
||||
|
||||
output_path = os.path.join(output_dir,
|
||||
f'keyframe_{frame_idx:06d}_{time_str}.jpg')
|
||||
cv2.imwrite(output_path, keyframe)
|
||||
|
||||
def extract_frames_by_numbers(self, frame_numbers: List[int], output_folder: str) -> None:
|
||||
"""
|
||||
根据指定的帧号提取帧,如果多个帧在同一毫秒内,只保留一个
|
||||
"""
|
||||
if not frame_numbers:
|
||||
raise ValueError("未提供帧号列表")
|
||||
|
||||
if any(fn >= self.total_frames or fn < 0 for fn in frame_numbers):
|
||||
raise ValueError("存在无效的帧号")
|
||||
|
||||
if not os.path.exists(output_folder):
|
||||
os.makedirs(output_folder)
|
||||
|
||||
# 用于记录已处理的时间戳(毫秒)
|
||||
processed_timestamps = set()
|
||||
|
||||
for frame_number in tqdm(frame_numbers, desc="提取高清帧"):
|
||||
# 计算精确到毫秒的时间戳
|
||||
timestamp = frame_number / self.fps
|
||||
timestamp_ms = int(timestamp * 1000) # 转换为毫秒
|
||||
|
||||
# 如果这一毫秒已经处理过,跳过
|
||||
if timestamp_ms in processed_timestamps:
|
||||
continue
|
||||
|
||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
|
||||
ret, frame = self.cap.read()
|
||||
|
||||
if ret:
|
||||
# 记录这一毫秒已经处理
|
||||
processed_timestamps.add(timestamp_ms)
|
||||
|
||||
# 计算时间戳字符串
|
||||
hours = int(timestamp // 3600)
|
||||
minutes = int((timestamp % 3600) // 60)
|
||||
seconds = int(timestamp % 60)
|
||||
milliseconds = int((timestamp % 1) * 1000) # 计算毫秒部分
|
||||
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}{milliseconds:03d}"
|
||||
|
||||
output_path = os.path.join(output_folder,
|
||||
f"keyframe_{frame_number:06d}_{time_str}.jpg")
|
||||
cv2.imwrite(output_path, frame)
|
||||
else:
|
||||
logger.info(f"无法读取帧 {frame_number}")
|
||||
|
||||
logger.info(f"共提取了 {len(processed_timestamps)} 个不同时间戳的帧")
|
||||
|
||||
@staticmethod
|
||||
def extract_numbers_from_folder(folder_path: str) -> List[int]:
|
||||
"""
|
||||
从文件夹中提取帧号
|
||||
|
||||
Args:
|
||||
folder_path: 关键帧文件夹路径
|
||||
|
||||
Returns:
|
||||
List[int]: 排序后的帧号列表
|
||||
"""
|
||||
files = [f for f in os.listdir(folder_path) if f.endswith('.jpg')]
|
||||
# 更新正则表达式以匹配新的文件名格式:keyframe_000123_010534123.jpg
|
||||
pattern = re.compile(r'keyframe_(\d+)_\d{9}\.jpg$')
|
||||
numbers = []
|
||||
|
||||
for f in files:
|
||||
match = pattern.search(f)
|
||||
if match:
|
||||
numbers.append(int(match.group(1)))
|
||||
else:
|
||||
logger.warning(f"文件名格式不匹配: {f}")
|
||||
|
||||
if not numbers:
|
||||
logger.error(f"在目录 {folder_path} 中未找到有效的关键帧文件")
|
||||
|
||||
return sorted(numbers)
|
||||
|
||||
def process_video(self, output_dir: str, skip_seconds: float = 0, threshold: int = 30) -> None:
|
||||
"""
|
||||
处理视频并提取关键帧
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录
|
||||
skip_seconds: 跳过视频开头的秒数
|
||||
"""
|
||||
skip_frames = int(skip_seconds * self.fps)
|
||||
|
||||
logger.info("读取视频帧...")
|
||||
frames = []
|
||||
for frame in tqdm(self.preprocess_video(),
|
||||
total=self.total_frames,
|
||||
desc="读取视频"):
|
||||
frames.append(frame)
|
||||
|
||||
frames = frames[skip_frames:]
|
||||
|
||||
if not frames:
|
||||
raise ValueError(f"跳过 {skip_seconds} 秒后没有剩余帧可以处理")
|
||||
|
||||
logger.info("检测场景边界...")
|
||||
shot_boundaries = self.detect_shot_boundaries(frames, threshold)
|
||||
logger.info(f"检测到 {len(shot_boundaries)} 个场景边界")
|
||||
|
||||
keyframes, keyframe_indices = self.extract_keyframes(frames, shot_boundaries)
|
||||
|
||||
adjusted_indices = [idx + skip_frames for idx in keyframe_indices]
|
||||
self.save_keyframes(keyframes, adjusted_indices, output_dir, desc="保存压缩关键帧")
|
||||
|
||||
def process_video_pipeline(self,
|
||||
output_dir: str,
|
||||
skip_seconds: float = 0,
|
||||
threshold: int = 20, # 降低默认阈值
|
||||
compressed_width: int = 320,
|
||||
keep_temp: bool = False) -> None:
|
||||
"""
|
||||
执行完整的视频处理流程
|
||||
|
||||
Args:
|
||||
threshold: 降低默认阈值为20,使场景检测更敏感
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
temp_dir = os.path.join(output_dir, 'temp')
|
||||
compressed_dir = os.path.join(temp_dir, 'compressed')
|
||||
mini_frames_dir = os.path.join(temp_dir, 'mini_frames')
|
||||
hd_frames_dir = output_dir
|
||||
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
os.makedirs(compressed_dir, exist_ok=True)
|
||||
os.makedirs(mini_frames_dir, exist_ok=True)
|
||||
os.makedirs(hd_frames_dir, exist_ok=True)
|
||||
|
||||
mini_processor = None
|
||||
compressed_video = None
|
||||
|
||||
try:
|
||||
# 1. 压缩视频
|
||||
video_name = os.path.splitext(os.path.basename(self.video_path))[0]
|
||||
compressed_video = os.path.join(compressed_dir, f"{video_name}_compressed.mp4")
|
||||
|
||||
# 获取原始视频的宽度和高度
|
||||
original_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
original_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
|
||||
logger.info("步骤1: 压缩视频...")
|
||||
if original_width > original_height:
|
||||
# 横版视频
|
||||
scale_filter = f'scale={compressed_width}:-1'
|
||||
else:
|
||||
# 竖版视频
|
||||
scale_filter = f'scale=-1:{compressed_width}'
|
||||
|
||||
ffmpeg_cmd = [
|
||||
'ffmpeg', '-i', self.video_path,
|
||||
'-vf', scale_filter,
|
||||
'-y',
|
||||
compressed_video
|
||||
]
|
||||
|
||||
try:
|
||||
subprocess.run(ffmpeg_cmd, check=True, capture_output=True, text=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"FFmpeg 错误输出: {e.stderr}")
|
||||
raise
|
||||
|
||||
# 2. 从压缩视频中提取关键帧
|
||||
logger.info("\n步骤2: 从压缩视频提取关键帧...")
|
||||
mini_processor = VideoProcessor(compressed_video)
|
||||
mini_processor.process_video(mini_frames_dir, skip_seconds, threshold)
|
||||
|
||||
# 3. 从原始视频提取高清关键帧
|
||||
logger.info("\n步骤3: 提取高清关键帧...")
|
||||
frame_numbers = self.extract_numbers_from_folder(mini_frames_dir)
|
||||
|
||||
if not frame_numbers:
|
||||
raise ValueError("未能从压缩视频中提取到有效的关键帧")
|
||||
|
||||
self.extract_frames_by_numbers(frame_numbers, hd_frames_dir)
|
||||
|
||||
logger.info(f"处理完成!高清关键帧保存在: {hd_frames_dir}")
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"视频处理失败: \n{traceback.format_exc()}")
|
||||
raise
|
||||
|
||||
finally:
|
||||
# 释放资源
|
||||
if mini_processor:
|
||||
mini_processor.cap.release()
|
||||
del mini_processor
|
||||
|
||||
# 确保视频文件句柄被释放
|
||||
if hasattr(self, 'cap'):
|
||||
self.cap.release()
|
||||
|
||||
# 等待资源释放
|
||||
import time
|
||||
time.sleep(0.5)
|
||||
|
||||
if not keep_temp:
|
||||
try:
|
||||
# 先删除压缩视频文件
|
||||
if compressed_video and os.path.exists(compressed_video):
|
||||
try:
|
||||
os.remove(compressed_video)
|
||||
except Exception as e:
|
||||
logger.warning(f"删除压缩视频失败: {e}")
|
||||
|
||||
# 再删除临时目录
|
||||
import shutil
|
||||
if os.path.exists(temp_dir):
|
||||
max_retries = 3
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
shutil.rmtree(temp_dir)
|
||||
break
|
||||
except Exception as e:
|
||||
if i == max_retries - 1:
|
||||
logger.warning(f"清理临时文件失败: {e}")
|
||||
else:
|
||||
time.sleep(1) # 等待1秒后重试
|
||||
continue
|
||||
|
||||
logger.info("临时文件已清理")
|
||||
except Exception as e:
|
||||
logger.warning(f"清理临时文件时出错: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
processor = VideoProcessor("E:\\projects\\NarratoAI\\resource\\videos\\test.mp4")
|
||||
processor.process_video_pipeline(output_dir="output")
|
||||
end_time = time.time()
|
||||
print(f"处理完成!总耗时: {end_time - start_time:.2f} 秒")
|
||||
@ -1,5 +1,5 @@
|
||||
[app]
|
||||
project_version="0.5.3"
|
||||
project_version="0.6.0"
|
||||
# 支持视频理解的大模型提供商
|
||||
# gemini
|
||||
# qwenvl
|
||||
|
||||
@ -1,38 +1,45 @@
|
||||
# 必须项
|
||||
requests~=2.32.0
|
||||
moviepy==2.1.1
|
||||
edge-tts==6.1.19
|
||||
streamlit~=1.45.0
|
||||
watchdog==6.0.0
|
||||
loguru~=0.7.3
|
||||
tomli~=2.2.1
|
||||
|
||||
openai~=1.77.0
|
||||
google-generativeai>=0.8.5
|
||||
|
||||
loguru~=0.7.2
|
||||
fastapi~=0.115.4
|
||||
uvicorn~=0.27.1
|
||||
pydantic~=2.11.4
|
||||
# 待优化项
|
||||
# opencv-python==4.11.0.86
|
||||
# scikit-learn==1.6.1
|
||||
|
||||
faster-whisper~=1.0.1
|
||||
tomli~=2.0.1
|
||||
aiohttp~=3.10.10
|
||||
httpx==0.27.2
|
||||
urllib3~=2.2.1
|
||||
# fastapi~=0.115.4
|
||||
# uvicorn~=0.27.1
|
||||
# pydantic~=2.11.4
|
||||
|
||||
python-multipart~=0.0.9
|
||||
redis==5.0.3
|
||||
opencv-python~=4.10.0.84
|
||||
azure-cognitiveservices-speech~=1.37.0
|
||||
git-changelog~=2.5.2
|
||||
watchdog==5.0.2
|
||||
pydub==0.25.1
|
||||
psutil>=5.9.0
|
||||
scikit-learn~=1.5.2
|
||||
pillow==10.3.0
|
||||
python-dotenv~=1.0.1
|
||||
# faster-whisper~=1.0.1
|
||||
# tomli~=2.0.1
|
||||
# aiohttp~=3.10.10
|
||||
# httpx==0.27.2
|
||||
# urllib3~=2.2.1
|
||||
|
||||
tqdm>=4.66.6
|
||||
tenacity>=9.0.0
|
||||
tiktoken==0.8.0
|
||||
pysrt==1.1.2
|
||||
transformers==4.50.0
|
||||
# python-multipart~=0.0.9
|
||||
# redis==5.0.3
|
||||
# opencv-python~=4.10.0.84
|
||||
# azure-cognitiveservices-speech~=1.37.0
|
||||
# git-changelog~=2.5.2
|
||||
# watchdog==5.0.2
|
||||
# pydub==0.25.1
|
||||
# psutil>=5.9.0
|
||||
# scikit-learn~=1.5.2
|
||||
# pillow==10.3.0
|
||||
# python-dotenv~=1.0.1
|
||||
|
||||
# tqdm>=4.66.6
|
||||
# tenacity>=9.0.0
|
||||
# tiktoken==0.8.0
|
||||
# pysrt==1.1.2
|
||||
# transformers==4.50.0
|
||||
|
||||
# yt-dlp==2025.4.30
|
||||
72
webui.py
72
webui.py
@ -1,7 +1,7 @@
|
||||
import streamlit as st
|
||||
import os
|
||||
import sys
|
||||
from uuid import uuid4
|
||||
from loguru import logger
|
||||
from app.config import config
|
||||
from webui.components import basic_settings, video_settings, audio_settings, subtitle_settings, script_settings, \
|
||||
review_settings, merge_settings, system_settings
|
||||
@ -18,7 +18,7 @@ st.set_page_config(
|
||||
initial_sidebar_state="auto",
|
||||
menu_items={
|
||||
"Report a bug": "https://github.com/linyqh/NarratoAI/issues",
|
||||
'About': f"# NarratoAI:sunglasses: 📽️ \n #### Version: v{config.project_version} \n "
|
||||
'About': f"# Narrato:blue[AI] :sunglasses: 📽️ \n #### Version: v{config.project_version} \n "
|
||||
f"自动化影视解说视频详情请移步:https://github.com/linyqh/NarratoAI"
|
||||
},
|
||||
)
|
||||
@ -37,17 +37,7 @@ def init_log():
|
||||
_lvl = "DEBUG"
|
||||
|
||||
def format_record(record):
|
||||
# 增加更多需要过滤的警告消息
|
||||
ignore_messages = [
|
||||
"Examining the path of torch.classes raised",
|
||||
"torch.cuda.is_available()",
|
||||
"CUDA initialization"
|
||||
]
|
||||
|
||||
for msg in ignore_messages:
|
||||
if msg in record["message"]:
|
||||
return ""
|
||||
|
||||
# 简化日志格式化处理,不尝试按特定字符串过滤torch相关内容
|
||||
file_path = record["file"].path
|
||||
relative_path = os.path.relpath(file_path, config.root_dir)
|
||||
record["file"].path = f"./{relative_path}"
|
||||
@ -59,8 +49,25 @@ def init_log():
|
||||
'- <level>{message}</>' + "\n"
|
||||
return _format
|
||||
|
||||
# 优化日志过滤器
|
||||
def log_filter(record):
|
||||
# 替换为更简单的过滤方式,避免在过滤时访问message内容
|
||||
# 此处先不设置复杂的过滤器,等应用启动后再动态添加
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
level=_lvl,
|
||||
format=format_record,
|
||||
colorize=True
|
||||
)
|
||||
|
||||
# 应用启动后,可以再添加更复杂的过滤器
|
||||
def setup_advanced_filters():
|
||||
"""在应用完全启动后设置高级过滤器"""
|
||||
try:
|
||||
for handler_id in logger._core.handlers:
|
||||
logger.remove(handler_id)
|
||||
|
||||
# 重新添加带有高级过滤的处理器
|
||||
def advanced_filter(record):
|
||||
"""更复杂的过滤器,在应用启动后安全使用"""
|
||||
ignore_messages = [
|
||||
"Examining the path of torch.classes raised",
|
||||
"torch.cuda.is_available()",
|
||||
@ -73,8 +80,21 @@ def init_log():
|
||||
level=_lvl,
|
||||
format=format_record,
|
||||
colorize=True,
|
||||
filter=log_filter
|
||||
filter=advanced_filter
|
||||
)
|
||||
except Exception as e:
|
||||
# 如果过滤器设置失败,确保日志仍然可用
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
level=_lvl,
|
||||
format=format_record,
|
||||
colorize=True
|
||||
)
|
||||
logger.error(f"设置高级日志过滤器失败: {e}")
|
||||
|
||||
# 将高级过滤器设置放到启动主逻辑后
|
||||
import threading
|
||||
threading.Timer(5.0, setup_advanced_filters).start()
|
||||
|
||||
|
||||
def init_global_state():
|
||||
@ -177,11 +197,18 @@ def main():
|
||||
"""主函数"""
|
||||
init_log()
|
||||
init_global_state()
|
||||
utils.init_resources()
|
||||
|
||||
st.title(f"NarratoAI :sunglasses:📽️")
|
||||
# 仅初始化基本资源,避免过早地加载依赖PyTorch的资源
|
||||
# 检查是否能分解utils.init_resources()为基本资源和高级资源(如依赖PyTorch的资源)
|
||||
try:
|
||||
utils.init_resources()
|
||||
except Exception as e:
|
||||
logger.warning(f"资源初始化时出现警告: {e}")
|
||||
|
||||
st.title(f"Narrato:blue[AI]:sunglasses: 📽️")
|
||||
st.write(tr("Get Help"))
|
||||
|
||||
# 首先渲染不依赖PyTorch的UI部分
|
||||
# 渲染基础设置面板
|
||||
basic_settings.render_basic_settings(tr)
|
||||
# 渲染合并设置
|
||||
@ -196,13 +223,16 @@ def main():
|
||||
audio_settings.render_audio_panel(tr)
|
||||
with panel[2]:
|
||||
subtitle_settings.render_subtitle_panel(tr)
|
||||
# 渲染系统设置面板
|
||||
system_settings.render_system_panel(tr)
|
||||
|
||||
# 渲染视频审查面板
|
||||
review_settings.render_review_panel(tr)
|
||||
|
||||
# 渲染生成按钮和处理逻辑
|
||||
# 放到最后渲染可能使用PyTorch的部分
|
||||
# 渲染系统设置面板
|
||||
with panel[2]:
|
||||
system_settings.render_system_panel(tr)
|
||||
|
||||
# 放到最后渲染生成按钮和处理逻辑
|
||||
render_generate_button()
|
||||
|
||||
|
||||
|
||||
@ -285,8 +285,8 @@ def render_merge_settings(tr):
|
||||
error_message = str(e)
|
||||
if "moviepy" in error_message.lower():
|
||||
st.error(tr("Error processing video files. Please check if the videos are valid MP4 files."))
|
||||
elif "pysrt" in error_message.lower():
|
||||
st.error(tr("Error processing subtitle files. Please check if the subtitles are valid SRT files."))
|
||||
# elif "pysrt" in error_message.lower():
|
||||
# st.error(tr("Error processing subtitle files. Please check if the subtitles are valid SRT files."))
|
||||
else:
|
||||
st.error(f"{tr('Error during merge')}: {error_message}")
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ import time
|
||||
import asyncio
|
||||
import traceback
|
||||
import requests
|
||||
from app.utils import video_processor
|
||||
import streamlit as st
|
||||
from loguru import logger
|
||||
from requests.adapters import HTTPAdapter
|
||||
@ -12,7 +13,7 @@ from urllib3.util.retry import Retry
|
||||
|
||||
from app.config import config
|
||||
from app.utils.script_generator import ScriptProcessor
|
||||
from app.utils import utils, video_processor, video_processor_v2, qwenvl_analyzer
|
||||
from app.utils import utils, video_processor, qwenvl_analyzer
|
||||
from webui.tools.base import create_vision_analyzer, get_batch_files, get_batch_timestamps, chekc_video_config
|
||||
|
||||
|
||||
@ -64,21 +65,13 @@ def generate_script_docu(tr, params):
|
||||
os.makedirs(video_keyframes_dir, exist_ok=True)
|
||||
|
||||
# 初始化视频处理器
|
||||
if config.frames.get("version") == "v2":
|
||||
processor = video_processor_v2.VideoProcessor(params.video_origin_path)
|
||||
processor = video_processor.VideoProcessor(params.video_origin_path)
|
||||
# 处理视频并提取关键帧
|
||||
processor.process_video_pipeline(
|
||||
output_dir=video_keyframes_dir,
|
||||
skip_seconds=st.session_state.get('skip_seconds'),
|
||||
threshold=st.session_state.get('threshold')
|
||||
)
|
||||
else:
|
||||
processor = video_processor.VideoProcessor(params.video_origin_path)
|
||||
# 处理视频并提取关键帧
|
||||
processor.process_video(
|
||||
output_dir=video_keyframes_dir,
|
||||
skip_seconds=0
|
||||
)
|
||||
|
||||
# 获取所有关键文件路径
|
||||
for filename in sorted(os.listdir(video_keyframes_dir)):
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
合并视频和字幕文件
|
||||
"""
|
||||
from moviepy import VideoFileClip, concatenate_videoclips
|
||||
import pysrt
|
||||
# import pysrt
|
||||
import os
|
||||
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import psutil
|
||||
# import psutil
|
||||
import os
|
||||
from loguru import logger
|
||||
import torch
|
||||
|
||||
class PerformanceMonitor:
|
||||
@staticmethod
|
||||
@ -11,19 +10,35 @@ class PerformanceMonitor:
|
||||
|
||||
logger.debug(f"Memory usage: {memory_info.rss / 1024 / 1024:.2f} MB")
|
||||
|
||||
# 延迟导入torch并检查CUDA
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
gpu_memory = torch.cuda.memory_allocated() / 1024 / 1024
|
||||
logger.debug(f"GPU Memory usage: {gpu_memory:.2f} MB")
|
||||
except (ImportError, RuntimeError) as e:
|
||||
# 无法导入torch或触发CUDA相关错误时,静默处理
|
||||
logger.debug(f"无法获取GPU内存信息: {e}")
|
||||
|
||||
@staticmethod
|
||||
def cleanup_resources():
|
||||
# 延迟导入torch并清理CUDA
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
logger.debug("CUDA缓存已清理")
|
||||
except (ImportError, RuntimeError) as e:
|
||||
# 无法导入torch或触发CUDA相关错误时,静默处理
|
||||
logger.debug(f"无法清理CUDA资源: {e}")
|
||||
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
PerformanceMonitor.monitor_memory()
|
||||
# 仅报告进程内存,不尝试获取GPU内存
|
||||
process = psutil.Process(os.getpid())
|
||||
memory_info = process.memory_info()
|
||||
logger.debug(f"Memory usage after cleanup: {memory_info.rss / 1024 / 1024:.2f} MB")
|
||||
|
||||
def monitor_performance(func):
|
||||
"""性能监控装饰器"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user