perf(vision): 优化视觉分析流程和批量处理逻辑

- 移除了 vision_analysis_prompt 配置项
- 优化了 Gemini 和 QwenVL 分析器的批量处理逻辑
- 更新了文档生成脚本和 UI 组件以适应新的分析流程
- 调整了视频帧提取相关函数,移除了不必要的 skip_seconds 参数
- 更新了中文翻译文件,添加了新的批处理大小相关提示
This commit is contained in:
linyq 2025-05-07 18:44:37 +08:00
parent 82823297f2
commit 2dc83bc18e
7 changed files with 70 additions and 35 deletions

View File

@ -61,7 +61,6 @@ class VisionAnalyzer:
try: try:
# 加载图片 # 加载图片
if isinstance(images[0], str): if isinstance(images[0], str):
logger.info("正在加载图片...")
images = self.load_images(images) images = self.load_images(images)
# 验证图片列表 # 验证图片列表
@ -81,11 +80,14 @@ class VisionAnalyzer:
images = valid_images images = valid_images
results = [] results = []
total_batches = (len(images) + batch_size - 1) // batch_size # 视频帧总数除以批量处理大小,如果有小数则+1
batches_needed = len(images) // batch_size
if len(images) % batch_size > 0:
batches_needed += 1
logger.debug(f"视频帧总数:{len(images)}, 每批处理 {batch_size} 帧, 需要访问 VLM {batches_needed}")
logger.debug(f"{total_batches} 个批次,每批次 {batch_size} 张图片") with tqdm(total=batches_needed, desc="分析进度") as pbar:
with tqdm(total=total_batches, desc="分析进度") as pbar:
for i in range(0, len(images), batch_size): for i in range(0, len(images), batch_size):
batch = images[i:i + batch_size] batch = images[i:i + batch_size]
retry_count = 0 retry_count = 0
@ -93,8 +95,8 @@ class VisionAnalyzer:
while retry_count < 3: while retry_count < 3:
try: try:
# 在每个批次处理前添加小延迟 # 在每个批次处理前添加小延迟
if i > 0: # if i > 0:
await asyncio.sleep(2) # await asyncio.sleep(2)
# 确保每个批次的图片都是有效的 # 确保每个批次的图片都是有效的
valid_batch = [img for img in batch if isinstance(img, PIL.Image.Image)] valid_batch = [img for img in batch if isinstance(img, PIL.Image.Image)]

View File

@ -80,7 +80,7 @@ class QwenAnalyzer:
# 添加文本提示 # 添加文本提示
content.append({ content.append({
"type": "text", "type": "text",
"text": prompt "text": prompt % (len(content), len(content), len(content))
}) })
# 调用API # 调用API
@ -102,7 +102,7 @@ class QwenAnalyzer:
async def analyze_images(self, async def analyze_images(self,
images: Union[List[str], List[PIL.Image.Image]], images: Union[List[str], List[PIL.Image.Image]],
prompt: str, prompt: str,
batch_size: int = 5) -> List[Dict]: batch_size: int) -> List[Dict]:
""" """
批量分析多张图片 批量分析多张图片
Args: Args:
@ -118,7 +118,6 @@ class QwenAnalyzer:
# 加载图片 # 加载图片
if isinstance(images[0], str): if isinstance(images[0], str):
logger.info("正在加载图片...")
images = self.load_images(images) images = self.load_images(images)
# 验证图片列表 # 验证图片列表
@ -141,9 +140,14 @@ class QwenAnalyzer:
images = valid_images images = valid_images
results = [] results = []
total_batches = (len(images) + batch_size - 1) // batch_size # 视频帧总数除以批量处理大小,如果有小数则+1
batches_needed = len(images) // batch_size
if len(images) % batch_size > 0:
batches_needed += 1
logger.debug(f"视频帧总数:{len(images)}, 每批处理 {batch_size} 帧, 需要访问 VLM {batches_needed}")
with tqdm(total=total_batches, desc="分析进度") as pbar: with tqdm(total=batches_needed, desc="分析进度") as pbar:
for i in range(0, len(images), batch_size): for i in range(0, len(images), batch_size):
batch = images[i:i + batch_size] batch = images[i:i + batch_size]
batch_paths = valid_paths[i:i + batch_size] if valid_paths else None batch_paths = valid_paths[i:i + batch_size] if valid_paths else None
@ -151,9 +155,9 @@ class QwenAnalyzer:
while retry_count < 3: while retry_count < 3:
try: try:
# 在每个批次处理前<EFBFBD><EFBFBD>加小延迟 # 在每个批次处理前加小延迟
if i > 0: # if i > 0:
await asyncio.sleep(2) # await asyncio.sleep(0.5)
# 确保每个批次的图片都是有效的 # 确保每个批次的图片都是有效的
valid_batch = [img for img in batch if isinstance(img, PIL.Image.Image)] valid_batch = [img for img in batch if isinstance(img, PIL.Image.Image)]
@ -209,7 +213,7 @@ class QwenAnalyzer:
for i, result in enumerate(results): for i, result in enumerate(results):
response_text = result['response'] response_text = result['response']
# 如果有图片路径信息,<EFBFBD><EFBFBD><EFBFBD>用它来生成文件名 # 如果有图片路径信息,用它来生成文件名
if result.get('image_paths'): if result.get('image_paths'):
image_paths = result['image_paths'] image_paths = result['image_paths']
img_name_start = Path(image_paths[0]).stem.split('_')[-1] img_name_start = Path(image_paths[0]).stem.split('_')[-1]

