mirror of
https://github.com/linyqh/NarratoAI.git
synced 2025-12-12 03:02:48 +00:00
perf(vision): 优化视觉分析流程和批量处理逻辑
- 移除了 vision_analysis_prompt 配置项 - 优化了 Gemini 和 QwenVL 分析器的批量处理逻辑 - 更新了文档生成脚本和 UI 组件以适应新的分析流程 - 调整了视频帧提取相关函数,移除了不必要的 skip_seconds 参数 - 更新了中文翻译文件,添加了新的批处理大小相关提示
This commit is contained in:
parent
82823297f2
commit
2dc83bc18e
@ -61,7 +61,6 @@ class VisionAnalyzer:
|
|||||||
try:
|
try:
|
||||||
# 加载图片
|
# 加载图片
|
||||||
if isinstance(images[0], str):
|
if isinstance(images[0], str):
|
||||||
logger.info("正在加载图片...")
|
|
||||||
images = self.load_images(images)
|
images = self.load_images(images)
|
||||||
|
|
||||||
# 验证图片列表
|
# 验证图片列表
|
||||||
@ -81,11 +80,14 @@ class VisionAnalyzer:
|
|||||||
|
|
||||||
images = valid_images
|
images = valid_images
|
||||||
results = []
|
results = []
|
||||||
total_batches = (len(images) + batch_size - 1) // batch_size
|
# 视频帧总数除以批量处理大小,如果有小数则+1
|
||||||
|
batches_needed = len(images) // batch_size
|
||||||
|
if len(images) % batch_size > 0:
|
||||||
|
batches_needed += 1
|
||||||
|
|
||||||
|
logger.debug(f"视频帧总数:{len(images)}, 每批处理 {batch_size} 帧, 需要访问 VLM {batches_needed} 次")
|
||||||
|
|
||||||
logger.debug(f"共 {total_batches} 个批次,每批次 {batch_size} 张图片")
|
with tqdm(total=batches_needed, desc="分析进度") as pbar:
|
||||||
|
|
||||||
with tqdm(total=total_batches, desc="分析进度") as pbar:
|
|
||||||
for i in range(0, len(images), batch_size):
|
for i in range(0, len(images), batch_size):
|
||||||
batch = images[i:i + batch_size]
|
batch = images[i:i + batch_size]
|
||||||
retry_count = 0
|
retry_count = 0
|
||||||
@ -93,8 +95,8 @@ class VisionAnalyzer:
|
|||||||
while retry_count < 3:
|
while retry_count < 3:
|
||||||
try:
|
try:
|
||||||
# 在每个批次处理前添加小延迟
|
# 在每个批次处理前添加小延迟
|
||||||
if i > 0:
|
# if i > 0:
|
||||||
await asyncio.sleep(2)
|
# await asyncio.sleep(2)
|
||||||
|
|
||||||
# 确保每个批次的图片都是有效的
|
# 确保每个批次的图片都是有效的
|
||||||
valid_batch = [img for img in batch if isinstance(img, PIL.Image.Image)]
|
valid_batch = [img for img in batch if isinstance(img, PIL.Image.Image)]
|
||||||
|
|||||||
@ -80,7 +80,7 @@ class QwenAnalyzer:
|
|||||||
# 添加文本提示
|
# 添加文本提示
|
||||||
content.append({
|
content.append({
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": prompt
|
"text": prompt % (len(content), len(content), len(content))
|
||||||
})
|
})
|
||||||
|
|
||||||
# 调用API
|
# 调用API
|
||||||
@ -102,7 +102,7 @@ class QwenAnalyzer:
|
|||||||
async def analyze_images(self,
|
async def analyze_images(self,
|
||||||
images: Union[List[str], List[PIL.Image.Image]],
|
images: Union[List[str], List[PIL.Image.Image]],
|
||||||
prompt: str,
|
prompt: str,
|
||||||
batch_size: int = 5) -> List[Dict]:
|
batch_size: int) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
批量分析多张图片
|
批量分析多张图片
|
||||||
Args:
|
Args:
|
||||||
@ -118,7 +118,6 @@ class QwenAnalyzer:
|
|||||||
|
|
||||||
# 加载图片
|
# 加载图片
|
||||||
if isinstance(images[0], str):
|
if isinstance(images[0], str):
|
||||||
logger.info("正在加载图片...")
|
|
||||||
images = self.load_images(images)
|
images = self.load_images(images)
|
||||||
|
|
||||||
# 验证图片列表
|
# 验证图片列表
|
||||||
@ -141,9 +140,14 @@ class QwenAnalyzer:
|
|||||||
|
|
||||||
images = valid_images
|
images = valid_images
|
||||||
results = []
|
results = []
|
||||||
total_batches = (len(images) + batch_size - 1) // batch_size
|
# 视频帧总数除以批量处理大小,如果有小数则+1
|
||||||
|
batches_needed = len(images) // batch_size
|
||||||
|
if len(images) % batch_size > 0:
|
||||||
|
batches_needed += 1
|
||||||
|
|
||||||
|
logger.debug(f"视频帧总数:{len(images)}, 每批处理 {batch_size} 帧, 需要访问 VLM {batches_needed} 次")
|
||||||
|
|
||||||
with tqdm(total=total_batches, desc="分析进度") as pbar:
|
with tqdm(total=batches_needed, desc="分析进度") as pbar:
|
||||||
for i in range(0, len(images), batch_size):
|
for i in range(0, len(images), batch_size):
|
||||||
batch = images[i:i + batch_size]
|
batch = images[i:i + batch_size]
|
||||||
batch_paths = valid_paths[i:i + batch_size] if valid_paths else None
|
batch_paths = valid_paths[i:i + batch_size] if valid_paths else None
|
||||||
@ -151,9 +155,9 @@ class QwenAnalyzer:
|
|||||||
|
|
||||||
while retry_count < 3:
|
while retry_count < 3:
|
||||||
try:
|
try:
|
||||||
# 在每个批次处理前<EFBFBD><EFBFBD>加小延迟
|
# 在每个批次处理前添加小延迟
|
||||||
if i > 0:
|
# if i > 0:
|
||||||
await asyncio.sleep(2)
|
# await asyncio.sleep(0.5)
|
||||||
|
|
||||||
# 确保每个批次的图片都是有效的
|
# 确保每个批次的图片都是有效的
|
||||||
valid_batch = [img for img in batch if isinstance(img, PIL.Image.Image)]
|
valid_batch = [img for img in batch if isinstance(img, PIL.Image.Image)]
|
||||||
@ -209,7 +213,7 @@ class QwenAnalyzer:
|
|||||||
for i, result in enumerate(results):
|
for i, result in enumerate(results):
|
||||||
response_text = result['response']
|
response_text = result['response']
|
||||||
|
|
||||||
# 如果有图片路径信息,<EFBFBD><EFBFBD><EFBFBD>用它来生成文件名
|
# 如果有图片路径信息,用它来生成文件名
|
||||||
if result.get('image_paths'):
|
if result.get('image_paths'):
|
||||||
image_paths = result['image_paths']
|
image_paths = result['image_paths']
|
||||||
img_name_start = Path(image_paths[0]).stem.split('_')[-1]
|
img_name_start = Path(image_paths[0]).stem.split('_')[-1]
|
||||||
|
|||||||
@ -84,7 +84,7 @@ class VideoProcessor:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def extract_frames_by_interval(self, output_dir: str, interval_seconds: float = 5.0,
|
def extract_frames_by_interval(self, output_dir: str, interval_seconds: float = 5.0,
|
||||||
use_hw_accel: bool = True, skip_seconds: float = 0.0) -> List[int]:
|
use_hw_accel: bool = True) -> List[int]:
|
||||||
"""
|
"""
|
||||||
按指定时间间隔提取视频帧
|
按指定时间间隔提取视频帧
|
||||||
|
|
||||||
@ -92,7 +92,6 @@ class VideoProcessor:
|
|||||||
output_dir: 输出目录
|
output_dir: 输出目录
|
||||||
interval_seconds: 帧提取间隔(秒)
|
interval_seconds: 帧提取间隔(秒)
|
||||||
use_hw_accel: 是否使用硬件加速
|
use_hw_accel: 是否使用硬件加速
|
||||||
skip_seconds: 跳过视频开头的秒数
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[int]: 提取的帧号列表
|
List[int]: 提取的帧号列表
|
||||||
@ -101,7 +100,7 @@ class VideoProcessor:
|
|||||||
os.makedirs(output_dir)
|
os.makedirs(output_dir)
|
||||||
|
|
||||||
# 计算起始时间和帧提取点
|
# 计算起始时间和帧提取点
|
||||||
start_time = skip_seconds
|
start_time = 0
|
||||||
end_time = self.duration
|
end_time = self.duration
|
||||||
extraction_times = []
|
extraction_times = []
|
||||||
|
|
||||||
@ -291,7 +290,6 @@ class VideoProcessor:
|
|||||||
|
|
||||||
def process_video_pipeline(self,
|
def process_video_pipeline(self,
|
||||||
output_dir: str,
|
output_dir: str,
|
||||||
skip_seconds: float = 0.0,
|
|
||||||
interval_seconds: float = 5.0, # 帧提取间隔(秒)
|
interval_seconds: float = 5.0, # 帧提取间隔(秒)
|
||||||
use_hw_accel: bool = True) -> None:
|
use_hw_accel: bool = True) -> None:
|
||||||
"""
|
"""
|
||||||
@ -299,7 +297,6 @@ class VideoProcessor:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
output_dir: 输出目录
|
output_dir: 输出目录
|
||||||
skip_seconds: 跳过视频开头的秒数
|
|
||||||
interval_seconds: 帧提取间隔(秒)
|
interval_seconds: 帧提取间隔(秒)
|
||||||
use_hw_accel: 是否使用硬件加速
|
use_hw_accel: 是否使用硬件加速
|
||||||
"""
|
"""
|
||||||
@ -312,8 +309,7 @@ class VideoProcessor:
|
|||||||
self.extract_frames_by_interval(
|
self.extract_frames_by_interval(
|
||||||
output_dir,
|
output_dir,
|
||||||
interval_seconds=interval_seconds,
|
interval_seconds=interval_seconds,
|
||||||
use_hw_accel=use_hw_accel,
|
use_hw_accel=use_hw_accel
|
||||||
skip_seconds=skip_seconds
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"处理完成!视频帧已保存在: {output_dir}")
|
logger.info(f"处理完成!视频帧已保存在: {output_dir}")
|
||||||
|
|||||||
@ -4,7 +4,6 @@
|
|||||||
# gemini
|
# gemini
|
||||||
# qwenvl
|
# qwenvl
|
||||||
vision_llm_provider="qwenvl"
|
vision_llm_provider="qwenvl"
|
||||||
vision_analysis_prompt = "你是资深视频内容分析专家,擅长分析视频画面信息,分析下面视频画面内容,只输出客观的画面描述不要给任何总结或评价"
|
|
||||||
|
|
||||||
########## Vision Gemini API Key
|
########## Vision Gemini API Key
|
||||||
vision_gemini_api_key = ""
|
vision_gemini_api_key = ""
|
||||||
@ -181,4 +180,4 @@
|
|||||||
threshold = 30
|
threshold = 30
|
||||||
version = "v2"
|
version = "v2"
|
||||||
# 大模型单次处理的关键帧数量
|
# 大模型单次处理的关键帧数量
|
||||||
vision_batch_size = 5
|
vision_batch_size = 10
|
||||||
|
|||||||
@ -220,10 +220,19 @@ def render_script_buttons(tr, params):
|
|||||||
st.number_input(
|
st.number_input(
|
||||||
tr("Frame Interval (seconds)"),
|
tr("Frame Interval (seconds)"),
|
||||||
min_value=0,
|
min_value=0,
|
||||||
value=st.session_state.get('frame_interval_input', config.frames.get('frame_interval_input', 5)),
|
value=st.session_state.get('frame_interval_input', config.frames.get('frame_interval_input', 3)),
|
||||||
help=tr("Frame Interval (seconds) (More keyframes consume more tokens)"),
|
help=tr("Frame Interval (seconds) (More keyframes consume more tokens)"),
|
||||||
key="frame_interval_input"
|
key="frame_interval_input"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with input_cols[1]:
|
||||||
|
st.number_input(
|
||||||
|
tr("Batch Size"),
|
||||||
|
min_value=0,
|
||||||
|
value=st.session_state.get('vision_batch_size', config.frames.get('vision_batch_size', 10)),
|
||||||
|
help=tr("Batch Size (More keyframes consume more tokens)"),
|
||||||
|
key="vision_batch_size"
|
||||||
|
)
|
||||||
|
|
||||||
# 生成/加载按钮
|
# 生成/加载按钮
|
||||||
if script_path == "auto":
|
if script_path == "auto":
|
||||||
|
|||||||
@ -115,7 +115,6 @@
|
|||||||
"Text Generation Model Settings": "文案生成模型设置",
|
"Text Generation Model Settings": "文案生成模型设置",
|
||||||
"LLM Model Name": "大语言模型名称",
|
"LLM Model Name": "大语言模型名称",
|
||||||
"LLM Model API Key": "大语言模型 API 密钥",
|
"LLM Model API Key": "大语言模型 API 密钥",
|
||||||
"Batch Size": "批处理大小",
|
|
||||||
"Text Model Provider": "文案生成模型提供商",
|
"Text Model Provider": "文案生成模型提供商",
|
||||||
"Text API Key": "文案生成 API 密钥",
|
"Text API Key": "文案生成 API 密钥",
|
||||||
"Text Base URL": "文案生成接口地址",
|
"Text Base URL": "文案生成接口地址",
|
||||||
@ -194,6 +193,8 @@
|
|||||||
"Original Volume": "视频音量",
|
"Original Volume": "视频音量",
|
||||||
"Auto Generate": "纪录片解说 (画面解说)",
|
"Auto Generate": "纪录片解说 (画面解说)",
|
||||||
"Frame Interval (seconds)": "帧间隔 (秒)",
|
"Frame Interval (seconds)": "帧间隔 (秒)",
|
||||||
"Frame Interval (seconds) (More keyframes consume more tokens)": "帧间隔 (秒) (更多关键帧消耗更多令牌)"
|
"Frame Interval (seconds) (More keyframes consume more tokens)": "帧间隔 (秒) (更多关键帧消耗更多令牌)",
|
||||||
|
"Batch Size": "批处理大小",
|
||||||
|
"Batch Size (More keyframes consume more tokens)": "批处理大小, 每批处理越少消耗 token 越多"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -9,7 +9,6 @@ from app.utils import video_processor
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from requests.adapters import HTTPAdapter
|
from requests.adapters import HTTPAdapter
|
||||||
from urllib3.util.retry import Retry
|
|
||||||
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.utils.script_generator import ScriptProcessor
|
from app.utils.script_generator import ScriptProcessor
|
||||||
@ -38,8 +37,9 @@ def generate_script_docu(params):
|
|||||||
if not params.video_origin_path:
|
if not params.video_origin_path:
|
||||||
st.error("请先选择视频文件")
|
st.error("请先选择视频文件")
|
||||||
return
|
return
|
||||||
|
"""
|
||||||
# ===================提取键帧===================
|
1. 提取键帧
|
||||||
|
"""
|
||||||
update_progress(10, "正在提取关键帧...")
|
update_progress(10, "正在提取关键帧...")
|
||||||
|
|
||||||
# 创建临时目录用于存储关键帧
|
# 创建临时目录用于存储关键帧
|
||||||
@ -95,9 +95,11 @@ def generate_script_docu(params):
|
|||||||
|
|
||||||
raise Exception(f"关键帧提取失败: {str(e)}")
|
raise Exception(f"关键帧提取失败: {str(e)}")
|
||||||
|
|
||||||
# 根据不同的 LLM 提供商处理
|
"""
|
||||||
|
2. 视觉分析
|
||||||
|
"""
|
||||||
vision_llm_provider = st.session_state.get('vision_llm_providers').lower()
|
vision_llm_provider = st.session_state.get('vision_llm_providers').lower()
|
||||||
logger.debug(f"Vision LLM 提供商: {vision_llm_provider}")
|
logger.debug(f"VLM 视觉大模型提供商: {vision_llm_provider}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# ===================初始化视觉分析器===================
|
# ===================初始化视觉分析器===================
|
||||||
@ -131,10 +133,32 @@ def generate_script_docu(params):
|
|||||||
|
|
||||||
# 执行异步分析
|
# 执行异步分析
|
||||||
vision_batch_size = st.session_state.get('vision_batch_size') or config.frames.get("vision_batch_size")
|
vision_batch_size = st.session_state.get('vision_batch_size') or config.frames.get("vision_batch_size")
|
||||||
|
vision_analysis_prompt = """
|
||||||
|
我提供了 %s 张视频帧,它们按时间顺序排列,代表一个连续的视频片段。请仔细分析每一帧的内容,并关注帧与帧之间的变化,以理解整个片段的活动。
|
||||||
|
|
||||||
|
首先,请详细描述每一帧的关键视觉信息(包含:主要内容、人物、动作和场景)。
|
||||||
|
然后,基于所有帧的分析,请用**简洁的语言**总结整个视频片段中发生的主要活动或事件流程。
|
||||||
|
|
||||||
|
请务必使用 JSON 格式输出你的结果。JSON 结构应如下:
|
||||||
|
{
|
||||||
|
"frame_observations": [
|
||||||
|
{
|
||||||
|
"frame_number": 1, // 或其他标识帧的方式
|
||||||
|
"observation": "描述每张视频帧中的主要内容、人物、动作和场景。"
|
||||||
|
},
|
||||||
|
// ... 更多帧的观察 ...
|
||||||
|
],
|
||||||
|
"overall_activity_summary": "在这里填写你总结的整个片段的主要活动,保持简洁。"
|
||||||
|
}
|
||||||
|
|
||||||
|
请务必不要遗漏视频帧,我提供了 %s 张视频帧,frame_observations 必须包含 %s 个元素
|
||||||
|
|
||||||
|
请只返回 JSON 字符串,不要包含任何其他解释性文字。
|
||||||
|
"""
|
||||||
results = loop.run_until_complete(
|
results = loop.run_until_complete(
|
||||||
analyzer.analyze_images(
|
analyzer.analyze_images(
|
||||||
images=keyframe_files,
|
images=keyframe_files,
|
||||||
prompt=config.app.get('vision_analysis_prompt'),
|
prompt=vision_analysis_prompt,
|
||||||
batch_size=vision_batch_size
|
batch_size=vision_batch_size
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user