mirror of
https://github.com/linyqh/NarratoAI.git
synced 2026-05-01 14:18:19 +00:00
feat(documentary): preserve failed batches and add vision concurrency
This commit is contained in:
parent
8201911b82
commit
4d21c43b89
@ -1,4 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
from app.utils import utils
|
||||
from app.services.documentary.frame_analysis_models import FrameBatchResult
|
||||
@ -78,3 +80,65 @@ JSON 必须包含以下键:
|
||||
]
|
||||
)
|
||||
return f"{legacy_prefix}_{utils.md5(payload)}"
|
||||
|
||||
def _strip_code_fence(self, response_text: str) -> str:
|
||||
cleaned = (response_text or "").strip()
|
||||
cleaned = re.sub(r"^```[a-zA-Z0-9_-]*\s*", "", cleaned)
|
||||
cleaned = re.sub(r"\s*```$", "", cleaned)
|
||||
return cleaned.strip()
|
||||
|
||||
def _parse_batch_response(
|
||||
self,
|
||||
*,
|
||||
batch_index: int,
|
||||
raw_response: str,
|
||||
frame_paths: list[str],
|
||||
time_range: str,
|
||||
) -> FrameBatchResult:
|
||||
try:
|
||||
payload = json.loads(self._strip_code_fence(raw_response))
|
||||
if not isinstance(payload, dict):
|
||||
raise ValueError("Batch response JSON payload must be an object")
|
||||
except Exception as exc:
|
||||
return self._build_failed_batch_result(
|
||||
batch_index=batch_index,
|
||||
raw_response=raw_response,
|
||||
error_message=str(exc),
|
||||
frame_paths=frame_paths,
|
||||
time_range=time_range,
|
||||
)
|
||||
|
||||
raw_observations = payload.get("frame_observations")
|
||||
if not isinstance(raw_observations, list):
|
||||
raw_observations = []
|
||||
|
||||
frame_observations: list[dict] = []
|
||||
for index, frame_path in enumerate(frame_paths):
|
||||
entry = raw_observations[index] if index < len(raw_observations) else {}
|
||||
if isinstance(entry, dict):
|
||||
observation = str(entry.get("observation", "") or "")
|
||||
timestamp = str(entry.get("timestamp", "") or "")
|
||||
else:
|
||||
observation = str(entry or "")
|
||||
timestamp = ""
|
||||
frame_observations.append(
|
||||
{
|
||||
"frame_path": frame_path,
|
||||
"timestamp": timestamp,
|
||||
"observation": observation,
|
||||
}
|
||||
)
|
||||
|
||||
summary = payload.get("overall_activity_summary", "")
|
||||
if not isinstance(summary, str):
|
||||
summary = str(summary or "")
|
||||
|
||||
return FrameBatchResult(
|
||||
batch_index=batch_index,
|
||||
status="success",
|
||||
time_range=time_range,
|
||||
raw_response=raw_response,
|
||||
frame_paths=list(frame_paths),
|
||||
frame_observations=frame_observations,
|
||||
overall_activity_summary=summary,
|
||||
)
|
||||
|
||||
@ -108,6 +108,7 @@ class VisionModelProvider(BaseLLMProvider):
|
||||
images: List[Union[str, Path, PIL.Image.Image]],
|
||||
prompt: str,
|
||||
batch_size: int = 10,
|
||||
max_concurrency: int = 1,
|
||||
**kwargs) -> List[str]:
|
||||
"""
|
||||
分析图片并返回结果
|
||||
@ -116,6 +117,7 @@ class VisionModelProvider(BaseLLMProvider):
|
||||
images: 图片路径列表或PIL图片对象列表
|
||||
prompt: 分析提示词
|
||||
batch_size: 批处理大小
|
||||
max_concurrency: 最大并发批次数(实现支持时生效)
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
|
||||
@ -159,7 +159,8 @@ class VisionAnalyzerAdapter:
|
||||
async def analyze_images(self,
|
||||
images: List[Union[str, Path, PIL.Image.Image]],
|
||||
prompt: str,
|
||||
batch_size: int = 10) -> List[Dict[str, Any]]:
|
||||
batch_size: int = 10,
|
||||
max_concurrency: int = 1) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
分析图片 - 兼容原有接口
|
||||
|
||||
@ -167,6 +168,7 @@ class VisionAnalyzerAdapter:
|
||||
images: 图片列表
|
||||
prompt: 分析提示词
|
||||
batch_size: 批处理大小
|
||||
max_concurrency: 最大并发批次数
|
||||
|
||||
Returns:
|
||||
分析结果列表,格式与旧实现兼容
|
||||
@ -177,7 +179,8 @@ class VisionAnalyzerAdapter:
|
||||
images=images,
|
||||
prompt=prompt,
|
||||
provider=self.provider,
|
||||
batch_size=batch_size
|
||||
batch_size=batch_size,
|
||||
max_concurrency=max_concurrency,
|
||||
)
|
||||
|
||||
# 转换为旧格式以保持向后兼容性
|
||||
|
||||
@ -4,6 +4,7 @@ OpenAI 兼容提供商实现
|
||||
使用 OpenAI 官方 SDK 调用 OpenAI 兼容接口,支持文本和视觉模型。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import base64
|
||||
import re
|
||||
@ -96,24 +97,35 @@ class OpenAICompatibleVisionProvider(_OpenAICompatibleBase, VisionModelProvider)
|
||||
images: List[Union[str, Path, PIL.Image.Image]],
|
||||
prompt: str,
|
||||
batch_size: int = 10,
|
||||
max_concurrency: int = 1,
|
||||
**kwargs,
|
||||
) -> List[str]:
|
||||
logger.info(f"开始使用 OpenAI 兼容接口 ({self.model_name}) 分析 {len(images)} 张图片")
|
||||
|
||||
processed_images = self._prepare_images(images)
|
||||
results: List[str] = []
|
||||
if not processed_images:
|
||||
return []
|
||||
|
||||
for i in range(0, len(processed_images), batch_size):
|
||||
batch = processed_images[i : i + batch_size]
|
||||
logger.info(f"处理第 {i // batch_size + 1} 批,共 {len(batch)} 张图片")
|
||||
try:
|
||||
result = await self._analyze_batch(batch, prompt, **kwargs)
|
||||
results.append(result)
|
||||
except Exception as exc:
|
||||
logger.error(f"批次 {i // batch_size + 1} 处理失败: {exc}")
|
||||
results.append(f"批次处理失败: {exc}")
|
||||
bounded_concurrency = max(1, int(max_concurrency))
|
||||
semaphore = asyncio.Semaphore(bounded_concurrency)
|
||||
batches = [
|
||||
(index // batch_size, processed_images[index : index + batch_size])
|
||||
for index in range(0, len(processed_images), batch_size)
|
||||
]
|
||||
|
||||
return results
|
||||
async def run_batch(batch_index: int, batch: List[PIL.Image.Image]) -> tuple[int, str]:
|
||||
logger.info(f"处理第 {batch_index + 1} 批,共 {len(batch)} 张图片")
|
||||
async with semaphore:
|
||||
try:
|
||||
result = await self._analyze_batch(batch, prompt, **kwargs)
|
||||
return batch_index, result
|
||||
except Exception as exc:
|
||||
logger.error(f"批次 {batch_index + 1} 处理失败: {exc}")
|
||||
return batch_index, f"批次处理失败: {exc}"
|
||||
|
||||
completed = await asyncio.gather(*(run_batch(index, batch) for index, batch in batches))
|
||||
completed.sort(key=lambda item: item[0])
|
||||
return [result for _, result in completed]
|
||||
|
||||
async def _analyze_batch(self, batch: List[PIL.Image.Image], prompt: str, **kwargs) -> str:
|
||||
content = [{"type": "text", "text": prompt}]
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
"""OpenAI 兼容 provider 的最小回归测试。"""
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
|
||||
from app.config import config
|
||||
from app.services.llm.base import TextModelProvider
|
||||
from app.services.llm.manager import LLMServiceManager
|
||||
from app.services.llm.openai_compatible_provider import OpenAICompatibleVisionProvider
|
||||
from app.services.llm.providers import register_all_providers
|
||||
|
||||
|
||||
@ -63,5 +65,54 @@ class OpenAICompatManagerTests(unittest.TestCase):
|
||||
self.assertEqual("https://new.example/v1", provider.base_url)
|
||||
|
||||
|
||||
class OpenAICompatVisionConcurrencyTests(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_analyze_images_keeps_batch_order_when_running_concurrently(self):
|
||||
provider = OpenAICompatibleVisionProvider(api_key="k", model_name="m")
|
||||
provider._prepare_images = lambda images: list(images)
|
||||
|
||||
async def fake_analyze_batch(batch, prompt, **kwargs):
|
||||
delays = {"a": 0.03, "c": 0.01, "e": 0.0}
|
||||
await asyncio.sleep(delays[batch[0]])
|
||||
return f"batch-{batch[0]}"
|
||||
|
||||
provider._analyze_batch = fake_analyze_batch
|
||||
|
||||
result = await provider.analyze_images(
|
||||
images=["a", "b", "c", "d", "e", "f"],
|
||||
prompt="prompt",
|
||||
batch_size=2,
|
||||
max_concurrency=2,
|
||||
)
|
||||
|
||||
self.assertEqual(["batch-a", "batch-c", "batch-e"], result)
|
||||
|
||||
async def test_analyze_images_respects_max_concurrency_limit(self):
|
||||
provider = OpenAICompatibleVisionProvider(api_key="k", model_name="m")
|
||||
provider._prepare_images = lambda images: list(images)
|
||||
|
||||
in_flight = 0
|
||||
max_in_flight = 0
|
||||
|
||||
async def fake_analyze_batch(batch, prompt, **kwargs):
|
||||
nonlocal in_flight, max_in_flight
|
||||
in_flight += 1
|
||||
max_in_flight = max(max_in_flight, in_flight)
|
||||
await asyncio.sleep(0.02)
|
||||
in_flight -= 1
|
||||
return f"batch-{batch[0]}"
|
||||
|
||||
provider._analyze_batch = fake_analyze_batch
|
||||
|
||||
result = await provider.analyze_images(
|
||||
images=["a", "b", "c", "d", "e", "f"],
|
||||
prompt="prompt",
|
||||
batch_size=1,
|
||||
max_concurrency=2,
|
||||
)
|
||||
|
||||
self.assertEqual(6, len(result))
|
||||
self.assertEqual(2, max_in_flight)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -152,3 +152,6 @@
|
||||
|
||||
# 大模型单次处理的关键帧数量
|
||||
vision_batch_size = 10
|
||||
|
||||
# 视觉批处理最大并发批次数(OpenAI 兼容 provider)
|
||||
vision_max_concurrency = 2
|
||||
|
||||
@ -66,6 +66,63 @@ class DocumentaryFrameAnalysisServiceTests(unittest.TestCase):
|
||||
self.assertFalse(hasattr(batch, "observations"))
|
||||
self.assertFalse(hasattr(batch, "summary"))
|
||||
|
||||
def test_parse_batch_returns_failed_result_when_json_is_invalid(self):
|
||||
service = DocumentaryFrameAnalysisService()
|
||||
|
||||
batch = service._parse_batch_response(
|
||||
batch_index=0,
|
||||
raw_response="plain text",
|
||||
frame_paths=["/tmp/keyframe_000000_000000000.jpg"],
|
||||
time_range="00:00:00,000-00:00:03,000",
|
||||
)
|
||||
|
||||
self.assertEqual("failed", batch.status)
|
||||
self.assertEqual("plain text", batch.raw_response)
|
||||
self.assertEqual(["/tmp/keyframe_000000_000000000.jpg"], batch.frame_paths)
|
||||
self.assertEqual([], batch.frame_observations)
|
||||
self.assertEqual("", batch.overall_activity_summary)
|
||||
|
||||
def test_parse_batch_parses_code_fenced_json_into_structured_result(self):
|
||||
service = DocumentaryFrameAnalysisService()
|
||||
raw_response = """```json
|
||||
{
|
||||
"frame_observations": [
|
||||
{"observation": "第一帧画面"},
|
||||
{"observation": "第二帧画面"}
|
||||
],
|
||||
"overall_activity_summary": "人物从房间走到街道"
|
||||
}
|
||||
```"""
|
||||
|
||||
batch = service._parse_batch_response(
|
||||
batch_index=1,
|
||||
raw_response=raw_response,
|
||||
frame_paths=[
|
||||
"/tmp/keyframe_000000_000000000.jpg",
|
||||
"/tmp/keyframe_000075_000003000.jpg",
|
||||
],
|
||||
time_range="00:00:00,000-00:00:06,000",
|
||||
)
|
||||
|
||||
self.assertEqual("success", batch.status)
|
||||
self.assertEqual(
|
||||
[
|
||||
{
|
||||
"frame_path": "/tmp/keyframe_000000_000000000.jpg",
|
||||
"timestamp": "",
|
||||
"observation": "第一帧画面",
|
||||
},
|
||||
{
|
||||
"frame_path": "/tmp/keyframe_000075_000003000.jpg",
|
||||
"timestamp": "",
|
||||
"observation": "第二帧画面",
|
||||
},
|
||||
],
|
||||
batch.frame_observations,
|
||||
)
|
||||
self.assertEqual("人物从房间走到街道", batch.overall_activity_summary)
|
||||
self.assertEqual("", batch.fallback_summary)
|
||||
|
||||
def test_cache_key_changes_when_interval_changes(self):
|
||||
service = DocumentaryFrameAnalysisService()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user