mirror of
https://github.com/linyqh/NarratoAI.git
synced 2025-12-11 02:12:50 +00:00
perf(vision): 优化视觉分析流程和批量处理逻辑
- 移除了 vision_analysis_prompt 配置项 - 优化了 Gemini 和 QwenVL 分析器的批量处理逻辑 - 更新了文档生成脚本和 UI 组件以适应新的分析流程 - 调整了视频帧提取相关函数,移除了不必要的 skip_seconds 参数 - 更新了中文翻译文件,添加了新的批处理大小相关提示
This commit is contained in:
parent
82823297f2
commit
2dc83bc18e
@ -61,7 +61,6 @@ class VisionAnalyzer:
|
||||
try:
|
||||
# 加载图片
|
||||
if isinstance(images[0], str):
|
||||
logger.info("正在加载图片...")
|
||||
images = self.load_images(images)
|
||||
|
||||
# 验证图片列表
|
||||
@ -81,11 +80,14 @@ class VisionAnalyzer:
|
||||
|
||||
images = valid_images
|
||||
results = []
|
||||
total_batches = (len(images) + batch_size - 1) // batch_size
|
||||
# 视频帧总数除以批量处理大小,如果有小数则+1
|
||||
batches_needed = len(images) // batch_size
|
||||
if len(images) % batch_size > 0:
|
||||
batches_needed += 1
|
||||
|
||||
logger.debug(f"视频帧总数:{len(images)}, 每批处理 {batch_size} 帧, 需要访问 VLM {batches_needed} 次")
|
||||
|
||||
logger.debug(f"共 {total_batches} 个批次,每批次 {batch_size} 张图片")
|
||||
|
||||
with tqdm(total=total_batches, desc="分析进度") as pbar:
|
||||
with tqdm(total=batches_needed, desc="分析进度") as pbar:
|
||||
for i in range(0, len(images), batch_size):
|
||||
batch = images[i:i + batch_size]
|
||||
retry_count = 0
|
||||
@ -93,8 +95,8 @@ class VisionAnalyzer:
|
||||
while retry_count < 3:
|
||||
try:
|
||||
# 在每个批次处理前添加小延迟
|
||||
if i > 0:
|
||||
await asyncio.sleep(2)
|
||||
# if i > 0:
|
||||
# await asyncio.sleep(2)
|
||||
|
||||
# 确保每个批次的图片都是有效的
|
||||
valid_batch = [img for img in batch if isinstance(img, PIL.Image.Image)]
|
||||
|
||||
@ -80,7 +80,7 @@ class QwenAnalyzer:
|
||||
# 添加文本提示
|
||||
content.append({
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
"text": prompt % (len(content), len(content), len(content))
|
||||
})
|
||||
|
||||
# 调用API
|
||||
@ -102,7 +102,7 @@ class QwenAnalyzer:
|
||||
async def analyze_images(self,
|
||||
images: Union[List[str], List[PIL.Image.Image]],
|
||||
prompt: str,
|
||||
batch_size: int = 5) -> List[Dict]:
|
||||
batch_size: int) -> List[Dict]:
|
||||
"""
|
||||
批量分析多张图片
|
||||
Args:
|
||||
@ -118,7 +118,6 @@ class QwenAnalyzer:
|
||||
|
||||
# 加载图片
|
||||
if isinstance(images[0], str):
|
||||
logger.info("正在加载图片...")
|
||||
images = self.load_images(images)
|
||||
|
||||
# 验证图片列表
|
||||
@ -141,9 +140,14 @@ class QwenAnalyzer:
|
||||
|
||||
images = valid_images
|
||||
results = []
|
||||
total_batches = (len(images) + batch_size - 1) // batch_size
|
||||
# 视频帧总数除以批量处理大小,如果有小数则+1
|
||||
batches_needed = len(images) // batch_size
|
||||
if len(images) % batch_size > 0:
|
||||
batches_needed += 1
|
||||
|
||||
logger.debug(f"视频帧总数:{len(images)}, 每批处理 {batch_size} 帧, 需要访问 VLM {batches_needed} 次")
|
||||
|
||||
with tqdm(total=total_batches, desc="分析进度") as pbar:
|
||||
with tqdm(total=batches_needed, desc="分析进度") as pbar:
|
||||
for i in range(0, len(images), batch_size):
|
||||
batch = images[i:i + batch_size]
|
||||
batch_paths = valid_paths[i:i + batch_size] if valid_paths else None
|
||||
@ -151,9 +155,9 @@ class QwenAnalyzer:
|
||||
|
||||
while retry_count < 3:
|
||||
try:
|
||||
# 在每个批次处理前<EFBFBD><EFBFBD>加小延迟
|
||||
if i > 0:
|
||||
await asyncio.sleep(2)
|
||||
# 在每个批次处理前添加小延迟
|
||||
# if i > 0:
|
||||
# await asyncio.sleep(0.5)
|
||||
|
||||
# 确保每个批次的图片都是有效的
|
||||
valid_batch = [img for img in batch if isinstance(img, PIL.Image.Image)]
|
||||
@ -209,7 +213,7 @@ class QwenAnalyzer:
|
||||
for i, result in enumerate(results):
|
||||
response_text = result['response']
|
||||
|
||||
# 如果有图片路径信息,<EFBFBD><EFBFBD><EFBFBD>用它来生成文件名
|
||||
# 如果有图片路径信息,用它来生成文件名
|
||||
if result.get('image_paths'):
|
||||
image_paths = result['image_paths']
|
||||
img_name_start = Path(image_paths[0]).stem.split('_')[-1]
|
||||
|
||||
@ -84,7 +84,7 @@ class VideoProcessor:
|
||||
}
|
||||
|
||||
def extract_frames_by_interval(self, output_dir: str, interval_seconds: float = 5.0,
|
||||
use_hw_accel: bool = True, skip_seconds: float = 0.0) -> List[int]:
|
||||
use_hw_accel: bool = True) -> List[int]:
|
||||
"""
|
||||
按指定时间间隔提取视频帧
|
||||
|
||||
@ -92,7 +92,6 @@ class VideoProcessor:
|
||||
output_dir: 输出目录
|
||||
interval_seconds: 帧提取间隔(秒)
|
||||
use_hw_accel: 是否使用硬件加速
|
||||
skip_seconds: 跳过视频开头的秒数
|
||||
|
||||
Returns:
|
||||
List[int]: 提取的帧号列表
|
||||
@ -101,7 +100,7 @@ class VideoProcessor:
|
||||
os.makedirs(output_dir)
|
||||
|
||||
# 计算起始时间和帧提取点
|
||||
start_time = skip_seconds
|
||||
start_time = 0
|
||||
end_time = self.duration
|
||||
extraction_times = []
|
||||
|
||||
@ -291,7 +290,6 @@ class VideoProcessor:
|
||||
|
||||
def process_video_pipeline(self,
|
||||
output_dir: str,
|
||||
skip_seconds: float = 0.0,
|
||||
interval_seconds: float = 5.0, # 帧提取间隔(秒)
|
||||
use_hw_accel: bool = True) -> None:
|
||||
"""
|
||||
@ -299,7 +297,6 @@ class VideoProcessor:
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录
|
||||
skip_seconds: 跳过视频开头的秒数
|
||||
interval_seconds: 帧提取间隔(秒)
|
||||
use_hw_accel: 是否使用硬件加速
|
||||
"""
|
||||
@ -312,8 +309,7 @@ class VideoProcessor:
|
||||
self.extract_frames_by_interval(
|
||||
output_dir,
|
||||
interval_seconds=interval_seconds,
|
||||
use_hw_accel=use_hw_accel,
|
||||
skip_seconds=skip_seconds
|
||||
use_hw_accel=use_hw_accel
|
||||
)
|
||||
|
||||
logger.info(f"处理完成!视频帧已保存在: {output_dir}")
|
||||
|
||||
@ -4,7 +4,6 @@
|
||||
# gemini
|
||||
# qwenvl
|
||||
vision_llm_provider="qwenvl"
|
||||
vision_analysis_prompt = "你是资深视频内容分析专家,擅长分析视频画面信息,分析下面视频画面内容,只输出客观的画面描述不要给任何总结或评价"
|
||||
|
||||
########## Vision Gemini API Key
|
||||
vision_gemini_api_key = ""
|
||||
@ -181,4 +180,4 @@
|
||||
threshold = 30
|
||||
version = "v2"
|
||||
# 大模型单次处理的关键帧数量
|
||||
vision_batch_size = 5
|
||||
vision_batch_size = 10
|
||||
|
||||
@ -220,10 +220,19 @@ def render_script_buttons(tr, params):
|
||||
st.number_input(
|
||||
tr("Frame Interval (seconds)"),
|
||||
min_value=0,
|
||||
value=st.session_state.get('frame_interval_input', config.frames.get('frame_interval_input', 5)),
|
||||
value=st.session_state.get('frame_interval_input', config.frames.get('frame_interval_input', 3)),
|
||||
help=tr("Frame Interval (seconds) (More keyframes consume more tokens)"),
|
||||
key="frame_interval_input"
|
||||
)
|
||||
|
||||
with input_cols[1]:
|
||||
st.number_input(
|
||||
tr("Batch Size"),
|
||||
min_value=0,
|
||||
value=st.session_state.get('vision_batch_size', config.frames.get('vision_batch_size', 10)),
|
||||
help=tr("Batch Size (More keyframes consume more tokens)"),
|
||||
key="vision_batch_size"
|
||||
)
|
||||
|
||||
# 生成/加载按钮
|
||||
if script_path == "auto":
|
||||
|
||||
@ -115,7 +115,6 @@
|
||||
"Text Generation Model Settings": "文案生成模型设置",
|
||||
"LLM Model Name": "大语言模型名称",
|
||||
"LLM Model API Key": "大语言模型 API 密钥",
|
||||
"Batch Size": "批处理大小",
|
||||
"Text Model Provider": "文案生成模型提供商",
|
||||
"Text API Key": "文案生成 API 密钥",
|
||||
"Text Base URL": "文案生成接口地址",
|
||||
@ -194,6 +193,8 @@
|
||||
"Original Volume": "视频音量",
|
||||
"Auto Generate": "纪录片解说 (画面解说)",
|
||||
"Frame Interval (seconds)": "帧间隔 (秒)",
|
||||
"Frame Interval (seconds) (More keyframes consume more tokens)": "帧间隔 (秒) (更多关键帧消耗更多令牌)"
|
||||
"Frame Interval (seconds) (More keyframes consume more tokens)": "帧间隔 (秒) (更多关键帧消耗更多令牌)",
|
||||
"Batch Size": "批处理大小",
|
||||
"Batch Size (More keyframes consume more tokens)": "批处理大小, 每批处理越少消耗 token 越多"
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -9,7 +9,6 @@ from app.utils import video_processor
|
||||
import streamlit as st
|
||||
from loguru import logger
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
from app.config import config
|
||||
from app.utils.script_generator import ScriptProcessor
|
||||
@ -38,8 +37,9 @@ def generate_script_docu(params):
|
||||
if not params.video_origin_path:
|
||||
st.error("请先选择视频文件")
|
||||
return
|
||||
|
||||
# ===================提取键帧===================
|
||||
"""
|
||||
1. 提取键帧
|
||||
"""
|
||||
update_progress(10, "正在提取关键帧...")
|
||||
|
||||
# 创建临时目录用于存储关键帧
|
||||
@ -95,9 +95,11 @@ def generate_script_docu(params):
|
||||
|
||||
raise Exception(f"关键帧提取失败: {str(e)}")
|
||||
|
||||
# 根据不同的 LLM 提供商处理
|
||||
"""
|
||||
2. 视觉分析
|
||||
"""
|
||||
vision_llm_provider = st.session_state.get('vision_llm_providers').lower()
|
||||
logger.debug(f"Vision LLM 提供商: {vision_llm_provider}")
|
||||
logger.debug(f"VLM 视觉大模型提供商: {vision_llm_provider}")
|
||||
|
||||
try:
|
||||
# ===================初始化视觉分析器===================
|
||||
@ -131,10 +133,32 @@ def generate_script_docu(params):
|
||||
|
||||
# 执行异步分析
|
||||
vision_batch_size = st.session_state.get('vision_batch_size') or config.frames.get("vision_batch_size")
|
||||
vision_analysis_prompt = """
|
||||
我提供了 %s 张视频帧,它们按时间顺序排列,代表一个连续的视频片段。请仔细分析每一帧的内容,并关注帧与帧之间的变化,以理解整个片段的活动。
|
||||
|
||||
首先,请详细描述每一帧的关键视觉信息(包含:主要内容、人物、动作和场景)。
|
||||
然后,基于所有帧的分析,请用**简洁的语言**总结整个视频片段中发生的主要活动或事件流程。
|
||||
|
||||
请务必使用 JSON 格式输出你的结果。JSON 结构应如下:
|
||||
{
|
||||
"frame_observations": [
|
||||
{
|
||||
"frame_number": 1, // 或其他标识帧的方式
|
||||
"observation": "描述每张视频帧中的主要内容、人物、动作和场景。"
|
||||
},
|
||||
// ... 更多帧的观察 ...
|
||||
],
|
||||
"overall_activity_summary": "在这里填写你总结的整个片段的主要活动,保持简洁。"
|
||||
}
|
||||
|
||||
请务必不要遗漏视频帧,我提供了 %s 张视频帧,frame_observations 必须包含 %s 个元素
|
||||
|
||||
请只返回 JSON 字符串,不要包含任何其他解释性文字。
|
||||
"""
|
||||
results = loop.run_until_complete(
|
||||
analyzer.analyze_images(
|
||||
images=keyframe_files,
|
||||
prompt=config.app.get('vision_analysis_prompt'),
|
||||
prompt=vision_analysis_prompt,
|
||||
batch_size=vision_batch_size
|
||||
)
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user