mirror of
https://github.com/linyqh/NarratoAI.git
synced 2025-12-12 11:22:51 +00:00
feat(video): 实现关键帧提取 V2 版本
- 新增 VideoProcessor 类,实现视频预处理、场景边界检测、关键帧提取等功能 - 在 config.example.toml 中添加 frames 配置项,用于控制关键帧提取参数- 修改 script_settings.py,支持使用新的 VideoProcessor 进行关键帧提取 - 优化关键帧提取流程,提高处理效率和准确性
This commit is contained in:
parent
d1cbaaf040
commit
cc44aab181
@ -45,6 +45,7 @@ whisper = _cfg.get("whisper", {})
|
|||||||
proxy = _cfg.get("proxy", {})
|
proxy = _cfg.get("proxy", {})
|
||||||
azure = _cfg.get("azure", {})
|
azure = _cfg.get("azure", {})
|
||||||
ui = _cfg.get("ui", {})
|
ui = _cfg.get("ui", {})
|
||||||
|
frames = _cfg.get("frames", {})
|
||||||
|
|
||||||
hostname = socket.gethostname()
|
hostname = socket.gethostname()
|
||||||
|
|
||||||
|
|||||||
294
app/utils/video_processor_v2.py
Normal file
294
app/utils/video_processor_v2.py
Normal file
@ -0,0 +1,294 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.cluster import KMeans
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import List, Tuple, Generator
|
||||||
|
from loguru import logger
|
||||||
|
import subprocess
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class VideoProcessor:
|
||||||
|
def __init__(self, video_path: str):
|
||||||
|
"""
|
||||||
|
初始化视频处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: 视频文件路径
|
||||||
|
"""
|
||||||
|
if not os.path.exists(video_path):
|
||||||
|
raise FileNotFoundError(f"视频文件不存在: {video_path}")
|
||||||
|
|
||||||
|
self.video_path = video_path
|
||||||
|
self.cap = cv2.VideoCapture(video_path)
|
||||||
|
|
||||||
|
if not self.cap.isOpened():
|
||||||
|
raise RuntimeError(f"无法打开视频文件: {video_path}")
|
||||||
|
|
||||||
|
self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
self.fps = int(self.cap.get(cv2.CAP_PROP_FPS))
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""析构函数,确保视频资源被释放"""
|
||||||
|
if hasattr(self, 'cap'):
|
||||||
|
self.cap.release()
|
||||||
|
|
||||||
|
def preprocess_video(self) -> Generator[np.ndarray, None, None]:
|
||||||
|
"""
|
||||||
|
使用生成器方式读取视频帧
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
np.ndarray: 视频帧
|
||||||
|
"""
|
||||||
|
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # 重置到视频开始
|
||||||
|
while self.cap.isOpened():
|
||||||
|
ret, frame = self.cap.read()
|
||||||
|
if not ret:
|
||||||
|
break
|
||||||
|
yield frame
|
||||||
|
|
||||||
|
def detect_shot_boundaries(self, frames: List[np.ndarray], threshold: int = 30) -> List[int]:
|
||||||
|
"""
|
||||||
|
使用帧差法检测镜头边界
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frames: 视频帧列表
|
||||||
|
threshold: 差异阈值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[int]: 镜头边界帧的索引列表
|
||||||
|
"""
|
||||||
|
shot_boundaries = []
|
||||||
|
for i in range(1, len(frames)):
|
||||||
|
prev_frame = cv2.cvtColor(frames[i - 1], cv2.COLOR_BGR2GRAY)
|
||||||
|
curr_frame = cv2.cvtColor(frames[i], cv2.COLOR_BGR2GRAY)
|
||||||
|
diff = np.mean(np.abs(curr_frame.astype(int) - prev_frame.astype(int)))
|
||||||
|
if diff > threshold:
|
||||||
|
shot_boundaries.append(i)
|
||||||
|
return shot_boundaries
|
||||||
|
|
||||||
|
def extract_keyframes(self, frames: List[np.ndarray], shot_boundaries: List[int]) -> Tuple[
|
||||||
|
List[np.ndarray], List[int]]:
|
||||||
|
"""
|
||||||
|
从每个镜头中提取关键帧
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frames: 视频帧列表
|
||||||
|
shot_boundaries: 镜头边界列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[List[np.ndarray], List[int]]: 关键帧列表和对应的帧索引
|
||||||
|
"""
|
||||||
|
keyframes = []
|
||||||
|
keyframe_indices = []
|
||||||
|
|
||||||
|
for i in range(len(shot_boundaries)):
|
||||||
|
start = shot_boundaries[i - 1] if i > 0 else 0
|
||||||
|
end = shot_boundaries[i]
|
||||||
|
shot_frames = frames[start:end]
|
||||||
|
|
||||||
|
frame_features = np.array([cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).flatten()
|
||||||
|
for frame in shot_frames])
|
||||||
|
kmeans = KMeans(n_clusters=1, random_state=0).fit(frame_features)
|
||||||
|
center_idx = np.argmin(np.sum((frame_features - kmeans.cluster_centers_[0]) ** 2, axis=1))
|
||||||
|
|
||||||
|
keyframes.append(shot_frames[center_idx])
|
||||||
|
keyframe_indices.append(start + center_idx)
|
||||||
|
|
||||||
|
return keyframes, keyframe_indices
|
||||||
|
|
||||||
|
def save_keyframes(self, keyframes: List[np.ndarray], keyframe_indices: List[int],
|
||||||
|
output_dir: str, desc: str = "保存关键帧") -> None:
|
||||||
|
"""
|
||||||
|
保存关键帧到指定目录,文件名格式为:keyframe_帧序号_时间戳.jpg
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keyframes: 关键帧列表
|
||||||
|
keyframe_indices: 关键帧索引列表
|
||||||
|
output_dir: 输出目录
|
||||||
|
desc: 进度条描述
|
||||||
|
"""
|
||||||
|
if not os.path.exists(output_dir):
|
||||||
|
os.makedirs(output_dir)
|
||||||
|
|
||||||
|
for keyframe, frame_idx in tqdm(zip(keyframes, keyframe_indices),
|
||||||
|
total=len(keyframes),
|
||||||
|
desc=desc):
|
||||||
|
timestamp = frame_idx / self.fps
|
||||||
|
hours = int(timestamp // 3600)
|
||||||
|
minutes = int((timestamp % 3600) // 60)
|
||||||
|
seconds = int(timestamp % 60)
|
||||||
|
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}"
|
||||||
|
|
||||||
|
output_path = os.path.join(output_dir,
|
||||||
|
f'keyframe_{frame_idx:06d}_{time_str}.jpg')
|
||||||
|
cv2.imwrite(output_path, keyframe)
|
||||||
|
|
||||||
|
def extract_frames_by_numbers(self, frame_numbers: List[int], output_folder: str) -> None:
|
||||||
|
"""
|
||||||
|
根据指定的帧号提取帧,如果多个帧在同一秒内,只保留一个
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame_numbers: 要提取的帧号列表
|
||||||
|
output_folder: 输出文件夹路径
|
||||||
|
"""
|
||||||
|
if not frame_numbers:
|
||||||
|
raise ValueError("未提供帧号列表")
|
||||||
|
|
||||||
|
if any(fn >= self.total_frames or fn < 0 for fn in frame_numbers):
|
||||||
|
raise ValueError("存在无效的帧号")
|
||||||
|
|
||||||
|
if not os.path.exists(output_folder):
|
||||||
|
os.makedirs(output_folder)
|
||||||
|
|
||||||
|
# 用于记录已处理的时间戳(秒)
|
||||||
|
processed_seconds = set()
|
||||||
|
|
||||||
|
for frame_number in tqdm(frame_numbers, desc="提取高清帧"):
|
||||||
|
# 计算时间戳(秒)
|
||||||
|
timestamp_seconds = int(frame_number / self.fps)
|
||||||
|
|
||||||
|
# 如果这一秒已经处理过,跳过
|
||||||
|
if timestamp_seconds in processed_seconds:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
|
||||||
|
ret, frame = self.cap.read()
|
||||||
|
|
||||||
|
if ret:
|
||||||
|
# 记录这一秒已经处理
|
||||||
|
processed_seconds.add(timestamp_seconds)
|
||||||
|
|
||||||
|
# 计算时间戳字符串
|
||||||
|
hours = int(timestamp_seconds // 3600)
|
||||||
|
minutes = int((timestamp_seconds % 3600) // 60)
|
||||||
|
seconds = int(timestamp_seconds % 60)
|
||||||
|
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}"
|
||||||
|
|
||||||
|
output_path = os.path.join(output_folder,
|
||||||
|
f"keyframe_{frame_number:06d}_{time_str}.jpg")
|
||||||
|
cv2.imwrite(output_path, frame)
|
||||||
|
else:
|
||||||
|
logger.info(f"无法读取帧 {frame_number}")
|
||||||
|
|
||||||
|
logger.info(f"共提取了 {len(processed_seconds)} 个不同时间戳的帧")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_numbers_from_folder(folder_path: str) -> List[int]:
|
||||||
|
"""
|
||||||
|
从文件夹中提取帧号
|
||||||
|
|
||||||
|
Args:
|
||||||
|
folder_path: 关键帧文件夹路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[int]: 排序后的帧号列表
|
||||||
|
"""
|
||||||
|
files = [f for f in os.listdir(folder_path) if f.endswith('.jpg')]
|
||||||
|
# 更新正则表达式以匹配新的文件名格式:keyframe_000123_010534.jpg
|
||||||
|
pattern = re.compile(r'keyframe_(\d+)_\d+\.jpg$')
|
||||||
|
numbers = []
|
||||||
|
for f in files:
|
||||||
|
match = pattern.search(f)
|
||||||
|
if match:
|
||||||
|
numbers.append(int(match.group(1)))
|
||||||
|
return sorted(numbers)
|
||||||
|
|
||||||
|
def process_video(self, output_dir: str, skip_seconds: float = 0, threshold: int = 30) -> None:
|
||||||
|
"""
|
||||||
|
处理视频并提取关键帧
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_dir: 输出目录
|
||||||
|
skip_seconds: 跳过视频开头的秒数
|
||||||
|
"""
|
||||||
|
skip_frames = int(skip_seconds * self.fps)
|
||||||
|
|
||||||
|
logger.info("读取视频帧...")
|
||||||
|
frames = []
|
||||||
|
for frame in tqdm(self.preprocess_video(),
|
||||||
|
total=self.total_frames,
|
||||||
|
desc="读取视频"):
|
||||||
|
frames.append(frame)
|
||||||
|
|
||||||
|
frames = frames[skip_frames:]
|
||||||
|
|
||||||
|
if not frames:
|
||||||
|
raise ValueError(f"跳过 {skip_seconds} 秒后没有剩余帧可以处理")
|
||||||
|
|
||||||
|
logger.info("\n检测场景边界...")
|
||||||
|
shot_boundaries = self.detect_shot_boundaries(frames, threshold)
|
||||||
|
logger.info(f"检测到 {len(shot_boundaries)} 个场景边界")
|
||||||
|
|
||||||
|
logger.info("\n提取关键帧...")
|
||||||
|
keyframes, keyframe_indices = self.extract_keyframes(frames, shot_boundaries)
|
||||||
|
|
||||||
|
adjusted_indices = [idx + skip_frames for idx in keyframe_indices]
|
||||||
|
self.save_keyframes(keyframes, adjusted_indices, output_dir, desc="保存压缩关键帧")
|
||||||
|
|
||||||
|
def process_video_pipeline(self,
|
||||||
|
output_dir: str,
|
||||||
|
skip_seconds: float = 0,
|
||||||
|
threshold: int = 30,
|
||||||
|
compressed_width: int = 320,
|
||||||
|
keep_temp: bool = False) -> None:
|
||||||
|
"""
|
||||||
|
执行完整的视频处理流程:压缩、提取关键帧、导出高清帧
|
||||||
|
"""
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
temp_dir = os.path.join(output_dir, 'temp')
|
||||||
|
compressed_dir = os.path.join(temp_dir, 'compressed')
|
||||||
|
mini_frames_dir = os.path.join(temp_dir, 'mini_frames')
|
||||||
|
hd_frames_dir = output_dir
|
||||||
|
|
||||||
|
os.makedirs(temp_dir, exist_ok=True)
|
||||||
|
os.makedirs(compressed_dir, exist_ok=True)
|
||||||
|
os.makedirs(mini_frames_dir, exist_ok=True)
|
||||||
|
os.makedirs(hd_frames_dir, exist_ok=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. 压缩视频
|
||||||
|
video_name = os.path.splitext(os.path.basename(self.video_path))[0]
|
||||||
|
compressed_video = os.path.join(compressed_dir, f"{video_name}_compressed.mp4")
|
||||||
|
|
||||||
|
logger.info("步骤1: 压缩视频...")
|
||||||
|
ffmpeg_cmd = [
|
||||||
|
'ffmpeg', '-i', self.video_path,
|
||||||
|
'-vf', f'scale={compressed_width}:-1',
|
||||||
|
'-y',
|
||||||
|
compressed_video
|
||||||
|
]
|
||||||
|
subprocess.run(ffmpeg_cmd, check=True)
|
||||||
|
|
||||||
|
# 2. 从压缩视频中提取关键帧
|
||||||
|
logger.info("\n步骤2: 从压缩视频提取关键帧...")
|
||||||
|
mini_processor = VideoProcessor(compressed_video)
|
||||||
|
mini_processor.process_video(mini_frames_dir, skip_seconds, threshold)
|
||||||
|
|
||||||
|
# 3. 从原始视频提取高清关键帧
|
||||||
|
logger.info("\n步骤3: 提取高清关键帧...")
|
||||||
|
frame_numbers = mini_processor.extract_numbers_from_folder(mini_frames_dir)
|
||||||
|
self.extract_frames_by_numbers(frame_numbers, hd_frames_dir)
|
||||||
|
|
||||||
|
logger.info(f"\n处理完成!")
|
||||||
|
logger.info(f"高清关键帧保存在: {hd_frames_dir}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if not keep_temp:
|
||||||
|
import shutil
|
||||||
|
try:
|
||||||
|
shutil.rmtree(temp_dir)
|
||||||
|
logger.info("临时文件已清理")
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f"清理临时文件时出错: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
processor = VideoProcessor("best.mp4")
|
||||||
|
processor.process_video_pipeline(output_dir="output4")
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"处理完成!总耗时: {end_time - start_time:.2f} 秒")
|
||||||
@ -5,7 +5,6 @@
|
|||||||
# NarratoAPI
|
# NarratoAPI
|
||||||
# qwen2-vl (待增加)
|
# qwen2-vl (待增加)
|
||||||
vision_llm_provider="gemini"
|
vision_llm_provider="gemini"
|
||||||
vision_batch_size = 7
|
|
||||||
vision_analysis_prompt = "你是资深视频内容分析专家,擅长分析视频画面信息,分析下面视频画面内容,只输出客观的画面描述不要给任何总结或评价"
|
vision_analysis_prompt = "你是资深视频内容分析专家,擅长分析视频画面信息,分析下面视频画面内容,只输出客观的画面描述不要给任何总结或评价"
|
||||||
|
|
||||||
########## Vision Gemini API Key
|
########## Vision Gemini API Key
|
||||||
@ -170,4 +169,15 @@
|
|||||||
# Azure Speech API Key
|
# Azure Speech API Key
|
||||||
# Get your API key at https://portal.azure.com/#view/Microsoft_Azure_ProjectOxford/CognitiveServicesHub/~/SpeechServices
|
# Get your API key at https://portal.azure.com/#view/Microsoft_Azure_ProjectOxford/CognitiveServicesHub/~/SpeechServices
|
||||||
speech_key=""
|
speech_key=""
|
||||||
speech_region=""
|
speech_region=""
|
||||||
|
|
||||||
|
[frames]
|
||||||
|
skip_seconds = 0
|
||||||
|
# threshold(差异阈值)用于判断两个连续帧之间是否发生了场景切换
|
||||||
|
# 较小的阈值(如 20):更敏感,能捕捉到细微的场景变化,但可能会误判,关键帧图片更多
|
||||||
|
# 较大的阈值(如 40):更保守,只捕捉明显的场景切换,但可能会漏掉渐变场景,关键帧图片更少
|
||||||
|
# 默认值 30:在实践中是一个比较平衡的选择
|
||||||
|
threshold = 30
|
||||||
|
version = "v2"
|
||||||
|
# 大模型单次处理的关键帧数量
|
||||||
|
vision_batch_size = 5
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from loguru import logger
|
|||||||
from app.config import config
|
from app.config import config
|
||||||
from app.models.schema import VideoClipParams
|
from app.models.schema import VideoClipParams
|
||||||
from app.utils.script_generator import ScriptProcessor
|
from app.utils.script_generator import ScriptProcessor
|
||||||
from app.utils import utils, check_script, vision_analyzer, video_processor
|
from app.utils import utils, check_script, vision_analyzer, video_processor, video_processor_v2
|
||||||
from webui.utils import file_utils
|
from webui.utils import file_utils
|
||||||
|
|
||||||
|
|
||||||
@ -318,13 +318,21 @@ def generate_script(tr, params):
|
|||||||
os.makedirs(video_keyframes_dir, exist_ok=True)
|
os.makedirs(video_keyframes_dir, exist_ok=True)
|
||||||
|
|
||||||
# 初始化视频处理器
|
# 初始化视频处理器
|
||||||
processor = video_processor.VideoProcessor(params.video_origin_path)
|
if config.frames.get("version") == "v2":
|
||||||
|
processor = video_processor_v2.VideoProcessor(params.video_origin_path)
|
||||||
# 处理视频并提取关键帧
|
# 处理视频并提取关键帧
|
||||||
processor.process_video(
|
processor.process_video_pipeline(
|
||||||
output_dir=video_keyframes_dir,
|
output_dir=video_keyframes_dir,
|
||||||
skip_seconds=0
|
skip_seconds=config.frames.get("skip_seconds", 0),
|
||||||
)
|
threshold=config.frames.get("threshold", 30)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
processor = video_processor.VideoProcessor(params.video_origin_path)
|
||||||
|
# 处理视频并提取关键帧
|
||||||
|
processor.process_video(
|
||||||
|
output_dir=video_keyframes_dir,
|
||||||
|
skip_seconds=0
|
||||||
|
)
|
||||||
|
|
||||||
# 获取所有关键帧文件路径
|
# 获取所有关键帧文件路径
|
||||||
for filename in sorted(os.listdir(video_keyframes_dir)):
|
for filename in sorted(os.listdir(video_keyframes_dir)):
|
||||||
@ -380,7 +388,7 @@ def generate_script(tr, params):
|
|||||||
analyzer.analyze_images(
|
analyzer.analyze_images(
|
||||||
images=keyframe_files,
|
images=keyframe_files,
|
||||||
prompt=config.app.get('vision_analysis_prompt'),
|
prompt=config.app.get('vision_analysis_prompt'),
|
||||||
batch_size=config.app.get("vision_batch_size", 5)
|
batch_size=config.frames.get("vision_batch_size", 5)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
loop.close()
|
loop.close()
|
||||||
@ -397,7 +405,7 @@ def generate_script(tr, params):
|
|||||||
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
|
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
batch_files = get_batch_files(keyframe_files, result, config.app.get("vision_batch_size", 5))
|
batch_files = get_batch_files(keyframe_files, result, config.frames.get("vision_batch_size", 5))
|
||||||
logger.debug(f"批次 {result['batch_index']} 处理完成,共 {len(batch_files)} 张图片")
|
logger.debug(f"批次 {result['batch_index']} 处理完成,共 {len(batch_files)} 张图片")
|
||||||
logger.debug(batch_files)
|
logger.debug(batch_files)
|
||||||
|
|
||||||
@ -436,7 +444,7 @@ def generate_script(tr, params):
|
|||||||
if 'error' in result:
|
if 'error' in result:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
batch_files = get_batch_files(keyframe_files, result, config.app.get("vision_batch_size", 5))
|
batch_files = get_batch_files(keyframe_files, result, config.frames.get("vision_batch_size", 5))
|
||||||
_, _, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files)
|
_, _, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files)
|
||||||
|
|
||||||
frame_content = {
|
frame_content = {
|
||||||
@ -612,14 +620,14 @@ def generate_script(tr, params):
|
|||||||
if script is None:
|
if script is None:
|
||||||
st.error("生成脚本失败,请检查日志")
|
st.error("生成脚本失败,请检查日志")
|
||||||
st.stop()
|
st.stop()
|
||||||
logger.info(f"脚本生成完成\n{script} \n{type(script)}")
|
logger.info(f"脚本生成完成")
|
||||||
if isinstance(script, list):
|
if isinstance(script, list):
|
||||||
st.session_state['video_clip_json'] = script
|
st.session_state['video_clip_json'] = script
|
||||||
elif isinstance(script, str):
|
elif isinstance(script, str):
|
||||||
st.session_state['video_clip_json'] = json.loads(script)
|
st.session_state['video_clip_json'] = json.loads(script)
|
||||||
update_progress(90, "脚本生成完成")
|
update_progress(80, "脚本生成完成")
|
||||||
|
|
||||||
time.sleep(0.5)
|
time.sleep(0.1)
|
||||||
progress_bar.progress(100)
|
progress_bar.progress(100)
|
||||||
status_text.text("脚本生成完成!")
|
status_text.text("脚本生成完成!")
|
||||||
st.success("视频脚本生成成功!")
|
st.success("视频脚本生成成功!")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user