diff --git a/app/services/documentary/frame_analysis_service.py b/app/services/documentary/frame_analysis_service.py index 24521a6..05dfa13 100644 --- a/app/services/documentary/frame_analysis_service.py +++ b/app/services/documentary/frame_analysis_service.py @@ -1,4 +1,6 @@ +import json import os +import re from app.utils import utils from app.services.documentary.frame_analysis_models import FrameBatchResult @@ -78,3 +80,65 @@ JSON 必须包含以下键: ] ) return f"{legacy_prefix}_{utils.md5(payload)}" + + def _strip_code_fence(self, response_text: str) -> str: + cleaned = (response_text or "").strip() + cleaned = re.sub(r"^```[a-zA-Z0-9_-]*\s*", "", cleaned) + cleaned = re.sub(r"\s*```$", "", cleaned) + return cleaned.strip() + + def _parse_batch_response( + self, + *, + batch_index: int, + raw_response: str, + frame_paths: list[str], + time_range: str, + ) -> FrameBatchResult: + try: + payload = json.loads(self._strip_code_fence(raw_response)) + if not isinstance(payload, dict): + raise ValueError("Batch response JSON payload must be an object") + except Exception as exc: + return self._build_failed_batch_result( + batch_index=batch_index, + raw_response=raw_response, + error_message=str(exc), + frame_paths=frame_paths, + time_range=time_range, + ) + + raw_observations = payload.get("frame_observations") + if not isinstance(raw_observations, list): + raw_observations = [] + + frame_observations: list[dict] = [] + for index, frame_path in enumerate(frame_paths): + entry = raw_observations[index] if index < len(raw_observations) else {} + if isinstance(entry, dict): + observation = str(entry.get("observation", "") or "") + timestamp = str(entry.get("timestamp", "") or "") + else: + observation = str(entry or "") + timestamp = "" + frame_observations.append( + { + "frame_path": frame_path, + "timestamp": timestamp, + "observation": observation, + } + ) + + summary = payload.get("overall_activity_summary", "") + if not isinstance(summary, str): + summary = str(summary or "") + + return FrameBatchResult( + batch_index=batch_index, + status="success", + time_range=time_range, + raw_response=raw_response, + frame_paths=list(frame_paths), + frame_observations=frame_observations, + overall_activity_summary=summary, + ) diff --git a/app/services/llm/base.py b/app/services/llm/base.py index 87f1368..737ceb9 100644 --- a/app/services/llm/base.py +++ b/app/services/llm/base.py @@ -108,6 +108,7 @@ class VisionModelProvider(BaseLLMProvider): images: List[Union[str, Path, PIL.Image.Image]], prompt: str, batch_size: int = 10, + max_concurrency: int = 1, **kwargs) -> List[str]: """ 分析图片并返回结果 @@ -116,6 +117,7 @@ class VisionModelProvider(BaseLLMProvider): images: 图片路径列表或PIL图片对象列表 prompt: 分析提示词 batch_size: 批处理大小 + max_concurrency: 最大并发批次数(实现支持时生效) **kwargs: 其他参数 Returns: diff --git a/app/services/llm/migration_adapter.py b/app/services/llm/migration_adapter.py index 49ac75a..68615a8 100644 --- a/app/services/llm/migration_adapter.py +++ b/app/services/llm/migration_adapter.py @@ -159,7 +159,8 @@ class VisionAnalyzerAdapter: async def analyze_images(self, images: List[Union[str, Path, PIL.Image.Image]], prompt: str, - batch_size: int = 10) -> List[Dict[str, Any]]: + batch_size: int = 10, + max_concurrency: int = 1) -> List[Dict[str, Any]]: """ 分析图片 - 兼容原有接口 @@ -167,6 +168,7 @@ class VisionAnalyzerAdapter: images: 图片列表 prompt: 分析提示词 batch_size: 批处理大小 + max_concurrency: 最大并发批次数 Returns: 分析结果列表,格式与旧实现兼容 @@ -177,7 +179,8 @@ class VisionAnalyzerAdapter: images=images, prompt=prompt, provider=self.provider, - batch_size=batch_size + batch_size=batch_size, + max_concurrency=max_concurrency, ) # 转换为旧格式以保持向后兼容性 diff --git a/app/services/llm/openai_compatible_provider.py b/app/services/llm/openai_compatible_provider.py index 6423ec9..b91c6dc 100644 --- a/app/services/llm/openai_compatible_provider.py +++ b/app/services/llm/openai_compatible_provider.py @@ -4,6 +4,7 @@ OpenAI 兼容提供商实现 使用 OpenAI 官方 SDK 调用 OpenAI 兼容接口,支持文本和视觉模型。 """ +import asyncio import io import base64 import re @@ -96,24 +97,35 @@ class OpenAICompatibleVisionProvider(_OpenAICompatibleBase, VisionModelProvider) images: List[Union[str, Path, PIL.Image.Image]], prompt: str, batch_size: int = 10, + max_concurrency: int = 1, **kwargs, ) -> List[str]: logger.info(f"开始使用 OpenAI 兼容接口 ({self.model_name}) 分析 {len(images)} 张图片") processed_images = self._prepare_images(images) - results: List[str] = [] + if not processed_images: + return [] - for i in range(0, len(processed_images), batch_size): - batch = processed_images[i : i + batch_size] - logger.info(f"处理第 {i // batch_size + 1} 批,共 {len(batch)} 张图片") - try: - result = await self._analyze_batch(batch, prompt, **kwargs) - results.append(result) - except Exception as exc: - logger.error(f"批次 {i // batch_size + 1} 处理失败: {exc}") - results.append(f"批次处理失败: {exc}") + bounded_concurrency = max(1, int(max_concurrency)) + semaphore = asyncio.Semaphore(bounded_concurrency) + batches = [ + (index // batch_size, processed_images[index : index + batch_size]) + for index in range(0, len(processed_images), batch_size) + ] - return results + async def run_batch(batch_index: int, batch: List[PIL.Image.Image]) -> tuple[int, str]: + logger.info(f"处理第 {batch_index + 1} 批,共 {len(batch)} 张图片") + async with semaphore: + try: + result = await self._analyze_batch(batch, prompt, **kwargs) + return batch_index, result + except Exception as exc: + logger.error(f"批次 {batch_index + 1} 处理失败: {exc}") + return batch_index, f"批次处理失败: {exc}" + + completed = await asyncio.gather(*(run_batch(index, batch) for index, batch in batches)) + completed.sort(key=lambda item: item[0]) + return [result for _, result in completed] async def _analyze_batch(self, batch: List[PIL.Image.Image], prompt: str, **kwargs) -> str: content = [{"type": "text", "text": prompt}] diff --git a/app/services/llm/test_openai_compat_unittest.py b/app/services/llm/test_openai_compat_unittest.py index faa4e80..8393ee3 100644 --- a/app/services/llm/test_openai_compat_unittest.py +++ b/app/services/llm/test_openai_compat_unittest.py @@ -1,10 +1,12 @@ """OpenAI 兼容 provider 的最小回归测试。""" +import asyncio import unittest from app.config import config from app.services.llm.base import TextModelProvider from app.services.llm.manager import LLMServiceManager +from app.services.llm.openai_compatible_provider import OpenAICompatibleVisionProvider from app.services.llm.providers import register_all_providers @@ -63,5 +65,54 @@ class OpenAICompatManagerTests(unittest.TestCase): self.assertEqual("https://new.example/v1", provider.base_url) +class OpenAICompatVisionConcurrencyTests(unittest.IsolatedAsyncioTestCase): + async def test_analyze_images_keeps_batch_order_when_running_concurrently(self): + provider = OpenAICompatibleVisionProvider(api_key="k", model_name="m") + provider._prepare_images = lambda images: list(images) + + async def fake_analyze_batch(batch, prompt, **kwargs): + delays = {"a": 0.03, "c": 0.01, "e": 0.0} + await asyncio.sleep(delays[batch[0]]) + return f"batch-{batch[0]}" + + provider._analyze_batch = fake_analyze_batch + + result = await provider.analyze_images( + images=["a", "b", "c", "d", "e", "f"], + prompt="prompt", + batch_size=2, + max_concurrency=2, + ) + + self.assertEqual(["batch-a", "batch-c", "batch-e"], result) + + async def test_analyze_images_respects_max_concurrency_limit(self): + provider = OpenAICompatibleVisionProvider(api_key="k", model_name="m") + provider._prepare_images = lambda images: list(images) + + in_flight = 0 + max_in_flight = 0 + + async def fake_analyze_batch(batch, prompt, **kwargs): + nonlocal in_flight, max_in_flight + in_flight += 1 + max_in_flight = max(max_in_flight, in_flight) + await asyncio.sleep(0.02) + in_flight -= 1 + return f"batch-{batch[0]}" + + provider._analyze_batch = fake_analyze_batch + + result = await provider.analyze_images( + images=["a", "b", "c", "d", "e", "f"], + prompt="prompt", + batch_size=1, + max_concurrency=2, + ) + + self.assertEqual(6, len(result)) + self.assertEqual(2, max_in_flight) + + if __name__ == "__main__": unittest.main() diff --git a/config.example.toml b/config.example.toml index f226c34..4d49d61 100644 --- a/config.example.toml +++ b/config.example.toml @@ -152,3 +152,6 @@ # 大模型单次处理的关键帧数量 vision_batch_size = 10 + + # 视觉批处理最大并发批次数(OpenAI 兼容 provider) + vision_max_concurrency = 2 diff --git a/tests/test_documentary_frame_analysis_service.py b/tests/test_documentary_frame_analysis_service.py index 1ac0875..1d3415f 100644 --- a/tests/test_documentary_frame_analysis_service.py +++ b/tests/test_documentary_frame_analysis_service.py @@ -66,6 +66,63 @@ class DocumentaryFrameAnalysisServiceTests(unittest.TestCase): self.assertFalse(hasattr(batch, "observations")) self.assertFalse(hasattr(batch, "summary")) + def test_parse_batch_returns_failed_result_when_json_is_invalid(self): + service = DocumentaryFrameAnalysisService() + + batch = service._parse_batch_response( + batch_index=0, + raw_response="plain text", + frame_paths=["/tmp/keyframe_000000_000000000.jpg"], + time_range="00:00:00,000-00:00:03,000", + ) + + self.assertEqual("failed", batch.status) + self.assertEqual("plain text", batch.raw_response) + self.assertEqual(["/tmp/keyframe_000000_000000000.jpg"], batch.frame_paths) + self.assertEqual([], batch.frame_observations) + self.assertEqual("", batch.overall_activity_summary) + + def test_parse_batch_parses_code_fenced_json_into_structured_result(self): + service = DocumentaryFrameAnalysisService() + raw_response = """```json +{ + "frame_observations": [ + {"observation": "第一帧画面"}, + {"observation": "第二帧画面"} + ], + "overall_activity_summary": "人物从房间走到街道" +} +```""" + + batch = service._parse_batch_response( + batch_index=1, + raw_response=raw_response, + frame_paths=[ + "/tmp/keyframe_000000_000000000.jpg", + "/tmp/keyframe_000075_000003000.jpg", + ], + time_range="00:00:00,000-00:00:06,000", + ) + + self.assertEqual("success", batch.status) + self.assertEqual( + [ + { + "frame_path": "/tmp/keyframe_000000_000000000.jpg", + "timestamp": "", + "observation": "第一帧画面", + }, + { + "frame_path": "/tmp/keyframe_000075_000003000.jpg", + "timestamp": "", + "observation": "第二帧画面", + }, + ], + batch.frame_observations, + ) + self.assertEqual("人物从房间走到街道", batch.overall_activity_summary) + self.assertEqual("", batch.fallback_summary) + def test_cache_key_changes_when_interval_changes(self): service = DocumentaryFrameAnalysisService()