From 3d76bff4427e1c9dffc9a51ab18432dde22388fd Mon Sep 17 00:00:00 2001 From: linyq Date: Fri, 3 Apr 2026 01:30:51 +0800 Subject: [PATCH] perf(documentary): add fast frame extraction and cache keys --- .gitignore | 1 + .../documentary/frame_analysis_service.py | 31 ++++++++ app/utils/utils.py | 28 ++++--- app/utils/video_processor.py | 76 +++++++++++++++++++ ...test_documentary_frame_analysis_service.py | 43 +++++++++++ ...st_video_processor_documentary_unittest.py | 46 +++++++++++ 6 files changed, 214 insertions(+), 11 deletions(-) create mode 100644 tests/test_video_processor_documentary_unittest.py diff --git a/.gitignore b/.gitignore index bf7a572..9cf0620 100644 --- a/.gitignore +++ b/.gitignore @@ -47,3 +47,4 @@ AGENTS.md CLAUDE.md tests/* !tests/test_documentary_frame_analysis_service.py +!tests/test_video_processor_documentary_unittest.py diff --git a/app/services/documentary/frame_analysis_service.py b/app/services/documentary/frame_analysis_service.py index cdb0ad3..51b7aa8 100644 --- a/app/services/documentary/frame_analysis_service.py +++ b/app/services/documentary/frame_analysis_service.py @@ -1,3 +1,6 @@ +import os + +from app.utils import utils from app.services.documentary.frame_analysis_models import FrameBatchResult @@ -45,3 +48,31 @@ JSON 必须包含以下键: fallback_summary=fallback_summary, error_message=error_message, ) + + def _build_cache_key( + self, + video_path: str, + interval_seconds: float, + prompt_version: str, + model_name: str, + batch_size: int, + max_concurrency: int, + ) -> str: + try: + video_mtime = os.path.getmtime(video_path) + except OSError: + video_mtime = 0 + + payload = "|".join( + [ + str(video_path), + str(video_mtime), + str(interval_seconds), + str(prompt_version), + str(model_name), + str(batch_size), + str(max_concurrency), + "documentary-frame-analysis-v2", + ] + ) + return utils.md5(payload) diff --git a/app/utils/utils.py b/app/utils/utils.py index d101dce..98e8d1c 100644 --- a/app/utils/utils.py +++ b/app/utils/utils.py @@ -570,29 +570,35 @@ def temp_dir(sub_dir: str = ""): return d -def clear_keyframes_cache(video_path: str = None): +def clear_keyframes_cache(video_path: str = None, cache_scope: str = "keyframes"): """ 清理关键帧缓存 Args: video_path: 视频文件路径,如果指定则只清理该视频的缓存 + cache_scope: 缓存作用域目录,默认 keyframes """ try: - keyframes_dir = os.path.join(temp_dir(), "keyframes") - if not os.path.exists(keyframes_dir): + cache_dir = os.path.join(temp_dir(), cache_scope) + if not os.path.exists(cache_dir): return + import shutil + if video_path: - # 理指定视频的缓存 + # 清理指定视频的缓存(兼容前缀扩展键) video_hash = md5(video_path + str(os.path.getmtime(video_path))) - video_keyframes_dir = os.path.join(keyframes_dir, video_hash) - if os.path.exists(video_keyframes_dir): - import shutil - shutil.rmtree(video_keyframes_dir) - logger.info(f"已清理视频关键帧缓存: {video_path}") + for entry in os.listdir(cache_dir): + if not entry.startswith(video_hash): + continue + target_path = os.path.join(cache_dir, entry) + if os.path.isdir(target_path): + shutil.rmtree(target_path) + else: + os.remove(target_path) + logger.info(f"已清理视频关键帧缓存: {video_path}") else: # 清理所有缓存 - import shutil - shutil.rmtree(keyframes_dir) + shutil.rmtree(cache_dir) logger.info("已清理所有关键帧缓存") except Exception as e: diff --git a/app/utils/video_processor.py b/app/utils/video_processor.py index 6c46737..14b113a 100644 --- a/app/utils/video_processor.py +++ b/app/utils/video_processor.py @@ -185,6 +185,82 @@ class VideoProcessor: return frame_numbers + def extract_frames_by_interval_with_fallback(self, output_dir: str, interval_seconds: float = 5.0) -> List[str]: + """ + 先尝试单次 ffmpeg 快路径抽帧,失败时回退到高兼容方案。 + """ + os.makedirs(output_dir, exist_ok=True) + + try: + return self._extract_frames_fast_path(output_dir, interval_seconds=interval_seconds) + except Exception as exc: + logger.warning(f"快路径抽帧失败,回退到兼容模式: {exc}") + self.extract_frames_by_interval_ultra_compatible(output_dir, interval_seconds=interval_seconds) + return self._collect_extracted_frame_paths(output_dir) + + def _extract_frames_fast_path(self, output_dir: str, interval_seconds: float = 5.0) -> List[str]: + """ + 使用单次 ffmpeg 命令按固定间隔抽帧,随后重命名为既有 keyframe 约定格式。 + """ + if interval_seconds <= 0: + raise ValueError("interval_seconds must be > 0") + + os.makedirs(output_dir, exist_ok=True) + raw_pattern = os.path.join(output_dir, "fastframe_%06d.jpg") + cmd = [ + "ffmpeg", + "-hide_banner", + "-loglevel", + "error", + "-i", + self.video_path, + "-vf", + f"fps=1/{interval_seconds}", + "-q:v", + "2", + "-start_number", + "0", + "-y", + raw_pattern, + ] + subprocess.run(cmd, capture_output=True, text=True, check=True, timeout=120) + + raw_files = sorted( + filename + for filename in os.listdir(output_dir) + if re.fullmatch(r"fastframe_\d{6}\.jpg", filename) + ) + if not raw_files: + raise RuntimeError("Fast-path extraction produced no frames") + + renamed_files: List[str] = [] + for index, filename in enumerate(raw_files): + timestamp = index * interval_seconds + frame_number = int(timestamp * self.fps) + token = self._format_timestamp_token(timestamp) + source_path = os.path.join(output_dir, filename) + target_path = os.path.join(output_dir, f"keyframe_{frame_number:06d}_{token}.jpg") + os.replace(source_path, target_path) + renamed_files.append(target_path) + + return renamed_files + + @staticmethod + def _format_timestamp_token(timestamp: float) -> str: + hours = int(timestamp // 3600) + minutes = int((timestamp % 3600) // 60) + seconds = int(timestamp % 60) + milliseconds = int((timestamp % 1) * 1000) + return f"{hours:02d}{minutes:02d}{seconds:02d}{milliseconds:03d}" + + @staticmethod + def _collect_extracted_frame_paths(output_dir: str) -> List[str]: + return sorted( + os.path.join(output_dir, name) + for name in os.listdir(output_dir) + if name.endswith(".jpg") + ) + def _extract_single_frame_optimized(self, timestamp: float, output_path: str, use_hw_accel: bool, hwaccel_type: str) -> bool: """ diff --git a/tests/test_documentary_frame_analysis_service.py b/tests/test_documentary_frame_analysis_service.py index d5c4c11..1cf7284 100644 --- a/tests/test_documentary_frame_analysis_service.py +++ b/tests/test_documentary_frame_analysis_service.py @@ -1,7 +1,11 @@ import unittest +import os +from tempfile import TemporaryDirectory +from unittest.mock import patch from app.services.documentary.frame_analysis_models import DocumentaryAnalysisConfig from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService +from app.utils import utils class DocumentaryFrameAnalysisServiceTests(unittest.TestCase): @@ -62,6 +66,45 @@ class DocumentaryFrameAnalysisServiceTests(unittest.TestCase): self.assertFalse(hasattr(batch, "observations")) self.assertFalse(hasattr(batch, "summary")) + def test_cache_key_changes_when_interval_changes(self): + service = DocumentaryFrameAnalysisService() + + with patch("app.services.documentary.frame_analysis_service.os.path.getmtime", return_value=100.0): + key_a = service._build_cache_key("video.mp4", 3.0, "prompt-v1", "model-a", 10, 2) + key_b = service._build_cache_key("video.mp4", 5.0, "prompt-v1", "model-a", 10, 2) + + self.assertNotEqual(key_a, key_b) + + def test_cache_key_changes_when_model_changes(self): + service = DocumentaryFrameAnalysisService() + + with patch("app.services.documentary.frame_analysis_service.os.path.getmtime", return_value=100.0): + key_a = service._build_cache_key("video.mp4", 3.0, "prompt-v1", "model-a", 10, 2) + key_b = service._build_cache_key("video.mp4", 3.0, "prompt-v1", "model-b", 10, 2) + + self.assertNotEqual(key_a, key_b) + + def test_clear_keyframes_cache_respects_scope_and_prefix_match(self): + with TemporaryDirectory() as temp_root: + analysis_dir = os.path.join(temp_root, "analysis") + os.makedirs(analysis_dir, exist_ok=True) + + with patch("app.utils.utils.os.path.getmtime", return_value=123.0): + prefix = utils.md5("video.mp4" + "123.0") + + target_dir = os.path.join(analysis_dir, f"{prefix}_interval3") + keep_dir = os.path.join(analysis_dir, "other_video") + os.makedirs(target_dir, exist_ok=True) + os.makedirs(keep_dir, exist_ok=True) + + with patch("app.utils.utils.temp_dir", return_value=temp_root), patch( + "app.utils.utils.os.path.getmtime", return_value=123.0 + ): + utils.clear_keyframes_cache(video_path="video.mp4", cache_scope="analysis") + + self.assertFalse(os.path.exists(target_dir)) + self.assertTrue(os.path.exists(keep_dir)) + class DocumentaryAnalysisConfigTests(unittest.TestCase): def test_config_rejects_non_positive_frame_interval(self): diff --git a/tests/test_video_processor_documentary_unittest.py b/tests/test_video_processor_documentary_unittest.py new file mode 100644 index 0000000..d8851f2 --- /dev/null +++ b/tests/test_video_processor_documentary_unittest.py @@ -0,0 +1,46 @@ +import os +import unittest +from tempfile import TemporaryDirectory +from unittest.mock import patch + +from app.utils.video_processor import VideoProcessor + + +class VideoProcessorDocumentaryTests(unittest.TestCase): + @patch.object(VideoProcessor, "_extract_frames_fast_path", return_value=["a.jpg"]) + def test_extract_frames_by_interval_prefers_fast_path(self, fast_path): + processor = VideoProcessor.__new__(VideoProcessor) + processor.video_path = "demo.mp4" + processor.duration = 6.0 + processor.fps = 25.0 + + result = processor.extract_frames_by_interval_with_fallback("/tmp/out", interval_seconds=3.0) + + self.assertEqual(["a.jpg"], result) + fast_path.assert_called_once_with("/tmp/out", interval_seconds=3.0) + + def test_extract_frames_by_interval_falls_back_to_ultra_compatible(self): + processor = VideoProcessor.__new__(VideoProcessor) + processor.video_path = "demo.mp4" + processor.duration = 6.0 + processor.fps = 25.0 + + with TemporaryDirectory() as output_dir: + expected_frame_path = os.path.join(output_dir, "keyframe_000000_000000000.jpg") + + def ultra_compatible_fallback(self, output_dir_arg, interval_seconds=5.0): + with open(expected_frame_path, "wb") as frame_file: + frame_file.write(b"frame") + return [0] + + with patch.object(VideoProcessor, "_extract_frames_fast_path", side_effect=RuntimeError("fast path failed")) as fast_path, patch.object( + VideoProcessor, + "extract_frames_by_interval_ultra_compatible", + side_effect=ultra_compatible_fallback, + autospec=True, + ) as fallback: + result = processor.extract_frames_by_interval_with_fallback(output_dir, interval_seconds=3.0) + + self.assertEqual([expected_frame_path], result) + fast_path.assert_called_once_with(output_dir, interval_seconds=3.0) + fallback.assert_called_once_with(processor, output_dir, interval_seconds=3.0)