0.3.4 修改各种bug

This commit is contained in:
linyqh 2024-11-10 16:22:04 +08:00
parent 4bafd696a1
commit d0462ce91b
7 changed files with 224 additions and 99 deletions

View File

@ -251,16 +251,41 @@ def download_videos(
return video_paths return video_paths
def time_to_seconds(time_str: str) -> float:
"""
将时间字符串转换为秒数
支持格式
1. "MM:SS" (:)
2. "SS" (纯秒数)
"""
parts = time_str.split(':')
if len(parts) == 2:
minutes, seconds = map(float, parts)
return minutes * 60 + seconds
return float(time_str)
def format_timestamp(seconds: float) -> str:
"""
将秒数转换为 "MM:SS" 格式的时间字符串
"""
minutes = int(seconds) // 60
secs = int(seconds) % 60
return f"{minutes:02d}:{secs:02d}"
def save_clip_video(timestamp: str, origin_video: str, save_dir: str = "") -> dict: def save_clip_video(timestamp: str, origin_video: str, save_dir: str = "") -> dict:
""" """
保存剪辑后的视频 保存剪辑后的视频
Args: Args:
timestamp: 需要裁剪的单个时间戳'00:36-00:40' timestamp: 需要裁剪的单个时间戳支持两种格式
1. '00:36-00:40' (:-:)
2. 'SS-SS' (-)
origin_video: 原视频路径 origin_video: 原视频路径
save_dir: 存储目录 save_dir: 存储目录
Returns: Returns:
裁剪后的视频路径 裁剪后的视频路径格式为 {timestamp: video_path}
""" """
if not save_dir: if not save_dir:
save_dir = utils.storage_dir("cache_videos") save_dir = utils.storage_dir("cache_videos")
@ -276,35 +301,64 @@ def save_clip_video(timestamp: str, origin_video: str, save_dir: str = "") -> di
return {timestamp: video_path} return {timestamp: video_path}
try: try:
# 先加载视频获取总时长
video = VideoFileClip(origin_video)
total_duration = video.duration
# 获取目标时间段
start_str, end_str = timestamp.split('-')
start = time_to_seconds(start_str)
end = time_to_seconds(end_str)
# 验证时间段是否有效
if start >= total_duration:
logger.warning(f"起始时间 {format_timestamp(start)} ({start:.2f}秒) 超出视频总时长 {format_timestamp(total_duration)} ({total_duration:.2f}秒)")
video.close()
return {}
if end > total_duration:
logger.warning(f"结束时间 {format_timestamp(end)} ({end:.2f}秒) 超出视频总时长 {format_timestamp(total_duration)} ({total_duration:.2f}秒),将自动调整为视频结尾")
end = total_duration
if end <= start:
logger.warning(f"结束时间 {format_timestamp(end)} 必须大于起始时间 {format_timestamp(start)}")
video.close()
return {}
# 剪辑视频 # 剪辑视频
start, end = utils.split_timestamp(timestamp) duration = end - start
video = VideoFileClip(origin_video).subclip(start, end) logger.info(f"开始剪辑视频: {format_timestamp(start)} - {format_timestamp(end)},时长 {format_timestamp(duration)}")
subclip = video.subclip(start, end)
# 检查视频是否有音频轨道 try:
if video.audio is not None: # 检查视频是否有音频轨道并写入文件
video.write_videofile(video_path, logger=None) # 有音频时正常处理 subclip.write_videofile(video_path, audio=(subclip.audio is not None), logger=None)
else:
# 没有音频时使用不同的写入方式 # 验证生成的视频文件
video.write_videofile(video_path, audio=False, logger=None) if os.path.exists(video_path) and os.path.getsize(video_path) > 0:
with VideoFileClip(video_path) as clip:
video.close() # 确保关闭视频文件 if clip.duration > 0 and clip.fps > 0:
return {timestamp: video_path}
if os.path.getsize(video_path) > 0 and os.path.exists(video_path):
try: raise ValueError("视频文件验证失败")
clip = VideoFileClip(video_path)
duration = clip.duration except Exception as e:
fps = clip.fps logger.warning(f"视频文件处理失败: {video_path} => {str(e)}")
clip.close() if os.path.exists(video_path):
if duration > 0 and fps > 0: os.remove(video_path)
return {timestamp: video_path}
except Exception as e:
logger.warning(f"视频文件验证失败: {video_path} => {str(e)}")
if os.path.exists(video_path):
os.remove(video_path)
except Exception as e: except Exception as e:
logger.warning(f"视频剪辑失败: {str(e)}") logger.warning(f"视频剪辑失败: \n{str(traceback.format_exc())}")
if os.path.exists(video_path): if os.path.exists(video_path):
os.remove(video_path) os.remove(video_path)
finally:
# 确保视频对象被正确关闭
try:
video.close()
if 'subclip' in locals():
subclip.close()
except:
pass
return {} return {}
@ -426,6 +480,4 @@ def merge_videos(video_paths, ost_list):
if __name__ == "__main__": if __name__ == "__main__":
download_videos( save_clip_video('00:50-01:41', 'E:\\projects\\NarratoAI\\resource\\videos\\WeChat_20241110144511.mp4')
"test123", ["Money Exchange Medium"], audio_duration=100, source="pixabay"
)

View File

@ -326,7 +326,7 @@ class ScriptProcessor:
start_seconds = time_to_seconds(start_str) start_seconds = time_to_seconds(start_str)
end_seconds = time_to_seconds(end_str) end_seconds = time_to_seconds(end_str)
duration = end_seconds - start_seconds duration = end_seconds - start_seconds
word_count = int(duration / 0.2) word_count = int(duration / 0.35)
return word_count return word_count
except Exception as e: except Exception as e:

View File

@ -395,6 +395,8 @@ def cut_video(params, progress_callback=None):
time_list = [i['timestamp'] for i in video_script_list] time_list = [i['timestamp'] for i in video_script_list]
total_clips = len(time_list) total_clips = len(time_list)
print("time_list", time_list)
def clip_progress(current, total): def clip_progress(current, total):
progress = int((current / total) * 100) progress = int((current / total) * 100)
@ -413,12 +415,16 @@ def cut_video(params, progress_callback=None):
st.session_state['subclip_videos'] = subclip_videos st.session_state['subclip_videos'] = subclip_videos
print("list:", subclip_videos)
for i, video_script in enumerate(video_script_list): for i, video_script in enumerate(video_script_list):
print(i)
print(video_script)
try: try:
video_script['path'] = subclip_videos[video_script['timestamp']] video_script['path'] = subclip_videos[video_script['timestamp']]
except KeyError as err: except KeyError as err:
logger.error(f"裁剪视频失败: {err}") logger.error(f"裁剪视频失败: {err}")
raise ValueError(f"裁剪视频失败: {err}") # raise ValueError(f"裁剪视频失败: {err}")
return task_id, subclip_videos return task_id, subclip_videos

View File

@ -65,16 +65,44 @@ class VideoProcessor:
shot_boundaries.append(i) shot_boundaries.append(i)
return shot_boundaries return shot_boundaries
def filter_keyframes_by_time(self, keyframes: List[np.ndarray],
keyframe_indices: List[int]) -> Tuple[List[np.ndarray], List[int]]:
"""
过滤关键帧确保每秒最多只有一个关键帧
Args:
keyframes: 关键帧列表
keyframe_indices: 关键帧索引列表
Returns:
Tuple[List[np.ndarray], List[int]]: 过滤后的关键帧列表和对应的帧索引
"""
if not keyframes or not keyframe_indices:
return keyframes, keyframe_indices
filtered_frames = []
filtered_indices = []
last_second = -1
for frame, idx in zip(keyframes, keyframe_indices):
current_second = idx // self.fps
if current_second != last_second:
filtered_frames.append(frame)
filtered_indices.append(idx)
last_second = current_second
return filtered_frames, filtered_indices
def extract_keyframes(self, frames: List[np.ndarray], shot_boundaries: List[int]) -> Tuple[List[np.ndarray], List[int]]: def extract_keyframes(self, frames: List[np.ndarray], shot_boundaries: List[int]) -> Tuple[List[np.ndarray], List[int]]:
""" """
从每个镜头中提取关键帧 从每个镜头中提取关键帧并确保每秒最多一个关键帧
Args: Args:
frames: 视频帧列表 frames: 视频帧列表
shot_boundaries: 镜头边界列表 shot_boundaries: 镜头边界列表
Returns: Returns:
Tuple[List[np.ndarray], List[int]]: <EFBFBD><EFBFBD><EFBFBD>帧列表和对应的帧索引 Tuple[List[np.ndarray], List[int]]: 帧列表和对应的帧索引
""" """
keyframes = [] keyframes = []
keyframe_indices = [] keyframe_indices = []
@ -92,7 +120,10 @@ class VideoProcessor:
keyframes.append(shot_frames[center_idx]) keyframes.append(shot_frames[center_idx])
keyframe_indices.append(start + center_idx) keyframe_indices.append(start + center_idx)
return keyframes, keyframe_indices # 过滤每秒多余的关键帧
filtered_keyframes, filtered_indices = self.filter_keyframes_by_time(keyframes, keyframe_indices)
return filtered_keyframes, filtered_indices
def save_keyframes(self, keyframes: List[np.ndarray], keyframe_indices: List[int], def save_keyframes(self, keyframes: List[np.ndarray], keyframe_indices: List[int],
output_dir: str) -> None: output_dir: str) -> None:
@ -206,4 +237,4 @@ class VideoProcessor:
# 调整关键帧索引,加上跳过的帧数 # 调整关键帧索引,加上跳过的帧数
adjusted_indices = [idx + skip_frames for idx in keyframe_indices] adjusted_indices = [idx + skip_frames for idx in keyframe_indices]
self.save_keyframes(keyframes, adjusted_indices, output_dir) self.save_keyframes(keyframes, adjusted_indices, output_dir)

View File

@ -100,7 +100,7 @@ class VisionAnalyzer:
except Exception as e: except Exception as e:
retry_count += 1 retry_count += 1
error_msg = f"批次 {i // batch_size} 处理出错: {str(e)}\n{traceback.format_exc()}" error_msg = f"批次 {i // batch_size} 处理出错: {str(e)}"
logger.error(error_msg) logger.error(error_msg)
if retry_count >= 3: if retry_count >= 3:
@ -111,7 +111,7 @@ class VisionAnalyzer:
'model_used': self.model_name 'model_used': self.model_name
}) })
else: else:
logger.info(f"批次 {i // batch_size} 处理失败等待60秒后重试...") logger.info(f"批次 {i // batch_size} 处理失败等待60秒后重试当前批次...")
await asyncio.sleep(60) await asyncio.sleep(60)
pbar.update(1) pbar.update(1)

