feat(documentary): preserve failed batches and add vision concurrency

This commit is contained in:
linyq 2026-04-03 01:54:47 +08:00
parent 8201911b82
commit 4d21c43b89
7 changed files with 205 additions and 13 deletions

View File

@ -1,4 +1,6 @@
import json
import os
import re
from app.utils import utils
from app.services.documentary.frame_analysis_models import FrameBatchResult
@ -78,3 +80,65 @@ JSON 必须包含以下键:
]
)
return f"{legacy_prefix}_{utils.md5(payload)}"
def _strip_code_fence(self, response_text: str) -> str:
cleaned = (response_text or "").strip()
cleaned = re.sub(r"^```[a-zA-Z0-9_-]*\s*", "", cleaned)
cleaned = re.sub(r"\s*```$", "", cleaned)
return cleaned.strip()
def _parse_batch_response(
self,
*,
batch_index: int,
raw_response: str,
frame_paths: list[str],
time_range: str,
) -> FrameBatchResult:
try:
payload = json.loads(self._strip_code_fence(raw_response))
if not isinstance(payload, dict):
raise ValueError("Batch response JSON payload must be an object")
except Exception as exc:
return self._build_failed_batch_result(
batch_index=batch_index,
raw_response=raw_response,
error_message=str(exc),
frame_paths=frame_paths,
time_range=time_range,
)
raw_observations = payload.get("frame_observations")
if not isinstance(raw_observations, list):
raw_observations = []
frame_observations: list[dict] = []
for index, frame_path in enumerate(frame_paths):
entry = raw_observations[index] if index < len(raw_observations) else {}
if isinstance(entry, dict):
observation = str(entry.get("observation", "") or "")
timestamp = str(entry.get("timestamp", "") or "")
else:
observation = str(entry or "")
timestamp = ""
frame_observations.append(
{
"frame_path": frame_path,
"timestamp": timestamp,
"observation": observation,
}
)
summary = payload.get("overall_activity_summary", "")
if not isinstance(summary, str):
summary = str(summary or "")
return FrameBatchResult(
batch_index=batch_index,
status="success",
time_range=time_range,
raw_response=raw_response,
frame_paths=list(frame_paths),
frame_observations=frame_observations,
overall_activity_summary=summary,
)

View File

@ -108,6 +108,7 @@ class VisionModelProvider(BaseLLMProvider):
images: List[Union[str, Path, PIL.Image.Image]],
prompt: str,
batch_size: int = 10,
max_concurrency: int = 1,
**kwargs) -> List[str]:
"""
分析图片并返回结果
@ -116,6 +117,7 @@ class VisionModelProvider(BaseLLMProvider):
images: 图片路径列表或PIL图片对象列表
prompt: 分析提示词
batch_size: 批处理大小
max_concurrency: 最大并发批次数实现支持时生效
**kwargs: 其他参数
Returns:

View File

@ -159,7 +159,8 @@ class VisionAnalyzerAdapter:
async def analyze_images(self,
images: List[Union[str, Path, PIL.Image.Image]],
prompt: str,
batch_size: int = 10) -> List[Dict[str, Any]]:
batch_size: int = 10,
max_concurrency: int = 1) -> List[Dict[str, Any]]:
"""
分析图片 - 兼容原有接口
@ -167,6 +168,7 @@ class VisionAnalyzerAdapter:
images: 图片列表
prompt: 分析提示词
batch_size: 批处理大小
max_concurrency: 最大并发批次数
Returns:
分析结果列表格式与旧实现兼容
@ -177,7 +179,8 @@ class VisionAnalyzerAdapter:
images=images,
prompt=prompt,
provider=self.provider,
batch_size=batch_size
batch_size=batch_size,
max_concurrency=max_concurrency,
)
# 转换为旧格式以保持向后兼容性

View File

