diff --git a/app/utils/gemini_analyzer.py b/app/utils/gemini_analyzer.py index 07306c5..7236a9e 100644 --- a/app/utils/gemini_analyzer.py +++ b/app/utils/gemini_analyzer.py @@ -61,7 +61,6 @@ class VisionAnalyzer: try: # 加载图片 if isinstance(images[0], str): - logger.info("正在加载图片...") images = self.load_images(images) # 验证图片列表 @@ -81,11 +80,14 @@ class VisionAnalyzer: images = valid_images results = [] - total_batches = (len(images) + batch_size - 1) // batch_size + # 视频帧总数除以批量处理大小,如果有小数则+1 + batches_needed = len(images) // batch_size + if len(images) % batch_size > 0: + batches_needed += 1 + + logger.debug(f"视频帧总数:{len(images)}, 每批处理 {batch_size} 帧, 需要访问 VLM {batches_needed} 次") - logger.debug(f"共 {total_batches} 个批次,每批次 {batch_size} 张图片") - - with tqdm(total=total_batches, desc="分析进度") as pbar: + with tqdm(total=batches_needed, desc="分析进度") as pbar: for i in range(0, len(images), batch_size): batch = images[i:i + batch_size] retry_count = 0 @@ -93,8 +95,8 @@ class VisionAnalyzer: while retry_count < 3: try: # 在每个批次处理前添加小延迟 - if i > 0: - await asyncio.sleep(2) + # if i > 0: + # await asyncio.sleep(2) # 确保每个批次的图片都是有效的 valid_batch = [img for img in batch if isinstance(img, PIL.Image.Image)] diff --git a/app/utils/qwenvl_analyzer.py b/app/utils/qwenvl_analyzer.py index 54e6e36..ec4de39 100644 --- a/app/utils/qwenvl_analyzer.py +++ b/app/utils/qwenvl_analyzer.py @@ -80,7 +80,7 @@ class QwenAnalyzer: # 添加文本提示 content.append({ "type": "text", - "text": prompt + "text": prompt % (len(content), len(content), len(content)) }) # 调用API @@ -102,7 +102,7 @@ class QwenAnalyzer: async def analyze_images(self, images: Union[List[str], List[PIL.Image.Image]], prompt: str, - batch_size: int = 5) -> List[Dict]: + batch_size: int) -> List[Dict]: """ 批量分析多张图片 Args: @@ -118,7 +118,6 @@ class QwenAnalyzer: # 加载图片 if isinstance(images[0], str): - logger.info("正在加载图片...") images = self.load_images(images) # 验证图片列表 @@ -141,9 +140,14 @@ class QwenAnalyzer: images = valid_images results = [] - total_batches = (len(images) + batch_size - 1) // batch_size + # 视频帧总数除以批量处理大小,如果有小数则+1 + batches_needed = len(images) // batch_size + if len(images) % batch_size > 0: + batches_needed += 1 + + logger.debug(f"视频帧总数:{len(images)}, 每批处理 {batch_size} 帧, 需要访问 VLM {batches_needed} 次") - with tqdm(total=total_batches, desc="分析进度") as pbar: + with tqdm(total=batches_needed, desc="分析进度") as pbar: for i in range(0, len(images), batch_size): batch = images[i:i + batch_size] batch_paths = valid_paths[i:i + batch_size] if valid_paths else None @@ -151,9 +155,9 @@ class QwenAnalyzer: while retry_count < 3: try: - # 在每个批次处理前��加小延迟 - if i > 0: - await asyncio.sleep(2) + # 在每个批次处理前添加小延迟 + # if i > 0: + # await asyncio.sleep(0.5) # 确保每个批次的图片都是有效的 valid_batch = [img for img in batch if isinstance(img, PIL.Image.Image)] @@ -209,7 +213,7 @@ class QwenAnalyzer: for i, result in enumerate(results): response_text = result['response'] - # 如果有图片路径信息,���用它来生成文件名 + # 如果有图片路径信息,用它来生成文件名 if result.get('image_paths'): image_paths = result['image_paths'] img_name_start = Path(image_paths[0]).stem.split('_')[-1] diff --git a/app/utils/video_processor.py b/app/utils/video_processor.py index d10f8a7..1d3dd9b 100644 --- a/app/utils/video_processor.py +++ b/app/utils/video_processor.py @@ -84,7 +84,7 @@ class VideoProcessor: } def extract_frames_by_interval(self, output_dir: str, interval_seconds: float = 5.0, - use_hw_accel: bool = True, skip_seconds: float = 0.0) -> List[int]: + use_hw_accel: bool = True) -> List[int]: """ 按指定时间间隔提取视频帧 @@ -92,7 +92,6 @@ class VideoProcessor: output_dir: 输出目录 interval_seconds: 帧提取间隔(秒) use_hw_accel: 是否使用硬件加速 - skip_seconds: 跳过视频开头的秒数 Returns: List[int]: 提取的帧号列表 @@ -101,7 +100,7 @@ class VideoProcessor: os.makedirs(output_dir) # 计算起始时间和帧提取点 - start_time = skip_seconds + start_time = 0 end_time = self.duration extraction_times = [] @@ -291,7 +290,6 @@ class VideoProcessor: def process_video_pipeline(self, output_dir: str, - skip_seconds: float = 0.0, interval_seconds: float = 5.0, # 帧提取间隔(秒) use_hw_accel: bool = True) -> None: """ @@ -299,7 +297,6 @@ class VideoProcessor: Args: output_dir: 输出目录 - skip_seconds: 跳过视频开头的秒数 interval_seconds: 帧提取间隔(秒) use_hw_accel: 是否使用硬件加速 """ @@ -312,8 +309,7 @@ class VideoProcessor: self.extract_frames_by_interval( output_dir, interval_seconds=interval_seconds, - use_hw_accel=use_hw_accel, - skip_seconds=skip_seconds + use_hw_accel=use_hw_accel ) logger.info(f"处理完成!视频帧已保存在: {output_dir}") diff --git a/config.example.toml b/config.example.toml index 835a8e9..7576350 100644 --- a/config.example.toml +++ b/config.example.toml @@ -4,7 +4,6 @@ # gemini # qwenvl vision_llm_provider="qwenvl" - vision_analysis_prompt = "你是资深视频内容分析专家,擅长分析视频画面信息,分析下面视频画面内容,只输出客观的画面描述不要给任何总结或评价" ########## Vision Gemini API Key vision_gemini_api_key = "" @@ -181,4 +180,4 @@ threshold = 30 version = "v2" # 大模型单次处理的关键帧数量 - vision_batch_size = 5 + vision_batch_size = 10 diff --git a/webui/components/script_settings.py b/webui/components/script_settings.py index 1e681b3..7840886 100644 --- a/webui/components/script_settings.py +++ b/webui/components/script_settings.py @@ -220,10 +220,19 @@ def render_script_buttons(tr, params): st.number_input( tr("Frame Interval (seconds)"), min_value=0, - value=st.session_state.get('frame_interval_input', config.frames.get('frame_interval_input', 5)), + value=st.session_state.get('frame_interval_input', config.frames.get('frame_interval_input', 3)), help=tr("Frame Interval (seconds) (More keyframes consume more tokens)"), key="frame_interval_input" ) + + with input_cols[1]: + st.number_input( + tr("Batch Size"), + min_value=0, + value=st.session_state.get('vision_batch_size', config.frames.get('vision_batch_size', 10)), + help=tr("Batch Size (More keyframes consume more tokens)"), + key="vision_batch_size" + ) # 生成/加载按钮 if script_path == "auto": diff --git a/webui/i18n/zh.json b/webui/i18n/zh.json index 7a7a387..6aa7fbc 100644 --- a/webui/i18n/zh.json +++ b/webui/i18n/zh.json @@ -115,7 +115,6 @@ "Text Generation Model Settings": "文案生成模型设置", "LLM Model Name": "大语言模型名称", "LLM Model API Key": "大语言模型 API 密钥", - "Batch Size": "批处理大小", "Text Model Provider": "文案生成模型提供商", "Text API Key": "文案生成 API 密钥", "Text Base URL": "文案生成接口地址", @@ -194,6 +193,8 @@ "Original Volume": "视频音量", "Auto Generate": "纪录片解说 (画面解说)", "Frame Interval (seconds)": "帧间隔 (秒)", - "Frame Interval (seconds) (More keyframes consume more tokens)": "帧间隔 (秒) (更多关键帧消耗更多令牌)" + "Frame Interval (seconds) (More keyframes consume more tokens)": "帧间隔 (秒) (更多关键帧消耗更多令牌)", + "Batch Size": "批处理大小", + "Batch Size (More keyframes consume more tokens)": "批处理大小, 每批处理越少消耗 token 越多" } -} +} \ No newline at end of file diff --git a/webui/tools/generate_script_docu.py b/webui/tools/generate_script_docu.py index 7069215..ee388dc 100644 --- a/webui/tools/generate_script_docu.py +++ b/webui/tools/generate_script_docu.py @@ -9,7 +9,6 @@ from app.utils import video_processor import streamlit as st from loguru import logger from requests.adapters import HTTPAdapter -from urllib3.util.retry import Retry from app.config import config from app.utils.script_generator import ScriptProcessor @@ -38,8 +37,9 @@ def generate_script_docu(params): if not params.video_origin_path: st.error("请先选择视频文件") return - - # ===================提取键帧=================== + """ + 1. 提取键帧 + """ update_progress(10, "正在提取关键帧...") # 创建临时目录用于存储关键帧 @@ -95,9 +95,11 @@ def generate_script_docu(params): raise Exception(f"关键帧提取失败: {str(e)}") - # 根据不同的 LLM 提供商处理 + """ + 2. 视觉分析 + """ vision_llm_provider = st.session_state.get('vision_llm_providers').lower() - logger.debug(f"Vision LLM 提供商: {vision_llm_provider}") + logger.debug(f"VLM 视觉大模型提供商: {vision_llm_provider}") try: # ===================初始化视觉分析器=================== @@ -131,10 +133,32 @@ def generate_script_docu(params): # 执行异步分析 vision_batch_size = st.session_state.get('vision_batch_size') or config.frames.get("vision_batch_size") + vision_analysis_prompt = """ +我提供了 %s 张视频帧,它们按时间顺序排列,代表一个连续的视频片段。请仔细分析每一帧的内容,并关注帧与帧之间的变化,以理解整个片段的活动。 + +首先,请详细描述每一帧的关键视觉信息(包含:主要内容、人物、动作和场景)。 +然后,基于所有帧的分析,请用**简洁的语言**总结整个视频片段中发生的主要活动或事件流程。 + +请务必使用 JSON 格式输出你的结果。JSON 结构应如下: +{ + "frame_observations": [ + { + "frame_number": 1, // 或其他标识帧的方式 + "observation": "描述每张视频帧中的主要内容、人物、动作和场景。" + }, + // ... 更多帧的观察 ... + ], + "overall_activity_summary": "在这里填写你总结的整个片段的主要活动,保持简洁。" +} + +请务必不要遗漏视频帧,我提供了 %s 张视频帧,frame_observations 必须包含 %s 个元素 + +请只返回 JSON 字符串,不要包含任何其他解释性文字。 + """ results = loop.run_until_complete( analyzer.analyze_images( images=keyframe_files, - prompt=config.app.get('vision_analysis_prompt'), + prompt=vision_analysis_prompt, batch_size=vision_batch_size ) )