mirror of
https://github.com/linyqh/NarratoAI.git
synced 2025-12-13 12:12:50 +00:00
refactor(webui): 优化视觉分析批次处理逻辑
- 提取 vision_batch_size 到单独变量,提高代码可读性
- 使用 vision_batch_size 替代多次调用 config(frames.get("vision_batch_size")
- 添加调试日志,记录批次数量和每批次的图片数量
This commit is contained in:
parent
593b427061
commit
53b8cded04
1
.github/workflows/dockerImageBuild.yml
vendored
1
.github/workflows/dockerImageBuild.yml
vendored
@ -3,6 +3,7 @@ name: build_docker
|
|||||||
on:
|
on:
|
||||||
release:
|
release:
|
||||||
types: [created] # 表示在创建新的 Release 时触发
|
types: [created] # 表示在创建新的 Release 时触发
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build_docker:
|
build_docker:
|
||||||
|
|||||||
@ -55,7 +55,7 @@ class VisionAnalyzer:
|
|||||||
async def analyze_images(self,
|
async def analyze_images(self,
|
||||||
images: Union[List[str], List[PIL.Image.Image]],
|
images: Union[List[str], List[PIL.Image.Image]],
|
||||||
prompt: str,
|
prompt: str,
|
||||||
batch_size: int = 5) -> List[Dict]:
|
batch_size: int) -> List[Dict]:
|
||||||
"""批量分析多张图片"""
|
"""批量分析多张图片"""
|
||||||
try:
|
try:
|
||||||
# 加载图片
|
# 加载图片
|
||||||
@ -82,6 +82,8 @@ class VisionAnalyzer:
|
|||||||
results = []
|
results = []
|
||||||
total_batches = (len(images) + batch_size - 1) // batch_size
|
total_batches = (len(images) + batch_size - 1) // batch_size
|
||||||
|
|
||||||
|
logger.debug(f"共 {total_batches} 个批次,每批次 {batch_size} 张图片")
|
||||||
|
|
||||||
with tqdm(total=total_batches, desc="分析进度") as pbar:
|
with tqdm(total=total_batches, desc="分析进度") as pbar:
|
||||||
for i in range(0, len(images), batch_size):
|
for i in range(0, len(images), batch_size):
|
||||||
batch = images[i:i + batch_size]
|
batch = images[i:i + batch_size]
|
||||||
|
|||||||
@ -417,11 +417,12 @@ def generate_script(tr, params):
|
|||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
# 执行异步分析
|
# 执行异步分析
|
||||||
|
vision_batch_size = st.session_state.get('vision_batch_size') or config.frames.get("vision_batch_size")
|
||||||
results = loop.run_until_complete(
|
results = loop.run_until_complete(
|
||||||
analyzer.analyze_images(
|
analyzer.analyze_images(
|
||||||
images=keyframe_files,
|
images=keyframe_files,
|
||||||
prompt=config.app.get('vision_analysis_prompt'),
|
prompt=config.app.get('vision_analysis_prompt'),
|
||||||
batch_size=config.frames.get("vision_batch_size", st.session_state.get('vision_batch_size', 5))
|
batch_size=vision_batch_size
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
loop.close()
|
loop.close()
|
||||||
@ -437,8 +438,8 @@ def generate_script(tr, params):
|
|||||||
if 'error' in result:
|
if 'error' in result:
|
||||||
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
|
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
|
||||||
continue
|
continue
|
||||||
|
# 获取当前批次的文件列表
|
||||||
batch_files = get_batch_files(keyframe_files, result, config.frames.get("vision_batch_size", 5))
|
batch_files = get_batch_files(keyframe_files, result, vision_batch_size)
|
||||||
logger.debug(f"批次 {result['batch_index']} 处理完成,共 {len(batch_files)} 张图片")
|
logger.debug(f"批次 {result['batch_index']} 处理完成,共 {len(batch_files)} 张图片")
|
||||||
logger.debug(batch_files)
|
logger.debug(batch_files)
|
||||||
|
|
||||||
@ -477,7 +478,7 @@ def generate_script(tr, params):
|
|||||||
if 'error' in result:
|
if 'error' in result:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
batch_files = get_batch_files(keyframe_files, result, config.frames.get("vision_batch_size", 5))
|
batch_files = get_batch_files(keyframe_files, result, vision_batch_size)
|
||||||
_, _, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files)
|
_, _, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files)
|
||||||
|
|
||||||
frame_content = {
|
frame_content = {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user