@ -4,6 +4,7 @@ OpenAI 兼容提供商实现
使用 OpenAI 官方 SDK 调用 OpenAI 兼容接口支持文本和视觉模型
"""
import asyncio
import io
import base64
import re
@ -96,24 +97,35 @@ class OpenAICompatibleVisionProvider(_OpenAICompatibleBase, VisionModelProvider)
images: List[Union[str, Path, PIL.Image.Image]],
prompt: str,
batch_size: int = 10,
max_concurrency: int = 1,
**kwargs,
) -> List[str]:
logger.info(f"开始使用 OpenAI 兼容接口 ({self.model_name}) 分析 {len(images)} 张图片")
processed_images = self._prepare_images(images)
results: List[str] = []
if not processed_images:
return []
for i in range(0, len(processed_images), batch_size):
batch = processed_images[i : i + batch_size]
logger.info(f"处理第 {i // batch_size + 1} 批,共 {len(batch)} 张图片")
try:
result = await self._analyze_batch(batch, prompt, **kwargs)
results.append(result)
except Exception as exc:
logger.error(f"批次 {i // batch_size + 1} 处理失败: {exc}")
results.append(f"批次处理失败: {exc}")
bounded_concurrency = max(1, int(max_concurrency))
semaphore = asyncio.Semaphore(bounded_concurrency)
batches = [
(index // batch_size, processed_images[index : index + batch_size])
for index in range(0, len(processed_images), batch_size)
]
return results
async def run_batch(batch_index: int, batch: List[PIL.Image.Image]) -> tuple[int, str]:
logger.info(f"处理第 {batch_index + 1} 批,共 {len(batch)} 张图片")
async with semaphore:
try:
result = await self._analyze_batch(batch, prompt, **kwargs)
return batch_index, result
except Exception as exc:
logger.error(f"批次 {batch_index + 1} 处理失败: {exc}")
return batch_index, f"批次处理失败: {exc}"
completed = await asyncio.gather(*(run_batch(index, batch) for index, batch in batches))
completed.sort(key=lambda item: item[0])
return [result for _, result in completed]
async def _analyze_batch(self, batch: List[PIL.Image.Image], prompt: str, **kwargs) -> str:
content = [{"type": "text", "text": prompt}]

View File

@ -1,10 +1,12 @@
"""OpenAI 兼容 provider 的最小回归测试。"""
import asyncio
import unittest
from app.config import config
from app.services.llm.base import TextModelProvider
from app.services.llm.manager import LLMServiceManager
from app.services.llm.openai_compatible_provider import OpenAICompatibleVisionProvider
from app.services.llm.providers import register_all_providers
@ -63,5 +65,54 @@ class OpenAICompatManagerTests(unittest.TestCase):
self.assertEqual("https://new.example/v1", provider.base_url)
class OpenAICompatVisionConcurrencyTests(unittest.IsolatedAsyncioTestCase):
async def test_analyze_images_keeps_batch_order_when_running_concurrently(self):
provider = OpenAICompatibleVisionProvider(api_key="k", model_name="m")
provider._prepare_images = lambda images: list(images)
async def fake_analyze_batch(batch, prompt, **kwargs):
delays = {"a": 0.03, "c": 0.01, "e": 0.0}
await asyncio.sleep(delays[batch[0]])
return f"batch-{batch[0]}"
provider._analyze_batch = fake_analyze_batch
result = await provider.analyze_images(
images=["a", "b", "c", "d", "e", "f"],
prompt="prompt",
batch_size=2,
max_concurrency=2,
)
self.assertEqual(["batch-a", "batch-c", "batch-e"], result)
async def test_analyze_images_respects_max_concurrency_limit(self):
provider = OpenAICompatibleVisionProvider(api_key="k", model_name="m")
provider._prepare_images = lambda images: list(images)
in_flight = 0
max_in_flight = 0
async def fake_analyze_batch(batch, prompt, **kwargs):
nonlocal in_flight, max_in_flight
in_flight += 1
max_in_flight = max(max_in_flight, in_flight)
await asyncio.sleep(0.02)
in_flight -= 1
return f"batch-{batch[0]}"
provider._analyze_batch = fake_analyze_batch
result = await provider.analyze_images(
images=["a", "b", "c", "d", "e", "f"],
prompt="prompt",
batch_size=1,
max_concurrency=2,
)
self.assertEqual(6, len(result))
self.assertEqual(2, max_in_flight)
if __name__ == "__main__":
unittest.main()

View File

@ -152,3 +152,6 @@
# 大模型单次处理的关键帧数量
vision_batch_size = 10
# 视觉批处理最大并发批次数OpenAI 兼容 provider
vision_max_concurrency = 2

View File

@ -66,6 +66,63 @@ class DocumentaryFrameAnalysisServiceTests(unittest.TestCase):
self.assertFalse(hasattr(batch, "observations"))
self.assertFalse(hasattr(batch, "summary"))
def test_parse_batch_returns_failed_result_when_json_is_invalid(self):
service = DocumentaryFrameAnalysisService()
batch = service._parse_batch_response(
batch_index=0,
raw_response="plain text",
frame_paths=["/tmp/keyframe_000000_000000000.jpg"],
time_range="00:00:00,000-00:00:03,000",
)
self.assertEqual("failed", batch.status)
self.assertEqual("plain text", batch.raw_response)
self.assertEqual(["/tmp/keyframe_000000_000000000.jpg"], batch.frame_paths)
self.assertEqual([], batch.frame_observations)
self.assertEqual("", batch.overall_activity_summary)
def test_parse_batch_parses_code_fenced_json_into_structured_result(self):
service = DocumentaryFrameAnalysisService()
raw_response = """```json
{
"frame_observations": [
{"observation": "第一帧画面"},
{"observation": "第二帧画面"}
],
"overall_activity_summary": "人物从房间走到街道"
}
```"""
batch = service._parse_batch_response(
batch_index=1,
raw_response=raw_response,
frame_paths=[
"/tmp/keyframe_000000_000000000.jpg",
"/tmp/keyframe_000075_000003000.jpg",
],
time_range="00:00:00,000-00:00:06,000",
)
self.assertEqual("success", batch.status)
self.assertEqual(
[
{
"frame_path": "/tmp/keyframe_000000_000000000.jpg",
"timestamp": "",
"observation": "第一帧画面",
},
{
"frame_path": "/tmp/keyframe_000075_000003000.jpg",
"timestamp": "",
"observation": "第二帧画面",
},
],
batch.frame_observations,
)
self.assertEqual("人物从房间走到街道", batch.overall_activity_summary)
self.assertEqual("", batch.fallback_summary)
def test_cache_key_changes_when_interval_changes(self):
service = DocumentaryFrameAnalysisService()