View File

@ -84,7 +84,7 @@ class VideoProcessor:
} }
def extract_frames_by_interval(self, output_dir: str, interval_seconds: float = 5.0, def extract_frames_by_interval(self, output_dir: str, interval_seconds: float = 5.0,
use_hw_accel: bool = True, skip_seconds: float = 0.0) -> List[int]: use_hw_accel: bool = True) -> List[int]:
""" """
按指定时间间隔提取视频帧 按指定时间间隔提取视频帧
@ -92,7 +92,6 @@ class VideoProcessor:
output_dir: 输出目录 output_dir: 输出目录
interval_seconds: 帧提取间隔 interval_seconds: 帧提取间隔
use_hw_accel: 是否使用硬件加速 use_hw_accel: 是否使用硬件加速
skip_seconds: 跳过视频开头的秒数
Returns: Returns:
List[int]: 提取的帧号列表 List[int]: 提取的帧号列表
@ -101,7 +100,7 @@ class VideoProcessor:
os.makedirs(output_dir) os.makedirs(output_dir)
# 计算起始时间和帧提取点 # 计算起始时间和帧提取点
start_time = skip_seconds start_time = 0
end_time = self.duration end_time = self.duration
extraction_times = [] extraction_times = []
@ -291,7 +290,6 @@ class VideoProcessor:
def process_video_pipeline(self, def process_video_pipeline(self,
output_dir: str, output_dir: str,
skip_seconds: float = 0.0,
interval_seconds: float = 5.0, # 帧提取间隔(秒) interval_seconds: float = 5.0, # 帧提取间隔(秒)
use_hw_accel: bool = True) -> None: use_hw_accel: bool = True) -> None:
""" """
@ -299,7 +297,6 @@ class VideoProcessor:
Args: Args:
output_dir: 输出目录 output_dir: 输出目录
skip_seconds: 跳过视频开头的秒数
interval_seconds: 帧提取间隔 interval_seconds: 帧提取间隔
use_hw_accel: 是否使用硬件加速 use_hw_accel: 是否使用硬件加速
""" """
@ -312,8 +309,7 @@ class VideoProcessor:
self.extract_frames_by_interval( self.extract_frames_by_interval(
output_dir, output_dir,
interval_seconds=interval_seconds, interval_seconds=interval_seconds,
use_hw_accel=use_hw_accel, use_hw_accel=use_hw_accel
skip_seconds=skip_seconds
) )
logger.info(f"处理完成!视频帧已保存在: {output_dir}") logger.info(f"处理完成!视频帧已保存在: {output_dir}")

View File

@ -4,7 +4,6 @@
# gemini # gemini
# qwenvl # qwenvl
vision_llm_provider="qwenvl" vision_llm_provider="qwenvl"
vision_analysis_prompt = "你是资深视频内容分析专家,擅长分析视频画面信息,分析下面视频画面内容,只输出客观的画面描述不要给任何总结或评价"
########## Vision Gemini API Key ########## Vision Gemini API Key
vision_gemini_api_key = "" vision_gemini_api_key = ""
@ -181,4 +180,4 @@
threshold = 30 threshold = 30
version = "v2" version = "v2"
# 大模型单次处理的关键帧数量 # 大模型单次处理的关键帧数量
vision_batch_size = 5 vision_batch_size = 10

View File

@ -220,10 +220,19 @@ def render_script_buttons(tr, params):
st.number_input( st.number_input(
tr("Frame Interval (seconds)"), tr("Frame Interval (seconds)"),
min_value=0, min_value=0,
value=st.session_state.get('frame_interval_input', config.frames.get('frame_interval_input', 5)), value=st.session_state.get('frame_interval_input', config.frames.get('frame_interval_input', 3)),
help=tr("Frame Interval (seconds) (More keyframes consume more tokens)"), help=tr("Frame Interval (seconds) (More keyframes consume more tokens)"),
key="frame_interval_input" key="frame_interval_input"
) )
with input_cols[1]:
st.number_input(
tr("Batch Size"),
min_value=0,
value=st.session_state.get('vision_batch_size', config.frames.get('vision_batch_size', 10)),
help=tr("Batch Size (More keyframes consume more tokens)"),
key="vision_batch_size"
)
# 生成/加载按钮 # 生成/加载按钮
if script_path == "auto": if script_path == "auto":

