mirror of
https://github.com/linyqh/NarratoAI.git
synced 2026-05-01 14:18:19 +00:00
perf(documentary): add fast frame extraction and cache keys
This commit is contained in:
parent
40a48cc9ff
commit
3d76bff442
1
.gitignore
vendored
1
.gitignore
vendored
@ -47,3 +47,4 @@ AGENTS.md
|
||||
CLAUDE.md
|
||||
tests/*
|
||||
!tests/test_documentary_frame_analysis_service.py
|
||||
!tests/test_video_processor_documentary_unittest.py
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
import os
|
||||
|
||||
from app.utils import utils
|
||||
from app.services.documentary.frame_analysis_models import FrameBatchResult
|
||||
|
||||
|
||||
@ -45,3 +48,31 @@ JSON 必须包含以下键:
|
||||
fallback_summary=fallback_summary,
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
def _build_cache_key(
|
||||
self,
|
||||
video_path: str,
|
||||
interval_seconds: float,
|
||||
prompt_version: str,
|
||||
model_name: str,
|
||||
batch_size: int,
|
||||
max_concurrency: int,
|
||||
) -> str:
|
||||
try:
|
||||
video_mtime = os.path.getmtime(video_path)
|
||||
except OSError:
|
||||
video_mtime = 0
|
||||
|
||||
payload = "|".join(
|
||||
[
|
||||
str(video_path),
|
||||
str(video_mtime),
|
||||
str(interval_seconds),
|
||||
str(prompt_version),
|
||||
str(model_name),
|
||||
str(batch_size),
|
||||
str(max_concurrency),
|
||||
"documentary-frame-analysis-v2",
|
||||
]
|
||||
)
|
||||
return utils.md5(payload)
|
||||
|
||||
@ -570,29 +570,35 @@ def temp_dir(sub_dir: str = ""):
|
||||
return d
|
||||
|
||||
|
||||
def clear_keyframes_cache(video_path: str = None):
|
||||
def clear_keyframes_cache(video_path: str = None, cache_scope: str = "keyframes"):
|
||||
"""
|
||||
清理关键帧缓存
|
||||
Args:
|
||||
video_path: 视频文件路径,如果指定则只清理该视频的缓存
|
||||
cache_scope: 缓存作用域目录,默认 keyframes
|
||||
"""
|
||||
try:
|
||||
keyframes_dir = os.path.join(temp_dir(), "keyframes")
|
||||
if not os.path.exists(keyframes_dir):
|
||||
cache_dir = os.path.join(temp_dir(), cache_scope)
|
||||
if not os.path.exists(cache_dir):
|
||||
return
|
||||
|
||||
import shutil
|
||||
|
||||
if video_path:
|
||||
# 理指定视频的缓存
|
||||
# 清理指定视频的缓存(兼容前缀扩展键)
|
||||
video_hash = md5(video_path + str(os.path.getmtime(video_path)))
|
||||
video_keyframes_dir = os.path.join(keyframes_dir, video_hash)
|
||||
if os.path.exists(video_keyframes_dir):
|
||||
import shutil
|
||||
shutil.rmtree(video_keyframes_dir)
|
||||
logger.info(f"已清理视频关键帧缓存: {video_path}")
|
||||
for entry in os.listdir(cache_dir):
|
||||
if not entry.startswith(video_hash):
|
||||
continue
|
||||
target_path = os.path.join(cache_dir, entry)
|
||||
if os.path.isdir(target_path):
|
||||
shutil.rmtree(target_path)
|
||||
else:
|
||||
os.remove(target_path)
|
||||
logger.info(f"已清理视频关键帧缓存: {video_path}")
|
||||
else:
|
||||
# 清理所有缓存
|
||||
import shutil
|
||||
shutil.rmtree(keyframes_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
logger.info("已清理所有关键帧缓存")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@ -185,6 +185,82 @@ class VideoProcessor:
|
||||
|
||||
return frame_numbers
|
||||
|
||||
def extract_frames_by_interval_with_fallback(self, output_dir: str, interval_seconds: float = 5.0) -> List[str]:
|
||||
"""
|
||||
先尝试单次 ffmpeg 快路径抽帧,失败时回退到高兼容方案。
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
return self._extract_frames_fast_path(output_dir, interval_seconds=interval_seconds)
|
||||
except Exception as exc:
|
||||
logger.warning(f"快路径抽帧失败,回退到兼容模式: {exc}")
|
||||
self.extract_frames_by_interval_ultra_compatible(output_dir, interval_seconds=interval_seconds)
|
||||
return self._collect_extracted_frame_paths(output_dir)
|
||||
|
||||
def _extract_frames_fast_path(self, output_dir: str, interval_seconds: float = 5.0) -> List[str]:
|
||||
"""
|
||||
使用单次 ffmpeg 命令按固定间隔抽帧,随后重命名为既有 keyframe 约定格式。
|
||||
"""
|
||||
if interval_seconds <= 0:
|
||||
raise ValueError("interval_seconds must be > 0")
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
raw_pattern = os.path.join(output_dir, "fastframe_%06d.jpg")
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-hide_banner",
|
||||
"-loglevel",
|
||||
"error",
|
||||
"-i",
|
||||
self.video_path,
|
||||
"-vf",
|
||||
f"fps=1/{interval_seconds}",
|
||||
"-q:v",
|
||||
"2",
|
||||
"-start_number",
|
||||
"0",
|
||||
"-y",
|
||||
raw_pattern,
|
||||
]
|
||||
subprocess.run(cmd, capture_output=True, text=True, check=True, timeout=120)
|
||||
|
||||
raw_files = sorted(
|
||||
filename
|
||||
for filename in os.listdir(output_dir)
|
||||
if re.fullmatch(r"fastframe_\d{6}\.jpg", filename)
|
||||
)
|
||||
if not raw_files:
|
||||
raise RuntimeError("Fast-path extraction produced no frames")
|
||||
|
||||
renamed_files: List[str] = []
|
||||
for index, filename in enumerate(raw_files):
|
||||
timestamp = index * interval_seconds
|
||||
frame_number = int(timestamp * self.fps)
|
||||
token = self._format_timestamp_token(timestamp)
|
||||
source_path = os.path.join(output_dir, filename)
|
||||
target_path = os.path.join(output_dir, f"keyframe_{frame_number:06d}_{token}.jpg")
|
||||
os.replace(source_path, target_path)
|
||||
renamed_files.append(target_path)
|
||||
|
||||
return renamed_files
|
||||
|
||||
@staticmethod
|
||||
def _format_timestamp_token(timestamp: float) -> str:
|
||||
hours = int(timestamp // 3600)
|
||||
minutes = int((timestamp % 3600) // 60)
|
||||
seconds = int(timestamp % 60)
|
||||
milliseconds = int((timestamp % 1) * 1000)
|
||||
return f"{hours:02d}{minutes:02d}{seconds:02d}{milliseconds:03d}"
|
||||
|
||||
@staticmethod
|
||||
def _collect_extracted_frame_paths(output_dir: str) -> List[str]:
|
||||
return sorted(
|
||||
os.path.join(output_dir, name)
|
||||
for name in os.listdir(output_dir)
|
||||
if name.endswith(".jpg")
|
||||
)
|
||||
|
||||
def _extract_single_frame_optimized(self, timestamp: float, output_path: str,
|
||||
use_hw_accel: bool, hwaccel_type: str) -> bool:
|
||||
"""
|
||||
|
||||
@ -1,7 +1,11 @@
|
||||
import unittest
|
||||
import os
|
||||
from tempfile import TemporaryDirectory
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.services.documentary.frame_analysis_models import DocumentaryAnalysisConfig
|
||||
from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService
|
||||
from app.utils import utils
|
||||
|
||||
|
||||
class DocumentaryFrameAnalysisServiceTests(unittest.TestCase):
|
||||
@ -62,6 +66,45 @@ class DocumentaryFrameAnalysisServiceTests(unittest.TestCase):
|
||||
self.assertFalse(hasattr(batch, "observations"))
|
||||
self.assertFalse(hasattr(batch, "summary"))
|
||||
|
||||
def test_cache_key_changes_when_interval_changes(self):
|
||||
service = DocumentaryFrameAnalysisService()
|
||||
|
||||
with patch("app.services.documentary.frame_analysis_service.os.path.getmtime", return_value=100.0):
|
||||
key_a = service._build_cache_key("video.mp4", 3.0, "prompt-v1", "model-a", 10, 2)
|
||||
key_b = service._build_cache_key("video.mp4", 5.0, "prompt-v1", "model-a", 10, 2)
|
||||
|
||||
self.assertNotEqual(key_a, key_b)
|
||||
|
||||
def test_cache_key_changes_when_model_changes(self):
|
||||
service = DocumentaryFrameAnalysisService()
|
||||
|
||||
with patch("app.services.documentary.frame_analysis_service.os.path.getmtime", return_value=100.0):
|
||||
key_a = service._build_cache_key("video.mp4", 3.0, "prompt-v1", "model-a", 10, 2)
|
||||
key_b = service._build_cache_key("video.mp4", 3.0, "prompt-v1", "model-b", 10, 2)
|
||||
|
||||
self.assertNotEqual(key_a, key_b)
|
||||
|
||||
def test_clear_keyframes_cache_respects_scope_and_prefix_match(self):
|
||||
with TemporaryDirectory() as temp_root:
|
||||
analysis_dir = os.path.join(temp_root, "analysis")
|
||||
os.makedirs(analysis_dir, exist_ok=True)
|
||||
|
||||
with patch("app.utils.utils.os.path.getmtime", return_value=123.0):
|
||||
prefix = utils.md5("video.mp4" + "123.0")
|
||||
|
||||
target_dir = os.path.join(analysis_dir, f"{prefix}_interval3")
|
||||
keep_dir = os.path.join(analysis_dir, "other_video")
|
||||
os.makedirs(target_dir, exist_ok=True)
|
||||
os.makedirs(keep_dir, exist_ok=True)
|
||||
|
||||
with patch("app.utils.utils.temp_dir", return_value=temp_root), patch(
|
||||
"app.utils.utils.os.path.getmtime", return_value=123.0
|
||||
):
|
||||
utils.clear_keyframes_cache(video_path="video.mp4", cache_scope="analysis")
|
||||
|
||||
self.assertFalse(os.path.exists(target_dir))
|
||||
self.assertTrue(os.path.exists(keep_dir))
|
||||
|
||||
|
||||
class DocumentaryAnalysisConfigTests(unittest.TestCase):
|
||||
def test_config_rejects_non_positive_frame_interval(self):
|
||||
|
||||
46
tests/test_video_processor_documentary_unittest.py
Normal file
46
tests/test_video_processor_documentary_unittest.py
Normal file
@ -0,0 +1,46 @@
|
||||
import os
|
||||
import unittest
|
||||
from tempfile import TemporaryDirectory
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.utils.video_processor import VideoProcessor
|
||||
|
||||
|
||||
class VideoProcessorDocumentaryTests(unittest.TestCase):
|
||||
@patch.object(VideoProcessor, "_extract_frames_fast_path", return_value=["a.jpg"])
|
||||
def test_extract_frames_by_interval_prefers_fast_path(self, fast_path):
|
||||
processor = VideoProcessor.__new__(VideoProcessor)
|
||||
processor.video_path = "demo.mp4"
|
||||
processor.duration = 6.0
|
||||
processor.fps = 25.0
|
||||
|
||||
result = processor.extract_frames_by_interval_with_fallback("/tmp/out", interval_seconds=3.0)
|
||||
|
||||
self.assertEqual(["a.jpg"], result)
|
||||
fast_path.assert_called_once_with("/tmp/out", interval_seconds=3.0)
|
||||
|
||||
def test_extract_frames_by_interval_falls_back_to_ultra_compatible(self):
|
||||
processor = VideoProcessor.__new__(VideoProcessor)
|
||||
processor.video_path = "demo.mp4"
|
||||
processor.duration = 6.0
|
||||
processor.fps = 25.0
|
||||
|
||||
with TemporaryDirectory() as output_dir:
|
||||
expected_frame_path = os.path.join(output_dir, "keyframe_000000_000000000.jpg")
|
||||
|
||||
def ultra_compatible_fallback(self, output_dir_arg, interval_seconds=5.0):
|
||||
with open(expected_frame_path, "wb") as frame_file:
|
||||
frame_file.write(b"frame")
|
||||
return [0]
|
||||
|
||||
with patch.object(VideoProcessor, "_extract_frames_fast_path", side_effect=RuntimeError("fast path failed")) as fast_path, patch.object(
|
||||
VideoProcessor,
|
||||
"extract_frames_by_interval_ultra_compatible",
|
||||
side_effect=ultra_compatible_fallback,
|
||||
autospec=True,
|
||||
) as fallback:
|
||||
result = processor.extract_frames_by_interval_with_fallback(output_dir, interval_seconds=3.0)
|
||||
|
||||
self.assertEqual([expected_frame_path], result)
|
||||
fast_path.assert_called_once_with(output_dir, interval_seconds=3.0)
|
||||
fallback.assert_called_once_with(processor, output_dir, interval_seconds=3.0)
|
||||
Loading…
x
Reference in New Issue
Block a user