mirror of
https://github.com/linyqh/NarratoAI.git
synced 2025-12-11 18:42:49 +00:00
feat(video): 实现关键帧提取 V2 版本
- 新增 VideoProcessor 类,实现视频预处理、场景边界检测、关键帧提取等功能 - 在 config.example.toml 中添加 frames 配置项,用于控制关键帧提取参数- 修改 script_settings.py,支持使用新的 VideoProcessor 进行关键帧提取 - 优化关键帧提取流程,提高处理效率和准确性
This commit is contained in:
parent
d1cbaaf040
commit
cc44aab181
@ -45,6 +45,7 @@ whisper = _cfg.get("whisper", {})
|
||||
proxy = _cfg.get("proxy", {})
|
||||
azure = _cfg.get("azure", {})
|
||||
ui = _cfg.get("ui", {})
|
||||
frames = _cfg.get("frames", {})
|
||||
|
||||
hostname = socket.gethostname()
|
||||
|
||||
|
||||
294
app/utils/video_processor_v2.py
Normal file
294
app/utils/video_processor_v2.py
Normal file
@ -0,0 +1,294 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from sklearn.cluster import KMeans
|
||||
import os
|
||||
import re
|
||||
from typing import List, Tuple, Generator
|
||||
from loguru import logger
|
||||
import subprocess
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class VideoProcessor:
|
||||
def __init__(self, video_path: str):
|
||||
"""
|
||||
初始化视频处理器
|
||||
|
||||
Args:
|
||||
video_path: 视频文件路径
|
||||
"""
|
||||
if not os.path.exists(video_path):
|
||||
raise FileNotFoundError(f"视频文件不存在: {video_path}")
|
||||
|
||||
self.video_path = video_path
|
||||
self.cap = cv2.VideoCapture(video_path)
|
||||
|
||||
if not self.cap.isOpened():
|
||||
raise RuntimeError(f"无法打开视频文件: {video_path}")
|
||||
|
||||
self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
self.fps = int(self.cap.get(cv2.CAP_PROP_FPS))
|
||||
|
||||
def __del__(self):
|
||||
"""析构函数,确保视频资源被释放"""
|
||||
if hasattr(self, 'cap'):
|
||||
self.cap.release()
|
||||
|
||||
def preprocess_video(self) -> Generator[np.ndarray, None, None]:
|
||||
"""
|
||||
使用生成器方式读取视频帧
|
||||
|
||||
Yields:
|
||||
np.ndarray: 视频帧
|
||||
"""
|
||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # 重置到视频开始
|
||||
while self.cap.isOpened():
|
||||
ret, frame = self.cap.read()
|
||||
if not ret:
|
||||
break
|
||||
yield frame
|
||||
|
||||
def detect_shot_boundaries(self, frames: List[np.ndarray], threshold: int = 30) -> List[int]:
|
||||
"""
|
||||
使用帧差法检测镜头边界
|
||||
|
||||
Args:
|
||||
frames: 视频帧列表
|
||||
threshold: 差异阈值
|
||||
|
||||
Returns:
|
||||
List[int]: 镜头边界帧的索引列表
|
||||
"""
|
||||
shot_boundaries = []
|
||||
for i in range(1, len(frames)):
|
||||
prev_frame = cv2.cvtColor(frames[i - 1], cv2.COLOR_BGR2GRAY)
|
||||
curr_frame = cv2.cvtColor(frames[i], cv2.COLOR_BGR2GRAY)
|
||||
diff = np.mean(np.abs(curr_frame.astype(int) - prev_frame.astype(int)))
|
||||
if diff > threshold:
|
||||
shot_boundaries.append(i)
|
||||
return shot_boundaries
|
||||
|
||||
def extract_keyframes(self, frames: List[np.ndarray], shot_boundaries: List[int]) -> Tuple[
|
||||
List[np.ndarray], List[int]]:
|
||||
"""
|
||||
从每个镜头中提取关键帧
|
||||
|
||||
Args:
|
||||
frames: 视频帧列表
|
||||
shot_boundaries: 镜头边界列表
|
||||
|
||||
Returns:
|
||||
Tuple[List[np.ndarray], List[int]]: 关键帧列表和对应的帧索引
|
||||
"""
|
||||
keyframes = []
|
||||
keyframe_indices = []
|
||||
|
||||
for i in range(len(shot_boundaries)):
|
||||
start = shot_boundaries[i - 1] if i > 0 else 0
|
||||
end = shot_boundaries[i]
|
||||
shot_frames = frames[start:end]
|
||||
|
||||
frame_features = np.array([cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).flatten()
|
||||
for frame in shot_frames])
|
||||
kmeans = KMeans(n_clusters=1, random_state=0).fit(frame_features)
|
||||
center_idx = np.argmin(np.sum((frame_features - kmeans.cluster_centers_[0]) ** 2, axis=1))
|
||||
|
||||
keyframes.append(shot_frames[center_idx])
|
||||
keyframe_indices.append(start + center_idx)
|
||||
|
||||
return keyframes, keyframe_indices
|
||||
|
||||
def save_keyframes(self, keyframes: List[np.ndarray], keyframe_indices: List[int],
|
||||
output_dir: str, desc: str = "保存关键帧") -> None:
|
||||
"""
|
||||
保存关键帧到指定目录,文件名格式为:keyframe_帧序号_时间戳.jpg
|
||||
|
||||
Args:
|
||||
keyframes: 关键帧列表
|
||||
keyframe_indices: 关键帧索引列表
|
||||
output_dir: 输出目录
|
||||
desc: 进度条描述
|
||||
"""
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
for keyframe, frame_idx in tqdm(zip(keyframes, keyframe_indices),
|
||||
total=len(keyframes),
|
||||
desc=desc):
|
||||
timestamp = frame_idx / self.fps
|
||||
hours = int(timestamp // 3600)
|
||||
minutes = int((timestamp % 3600) // 60)
|
||||
seconds = int(timestamp % 60)
|
||||
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}"
|
||||
|
||||
output_path = os.path.join(output_dir,
|
||||
f'keyframe_{frame_idx:06d}_{time_str}.jpg')
|
||||
cv2.imwrite(output_path, keyframe)
|
||||
|
||||
def extract_frames_by_numbers(self, frame_numbers: List[int], output_folder: str) -> None:
|
||||
"""
|
||||
根据指定的帧号提取帧,如果多个帧在同一秒内,只保留一个
|
||||
|
||||
Args:
|
||||
frame_numbers: 要提取的帧号列表
|
||||
output_folder: 输出文件夹路径
|
||||
"""
|
||||
if not frame_numbers:
|
||||
raise ValueError("未提供帧号列表")
|
||||
|
||||
if any(fn >= self.total_frames or fn < 0 for fn in frame_numbers):
|
||||
raise ValueError("存在无效的帧号")
|
||||
|
||||
if not os.path.exists(output_folder):
|
||||
os.makedirs(output_folder)
|
||||
|
||||
# 用于记录已处理的时间戳(秒)
|
||||
processed_seconds = set()
|
||||
|
||||
for frame_number in tqdm(frame_numbers, desc="提取高清帧"):
|
||||
# 计算时间戳(秒)
|
||||
timestamp_seconds = int(frame_number / self.fps)
|
||||
|
||||
# 如果这一秒已经处理过,跳过
|
||||
if timestamp_seconds in processed_seconds:
|
||||
continue
|
||||
|
||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
|
||||
ret, frame = self.cap.read()
|
||||
|
||||
if ret:
|
||||
# 记录这一秒已经处理
|
||||
processed_seconds.add(timestamp_seconds)
|
||||
|
||||
# 计算时间戳字符串
|
||||
hours = int(timestamp_seconds // 3600)
|
||||
minutes = int((timestamp_seconds % 3600) // 60)
|
||||
seconds = int(timestamp_seconds % 60)
|
||||
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}"
|
||||
|
||||
output_path = os.path.join(output_folder,
|
||||
f"keyframe_{frame_number:06d}_{time_str}.jpg")
|
||||
cv2.imwrite(output_path, frame)
|
||||
else:
|
||||
logger.info(f"无法读取帧 {frame_number}")
|
||||
|
||||
logger.info(f"共提取了 {len(processed_seconds)} 个不同时间戳的帧")
|
||||
|
||||
@staticmethod
|
||||
def extract_numbers_from_folder(folder_path: str) -> List[int]:
|
||||
"""
|
||||
从文件夹中提取帧号
|
||||
|
||||
Args:
|
||||
folder_path: 关键帧文件夹路径
|
||||
|
||||
Returns:
|
||||
List[int]: 排序后的帧号列表
|
||||
"""
|
||||
files = [f for f in os.listdir(folder_path) if f.endswith('.jpg')]
|
||||
# 更新正则表达式以匹配新的文件名格式:keyframe_000123_010534.jpg
|
||||
pattern = re.compile(r'keyframe_(\d+)_\d+\.jpg$')
|
||||
numbers = []
|
||||
for f in files:
|
||||
match = pattern.search(f)
|
||||
if match:
|
||||
numbers.append(int(match.group(1)))
|
||||
return sorted(numbers)
|
||||
|
||||
def process_video(self, output_dir: str, skip_seconds: float = 0, threshold: int = 30) -> None:
|
||||
"""
|
||||
处理视频并提取关键帧
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录
|
||||
skip_seconds: 跳过视频开头的秒数
|
||||
"""
|
||||
skip_frames = int(skip_seconds * self.fps)
|
||||
|
||||
logger.info("读取视频帧...")
|
||||
frames = []
|
||||
for frame in tqdm(self.preprocess_video(),
|
||||
total=self.total_frames,
|
||||
desc="读取视频"):
|
||||
frames.append(frame)
|
||||
|
||||
frames = frames[skip_frames:]
|
||||
|
||||
if not frames:
|
||||
raise ValueError(f"跳过 {skip_seconds} 秒后没有剩余帧可以处理")
|
||||
|
||||
logger.info("\n检测场景边界...")
|
||||
shot_boundaries = self.detect_shot_boundaries(frames, threshold)
|
||||
logger.info(f"检测到 {len(shot_boundaries)} 个场景边界")
|
||||
|
||||
logger.info("\n提取关键帧...")
|
||||
keyframes, keyframe_indices = self.extract_keyframes(frames, shot_boundaries)
|
||||
|
||||
adjusted_indices = [idx + skip_frames for idx in keyframe_indices]
|
||||
self.save_keyframes(keyframes, adjusted_indices, output_dir, desc="保存压缩关键帧")
|
||||
|
||||
def process_video_pipeline(self,
|
||||
output_dir: str,
|
||||
skip_seconds: float = 0,
|
||||
threshold: int = 30,
|
||||
compressed_width: int = 320,
|
||||
keep_temp: bool = False) -> None:
|
||||
"""
|
||||
执行完整的视频处理流程:压缩、提取关键帧、导出高清帧
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
temp_dir = os.path.join(output_dir, 'temp')
|
||||
compressed_dir = os.path.join(temp_dir, 'compressed')
|
||||
mini_frames_dir = os.path.join(temp_dir, 'mini_frames')
|
||||
hd_frames_dir = output_dir
|
||||
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
os.makedirs(compressed_dir, exist_ok=True)
|
||||
os.makedirs(mini_frames_dir, exist_ok=True)
|
||||
os.makedirs(hd_frames_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
# 1. 压缩视频
|
||||
video_name = os.path.splitext(os.path.basename(self.video_path))[0]
|
||||
compressed_video = os.path.join(compressed_dir, f"{video_name}_compressed.mp4")
|
||||
|
||||
logger.info("步骤1: 压缩视频...")
|
||||
ffmpeg_cmd = [
|
||||
'ffmpeg', '-i', self.video_path,
|
||||
'-vf', f'scale={compressed_width}:-1',
|
||||
'-y',
|
||||
compressed_video
|
||||
]
|
||||
subprocess.run(ffmpeg_cmd, check=True)
|
||||
|
||||
# 2. 从压缩视频中提取关键帧
|
||||
logger.info("\n步骤2: 从压缩视频提取关键帧...")
|
||||
mini_processor = VideoProcessor(compressed_video)
|
||||
mini_processor.process_video(mini_frames_dir, skip_seconds, threshold)
|
||||
|
||||
# 3. 从原始视频提取高清关键帧
|
||||
logger.info("\n步骤3: 提取高清关键帧...")
|
||||
frame_numbers = mini_processor.extract_numbers_from_folder(mini_frames_dir)
|
||||
self.extract_frames_by_numbers(frame_numbers, hd_frames_dir)
|
||||
|
||||
logger.info(f"\n处理完成!")
|
||||
logger.info(f"高清关键帧保存在: {hd_frames_dir}")
|
||||
|
||||
finally:
|
||||
if not keep_temp:
|
||||
import shutil
|
||||
try:
|
||||
shutil.rmtree(temp_dir)
|
||||
logger.info("临时文件已清理")
|
||||
except Exception as e:
|
||||
logger.info(f"清理临时文件时出错: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
processor = VideoProcessor("best.mp4")
|
||||
processor.process_video_pipeline(output_dir="output4")
|
||||
end_time = time.time()
|
||||
print(f"处理完成!总耗时: {end_time - start_time:.2f} 秒")
|
||||
@ -5,7 +5,6 @@
|
||||
# NarratoAPI
|
||||
# qwen2-vl (待增加)
|
||||
vision_llm_provider="gemini"
|
||||
vision_batch_size = 7
|
||||
vision_analysis_prompt = "你是资深视频内容分析专家,擅长分析视频画面信息,分析下面视频画面内容,只输出客观的画面描述不要给任何总结或评价"
|
||||
|
||||
########## Vision Gemini API Key
|
||||
@ -170,4 +169,15 @@
|
||||
# Azure Speech API Key
|
||||
# Get your API key at https://portal.azure.com/#view/Microsoft_Azure_ProjectOxford/CognitiveServicesHub/~/SpeechServices
|
||||
speech_key=""
|
||||
speech_region=""
|
||||
speech_region=""
|
||||
|
||||
[frames]
|
||||
skip_seconds = 0
|
||||
# threshold(差异阈值)用于判断两个连续帧之间是否发生了场景切换
|
||||
# 较小的阈值(如 20):更敏感,能捕捉到细微的场景变化,但可能会误判,关键帧图片更多
|
||||
# 较大的阈值(如 40):更保守,只捕捉明显的场景切换,但可能会漏掉渐变场景,关键帧图片更少
|
||||
# 默认值 30:在实践中是一个比较平衡的选择
|
||||
threshold = 30
|
||||
version = "v2"
|
||||
# 大模型单次处理的关键帧数量
|
||||
vision_batch_size = 5
|
||||
|
||||
@ -14,7 +14,7 @@ from loguru import logger
|
||||
from app.config import config
|
||||
from app.models.schema import VideoClipParams
|
||||
from app.utils.script_generator import ScriptProcessor
|
||||
from app.utils import utils, check_script, vision_analyzer, video_processor
|
||||
from app.utils import utils, check_script, vision_analyzer, video_processor, video_processor_v2
|
||||
from webui.utils import file_utils
|
||||
|
||||
|
||||
@ -318,13 +318,21 @@ def generate_script(tr, params):
|
||||
os.makedirs(video_keyframes_dir, exist_ok=True)
|
||||
|
||||
# 初始化视频处理器
|
||||
processor = video_processor.VideoProcessor(params.video_origin_path)
|
||||
|
||||
# 处理视频并提取关键帧
|
||||
processor.process_video(
|
||||
output_dir=video_keyframes_dir,
|
||||
skip_seconds=0
|
||||
)
|
||||
if config.frames.get("version") == "v2":
|
||||
processor = video_processor_v2.VideoProcessor(params.video_origin_path)
|
||||
# 处理视频并提取关键帧
|
||||
processor.process_video_pipeline(
|
||||
output_dir=video_keyframes_dir,
|
||||
skip_seconds=config.frames.get("skip_seconds", 0),
|
||||
threshold=config.frames.get("threshold", 30)
|
||||
)
|
||||
else:
|
||||
processor = video_processor.VideoProcessor(params.video_origin_path)
|
||||
# 处理视频并提取关键帧
|
||||
processor.process_video(
|
||||
output_dir=video_keyframes_dir,
|
||||
skip_seconds=0
|
||||
)
|
||||
|
||||
# 获取所有关键帧文件路径
|
||||
for filename in sorted(os.listdir(video_keyframes_dir)):
|
||||
@ -380,7 +388,7 @@ def generate_script(tr, params):
|
||||
analyzer.analyze_images(
|
||||
images=keyframe_files,
|
||||
prompt=config.app.get('vision_analysis_prompt'),
|
||||
batch_size=config.app.get("vision_batch_size", 5)
|
||||
batch_size=config.frames.get("vision_batch_size", 5)
|
||||
)
|
||||
)
|
||||
loop.close()
|
||||
@ -397,7 +405,7 @@ def generate_script(tr, params):
|
||||
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
|
||||
continue
|
||||
|
||||
batch_files = get_batch_files(keyframe_files, result, config.app.get("vision_batch_size", 5))
|
||||
batch_files = get_batch_files(keyframe_files, result, config.frames.get("vision_batch_size", 5))
|
||||
logger.debug(f"批次 {result['batch_index']} 处理完成,共 {len(batch_files)} 张图片")
|
||||
logger.debug(batch_files)
|
||||
|
||||
@ -436,7 +444,7 @@ def generate_script(tr, params):
|
||||
if 'error' in result:
|
||||
continue
|
||||
|
||||
batch_files = get_batch_files(keyframe_files, result, config.app.get("vision_batch_size", 5))
|
||||
batch_files = get_batch_files(keyframe_files, result, config.frames.get("vision_batch_size", 5))
|
||||
_, _, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files)
|
||||
|
||||
frame_content = {
|
||||
@ -612,14 +620,14 @@ def generate_script(tr, params):
|
||||
if script is None:
|
||||
st.error("生成脚本失败,请检查日志")
|
||||
st.stop()
|
||||
logger.info(f"脚本生成完成\n{script} \n{type(script)}")
|
||||
logger.info(f"脚本生成完成")
|
||||
if isinstance(script, list):
|
||||
st.session_state['video_clip_json'] = script
|
||||
elif isinstance(script, str):
|
||||
st.session_state['video_clip_json'] = json.loads(script)
|
||||
update_progress(90, "脚本生成完成")
|
||||
update_progress(80, "脚本生成完成")
|
||||
|
||||
time.sleep(0.5)
|
||||
time.sleep(0.1)
|
||||
progress_bar.progress(100)
|
||||
status_text.text("脚本生成完成!")
|
||||
st.success("视频脚本生成成功!")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user