View File

@ -1,11 +1,11 @@
[app] [app]
project_version="0.3.2" project_version="0.3.4"
# 支持视频理解的大模型提供商 # 支持视频理解的大模型提供商
# gemini # gemini
# NarratoAPI # NarratoAPI
# qwen2-vl (待增加) # qwen2-vl (待增加)
vision_llm_provider="gemini" vision_llm_provider="gemini"
vision_batch_size = 5 vision_batch_size = 7
vision_analysis_prompt = "你是资深视频内容分析专家,擅长分析视频画面信息,分析下面视频画面内容,只输出客观的画面描述不要给任何总结或评价" vision_analysis_prompt = "你是资深视频内容分析专家,擅长分析视频画面信息,分析下面视频画面内容,只输出客观的画面描述不要给任何总结或评价"
########## Vision Gemini API Key ########## Vision Gemini API Key

View File

@ -18,6 +18,69 @@ from app.utils import utils, check_script, vision_analyzer, video_processor
from webui.utils import file_utils from webui.utils import file_utils
def get_batch_timestamps(batch_files, prev_batch_files=None):
"""
获取一批文件的时间戳范围
返回: (first_timestamp, last_timestamp, timestamp_range)
文件名格式: keyframe_001253_000050.jpg
其中 000050 表示 00:00:50 (50)
000101 表示 00:01:01 (1分1秒)
Args:
batch_files: 当前批次的文件列表
prev_batch_files: 上一个批次的文件列表用于处理单张图片的情况
"""
if not batch_files:
logger.warning("Empty batch files")
return "00:00", "00:00", "00:00-00:00"
# 如果当前批次只有一张图片,且有上一个批次的文件,则使用上一批次的最后一张作为首帧
if len(batch_files) == 1 and prev_batch_files and len(prev_batch_files) > 0:
first_frame = os.path.basename(prev_batch_files[-1])
last_frame = os.path.basename(batch_files[0])
logger.debug(f"单张图片批次,使用上一批次最后一帧作为首帧: {first_frame}")
else:
# 提取首帧和尾帧的时间戳
first_frame = os.path.basename(batch_files[0])
last_frame = os.path.basename(batch_files[-1])
# 从文件名中提取时间信息
first_time = first_frame.split('_')[2].replace('.jpg', '') # 000050
last_time = last_frame.split('_')[2].replace('.jpg', '') # 000101
# 转换为分:秒格式
def format_timestamp(time_str):
# 时间格式为 MMSS如 0050 表示 00:50, 0101 表示 01:01
if len(time_str) < 4:
logger.warning(f"Invalid timestamp format: {time_str}")
return "00:00"
minutes = int(time_str[-4:-2]) # 取后4位的前2位作为分钟
seconds = int(time_str[-2:]) # 取后2位作为秒数
# 处理进位
if seconds >= 60:
minutes += seconds // 60
seconds = seconds % 60
return f"{minutes:02d}:{seconds:02d}"
first_timestamp = format_timestamp(first_time)
last_timestamp = format_timestamp(last_time)
timestamp_range = f"{first_timestamp}-{last_timestamp}"
logger.debug(f"解析时间戳: {first_frame} -> {first_timestamp}, {last_frame} -> {last_timestamp}")
return first_timestamp, last_timestamp, timestamp_range
def get_batch_files(keyframe_files, result, batch_size=5):
"""
获取当前批次的图片文件
"""
batch_start = result['batch_index'] * batch_size
batch_end = min(batch_start + batch_size, len(keyframe_files))
return keyframe_files[batch_start:batch_end]
def render_script_panel(tr): def render_script_panel(tr):
"""渲染脚本配置面板""" """渲染脚本配置面板"""
with st.container(border=True): with st.container(border=True):
@ -238,7 +301,7 @@ def generate_script(tr, params):
# 检查是否已经提取过关键帧 # 检查是否已经提取过关键帧
keyframe_files = [] keyframe_files = []
if os.path.exists(video_keyframes_dir): if os.path.exists(video_keyframes_dir):
# 取已有的关键帧文件 # 取已有的关键帧文件
for filename in sorted(os.listdir(video_keyframes_dir)): for filename in sorted(os.listdir(video_keyframes_dir)):
if filename.endswith('.jpg'): if filename.endswith('.jpg'):
keyframe_files.append(os.path.join(video_keyframes_dir, filename)) keyframe_files.append(os.path.join(video_keyframes_dir, filename))
@ -282,7 +345,7 @@ def generate_script(tr, params):
except Exception as cleanup_err: except Exception as cleanup_err:
logger.error(f"清理失败的关键帧目录时出错: {cleanup_err}") logger.error(f"清理失败的关键帧目录时出错: {cleanup_err}")
raise Exception(f"关键帧提取败: {str(e)}") raise Exception(f"关键帧提取<EFBFBD><EFBFBD>败: {str(e)}")
# 根据不同的 LLM 提供商处理 # 根据不同的 LLM 提供商处理
vision_llm_provider = st.session_state.get('vision_llm_providers').lower() vision_llm_provider = st.session_state.get('vision_llm_providers').lower()
@ -327,38 +390,27 @@ def generate_script(tr, params):
# 合并所有批次的析结果 # 合并所有批次的析结果
frame_analysis = "" frame_analysis = ""
prev_batch_files = None
for result in results: for result in results:
if 'error' in result: if 'error' in result:
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}") logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
continue continue
# 获取当前批次的图片文件 batch_files = get_batch_files(keyframe_files, result, config.app.get("vision_batch_size", 5))
batch_start = result['batch_index'] * config.app.get("vision_batch_size", 5) logger.debug(f"批次 {result['batch_index']} 处理完成,共 {len(batch_files)} 张图片")
batch_end = min(batch_start + config.app.get("vision_batch_size", 5), len(keyframe_files)) logger.debug(batch_files)
batch_files = keyframe_files[batch_start:batch_end]
# 提取首帧和尾帧的时间戳 first_timestamp, last_timestamp, _ = get_batch_timestamps(batch_files, prev_batch_files)
first_frame = os.path.basename(batch_files[0]) logger.debug(f"处理时间戳: {first_timestamp}-{last_timestamp}")
last_frame = os.path.basename(batch_files[-1])
# 从文件名中提取时间信息
first_time = first_frame.split('_')[2].replace('.jpg', '') # 000002
last_time = last_frame.split('_')[2].replace('.jpg', '') # 000005
# 转换为分:秒格式
def format_timestamp(time_str):
seconds = int(time_str)
minutes = seconds // 60
seconds = seconds % 60
return f"{minutes:02d}:{seconds:02d}"
first_timestamp = format_timestamp(first_time)
last_timestamp = format_timestamp(last_time)
# 添加带时间戳的分析结果 # 添加带时间戳的分析结果
frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n" frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n"
frame_analysis += result['response'] frame_analysis += result['response']
frame_analysis += "\n" frame_analysis += "\n"
# 更新上一个批次的文件
prev_batch_files = batch_files
if not frame_analysis.strip(): if not frame_analysis.strip():
raise Exception("未能生成有效的帧分析结果") raise Exception("未能生成有效的帧分析结果")
@ -378,45 +430,27 @@ def generate_script(tr, params):
# 构建帧内容列表 # 构建帧内容列表
frame_content_list = [] frame_content_list = []
prev_batch_files = None
for i, result in enumerate(results): for i, result in enumerate(results):
if 'error' in result: if 'error' in result:
continue continue
# 获取当前批次的图片文件 batch_files = get_batch_files(keyframe_files, result, config.app.get("vision_batch_size", 5))
batch_start = result['batch_index'] * config.app.get("vision_batch_size", 5) _, _, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files)
batch_end = min(batch_start + config.app.get("vision_batch_size", 5), len(keyframe_files))
batch_files = keyframe_files[batch_start:batch_end]
# 提取首帧和尾帧的时间戳
first_frame = os.path.basename(batch_files[0])
last_frame = os.path.basename(batch_files[-1])
# 从文件名中提取时间信息
first_time = first_frame.split('_')[2].replace('.jpg', '') # 000002
last_time = last_frame.split('_')[2].replace('.jpg', '') # 000005
# 转换为分:秒格式
def format_timestamp(time_str):
seconds = int(time_str)
minutes = seconds // 60
seconds = seconds % 60
return f"{minutes:02d}:{seconds:02d}"
first_timestamp = format_timestamp(first_time)
last_timestamp = format_timestamp(last_time)
# 构建时间戳范围字符串 (MM:SS-MM:SS 格式)
timestamp_range = f"{first_timestamp}-{last_timestamp}"
frame_content = { frame_content = {
"timestamp": timestamp_range, # 使用时间范围格式 "MM:SS-MM:SS" "timestamp": timestamp_range,
"picture": result['response'], # 图片分析结果 "picture": result['response'],
"narration": "", # 将由 ScriptProcessor 生成 "narration": "",
"OST": 2 # 默认值 "OST": 2
} }
frame_content_list.append(frame_content) frame_content_list.append(frame_content)
logger.debug(f"添加帧内容: 时间范围={timestamp_range}, 分析结果长度={len(result['response'])}") logger.debug(f"添加帧内容: 时间范围={timestamp_range}, 分析结果长度={len(result['response'])}")
# 更新上一个批次的文件
prev_batch_files = batch_files
if not frame_content_list: if not frame_content_list:
raise Exception("没有有效的帧内容可以处理") raise Exception("没有有效的帧内容可以处理")
@ -442,13 +476,15 @@ def generate_script(tr, params):
) )
adapter = HTTPAdapter(max_retries=retry_strategy) adapter = HTTPAdapter(max_retries=retry_strategy)
session.mount("https://", adapter) session.mount("https://", adapter)
try:
response = session.post( response = session.post(
f"{config.app.get('narrato_api_url')}/video/config", f"{config.app.get('narrato_api_url')}/video/config",
params=api_params, params=api_params,
timeout=30, timeout=30,
verify=True # 启用证书验证 verify=True # 启用证书验证
) )
except:
pass
custom_prompt = st.session_state.get('custom_prompt', '') custom_prompt = st.session_state.get('custom_prompt', '')
processor = ScriptProcessor( processor = ScriptProcessor(
@ -621,7 +657,7 @@ def save_script(tr, video_clip_json_details):
# 显示成功消息 # 显示成功消息
st.success(tr("Script saved successfully")) st.success(tr("Script saved successfully"))
# 强制重新加载页面<EFBFBD><EFBFBD>更新选择框 # 强制重新加载页面更新选择框
time.sleep(0.5) # 给一点时间让用户看到成功消息 time.sleep(0.5) # 给一点时间让用户看到成功消息
st.rerun() st.rerun()