View File

@ -115,7 +115,6 @@
"Text Generation Model Settings": "文案生成模型设置", "Text Generation Model Settings": "文案生成模型设置",
"LLM Model Name": "大语言模型名称", "LLM Model Name": "大语言模型名称",
"LLM Model API Key": "大语言模型 API 密钥", "LLM Model API Key": "大语言模型 API 密钥",
"Batch Size": "批处理大小",
"Text Model Provider": "文案生成模型提供商", "Text Model Provider": "文案生成模型提供商",
"Text API Key": "文案生成 API 密钥", "Text API Key": "文案生成 API 密钥",
"Text Base URL": "文案生成接口地址", "Text Base URL": "文案生成接口地址",
@ -194,6 +193,8 @@
"Original Volume": "视频音量", "Original Volume": "视频音量",
"Auto Generate": "纪录片解说 (画面解说)", "Auto Generate": "纪录片解说 (画面解说)",
"Frame Interval (seconds)": "帧间隔 (秒)", "Frame Interval (seconds)": "帧间隔 (秒)",
"Frame Interval (seconds) (More keyframes consume more tokens)": "帧间隔 (秒) (更多关键帧消耗更多令牌)" "Frame Interval (seconds) (More keyframes consume more tokens)": "帧间隔 (秒) (更多关键帧消耗更多令牌)",
"Batch Size": "批处理大小",
"Batch Size (More keyframes consume more tokens)": "批处理大小, 每批处理越少消耗 token 越多"
} }
} }

View File

@ -9,7 +9,6 @@ from app.utils import video_processor
import streamlit as st import streamlit as st
from loguru import logger from loguru import logger
from requests.adapters import HTTPAdapter from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from app.config import config from app.config import config
from app.utils.script_generator import ScriptProcessor from app.utils.script_generator import ScriptProcessor
@ -38,8 +37,9 @@ def generate_script_docu(params):
if not params.video_origin_path: if not params.video_origin_path:
st.error("请先选择视频文件") st.error("请先选择视频文件")
return return
"""
# ===================提取键帧=================== 1. 提取键帧
"""
update_progress(10, "正在提取关键帧...") update_progress(10, "正在提取关键帧...")
# 创建临时目录用于存储关键帧 # 创建临时目录用于存储关键帧
@ -95,9 +95,11 @@ def generate_script_docu(params):
raise Exception(f"关键帧提取失败: {str(e)}") raise Exception(f"关键帧提取失败: {str(e)}")
# 根据不同的 LLM 提供商处理 """
2. 视觉分析
"""
vision_llm_provider = st.session_state.get('vision_llm_providers').lower() vision_llm_provider = st.session_state.get('vision_llm_providers').lower()
logger.debug(f"Vision LLM 提供商: {vision_llm_provider}") logger.debug(f"VLM 视觉大模型提供商: {vision_llm_provider}")
try: try:
# ===================初始化视觉分析器=================== # ===================初始化视觉分析器===================
@ -131,10 +133,32 @@ def generate_script_docu(params):
# 执行异步分析 # 执行异步分析
vision_batch_size = st.session_state.get('vision_batch_size') or config.frames.get("vision_batch_size") vision_batch_size = st.session_state.get('vision_batch_size') or config.frames.get("vision_batch_size")
vision_analysis_prompt = """
我提供了 %s 张视频帧它们按时间顺序排列代表一个连续的视频片段请仔细分析每一帧的内容并关注帧与帧之间的变化以理解整个片段的活动
首先请详细描述每一帧的关键视觉信息包含主要内容人物动作和场景
然后基于所有帧的分析请用**简洁的语言**总结整个视频片段中发生的主要活动或事件流程
请务必使用 JSON 格式输出你的结果JSON 结构应如下
{
"frame_observations": [
{
"frame_number": 1, // 或其他标识帧的方式
"observation": "描述每张视频帧中的主要内容、人物、动作和场景。"
},
// ... 更多帧的观察 ...
],
"overall_activity_summary": "在这里填写你总结的整个片段的主要活动,保持简洁。"
}
请务必不要遗漏视频帧我提供了 %s 张视频帧frame_observations 必须包含 %s 个元素
请只返回 JSON 字符串不要包含任何其他解释性文字
"""
results = loop.run_until_complete( results = loop.run_until_complete(
analyzer.analyze_images( analyzer.analyze_images(
images=keyframe_files, images=keyframe_files,
prompt=config.app.get('vision_analysis_prompt'), prompt=vision_analysis_prompt,
batch_size=vision_batch_size batch_size=vision_batch_size
) )
) )