diff --git a/.gitignore b/.gitignore index bd3c487..a1c25bd 100644 --- a/.gitignore +++ b/.gitignore @@ -39,9 +39,15 @@ bug清单.md task.md .claude/* .serena/* +.worktrees/ # OpenSpec: 忽略活动的变更提案,但保留归档和规范 openspec/* AGENTS.md CLAUDE.md -tests/* \ No newline at end of file +tests/* +!tests/test_documentary_frame_analysis_service.py +!tests/test_video_processor_documentary_unittest.py +!tests/test_script_service_documentary_unittest.py +!tests/test_generate_narration_script_documentary_unittest.py +!tests/test_generate_script_docu_unittest.py diff --git a/README-en.md b/README-en.md index 8bedd97..5c57ef6 100644 --- a/README-en.md +++ b/README-en.md @@ -33,6 +33,7 @@ NarratoAI is an automated video narration tool that provides an all-in-one solut ## Latest News +- 2026.04.03 Released version 0.7.8, refactored the documentary frame-analysis pipeline with a shared service and improved extraction, caching, vision batching, and narration generation - 2025.05.11 Released new version 0.6.0, supports **short drama commentary** and optimized editing process - 2025.03.06 Released new version 0.5.2, supports DeepSeek R1 and DeepSeek V3 models for short drama mixing - 2024.12.16 Released new version 0.3.9, supports Alibaba Qwen2-VL model for video understanding; supports short drama mixing diff --git a/README.md b/README.md index 6c4c4fc..6f1a00f 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ NarratoAI 是一个自动化影视解说工具,基于LLM实现文案撰写、 本项目仅供学习和研究使用,不得商用。如需商业授权,请联系作者。 ## 最新资讯 +- 2026.04.03 发布新版本 0.7.8,重构纪录片逐帧分析链路,统一共享服务并优化抽帧、缓存、视觉并发与文案生成流程 - 2026.03.27 发布新版本 0.7.7,出于安全考虑,已移除 LiteLLM 依赖,统一使用 OpenAI 兼容请求链路 - 2025.11.20 发布新版本 0.7.5,新增 [IndexTTS2](https://github.com/index-tts/index-tts) 语音克隆支持 - 2025.10.15 发布新版本 0.7.3,升级大模型供应商管理能力 diff --git a/app/services/documentary/__init__.py b/app/services/documentary/__init__.py new file mode 100644 index 0000000..3b9a020 --- /dev/null +++ b/app/services/documentary/__init__.py @@ -0,0 +1,13 @@ +from app.services.documentary.frame_analysis_models import ( + DocumentaryAnalysisConfig, + FrameBatchResult, +) +from app.services.documentary.frame_analysis_service import ( + DocumentaryFrameAnalysisService, +) + +__all__ = [ + "DocumentaryAnalysisConfig", + "FrameBatchResult", + "DocumentaryFrameAnalysisService", +] diff --git a/app/services/documentary/frame_analysis_models.py b/app/services/documentary/frame_analysis_models.py new file mode 100644 index 0000000..6ac419c --- /dev/null +++ b/app/services/documentary/frame_analysis_models.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass, field + + +@dataclass(slots=True) +class DocumentaryAnalysisConfig: + video_path: str + frame_interval_seconds: float + vision_batch_size: int + vision_llm_provider: str + vision_model_name: str + custom_prompt: str = "" + max_concurrency: int = 2 + + def __post_init__(self) -> None: + if self.frame_interval_seconds <= 0: + raise ValueError("frame_interval_seconds must be > 0") + if self.vision_batch_size <= 0: + raise ValueError("vision_batch_size must be > 0") + if self.max_concurrency <= 0: + raise ValueError("max_concurrency must be > 0") + + +@dataclass(slots=True) +class FrameBatchResult: + batch_index: int + status: str + time_range: str + raw_response: str + frame_paths: list[str] = field(default_factory=list) + frame_observations: list[dict] = field(default_factory=list) + overall_activity_summary: str = "" + fallback_summary: str = "" + error_message: str = "" diff --git a/app/services/documentary/frame_analysis_service.py b/app/services/documentary/frame_analysis_service.py new file mode 100644 index 0000000..cbb794a --- /dev/null +++ b/app/services/documentary/frame_analysis_service.py @@ -0,0 +1,761 @@ +import asyncio +import json +import os +import re +from datetime import datetime +from typing import Any, Callable + +from loguru import logger + +from app.config import config +from app.services.documentary.frame_analysis_models import FrameBatchResult +from app.services.generate_narration_script import generate_narration, parse_frame_analysis_to_markdown +from app.services.llm.migration_adapter import create_vision_analyzer +from app.utils import utils, video_processor + + +class DocumentaryFrameAnalysisService: + PROMPT_TEMPLATE = """ +我提供了 {frame_count} 张视频帧,它们按时间顺序排列,代表一个连续的视频片段。 +首先,请详细描述每一帧的关键视觉信息(包含:主要内容、人物、动作和场景)。 +然后,基于所有帧的分析,请用简洁的语言总结整个视频片段中发生的主要活动或事件流程。 +请务必使用 JSON 格式输出。 +JSON 必须包含以下键: +- frame_observations: 数组,且长度必须为 {frame_count} +- overall_activity_summary: 字符串,描述整个批次主要活动 +示例结构: +{{ + "frame_observations": [ + {{"timestamp": "00:00:00,000", "observation": "画面描述"}} + ], + "overall_activity_summary": "本批次主要活动总结" +}} +请务必不要遗漏视频帧,我提供了 {frame_count} 张视频帧,frame_observations 必须包含 {frame_count} 个元素 +请只返回 JSON 字符串,不要附加解释文字。 +""".strip() + + async def generate_documentary_script( + self, + *, + video_path: str, + video_theme: str = "", + custom_prompt: str = "", + frame_interval_input: int | float | None = None, + vision_batch_size: int | None = None, + vision_llm_provider: str | None = None, + progress_callback: Callable[[float, str], None] | None = None, + vision_api_key: str | None = None, + vision_model_name: str | None = None, + vision_base_url: str | None = None, + max_concurrency: int | None = None, + ) -> list[dict]: + progress = progress_callback or (lambda _p, _m: None) + analysis_result = await self.analyze_video( + video_path=video_path, + video_theme=video_theme, + custom_prompt=custom_prompt, + frame_interval_input=frame_interval_input, + vision_batch_size=vision_batch_size, + vision_llm_provider=vision_llm_provider, + progress_callback=progress_callback, + vision_api_key=vision_api_key, + vision_model_name=vision_model_name, + vision_base_url=vision_base_url, + max_concurrency=max_concurrency, + ) + analysis_json_path = analysis_result["analysis_json_path"] + + progress(80, "正在生成解说文案...") + text_provider = config.app.get("text_llm_provider", "openai").lower() + text_api_key = config.app.get(f"text_{text_provider}_api_key") + text_model = config.app.get(f"text_{text_provider}_model_name") + text_base_url = config.app.get(f"text_{text_provider}_base_url") + if not text_api_key or not text_model: + raise ValueError( + f"未配置 {text_provider} 的文本模型参数。" + f"请在设置中配置 text_{text_provider}_api_key 和 text_{text_provider}_model_name" + ) + + markdown_output = parse_frame_analysis_to_markdown(analysis_json_path) + narration_input = self._build_narration_input( + markdown_output=markdown_output, + video_theme=video_theme, + custom_prompt=custom_prompt, + ) + narration_raw = generate_narration( + narration_input, + text_api_key, + base_url=text_base_url, + model=text_model, + ) + narration_items = self._parse_narration_items(narration_raw) + + final_script = [{**item, "OST": 2} for item in narration_items] + progress(100, "脚本生成完成") + return final_script + + async def analyze_video( + self, + *, + video_path: str, + video_theme: str = "", + custom_prompt: str = "", + frame_interval_input: int | float | None = None, + vision_batch_size: int | None = None, + vision_llm_provider: str | None = None, + progress_callback: Callable[[float, str], None] | None = None, + vision_api_key: str | None = None, + vision_model_name: str | None = None, + vision_base_url: str | None = None, + max_concurrency: int | None = None, + ) -> dict[str, Any]: + progress = progress_callback or (lambda _p, _m: None) + + if not video_path or not os.path.exists(video_path): + raise FileNotFoundError(f"视频文件不存在: {video_path}") + + frame_interval_seconds = self._resolve_frame_interval(frame_interval_input) + batch_size = self._resolve_batch_size(vision_batch_size) + concurrency = self._resolve_max_concurrency(max_concurrency) + provider = (vision_llm_provider or config.app.get("vision_llm_provider", "openai")).lower() + + api_key = vision_api_key if vision_api_key is not None else config.app.get(f"vision_{provider}_api_key") + model_name = ( + vision_model_name if vision_model_name is not None else config.app.get(f"vision_{provider}_model_name") + ) + base_url = vision_base_url if vision_base_url is not None else config.app.get(f"vision_{provider}_base_url", "") + if not api_key or not model_name: + raise ValueError( + f"未配置 {provider} 的 API Key 或模型名称。" + f"请在设置中配置 vision_{provider}_api_key 和 vision_{provider}_model_name" + ) + + progress(10, "正在提取关键帧...") + keyframe_files = self._load_or_extract_keyframes(video_path, frame_interval_seconds) + progress(25, f"关键帧准备完成,共 {len(keyframe_files)} 帧") + + progress(30, "正在初始化视觉分析器...") + analyzer = create_vision_analyzer( + provider=provider, + api_key=api_key, + model=model_name, + base_url=base_url, + ) + + batches = self._chunk_keyframes(keyframe_files, batch_size=batch_size) + if not batches: + raise RuntimeError("未能构建任何关键帧批次") + + progress(40, f"正在分析关键帧,共 {len(batches)} 个批次...") + batch_results = await self._analyze_batches( + analyzer=analyzer, + batches=batches, + custom_prompt=custom_prompt, + video_theme=video_theme, + max_concurrency=concurrency, + progress_callback=progress, + ) + + progress(65, "正在整理分析结果...") + sorted_batches = self._sort_batch_results(batch_results) + artifact = self._build_analysis_artifact( + sorted_batches, + video_path=video_path, + frame_interval_seconds=frame_interval_seconds, + vision_batch_size=batch_size, + vision_llm_provider=provider, + vision_model_name=model_name, + max_concurrency=concurrency, + ) + analysis_json_path = self._save_analysis_artifact(artifact) + video_clip_json = self._build_video_clip_json(sorted_batches) + + progress(75, "逐帧分析完成") + return { + "analysis_json_path": analysis_json_path, + "analysis_artifact": artifact, + "video_clip_json": video_clip_json, + "keyframe_files": keyframe_files, + } + + def _parse_narration_items(self, narration_raw: str) -> list[dict[str, Any]]: + parsed = self._repair_narration_payload(narration_raw) + + items: list[dict[str, Any]] = [] + if isinstance(parsed, dict): + raw_items = parsed.get("items") + if isinstance(raw_items, list): + items = [item for item in raw_items if isinstance(item, dict)] + + if not items: + raise ValueError("解说文案格式错误,无法解析JSON或缺少items字段") + + return items + + def _build_narration_input(self, *, markdown_output: str, video_theme: str, custom_prompt: str) -> str: + context_lines: list[str] = [] + if (video_theme or "").strip(): + context_lines.append(f"视频主题:{video_theme.strip()}") + if (custom_prompt or "").strip(): + context_lines.append(f"补充创作要求:{custom_prompt.strip()}") + + if not context_lines: + return markdown_output + + context_block = "\n".join(f"- {line}" for line in context_lines) + return f"{markdown_output.rstrip()}\n\n## 创作上下文\n{context_block}\n" + + def _repair_narration_payload(self, narration_raw: str) -> dict[str, Any] | None: + def load_json_candidate(payload: str) -> dict[str, Any] | None: + try: + parsed = json.loads(payload) + return parsed if isinstance(parsed, dict) else None + except Exception: + return None + + cleaned = (narration_raw or "").strip() + if not cleaned: + return None + + candidates: list[str] = [cleaned] + candidates.append(cleaned.replace("{{", "{").replace("}}", "}")) + + json_block = re.search(r"```json\s*(.*?)\s*```", cleaned, re.DOTALL) + if json_block: + candidates.append(json_block.group(1).strip()) + + start = cleaned.find("{") + end = cleaned.rfind("}") + if start >= 0 and end > start: + candidates.append(cleaned[start : end + 1]) + + for candidate in candidates: + parsed = load_json_candidate(candidate) + if parsed is not None: + return parsed + + fixed = cleaned.replace("{{", "{").replace("}}", "}") + fixed_start = fixed.find("{") + fixed_end = fixed.rfind("}") + if fixed_start >= 0 and fixed_end > fixed_start: + fixed = fixed[fixed_start : fixed_end + 1] + + fixed = re.sub(r"^\s*#.*$", "", fixed, flags=re.MULTILINE) + fixed = re.sub(r"^\s*//.*$", "", fixed, flags=re.MULTILINE) + fixed = re.sub(r",\s*}", "}", fixed) + fixed = re.sub(r",\s*]", "]", fixed) + fixed = re.sub(r"'([^']*)'\s*:", r'"\1":', fixed) + fixed = re.sub(r'([{\[,]\s*)([A-Za-z_][\w\u4e00-\u9fff]*)(\s*:)', r'\1"\2"\3', fixed) + fixed = re.sub(r'""([^"]*?)""', r'"\1"', fixed) + + return load_json_candidate(fixed) + + def _resolve_frame_interval(self, frame_interval_input: int | float | None) -> float: + interval = frame_interval_input + if interval in (None, ""): + interval = config.frames.get("frame_interval_input", 3) + try: + value = float(interval) + except (TypeError, ValueError): + value = 3.0 + if value <= 0: + raise ValueError("frame_interval_input must be > 0") + return value + + def _resolve_batch_size(self, vision_batch_size: int | None) -> int: + size = vision_batch_size or config.frames.get("vision_batch_size", 10) + try: + value = int(size) + except (TypeError, ValueError): + value = 10 + if value <= 0: + raise ValueError("vision_batch_size must be > 0") + return value + + def _resolve_max_concurrency(self, max_concurrency: int | None) -> int: + value = max_concurrency if max_concurrency is not None else config.frames.get("vision_max_concurrency", 2) + try: + parsed = int(value) + except (TypeError, ValueError): + parsed = 1 + return max(1, parsed) + + def _load_or_extract_keyframes(self, video_path: str, frame_interval_seconds: float) -> list[str]: + keyframes_root = os.path.join(utils.temp_dir(), "keyframes") + os.makedirs(keyframes_root, exist_ok=True) + cache_key = self._build_keyframe_cache_key(video_path, frame_interval_seconds) + cache_dir = os.path.join(keyframes_root, cache_key) + os.makedirs(cache_dir, exist_ok=True) + + cached_files = self._collect_keyframe_paths(cache_dir) + if cached_files: + logger.info(f"使用已缓存关键帧: {cache_dir}, 共 {len(cached_files)} 帧") + return cached_files + + processor = video_processor.VideoProcessor(video_path) + extracted = processor.extract_frames_by_interval_with_fallback( + output_dir=cache_dir, + interval_seconds=frame_interval_seconds, + ) + keyframe_files = sorted(str(path) for path in extracted if str(path).endswith(".jpg")) + if not keyframe_files: + keyframe_files = self._collect_keyframe_paths(cache_dir) + if not keyframe_files: + raise RuntimeError("未提取到任何关键帧") + + logger.info(f"关键帧提取完成: {cache_dir}, 共 {len(keyframe_files)} 帧") + return keyframe_files + + def _build_keyframe_cache_key(self, video_path: str, frame_interval_seconds: float) -> str: + try: + video_mtime = os.path.getmtime(video_path) + except OSError: + video_mtime = 0 + + legacy_prefix = utils.md5(f"{video_path}{video_mtime}") + payload = "|".join( + [ + str(video_path), + str(video_mtime), + str(frame_interval_seconds), + "documentary-keyframes-v2", + ] + ) + return f"{legacy_prefix}_{utils.md5(payload)}" + + @staticmethod + def _collect_keyframe_paths(cache_dir: str) -> list[str]: + if not os.path.exists(cache_dir): + return [] + return sorted( + os.path.join(cache_dir, name) + for name in os.listdir(cache_dir) + if re.fullmatch(r"keyframe_\d{6}_\d{9}\.jpg", name) + ) + + @staticmethod + def _chunk_keyframes(keyframe_files: list[str], batch_size: int) -> list[list[str]]: + return [keyframe_files[index : index + batch_size] for index in range(0, len(keyframe_files), batch_size)] + + async def _analyze_batches( + self, + *, + analyzer: Any, + batches: list[list[str]], + custom_prompt: str, + video_theme: str, + max_concurrency: int, + progress_callback: Callable[[float, str], None], + ) -> list[FrameBatchResult]: + semaphore = asyncio.Semaphore(max(1, max_concurrency)) + total = len(batches) + done = 0 + done_lock = asyncio.Lock() + + batch_time_ranges: list[str] = [] + previous_batch_files: list[str] | None = None + for batch_files in batches: + _, _, time_range = self._get_batch_timestamps(batch_files, previous_batch_files) + batch_time_ranges.append(time_range) + previous_batch_files = batch_files + + async def run_single(batch_index: int, frame_paths: list[str], time_range: str) -> FrameBatchResult: + nonlocal done + prompt = self._build_batch_prompt( + frame_count=len(frame_paths), + video_theme=video_theme, + custom_prompt=custom_prompt, + ) + try: + async with semaphore: + raw_results = await analyzer.analyze_images( + images=frame_paths, + prompt=prompt, + batch_size=max(1, len(frame_paths)), + max_concurrency=1, + ) + raw_response, error_message = self._extract_batch_response(raw_results) + if error_message: + return self._build_failed_batch_result( + batch_index=batch_index, + raw_response=raw_response, + error_message=error_message, + frame_paths=frame_paths, + time_range=time_range, + ) + return self._parse_batch_response( + batch_index=batch_index, + raw_response=raw_response, + frame_paths=frame_paths, + time_range=time_range, + ) + except Exception as exc: + return self._build_failed_batch_result( + batch_index=batch_index, + raw_response="", + error_message=str(exc), + frame_paths=frame_paths, + time_range=time_range, + ) + finally: + async with done_lock: + done += 1 + progress = 40 + (done / max(1, total)) * 25 + progress_callback(progress, f"正在分析关键帧批次 ({done}/{total})...") + + tasks = [ + run_single(batch_index=index, frame_paths=batch_files, time_range=batch_time_ranges[index]) + for index, batch_files in enumerate(batches) + ] + return await asyncio.gather(*tasks) + + def _build_batch_prompt(self, *, frame_count: int, video_theme: str, custom_prompt: str) -> str: + prompt = self._build_analysis_prompt(frame_count=frame_count) + extra_lines: list[str] = [] + if (video_theme or "").strip(): + extra_lines.append(f"视频主题:{video_theme.strip()}") + if (custom_prompt or "").strip(): + extra_lines.append(custom_prompt.strip()) + if not extra_lines: + return prompt + + extras = "\n".join(f"- {line}" for line in extra_lines) + return f"{prompt}\n\n补充分析要求:\n{extras}" + + def _extract_batch_response(self, raw_results: Any) -> tuple[str, str]: + if not raw_results: + return "", "Batch response is empty" + + first_result = raw_results[0] if isinstance(raw_results, list) else raw_results + if isinstance(first_result, dict): + raw_response = str(first_result.get("response", "") or "") + error_message = str(first_result.get("error", "") or "") + if error_message: + if not raw_response: + raw_response = error_message + return raw_response, error_message + if not raw_response.strip(): + return raw_response, "Batch response is empty" + return raw_response, "" + + raw_response = str(first_result or "") + if not raw_response.strip(): + return raw_response, "Batch response is empty" + return raw_response, "" + + def _sort_batch_results(self, batch_results: list[FrameBatchResult]) -> list[FrameBatchResult]: + return sorted(batch_results, key=lambda item: (self._time_range_sort_key(item.time_range), item.batch_index)) + + def _build_analysis_artifact( + self, + batch_results: list[FrameBatchResult], + *, + video_path: str, + frame_interval_seconds: float, + vision_batch_size: int, + vision_llm_provider: str, + vision_model_name: str, + max_concurrency: int, + ) -> dict[str, Any]: + sorted_batches = self._sort_batch_results(batch_results) + + batch_dicts: list[dict[str, Any]] = [] + frame_observations: list[dict[str, Any]] = [] + overall_activity_summaries: list[dict[str, Any]] = [] + + for batch in sorted_batches: + batch_payload = { + "batch_index": batch.batch_index, + "status": batch.status, + "time_range": batch.time_range, + "raw_response": batch.raw_response, + "frame_paths": list(batch.frame_paths), + "frame_observations": list(batch.frame_observations), + "overall_activity_summary": batch.overall_activity_summary, + "fallback_summary": batch.fallback_summary, + "error_message": batch.error_message, + } + batch_dicts.append(batch_payload) + + for observation in batch.frame_observations: + observation_payload = dict(observation) + observation_payload["batch_index"] = batch.batch_index + observation_payload["time_range"] = batch.time_range + frame_observations.append(observation_payload) + + summary_text = (batch.overall_activity_summary or batch.fallback_summary or "").strip() + if summary_text: + overall_activity_summaries.append( + { + "batch_index": batch.batch_index, + "time_range": batch.time_range, + "summary": summary_text, + } + ) + + return { + "artifact_version": "documentary-frame-analysis-v2", + "generated_at": datetime.now().isoformat(), + "video_path": video_path, + "frame_interval_seconds": frame_interval_seconds, + "vision_batch_size": vision_batch_size, + "vision_llm_provider": vision_llm_provider, + "vision_model_name": vision_model_name, + "vision_max_concurrency": max_concurrency, + "batches": batch_dicts, + # 向后兼容旧解析器结构 + "frame_observations": frame_observations, + "overall_activity_summaries": overall_activity_summaries, + } + + def _save_analysis_artifact(self, artifact: dict[str, Any]) -> str: + analysis_dir = os.path.join(utils.storage_dir(), "temp", "analysis") + os.makedirs(analysis_dir, exist_ok=True) + + filename = f"frame_analysis_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" + file_path = os.path.join(analysis_dir, filename) + suffix = 1 + while os.path.exists(file_path): + filename = f"frame_analysis_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{suffix:02d}.json" + file_path = os.path.join(analysis_dir, filename) + suffix += 1 + + with open(file_path, "w", encoding="utf-8") as fp: + json.dump(artifact, fp, ensure_ascii=False, indent=2) + logger.info(f"分析结果已保存到: {file_path}") + return file_path + + def _build_video_clip_json(self, batch_results: list[FrameBatchResult]) -> list[dict]: + clips: list[dict] = [] + for batch in self._sort_batch_results(batch_results): + picture = self._build_batch_picture(batch) + clips.append( + { + "timestamp": batch.time_range, + "picture": picture, + "narration": "", + "OST": 2, + } + ) + return clips + + def _build_batch_picture(self, batch: FrameBatchResult) -> str: + summary = (batch.overall_activity_summary or "").strip() + if summary: + return summary + + fallback = (batch.fallback_summary or "").strip() + if fallback: + return fallback + + observation_lines = [] + for frame in batch.frame_observations: + timestamp = str(frame.get("timestamp", "") or "").strip() + observation = str(frame.get("observation", "") or "").strip() + if timestamp and observation: + observation_lines.append(f"{timestamp}: {observation}") + elif observation: + observation_lines.append(observation) + if observation_lines: + return " ".join(observation_lines) + + raw_response = (batch.raw_response or "").strip() + if raw_response: + return raw_response[:200] + return "该批次分析失败,未返回可用描述。" + + def _time_range_sort_key(self, time_range: str) -> tuple[int, str]: + start = (time_range or "").split("-", 1)[0].strip() + return self._timestamp_to_milliseconds(start), time_range + + @staticmethod + def _timestamp_to_milliseconds(timestamp: str) -> int: + text = (timestamp or "").strip() + try: + if "," in text: + time_part, ms_part = text.split(",", 1) + milliseconds = int(ms_part) + else: + time_part = text + milliseconds = 0 + + parts = [int(part) for part in time_part.split(":") if part] + while len(parts) < 3: + parts.insert(0, 0) + hours, minutes, seconds = parts[-3], parts[-2], parts[-1] + return ((hours * 3600 + minutes * 60 + seconds) * 1000) + milliseconds + except Exception: + return 0 + + def _get_batch_timestamps( + self, + batch_files: list[str], + prev_batch_files: list[str] | None = None, + ) -> tuple[str, str, str]: + if not batch_files: + return "00:00:00,000", "00:00:00,000", "00:00:00,000-00:00:00,000" + + if len(batch_files) == 1 and prev_batch_files: + first_frame = os.path.basename(prev_batch_files[-1]) + last_frame = os.path.basename(batch_files[0]) + else: + first_frame = os.path.basename(batch_files[0]) + last_frame = os.path.basename(batch_files[-1]) + + first_timestamp = self._timestamp_from_keyframe_name(first_frame) + last_timestamp = self._timestamp_from_keyframe_name(last_frame) + return first_timestamp, last_timestamp, f"{first_timestamp}-{last_timestamp}" + + def _timestamp_from_keyframe_name(self, filename: str) -> str: + match = re.search(r"keyframe_\d{6}_(\d{9})\.jpg$", filename) + if not match: + return "00:00:00,000" + token = match.group(1) + hours = int(token[0:2]) + minutes = int(token[2:4]) + seconds = int(token[4:6]) + milliseconds = int(token[6:9]) + return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}" + + def _build_analysis_prompt(self, frame_count: int) -> str: + return self.PROMPT_TEMPLATE.format(frame_count=frame_count) + + def _build_failed_batch_result( + self, + *, + batch_index: int, + raw_response: str, + error_message: str, + frame_paths: list[str], + time_range: str, + ) -> FrameBatchResult: + fallback_summary = (raw_response or "").strip()[:200] + if not fallback_summary: + fallback_summary = f"Batch {batch_index} analysis failed: {error_message or 'unknown error'}" + + return FrameBatchResult( + batch_index=batch_index, + status="failed", + time_range=time_range, + raw_response=raw_response, + frame_paths=list(frame_paths), + fallback_summary=fallback_summary, + error_message=error_message, + ) + + def _build_cache_key( + self, + video_path: str, + interval_seconds: float, + prompt_version: str, + model_name: str, + batch_size: int, + max_concurrency: int, + ) -> str: + try: + video_mtime = os.path.getmtime(video_path) + except OSError: + video_mtime = 0 + + legacy_prefix = utils.md5(f"{video_path}{video_mtime}") + + payload = "|".join( + [ + str(video_path), + str(video_mtime), + str(interval_seconds), + str(prompt_version), + str(model_name), + str(batch_size), + str(max_concurrency), + "documentary-frame-analysis-v2", + ] + ) + return f"{legacy_prefix}_{utils.md5(payload)}" + + def _strip_code_fence(self, response_text: str) -> str: + cleaned = (response_text or "").strip() + cleaned = re.sub(r"^```[a-zA-Z0-9_-]*\s*", "", cleaned) + cleaned = re.sub(r"\s*```$", "", cleaned) + return cleaned.strip() + + def _parse_batch_response( + self, + *, + batch_index: int, + raw_response: str, + frame_paths: list[str], + time_range: str, + ) -> FrameBatchResult: + try: + payload = json.loads(self._strip_code_fence(raw_response)) + except Exception as exc: + return self._build_failed_batch_result( + batch_index=batch_index, + raw_response=raw_response, + error_message=str(exc), + frame_paths=frame_paths, + time_range=time_range, + ) + + validation_error = self._validate_batch_payload_contract(payload, expected_frame_count=len(frame_paths)) + if validation_error: + return self._build_failed_batch_result( + batch_index=batch_index, + raw_response=raw_response, + error_message=validation_error, + frame_paths=frame_paths, + time_range=time_range, + ) + + raw_observations = payload["frame_observations"] + + frame_observations: list[dict] = [] + for index, frame_path in enumerate(frame_paths): + entry = raw_observations[index] if index < len(raw_observations) else {} + if isinstance(entry, dict): + observation = str(entry.get("observation", "") or "") + timestamp = str(entry.get("timestamp", "") or "") + else: + observation = str(entry or "") + timestamp = "" + frame_observations.append( + { + "frame_path": frame_path, + "timestamp": timestamp, + "observation": observation, + } + ) + + raw_summary = payload.get("overall_activity_summary", "") + if isinstance(raw_summary, str): + summary = raw_summary + elif raw_summary is None: + summary = "" + else: + summary = str(raw_summary) + + return FrameBatchResult( + batch_index=batch_index, + status="success", + time_range=time_range, + raw_response=raw_response, + frame_paths=list(frame_paths), + frame_observations=frame_observations, + overall_activity_summary=summary, + ) + + def _validate_batch_payload_contract(self, payload: object, *, expected_frame_count: int) -> str: + if not isinstance(payload, dict): + return "Batch response JSON payload must be an object" + + if "frame_observations" not in payload or not isinstance(payload["frame_observations"], list): + return "Batch response must include frame_observations as a list" + + if len(payload["frame_observations"]) < expected_frame_count: + return ( + "Batch response frame_observations length is shorter than provided frame_paths: " + f"{len(payload['frame_observations'])} < {expected_frame_count}" + ) + + return "" diff --git a/app/services/generate_narration_script.py b/app/services/generate_narration_script.py index 80fcf1a..e18622b 100644 --- a/app/services/generate_narration_script.py +++ b/app/services/generate_narration_script.py @@ -38,46 +38,90 @@ def parse_frame_analysis_to_markdown(json_file_path): with open(json_file_path, 'r', encoding='utf-8') as file: data = json.load(file) - # 初始化Markdown字符串 + def time_to_milliseconds(time_text): + time_text = (time_text or "").strip() + if not time_text: + return 0 + try: + if "," in time_text: + hhmmss, ms = time_text.split(",", 1) + milliseconds = int(ms) + else: + hhmmss = time_text + milliseconds = 0 + + parts = [int(part) for part in hhmmss.split(":") if part] + while len(parts) < 3: + parts.insert(0, 0) + hours, minutes, seconds = parts[-3], parts[-2], parts[-1] + return ((hours * 3600 + minutes * 60 + seconds) * 1000) + milliseconds + except Exception: + return 0 + + def batch_sort_key(batch): + time_range = batch.get("time_range", "") + start = time_range.split("-", 1)[0].strip() + return time_to_milliseconds(start), batch.get("batch_index", 0) + markdown = "" - - # 获取总结和帧观察数据 + + # 新结构:按批次保存完整分析产物 + if isinstance(data.get("batches"), list): + ordered_batches = sorted(data.get("batches", []), key=batch_sort_key) + + for i, batch in enumerate(ordered_batches, 1): + time_range = batch.get("time_range", "") + summary = ( + batch.get("overall_activity_summary") + or batch.get("summary") + or batch.get("fallback_summary") + or "" + ) + observations = batch.get("frame_observations") or batch.get("observations") or [] + + markdown += f"## 片段 {i}\n" + markdown += f"- 时间范围:{time_range}\n" + markdown += f"- 片段描述:{summary}\n" if summary else "- 片段描述:\n" + markdown += "- 详细描述:\n" + + for frame in observations: + timestamp = frame.get("timestamp", "") + observation = frame.get("observation", "") + markdown += f" - {timestamp}: {observation}\n" if observation else f" - {timestamp}: \n" + + markdown += "\n" + + return markdown + + # 兼容旧结构 summaries = data.get('overall_activity_summaries', []) frame_observations = data.get('frame_observations', []) - - # 按批次组织数据 + batch_frames = {} for frame in frame_observations: batch_index = frame.get('batch_index') if batch_index not in batch_frames: batch_frames[batch_index] = [] batch_frames[batch_index].append(frame) - - # 生成Markdown内容 + for i, summary in enumerate(summaries, 1): batch_index = summary.get('batch_index') time_range = summary.get('time_range', '') batch_summary = summary.get('summary', '') - + markdown += f"## 片段 {i}\n" markdown += f"- 时间范围:{time_range}\n" - - # 添加片段描述 markdown += f"- 片段描述:{batch_summary}\n" if batch_summary else f"- 片段描述:\n" - markdown += "- 详细描述:\n" - - # 添加该批次的帧观察详情 + frames = batch_frames.get(batch_index, []) for frame in frames: timestamp = frame.get('timestamp', '') observation = frame.get('observation', '') - - # 直接使用原始文本,不进行分割 markdown += f" - {timestamp}: {observation}\n" if observation else f" - {timestamp}: \n" - + markdown += "\n" - + return markdown except Exception as e: diff --git a/app/services/llm/base.py b/app/services/llm/base.py index 87f1368..737ceb9 100644 --- a/app/services/llm/base.py +++ b/app/services/llm/base.py @@ -108,6 +108,7 @@ class VisionModelProvider(BaseLLMProvider): images: List[Union[str, Path, PIL.Image.Image]], prompt: str, batch_size: int = 10, + max_concurrency: int = 1, **kwargs) -> List[str]: """ 分析图片并返回结果 @@ -116,6 +117,7 @@ class VisionModelProvider(BaseLLMProvider): images: 图片路径列表或PIL图片对象列表 prompt: 分析提示词 batch_size: 批处理大小 + max_concurrency: 最大并发批次数(实现支持时生效) **kwargs: 其他参数 Returns: diff --git a/app/services/llm/migration_adapter.py b/app/services/llm/migration_adapter.py index 49ac75a..7bd5142 100644 --- a/app/services/llm/migration_adapter.py +++ b/app/services/llm/migration_adapter.py @@ -5,7 +5,6 @@ """ import asyncio -import json from typing import List, Dict, Any, Optional, Union from pathlib import Path import PIL.Image @@ -13,6 +12,7 @@ from loguru import logger from .unified_service import UnifiedLLMService from .exceptions import LLMServiceError +from .manager import LLMServiceManager # 导入新的提示词管理系统 from app.services.prompts import PromptManager @@ -110,41 +110,11 @@ class LegacyLLMAdapter: temperature=1.5, response_format="json" ) - - # 使用增强的JSON解析器 - from webui.tools.generate_short_summary import parse_and_fix_json - parsed_result = parse_and_fix_json(result) - - if not parsed_result: - logger.error("无法解析LLM返回的JSON数据") - # 返回一个基本的JSON结构而不是错误字符串 - return json.dumps({ - "items": [ - { - "_id": 1, - "timestamp": "00:00:00-00:00:10", - "picture": "解析失败,请检查LLM输出", - "narration": "解说文案生成失败,请重试" - } - ] - }, ensure_ascii=False) - - # 确保返回的是JSON字符串 - return json.dumps(parsed_result, ensure_ascii=False) + return result if isinstance(result, str) else str(result) except Exception as e: logger.error(f"生成解说文案失败: {str(e)}") - # 返回一个基本的JSON结构而不是错误字符串 - return json.dumps({ - "items": [ - { - "_id": 1, - "timestamp": "00:00:00-00:00:10", - "picture": "生成失败", - "narration": f"解说文案生成失败: {str(e)}" - } - ] - }, ensure_ascii=False) + raise class VisionAnalyzerAdapter: @@ -155,11 +125,29 @@ class VisionAnalyzerAdapter: self.api_key = api_key self.model = model self.base_url = base_url + + def _build_provider_with_explicit_settings(self): + provider_name = (self.provider or "").lower() + if not LLMServiceManager.is_registered(): + from .providers import register_all_providers + + register_all_providers() + + provider_class = LLMServiceManager._vision_providers.get(provider_name) + if provider_class is None: + raise LLMServiceError(f"视觉模型提供商未注册: {provider_name}") + + return provider_class( + api_key=self.api_key, + model_name=self.model, + base_url=self.base_url, + ) async def analyze_images(self, images: List[Union[str, Path, PIL.Image.Image]], prompt: str, - batch_size: int = 10) -> List[Dict[str, Any]]: + batch_size: int = 10, + max_concurrency: int = 1) -> List[Dict[str, Any]]: """ 分析图片 - 兼容原有接口 @@ -167,17 +155,20 @@ class VisionAnalyzerAdapter: images: 图片列表 prompt: 分析提示词 batch_size: 批处理大小 + max_concurrency: 最大并发批次数 Returns: 分析结果列表,格式与旧实现兼容 """ try: - # 使用统一服务分析图片 - results = await UnifiedLLMService.analyze_images( + provider = self._build_provider_with_explicit_settings() + results = await provider.analyze_images( images=images, prompt=prompt, - provider=self.provider, - batch_size=batch_size + batch_size=batch_size, + max_concurrency=max_concurrency, + api_key=self.api_key, + api_base=self.base_url, ) # 转换为旧格式以保持向后兼容性 diff --git a/app/services/llm/openai_compatible_provider.py b/app/services/llm/openai_compatible_provider.py index 6423ec9..b91c6dc 100644 --- a/app/services/llm/openai_compatible_provider.py +++ b/app/services/llm/openai_compatible_provider.py @@ -4,6 +4,7 @@ OpenAI 兼容提供商实现 使用 OpenAI 官方 SDK 调用 OpenAI 兼容接口,支持文本和视觉模型。 """ +import asyncio import io import base64 import re @@ -96,24 +97,35 @@ class OpenAICompatibleVisionProvider(_OpenAICompatibleBase, VisionModelProvider) images: List[Union[str, Path, PIL.Image.Image]], prompt: str, batch_size: int = 10, + max_concurrency: int = 1, **kwargs, ) -> List[str]: logger.info(f"开始使用 OpenAI 兼容接口 ({self.model_name}) 分析 {len(images)} 张图片") processed_images = self._prepare_images(images) - results: List[str] = [] + if not processed_images: + return [] - for i in range(0, len(processed_images), batch_size): - batch = processed_images[i : i + batch_size] - logger.info(f"处理第 {i // batch_size + 1} 批,共 {len(batch)} 张图片") - try: - result = await self._analyze_batch(batch, prompt, **kwargs) - results.append(result) - except Exception as exc: - logger.error(f"批次 {i // batch_size + 1} 处理失败: {exc}") - results.append(f"批次处理失败: {exc}") + bounded_concurrency = max(1, int(max_concurrency)) + semaphore = asyncio.Semaphore(bounded_concurrency) + batches = [ + (index // batch_size, processed_images[index : index + batch_size]) + for index in range(0, len(processed_images), batch_size) + ] - return results + async def run_batch(batch_index: int, batch: List[PIL.Image.Image]) -> tuple[int, str]: + logger.info(f"处理第 {batch_index + 1} 批,共 {len(batch)} 张图片") + async with semaphore: + try: + result = await self._analyze_batch(batch, prompt, **kwargs) + return batch_index, result + except Exception as exc: + logger.error(f"批次 {batch_index + 1} 处理失败: {exc}") + return batch_index, f"批次处理失败: {exc}" + + completed = await asyncio.gather(*(run_batch(index, batch) for index, batch in batches)) + completed.sort(key=lambda item: item[0]) + return [result for _, result in completed] async def _analyze_batch(self, batch: List[PIL.Image.Image], prompt: str, **kwargs) -> str: content = [{"type": "text", "text": prompt}] diff --git a/app/services/llm/test_openai_compat_unittest.py b/app/services/llm/test_openai_compat_unittest.py index faa4e80..acef31a 100644 --- a/app/services/llm/test_openai_compat_unittest.py +++ b/app/services/llm/test_openai_compat_unittest.py @@ -1,10 +1,14 @@ """OpenAI 兼容 provider 的最小回归测试。""" +import asyncio import unittest +from unittest.mock import patch from app.config import config from app.services.llm.base import TextModelProvider from app.services.llm.manager import LLMServiceManager +from app.services.llm.migration_adapter import LegacyLLMAdapter, VisionAnalyzerAdapter +from app.services.llm.openai_compatible_provider import OpenAICompatibleVisionProvider from app.services.llm.providers import register_all_providers @@ -63,5 +67,128 @@ class OpenAICompatManagerTests(unittest.TestCase): self.assertEqual("https://new.example/v1", provider.base_url) +class OpenAICompatVisionConcurrencyTests(unittest.IsolatedAsyncioTestCase): + async def test_analyze_images_keeps_batch_order_when_running_concurrently(self): + provider = OpenAICompatibleVisionProvider(api_key="k", model_name="m") + provider._prepare_images = lambda images: list(images) + + async def fake_analyze_batch(batch, prompt, **kwargs): + delays = {"a": 0.03, "c": 0.01, "e": 0.0} + await asyncio.sleep(delays[batch[0]]) + return f"batch-{batch[0]}" + + provider._analyze_batch = fake_analyze_batch + + result = await provider.analyze_images( + images=["a", "b", "c", "d", "e", "f"], + prompt="prompt", + batch_size=2, + max_concurrency=2, + ) + + self.assertEqual(["batch-a", "batch-c", "batch-e"], result) + + async def test_analyze_images_respects_max_concurrency_limit(self): + provider = OpenAICompatibleVisionProvider(api_key="k", model_name="m") + provider._prepare_images = lambda images: list(images) + + in_flight = 0 + max_in_flight = 0 + + async def fake_analyze_batch(batch, prompt, **kwargs): + nonlocal in_flight, max_in_flight + in_flight += 1 + max_in_flight = max(max_in_flight, in_flight) + await asyncio.sleep(0.02) + in_flight -= 1 + return f"batch-{batch[0]}" + + provider._analyze_batch = fake_analyze_batch + + result = await provider.analyze_images( + images=["a", "b", "c", "d", "e", "f"], + prompt="prompt", + batch_size=1, + max_concurrency=2, + ) + + self.assertEqual(6, len(result)) + self.assertEqual(2, max_in_flight) + + +class ExplicitVisionAdapterSettingsTests(unittest.IsolatedAsyncioTestCase): + class _CapturingVisionProvider: + last_init: tuple[str, str, str | None] | None = None + last_call_kwargs: dict | None = None + + def __init__(self, api_key: str, model_name: str, base_url: str | None = None): + self.api_key = api_key + self.model_name = model_name + self.base_url = base_url + ExplicitVisionAdapterSettingsTests._CapturingVisionProvider.last_init = (api_key, model_name, base_url) + + async def analyze_images(self, images, prompt, batch_size=10, max_concurrency=1, **kwargs): + ExplicitVisionAdapterSettingsTests._CapturingVisionProvider.last_call_kwargs = dict(kwargs) + return [f"{self.model_name}|{self.api_key}|{self.base_url}"] + + def setUp(self): + _reset_manager_state() + self._original_app = dict(config.app) + + def tearDown(self): + _reset_manager_state() + config.app.clear() + config.app.update(self._original_app) + + async def test_adapter_uses_explicit_settings_instead_of_global_config(self): + LLMServiceManager.register_vision_provider("openai", self._CapturingVisionProvider) + config.app["vision_openai_api_key"] = "config-key" + config.app["vision_openai_model_name"] = "config-model" + config.app["vision_openai_base_url"] = "https://config.example/v1" + + adapter = VisionAnalyzerAdapter( + provider="openai", + api_key="explicit-key", + model="explicit-model", + base_url="https://explicit.example/v1", + ) + result = await adapter.analyze_images( + images=["/tmp/keyframe_000001_000000100.jpg"], + prompt="描述画面", + batch_size=1, + max_concurrency=1, + ) + + self.assertEqual( + ("explicit-key", "explicit-model", "https://explicit.example/v1"), + self._CapturingVisionProvider.last_init, + ) + self.assertEqual("explicit-key", self._CapturingVisionProvider.last_call_kwargs["api_key"]) + self.assertEqual("https://explicit.example/v1", self._CapturingVisionProvider.last_call_kwargs["api_base"]) + self.assertEqual("explicit-model|explicit-key|https://explicit.example/v1", result[0]["response"]) + + +class LegacyNarrationAdapterBehaviorTests(unittest.TestCase): + def test_generate_narration_returns_raw_unrecoverable_payload_without_fabrication(self): + raw_payload = "not-json-at-all ::: ???" + + with patch( + "app.services.llm.migration_adapter.PromptManager.get_prompt", + return_value="prompt", + ), patch( + "app.services.llm.migration_adapter._run_async_safely", + return_value=raw_payload, + ): + result = LegacyLLMAdapter.generate_narration( + markdown_content="markdown", + api_key="test-key", + base_url="https://example.com/v1", + model="test-model", + ) + + self.assertEqual(raw_payload, result) + self.assertNotIn('"items"', result) + + if __name__ == "__main__": unittest.main() diff --git a/app/services/script_service.py b/app/services/script_service.py index 34a17a6..47c329c 100644 --- a/app/services/script_service.py +++ b/app/services/script_service.py @@ -1,324 +1,40 @@ -import os -import json -import time -import asyncio -import requests -from app.utils import video_processor -from loguru import logger -from typing import List, Dict, Any, Callable +from typing import Any, Callable -from app.utils import utils, gemini_analyzer, video_processor -from app.utils.script_generator import ScriptProcessor -from app.config import config +from loguru import logger + +from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService class ScriptGenerator: - def __init__(self): - self.temp_dir = utils.temp_dir() - self.keyframes_dir = os.path.join(self.temp_dir, "keyframes") - + def __init__(self, documentary_service: DocumentaryFrameAnalysisService | None = None): + self.documentary_service = documentary_service or DocumentaryFrameAnalysisService() + async def generate_script( self, video_path: str, video_theme: str = "", custom_prompt: str = "", - frame_interval_input: int = 5, + frame_interval_input: int | None = None, skip_seconds: int = 0, threshold: int = 30, - vision_batch_size: int = 5, - vision_llm_provider: str = "gemini", - progress_callback: Callable[[float, str], None] = None - ) -> List[Dict[Any, Any]]: - """ - 生成视频脚本的核心逻辑 - - Args: - video_path: 视频文件路径 - video_theme: 视频主题 - custom_prompt: 自定义提示词 - skip_seconds: 跳过开始的秒数 - threshold: 差异���值 - vision_batch_size: 视觉处理批次大小 - vision_llm_provider: 视觉模型提供商 - progress_callback: 进度回调函数 - - Returns: - List[Dict]: 生成的视频脚本 - """ - if progress_callback is None: - progress_callback = lambda p, m: None - - try: - # 提取关键帧 - progress_callback(10, "正在提取关键帧...") - keyframe_files = await self._extract_keyframes( - video_path, - skip_seconds, - threshold - ) - - # 使用统一的 LLM 接口(支持所有 provider) - script = await self._process_with_llm( - keyframe_files, - video_theme, - custom_prompt, - vision_batch_size, - vision_llm_provider, - progress_callback - ) - - return json.loads(script) if isinstance(script, str) else script - - except Exception as e: - logger.exception("Generate script failed") - raise - - async def _extract_keyframes( - self, - video_path: str, - skip_seconds: int, - threshold: int - ) -> List[str]: - """提取视频关键帧""" - video_hash = utils.md5(video_path + str(os.path.getmtime(video_path))) - video_keyframes_dir = os.path.join(self.keyframes_dir, video_hash) - - # 检查缓存 - keyframe_files = [] - if os.path.exists(video_keyframes_dir): - for filename in sorted(os.listdir(video_keyframes_dir)): - if filename.endswith('.jpg'): - keyframe_files.append(os.path.join(video_keyframes_dir, filename)) - - if keyframe_files: - logger.info(f"Using cached keyframes: {video_keyframes_dir}") - return keyframe_files - - # 提取新的关键帧 - os.makedirs(video_keyframes_dir, exist_ok=True) - - try: - processor = video_processor.VideoProcessor(video_path) - processor.process_video_pipeline( - output_dir=video_keyframes_dir, - skip_seconds=skip_seconds, - threshold=threshold + vision_batch_size: int | None = None, + vision_llm_provider: str | None = None, + progress_callback: Callable[[float, str], None] | None = None, + ) -> list[dict[Any, Any]]: + callback = progress_callback or (lambda _p, _m: None) + if skip_seconds != 0 or threshold != 30: + logger.warning( + "ScriptGenerator documentary path received " + f"skip_seconds={skip_seconds} threshold={threshold}; " + "the shared documentary frame pipeline does not currently apply these parameters." ) - for filename in sorted(os.listdir(video_keyframes_dir)): - if filename.endswith('.jpg'): - keyframe_files.append(os.path.join(video_keyframes_dir, filename)) - - return keyframe_files - - except Exception as e: - if os.path.exists(video_keyframes_dir): - import shutil - shutil.rmtree(video_keyframes_dir) - raise - - async def _process_with_llm( - self, - keyframe_files: List[str], - video_theme: str, - custom_prompt: str, - vision_batch_size: int, - vision_llm_provider: str, - progress_callback: Callable[[float, str], None] - ) -> str: - """使用统一 LLM 接口处理视频帧""" - progress_callback(30, "正在初始化视觉分析器...") - - # 使用新的 LLM 迁移适配器(支持所有 provider) - from app.services.llm.migration_adapter import create_vision_analyzer - - # 获取配置 - text_provider = config.app.get('text_llm_provider', 'openai').lower() - vision_api_key = config.app.get(f'vision_{vision_llm_provider}_api_key') - vision_model = config.app.get(f'vision_{vision_llm_provider}_model_name') - vision_base_url = config.app.get(f'vision_{vision_llm_provider}_base_url') - - if not vision_api_key or not vision_model: - raise ValueError(f"未配置 {vision_llm_provider} API Key 或者模型") - - # 创建统一的视觉分析器 - analyzer = create_vision_analyzer( - provider=vision_llm_provider, - api_key=vision_api_key, - model=vision_model, - base_url=vision_base_url + return await self.documentary_service.generate_documentary_script( + video_path=video_path, + video_theme=video_theme, + custom_prompt=custom_prompt, + frame_interval_input=frame_interval_input, + vision_batch_size=vision_batch_size, + vision_llm_provider=vision_llm_provider, + progress_callback=callback, ) - - progress_callback(40, "正在分析关键帧...") - - # 执行异步分析 - results = await analyzer.analyze_images( - images=keyframe_files, - prompt=config.app.get('vision_analysis_prompt'), - batch_size=vision_batch_size - ) - - progress_callback(60, "正在整理分析结果...") - - # 合并所有批次的分析结果 - frame_analysis = "" - prev_batch_files = None - - for result in results: - if 'error' in result: - logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}") - continue - - batch_files = self._get_batch_files(keyframe_files, result, vision_batch_size) - first_timestamp, last_timestamp, _ = self._get_batch_timestamps(batch_files, prev_batch_files) - - # 添加带时间戳的分��结果 - frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n" - frame_analysis += result['response'] - frame_analysis += "\n" - - prev_batch_files = batch_files - - if not frame_analysis.strip(): - raise Exception("未能生成有效的帧分析结果") - - progress_callback(70, "正在生成脚本...") - - # 构建帧内容列表 - frame_content_list = [] - prev_batch_files = None - - for result in results: - if 'error' in result: - continue - - batch_files = self._get_batch_files(keyframe_files, result, vision_batch_size) - _, _, timestamp_range = self._get_batch_timestamps(batch_files, prev_batch_files) - - frame_content = { - "timestamp": timestamp_range, - "picture": result['response'], - "narration": "", - "OST": 2 - } - frame_content_list.append(frame_content) - prev_batch_files = batch_files - - if not frame_content_list: - raise Exception("没有有效的帧内容可以处理") - - progress_callback(90, "正在生成文案...") - - # 获取文本生��配置 - text_provider = config.app.get('text_llm_provider', 'gemini').lower() - text_api_key = config.app.get(f'text_{text_provider}_api_key') - text_model = config.app.get(f'text_{text_provider}_model_name') - text_base_url = config.app.get(f'text_{text_provider}_base_url') - - # 根据提供商类型选择合适的处理器 - if text_provider == 'gemini(openai)': - # 使用OpenAI兼容的Gemini代理 - from app.utils.script_generator import GeminiOpenAIGenerator - generator = GeminiOpenAIGenerator( - model_name=text_model, - api_key=text_api_key, - prompt=custom_prompt, - base_url=text_base_url - ) - processor = ScriptProcessor( - model_name=text_model, - api_key=text_api_key, - base_url=text_base_url, - prompt=custom_prompt, - video_theme=video_theme - ) - processor.generator = generator - else: - # 使用标准处理器(包括原生Gemini) - processor = ScriptProcessor( - model_name=text_model, - api_key=text_api_key, - base_url=text_base_url, - prompt=custom_prompt, - video_theme=video_theme - ) - - return processor.process_frames(frame_content_list) - - def _get_batch_files( - self, - keyframe_files: List[str], - result: Dict[str, Any], - batch_size: int - ) -> List[str]: - """获取当前批次的图片文件""" - batch_start = result['batch_index'] * batch_size - batch_end = min(batch_start + batch_size, len(keyframe_files)) - return keyframe_files[batch_start:batch_end] - - def _get_batch_timestamps( - self, - batch_files: List[str], - prev_batch_files: List[str] = None - ) -> tuple[str, str, str]: - """获取一批文件的时间戳范围,支持毫秒级精度""" - if not batch_files: - logger.warning("Empty batch files") - return "00:00:00,000", "00:00:00,000", "00:00:00,000-00:00:00,000" - - if len(batch_files) == 1 and prev_batch_files and len(prev_batch_files) > 0: - first_frame = os.path.basename(prev_batch_files[-1]) - last_frame = os.path.basename(batch_files[0]) - else: - first_frame = os.path.basename(batch_files[0]) - last_frame = os.path.basename(batch_files[-1]) - - first_time = first_frame.split('_')[2].replace('.jpg', '') - last_time = last_frame.split('_')[2].replace('.jpg', '') - - def format_timestamp(time_str: str) -> str: - """将时间字符串转换为 HH:MM:SS,mmm 格式""" - try: - if len(time_str) < 4: - logger.warning(f"Invalid timestamp format: {time_str}") - return "00:00:00,000" - - # 处理毫秒部分 - if ',' in time_str: - time_part, ms_part = time_str.split(',') - ms = int(ms_part) - else: - time_part = time_str - ms = 0 - - # 处理时分秒 - parts = time_part.split(':') - if len(parts) == 3: # HH:MM:SS - h, m, s = map(int, parts) - elif len(parts) == 2: # MM:SS - h = 0 - m, s = map(int, parts) - else: # SS - h = 0 - m = 0 - s = int(parts[0]) - - # 处理进位 - if s >= 60: - m += s // 60 - s = s % 60 - if m >= 60: - h += m // 60 - m = m % 60 - - return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}" - - except Exception as e: - logger.error(f"时间戳格式转换错误 {time_str}: {str(e)}") - return "00:00:00,000" - - first_timestamp = format_timestamp(first_time) - last_timestamp = format_timestamp(last_time) - timestamp_range = f"{first_timestamp}-{last_timestamp}" - - return first_timestamp, last_timestamp, timestamp_range diff --git a/app/utils/utils.py b/app/utils/utils.py index d101dce..19dd46f 100644 --- a/app/utils/utils.py +++ b/app/utils/utils.py @@ -570,29 +570,39 @@ def temp_dir(sub_dir: str = ""): return d -def clear_keyframes_cache(video_path: str = None): +def clear_keyframes_cache(video_path: str = None, cache_scope: str = "keyframes"): """ 清理关键帧缓存 Args: video_path: 视频文件路径,如果指定则只清理该视频的缓存 + cache_scope: 缓存作用域目录,默认 keyframes """ try: - keyframes_dir = os.path.join(temp_dir(), "keyframes") - if not os.path.exists(keyframes_dir): + cache_dir = os.path.join(temp_dir(), cache_scope) + if not os.path.exists(cache_dir): return + import shutil + if video_path: - # 理指定视频的缓存 - video_hash = md5(video_path + str(os.path.getmtime(video_path))) - video_keyframes_dir = os.path.join(keyframes_dir, video_hash) - if os.path.exists(video_keyframes_dir): - import shutil - shutil.rmtree(video_keyframes_dir) - logger.info(f"已清理视频关键帧缓存: {video_path}") + # 清理指定视频的缓存(兼容前缀扩展键) + try: + video_mtime = os.path.getmtime(video_path) + except OSError: + video_mtime = 0 + video_hash = md5(video_path + str(video_mtime)) + for entry in os.listdir(cache_dir): + if not entry.startswith(video_hash): + continue + target_path = os.path.join(cache_dir, entry) + if os.path.isdir(target_path): + shutil.rmtree(target_path) + else: + os.remove(target_path) + logger.info(f"已清理视频关键帧缓存: {video_path}") else: # 清理所有缓存 - import shutil - shutil.rmtree(keyframes_dir) + shutil.rmtree(cache_dir) logger.info("已清理所有关键帧缓存") except Exception as e: diff --git a/app/utils/video_processor.py b/app/utils/video_processor.py index 6c46737..5a2c95a 100644 --- a/app/utils/video_processor.py +++ b/app/utils/video_processor.py @@ -185,6 +185,95 @@ class VideoProcessor: return frame_numbers + def extract_frames_by_interval_with_fallback(self, output_dir: str, interval_seconds: float = 5.0) -> List[str]: + """ + 先尝试单次 ffmpeg 快路径抽帧,失败时回退到高兼容方案。 + """ + if interval_seconds <= 0: + raise ValueError("interval_seconds must be > 0") + + os.makedirs(output_dir, exist_ok=True) + + try: + return self._extract_frames_fast_path(output_dir, interval_seconds=interval_seconds) + except Exception as exc: + logger.warning(f"快路径抽帧失败,回退到兼容模式: {exc}") + self._cleanup_fast_path_artifacts(output_dir) + self.extract_frames_by_interval_ultra_compatible(output_dir, interval_seconds=interval_seconds) + return self._collect_extracted_frame_paths(output_dir) + + def _extract_frames_fast_path(self, output_dir: str, interval_seconds: float = 5.0) -> List[str]: + """ + 使用单次 ffmpeg 命令按固定间隔抽帧,随后重命名为既有 keyframe 约定格式。 + """ + if interval_seconds <= 0: + raise ValueError("interval_seconds must be > 0") + + os.makedirs(output_dir, exist_ok=True) + raw_pattern = os.path.join(output_dir, "fastframe_%06d.jpg") + cmd = [ + "ffmpeg", + "-hide_banner", + "-loglevel", + "error", + "-i", + self.video_path, + "-vf", + f"fps=1/{interval_seconds}", + "-q:v", + "2", + "-start_number", + "0", + "-y", + raw_pattern, + ] + subprocess.run(cmd, capture_output=True, text=True, check=True, timeout=120) + + raw_files = sorted( + filename + for filename in os.listdir(output_dir) + if re.fullmatch(r"fastframe_\d{6}\.jpg", filename) + ) + if not raw_files: + raise RuntimeError("Fast-path extraction produced no frames") + + renamed_files: List[str] = [] + for index, filename in enumerate(raw_files): + timestamp = index * interval_seconds + frame_number = int(timestamp * self.fps) + token = self._format_timestamp_token(timestamp) + source_path = os.path.join(output_dir, filename) + target_path = os.path.join(output_dir, f"keyframe_{frame_number:06d}_{token}.jpg") + os.replace(source_path, target_path) + renamed_files.append(target_path) + + return renamed_files + + @staticmethod + def _format_timestamp_token(timestamp: float) -> str: + hours = int(timestamp // 3600) + minutes = int((timestamp % 3600) // 60) + seconds = int(timestamp % 60) + milliseconds = int((timestamp % 1) * 1000) + return f"{hours:02d}{minutes:02d}{seconds:02d}{milliseconds:03d}" + + @staticmethod + def _collect_extracted_frame_paths(output_dir: str) -> List[str]: + return sorted( + os.path.join(output_dir, name) + for name in os.listdir(output_dir) + if re.fullmatch(r"keyframe_\d{6}_\d{9}\.jpg", name) + ) + + @staticmethod + def _cleanup_fast_path_artifacts(output_dir: str) -> None: + for name in os.listdir(output_dir): + if not re.fullmatch(r"fastframe_\d{6}\.jpg", name): + continue + artifact_path = os.path.join(output_dir, name) + if os.path.isfile(artifact_path): + os.remove(artifact_path) + def _extract_single_frame_optimized(self, timestamp: float, output_path: str, use_hw_accel: bool, hwaccel_type: str) -> bool: """ diff --git a/config.example.toml b/config.example.toml index f226c34..781aaa6 100644 --- a/config.example.toml +++ b/config.example.toml @@ -1,5 +1,5 @@ [app] - project_version="0.7.6" + project_version="0.7.8" # LLM API 超时配置(秒) llm_vision_timeout = 120 # 视觉模型基础超时时间 @@ -152,3 +152,6 @@ # 大模型单次处理的关键帧数量 vision_batch_size = 10 + + # 视觉批处理最大并发批次数(OpenAI 兼容 provider) + vision_max_concurrency = 2 diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000..b08f2df --- /dev/null +++ b/conftest.py @@ -0,0 +1,11 @@ +"""Pytest collection rules for the repository. + +These files are executable smoke-check scripts that live next to the LLM +implementation for convenience. They require live credentials or manual +execution semantics, so keep them out of the default automated test suite. +""" + +collect_ignore = [ + "app/services/llm/test_llm_service.py", + "app/services/llm/test_openai_compatible_integration.py", +] diff --git a/project_version b/project_version index 11d9d6c..e7c7d3c 100644 --- a/project_version +++ b/project_version @@ -1 +1 @@ -0.7.7 \ No newline at end of file +0.7.8 diff --git a/tests/test_documentary_frame_analysis_service.py b/tests/test_documentary_frame_analysis_service.py new file mode 100644 index 0000000..a444068 --- /dev/null +++ b/tests/test_documentary_frame_analysis_service.py @@ -0,0 +1,275 @@ +import unittest +import os +from tempfile import TemporaryDirectory +from unittest.mock import patch + +from app.services.documentary.frame_analysis_models import DocumentaryAnalysisConfig +from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService +from app.utils import utils + + +class DocumentaryFrameAnalysisServiceTests(unittest.TestCase): + def test_build_analysis_prompt_formats_real_frame_count(self): + service = DocumentaryFrameAnalysisService() + + prompt = service._build_analysis_prompt(frame_count=3) + + self.assertIn("我提供了 3 张视频帧", prompt) + self.assertNotIn("%s", prompt) + self.assertIn("frame_observations", prompt) + self.assertIn("overall_activity_summary", prompt) + + def test_parse_failed_batch_keeps_raw_response_and_time_range(self): + service = DocumentaryFrameAnalysisService() + + batch = service._build_failed_batch_result( + batch_index=2, + raw_response="not-json", + error_message="JSON decode failed", + frame_paths=["/tmp/keyframe_000000_000000000.jpg"], + time_range="00:00:00,000-00:00:03,000", + ) + + self.assertEqual("failed", batch.status) + self.assertEqual("not-json", batch.raw_response) + self.assertEqual("00:00:00,000-00:00:03,000", batch.time_range) + self.assertTrue(batch.fallback_summary) + + def test_parse_failed_batch_uses_non_empty_fallback_when_raw_response_is_empty(self): + service = DocumentaryFrameAnalysisService() + + batch = service._build_failed_batch_result( + batch_index=3, + raw_response="", + error_message="Empty model response", + frame_paths=["/tmp/keyframe_000001_000001000.jpg"], + time_range="00:00:03,000-00:00:06,000", + ) + + self.assertEqual("failed", batch.status) + self.assertEqual("", batch.raw_response) + self.assertTrue(batch.fallback_summary) + + def test_failed_batch_result_uses_prompt_contract_field_names(self): + service = DocumentaryFrameAnalysisService() + + batch = service._build_failed_batch_result( + batch_index=4, + raw_response="not-json", + error_message="JSON decode failed", + frame_paths=["/tmp/keyframe_000002_000002000.jpg"], + time_range="00:00:06,000-00:00:09,000", + ) + + self.assertEqual([], batch.frame_observations) + self.assertEqual("", batch.overall_activity_summary) + self.assertFalse(hasattr(batch, "observations")) + self.assertFalse(hasattr(batch, "summary")) + + def test_parse_batch_returns_failed_result_when_json_is_invalid(self): + service = DocumentaryFrameAnalysisService() + + batch = service._parse_batch_response( + batch_index=0, + raw_response="plain text", + frame_paths=["/tmp/keyframe_000000_000000000.jpg"], + time_range="00:00:00,000-00:00:03,000", + ) + + self.assertEqual("failed", batch.status) + self.assertEqual("plain text", batch.raw_response) + self.assertEqual(["/tmp/keyframe_000000_000000000.jpg"], batch.frame_paths) + self.assertEqual([], batch.frame_observations) + self.assertEqual("", batch.overall_activity_summary) + + def test_parse_batch_returns_failed_result_for_empty_json_object(self): + service = DocumentaryFrameAnalysisService() + + batch = service._parse_batch_response( + batch_index=0, + raw_response="{}", + frame_paths=["/tmp/keyframe_000000_000000000.jpg"], + time_range="00:00:00,000-00:00:03,000", + ) + + self.assertEqual("failed", batch.status) + self.assertEqual("{}", batch.raw_response) + self.assertIn("frame_observations", batch.error_message) + + def test_parse_batch_returns_failed_result_when_observations_are_too_short(self): + service = DocumentaryFrameAnalysisService() + raw_response = """ +{ + "frame_observations": [ + {"observation": "第一帧画面"} + ], + "overall_activity_summary": "只有一条帧观察" +} +""".strip() + + batch = service._parse_batch_response( + batch_index=1, + raw_response=raw_response, + frame_paths=[ + "/tmp/keyframe_000000_000000000.jpg", + "/tmp/keyframe_000075_000003000.jpg", + ], + time_range="00:00:00,000-00:00:06,000", + ) + + self.assertEqual("failed", batch.status) + self.assertEqual(raw_response, batch.raw_response) + self.assertIn("frame_observations", batch.error_message) + + def test_parse_batch_parses_code_fenced_json_into_structured_result(self): + service = DocumentaryFrameAnalysisService() + raw_response = """```json +{ + "frame_observations": [ + {"observation": "第一帧画面"}, + {"observation": "第二帧画面"} + ], + "overall_activity_summary": "人物从房间走到街道" +} +```""" + + batch = service._parse_batch_response( + batch_index=1, + raw_response=raw_response, + frame_paths=[ + "/tmp/keyframe_000000_000000000.jpg", + "/tmp/keyframe_000075_000003000.jpg", + ], + time_range="00:00:00,000-00:00:06,000", + ) + + self.assertEqual("success", batch.status) + self.assertEqual( + [ + { + "frame_path": "/tmp/keyframe_000000_000000000.jpg", + "timestamp": "", + "observation": "第一帧画面", + }, + { + "frame_path": "/tmp/keyframe_000075_000003000.jpg", + "timestamp": "", + "observation": "第二帧画面", + }, + ], + batch.frame_observations, + ) + self.assertEqual("人物从房间走到街道", batch.overall_activity_summary) + self.assertEqual("", batch.fallback_summary) + + def test_parse_batch_preserves_frames_when_summary_is_missing(self): + service = DocumentaryFrameAnalysisService() + raw_response = """ +{ + "frame_observations": [ + {"observation": "第一帧画面"}, + {"observation": "第二帧画面"} + ] +} +""".strip() + + batch = service._parse_batch_response( + batch_index=2, + raw_response=raw_response, + frame_paths=[ + "/tmp/keyframe_000000_000000000.jpg", + "/tmp/keyframe_000075_000003000.jpg", + ], + time_range="00:00:00,000-00:00:06,000", + ) + + self.assertEqual("success", batch.status) + self.assertEqual(2, len(batch.frame_observations)) + self.assertEqual("", batch.overall_activity_summary) + + def test_cache_key_changes_when_interval_changes(self): + service = DocumentaryFrameAnalysisService() + + with patch("app.services.documentary.frame_analysis_service.os.path.getmtime", return_value=100.0): + key_a = service._build_cache_key("video.mp4", 3.0, "prompt-v1", "model-a", 10, 2) + key_b = service._build_cache_key("video.mp4", 5.0, "prompt-v1", "model-a", 10, 2) + + self.assertNotEqual(key_a, key_b) + + def test_cache_key_changes_when_model_changes(self): + service = DocumentaryFrameAnalysisService() + + with patch("app.services.documentary.frame_analysis_service.os.path.getmtime", return_value=100.0): + key_a = service._build_cache_key("video.mp4", 3.0, "prompt-v1", "model-a", 10, 2) + key_b = service._build_cache_key("video.mp4", 3.0, "prompt-v1", "model-b", 10, 2) + + self.assertNotEqual(key_a, key_b) + + def test_cache_key_starts_with_legacy_video_hash_prefix(self): + service = DocumentaryFrameAnalysisService() + + with patch("app.services.documentary.frame_analysis_service.os.path.getmtime", return_value=123.0): + key = service._build_cache_key("video.mp4", 3.0, "prompt-v1", "model-a", 10, 2) + + expected_prefix = utils.md5("video.mp4" + "123.0") + self.assertTrue(key.startswith(expected_prefix)) + + def test_clear_keyframes_cache_respects_scope_and_prefix_match(self): + with TemporaryDirectory() as temp_root: + service = DocumentaryFrameAnalysisService() + analysis_dir = os.path.join(temp_root, "analysis") + os.makedirs(analysis_dir, exist_ok=True) + + with patch("app.services.documentary.frame_analysis_service.os.path.getmtime", return_value=123.0): + target_key_a = service._build_cache_key("video.mp4", 3.0, "prompt-v1", "model-a", 10, 2) + target_key_b = service._build_cache_key("video.mp4", 5.0, "prompt-v1", "model-a", 10, 2) + keep_key = service._build_cache_key("other.mp4", 3.0, "prompt-v1", "model-a", 10, 2) + + target_dir_a = os.path.join(analysis_dir, target_key_a) + target_dir_b = os.path.join(analysis_dir, target_key_b) + keep_dir = os.path.join(analysis_dir, keep_key) + os.makedirs(target_dir_a, exist_ok=True) + os.makedirs(target_dir_b, exist_ok=True) + os.makedirs(keep_dir, exist_ok=True) + + with patch("app.utils.utils.temp_dir", return_value=temp_root), patch( + "app.utils.utils.os.path.getmtime", return_value=123.0 + ): + utils.clear_keyframes_cache(video_path="video.mp4", cache_scope="analysis") + + self.assertFalse(os.path.exists(target_dir_a)) + self.assertFalse(os.path.exists(target_dir_b)) + self.assertTrue(os.path.exists(keep_dir)) + + +class DocumentaryAnalysisConfigTests(unittest.TestCase): + def test_config_rejects_non_positive_frame_interval(self): + with self.assertRaises(ValueError): + DocumentaryAnalysisConfig( + video_path="/tmp/demo.mp4", + frame_interval_seconds=0, + vision_batch_size=5, + vision_llm_provider="openai", + vision_model_name="gpt-4o-mini", + ) + + def test_config_rejects_non_positive_batch_size(self): + with self.assertRaises(ValueError): + DocumentaryAnalysisConfig( + video_path="/tmp/demo.mp4", + frame_interval_seconds=5, + vision_batch_size=0, + vision_llm_provider="openai", + vision_model_name="gpt-4o-mini", + ) + + def test_config_rejects_non_positive_max_concurrency(self): + with self.assertRaises(ValueError): + DocumentaryAnalysisConfig( + video_path="/tmp/demo.mp4", + frame_interval_seconds=5, + vision_batch_size=5, + vision_llm_provider="openai", + vision_model_name="gpt-4o-mini", + max_concurrency=0, + ) diff --git a/tests/test_generate_narration_script_documentary_unittest.py b/tests/test_generate_narration_script_documentary_unittest.py new file mode 100644 index 0000000..edb6963 --- /dev/null +++ b/tests/test_generate_narration_script_documentary_unittest.py @@ -0,0 +1,58 @@ +import json +import unittest +from pathlib import Path +from tempfile import TemporaryDirectory + +from app.services.generate_narration_script import parse_frame_analysis_to_markdown + + +class GenerateNarrationMarkdownTests(unittest.TestCase): + def test_markdown_keeps_batches_without_summary_and_sorts_by_time(self): + artifact = { + "batches": [ + { + "batch_index": 1, + "time_range": "00:00:03,000-00:00:06,000", + "overall_activity_summary": "人物转身跑向远处", + "fallback_summary": "", + "frame_observations": [ + { + "timestamp": "00:00:03,000", + "observation": "人物突然回头", + } + ], + }, + { + "batch_index": 0, + "time_range": "00:00:00,000-00:00:03,000", + "overall_activity_summary": "", + "fallback_summary": "原始响应回退摘要", + "frame_observations": [ + { + "timestamp": "00:00:00,000", + "observation": "镜头里有一只猫", + } + ], + }, + ] + } + + with TemporaryDirectory() as temp_dir: + analysis_path = Path(temp_dir) / "frame-analysis.json" + analysis_path.write_text(json.dumps(artifact, ensure_ascii=False), encoding="utf-8") + markdown = parse_frame_analysis_to_markdown(str(analysis_path)) + + first_range_index = markdown.find("00:00:00,000-00:00:03,000") + second_range_index = markdown.find("00:00:03,000-00:00:06,000") + + self.assertIn("原始响应回退摘要", markdown) + self.assertIn("镜头里有一只猫", markdown) + self.assertIn("人物转身跑向远处", markdown) + self.assertIn("人物突然回头", markdown) + self.assertNotEqual(-1, first_range_index) + self.assertNotEqual(-1, second_range_index) + self.assertLess(first_range_index, second_range_index) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_generate_script_docu_unittest.py b/tests/test_generate_script_docu_unittest.py new file mode 100644 index 0000000..6431ae5 --- /dev/null +++ b/tests/test_generate_script_docu_unittest.py @@ -0,0 +1,19 @@ +import unittest + +from webui.tools.generate_script_docu import _normalize_progress_value + + +class GenerateScriptDocuProgressTests(unittest.TestCase): + def test_normalize_progress_rounds_percentage_float_to_valid_streamlit_int(self): + self.assertEqual(43, _normalize_progress_value(43.125)) + + def test_normalize_progress_converts_ratio_float_to_percentage_int(self): + self.assertEqual(43, _normalize_progress_value(0.43125)) + + def test_normalize_progress_clamps_out_of_range_values(self): + self.assertEqual(0, _normalize_progress_value(-5)) + self.assertEqual(100, _normalize_progress_value(101)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_script_service_documentary_unittest.py b/tests/test_script_service_documentary_unittest.py new file mode 100644 index 0000000..d6789f1 --- /dev/null +++ b/tests/test_script_service_documentary_unittest.py @@ -0,0 +1,316 @@ +import json +import unittest +from pathlib import Path +from tempfile import TemporaryDirectory +from unittest.mock import AsyncMock, patch + +from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService +from app.services.script_service import ScriptGenerator + + +class ScriptGeneratorDocumentaryTests(unittest.IsolatedAsyncioTestCase): + async def test_generate_script_forwards_explicit_values_to_shared_service(self): + expected_script = [ + { + "timestamp": "00:00:00,000-00:00:03,000", + "picture": "批次描述", + "narration": "这里是解说词", + "OST": 2, + } + ] + callback = lambda _percent, _message: None + + with patch("app.services.script_service.DocumentaryFrameAnalysisService") as service_cls: + service = service_cls.return_value + service.generate_documentary_script = AsyncMock(return_value=expected_script) + generator = ScriptGenerator() + + result = await generator.generate_script( + video_path="demo.mp4", + video_theme="荒野生存", + custom_prompt="请聚焦生存动作", + frame_interval_input=3, + vision_batch_size=6, + vision_llm_provider="openai", + progress_callback=callback, + ) + + self.assertEqual(expected_script, result) + self.assertTrue(result[0]["narration"]) + service.generate_documentary_script.assert_awaited_once() + called_kwargs = service.generate_documentary_script.await_args.kwargs + self.assertEqual("demo.mp4", called_kwargs["video_path"]) + self.assertEqual(3, called_kwargs["frame_interval_input"]) + self.assertEqual(6, called_kwargs["vision_batch_size"]) + self.assertEqual("openai", called_kwargs["vision_llm_provider"]) + self.assertEqual("荒野生存", called_kwargs["video_theme"]) + self.assertEqual("请聚焦生存动作", called_kwargs["custom_prompt"]) + self.assertIs(called_kwargs["progress_callback"], callback) + + async def test_generate_script_forwards_unset_values_as_none(self): + expected_script = [ + { + "timestamp": "00:00:00,000-00:00:03,000", + "picture": "批次描述", + "narration": "这里是解说词", + "OST": 2, + } + ] + with patch("app.services.script_service.DocumentaryFrameAnalysisService") as service_cls: + service = service_cls.return_value + service.generate_documentary_script = AsyncMock(return_value=expected_script) + generator = ScriptGenerator() + + await generator.generate_script(video_path="demo.mp4") + + called_kwargs = service.generate_documentary_script.await_args.kwargs + self.assertIsNone(called_kwargs["frame_interval_input"]) + self.assertIsNone(called_kwargs["vision_batch_size"]) + self.assertIsNone(called_kwargs["vision_llm_provider"]) + + async def test_generate_script_warns_when_skip_seconds_or_threshold_are_non_default(self): + expected_script = [ + { + "timestamp": "00:00:00,000-00:00:03,000", + "picture": "批次描述", + "narration": "这里是解说词", + "OST": 2, + } + ] + with patch("app.services.script_service.DocumentaryFrameAnalysisService") as service_cls, patch( + "app.services.script_service.logger.warning" + ) as warning: + service = service_cls.return_value + service.generate_documentary_script = AsyncMock(return_value=expected_script) + generator = ScriptGenerator() + await generator.generate_script( + video_path="demo.mp4", + skip_seconds=2, + threshold=20, + ) + + warning.assert_called_once() + warning_message = warning.call_args.args[0] + self.assertIn("skip_seconds", warning_message) + self.assertIn("threshold", warning_message) + self.assertIn("does not currently apply", warning_message) + + +class DocumentaryFrameAnalysisServiceScriptGenerationTests(unittest.IsolatedAsyncioTestCase): + async def test_generate_documentary_script_returns_final_narrated_items(self): + service = DocumentaryFrameAnalysisService() + analysis_payload = { + "batches": [ + { + "batch_index": 0, + "time_range": "00:00:00,000-00:00:03,000", + "overall_activity_summary": "", + "fallback_summary": "回退摘要", + "frame_observations": [ + {"timestamp": "00:00:00,000", "observation": "镜头里有一只猫"}, + ], + } + ] + } + + with TemporaryDirectory() as temp_dir: + analysis_path = Path(temp_dir) / "frame_analysis_test.json" + analysis_path.write_text(json.dumps(analysis_payload, ensure_ascii=False), encoding="utf-8") + + with patch.object( + DocumentaryFrameAnalysisService, + "analyze_video", + AsyncMock(return_value={"analysis_json_path": str(analysis_path)}), + ), patch.dict( + "app.services.documentary.frame_analysis_service.config.app", + { + "text_llm_provider": "openai", + "text_openai_api_key": "test-key", + "text_openai_model_name": "test-model", + "text_openai_base_url": "https://example.com/v1", + }, + ), patch( + "app.services.documentary.frame_analysis_service.generate_narration", + return_value='{"items":[{"timestamp":"00:00:00,000-00:00:03,000","picture":"镜头里有一只猫","narration":"一只猫警觉地望向镜头。"}]}', + ): + result = await service.generate_documentary_script(video_path="demo.mp4") + + self.assertEqual(1, len(result)) + self.assertEqual("00:00:00,000-00:00:03,000", result[0]["timestamp"]) + self.assertEqual("镜头里有一只猫", result[0]["picture"]) + self.assertEqual("一只猫警觉地望向镜头。", result[0]["narration"]) + self.assertEqual(2, result[0]["OST"]) + + async def test_generate_documentary_script_raises_when_narration_json_is_malformed(self): + service = DocumentaryFrameAnalysisService() + analysis_payload = { + "batches": [ + { + "batch_index": 0, + "time_range": "00:00:00,000-00:00:03,000", + "overall_activity_summary": "测试摘要", + "fallback_summary": "", + "frame_observations": [ + {"timestamp": "00:00:00,000", "observation": "镜头里有一只猫"}, + ], + } + ] + } + + with TemporaryDirectory() as temp_dir: + analysis_path = Path(temp_dir) / "frame_analysis_test.json" + analysis_path.write_text(json.dumps(analysis_payload, ensure_ascii=False), encoding="utf-8") + + with patch.object( + DocumentaryFrameAnalysisService, + "analyze_video", + AsyncMock(return_value={"analysis_json_path": str(analysis_path)}), + ), patch.dict( + "app.services.documentary.frame_analysis_service.config.app", + { + "text_llm_provider": "openai", + "text_openai_api_key": "test-key", + "text_openai_model_name": "test-model", + "text_openai_base_url": "https://example.com/v1", + }, + ), patch( + "app.services.documentary.frame_analysis_service.generate_narration", + return_value="malformed narration payload", + ): + with self.assertRaises(Exception) as ctx: + await service.generate_documentary_script(video_path="demo.mp4") + + self.assertIn("解说文案格式错误", str(ctx.exception)) + self.assertIn("items", str(ctx.exception)) + + def test_parse_narration_items_recovers_from_common_json_damage(self): + service = DocumentaryFrameAnalysisService() + damaged_payload = """ +解释文字 +```json +{{ + "items": [ + {{ + "timestamp": "00:00:00,000-00:00:03,000", + "picture": "镜头里有一只猫", + "narration": "一只猫警觉地望向镜头。", + }}, + ], +}} +``` +补充文字 +""".strip() + + parsed_items = service._parse_narration_items(damaged_payload) + + self.assertEqual(1, len(parsed_items)) + self.assertEqual("00:00:00,000-00:00:03,000", parsed_items[0]["timestamp"]) + self.assertEqual("镜头里有一只猫", parsed_items[0]["picture"]) + self.assertEqual("一只猫警觉地望向镜头。", parsed_items[0]["narration"]) + + def test_parse_narration_items_raises_for_unrecoverable_payload(self): + service = DocumentaryFrameAnalysisService() + + with self.assertRaises(ValueError) as ctx: + service._parse_narration_items("not-json-at-all ::: ???") + + self.assertIn("解说文案格式错误", str(ctx.exception)) + self.assertIn("items", str(ctx.exception)) + + async def test_generate_documentary_script_includes_theme_and_custom_prompt_for_narration(self): + service = DocumentaryFrameAnalysisService() + analysis_payload = { + "batches": [ + { + "batch_index": 0, + "time_range": "00:00:00,000-00:00:03,000", + "overall_activity_summary": "测试摘要", + "fallback_summary": "", + "frame_observations": [ + {"timestamp": "00:00:00,000", "observation": "镜头里有一只猫"}, + ], + } + ] + } + + with TemporaryDirectory() as temp_dir: + analysis_path = Path(temp_dir) / "frame_analysis_test.json" + analysis_path.write_text(json.dumps(analysis_payload, ensure_ascii=False), encoding="utf-8") + + with patch.object( + DocumentaryFrameAnalysisService, + "analyze_video", + AsyncMock(return_value={"analysis_json_path": str(analysis_path)}), + ), patch.dict( + "app.services.documentary.frame_analysis_service.config.app", + { + "text_llm_provider": "openai", + "text_openai_api_key": "test-key", + "text_openai_model_name": "test-model", + "text_openai_base_url": "https://example.com/v1", + }, + ), patch( + "app.services.documentary.frame_analysis_service.generate_narration", + return_value='{"items":[{"timestamp":"00:00:00,000-00:00:03,000","picture":"镜头里有一只猫","narration":"一只猫警觉地望向镜头。"}]}', + ) as mocked_generate: + await service.generate_documentary_script( + video_path="demo.mp4", + video_theme="野生动物纪录片", + custom_prompt="重点描述危险信号", + ) + + narration_input = mocked_generate.call_args.args[0] + self.assertIn("## 创作上下文", narration_input) + self.assertIn("视频主题:野生动物纪录片", narration_input) + self.assertIn("补充创作要求:重点描述危险信号", narration_input) + + async def test_analyze_video_forwards_explicit_empty_base_url_without_config_fallback(self): + service = DocumentaryFrameAnalysisService() + + with patch.dict( + "app.services.documentary.frame_analysis_service.config.app", + { + "vision_llm_provider": "openai", + "vision_openai_api_key": "config-key", + "vision_openai_model_name": "config-model", + "vision_openai_base_url": "https://config.example/v1", + }, + ), patch( + "app.services.documentary.frame_analysis_service.os.path.exists", + return_value=True, + ), patch.object( + service, + "_load_or_extract_keyframes", + return_value=["/tmp/keyframe_000001_000000100.jpg"], + ), patch.object( + service, + "_analyze_batches", + AsyncMock(return_value=[]), + ), patch.object( + service, + "_save_analysis_artifact", + return_value="/tmp/frame_analysis_test.json", + ), patch.object( + service, + "_build_video_clip_json", + return_value=[], + ), patch( + "app.services.documentary.frame_analysis_service.create_vision_analyzer", + return_value=object(), + ) as mocked_create_analyzer: + await service.analyze_video( + video_path="/tmp/demo.mp4", + vision_api_key="explicit-key", + vision_model_name="explicit-model", + vision_base_url="", + ) + + called_kwargs = mocked_create_analyzer.call_args.kwargs + self.assertEqual("openai", called_kwargs["provider"]) + self.assertEqual("explicit-key", called_kwargs["api_key"]) + self.assertEqual("explicit-model", called_kwargs["model"]) + self.assertEqual("", called_kwargs["base_url"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_video_processor_documentary_unittest.py b/tests/test_video_processor_documentary_unittest.py new file mode 100644 index 0000000..fe8e51b --- /dev/null +++ b/tests/test_video_processor_documentary_unittest.py @@ -0,0 +1,91 @@ +import os +import unittest +from tempfile import TemporaryDirectory +from unittest.mock import patch + +from app.utils.video_processor import VideoProcessor + + +class VideoProcessorDocumentaryTests(unittest.TestCase): + @patch.object(VideoProcessor, "_extract_frames_fast_path", return_value=["a.jpg"]) + def test_extract_frames_by_interval_prefers_fast_path(self, fast_path): + processor = VideoProcessor.__new__(VideoProcessor) + processor.video_path = "demo.mp4" + processor.duration = 6.0 + processor.fps = 25.0 + + result = processor.extract_frames_by_interval_with_fallback("/tmp/out", interval_seconds=3.0) + + self.assertEqual(["a.jpg"], result) + fast_path.assert_called_once_with("/tmp/out", interval_seconds=3.0) + + def test_extract_frames_by_interval_falls_back_to_ultra_compatible(self): + processor = VideoProcessor.__new__(VideoProcessor) + processor.video_path = "demo.mp4" + processor.duration = 6.0 + processor.fps = 25.0 + + with TemporaryDirectory() as output_dir: + expected_frame_path = os.path.join(output_dir, "keyframe_000000_000000000.jpg") + + def ultra_compatible_fallback(self, output_dir_arg, interval_seconds=5.0): + with open(expected_frame_path, "wb") as frame_file: + frame_file.write(b"frame") + return [0] + + with patch.object(VideoProcessor, "_extract_frames_fast_path", side_effect=RuntimeError("fast path failed")) as fast_path, patch.object( + VideoProcessor, + "extract_frames_by_interval_ultra_compatible", + side_effect=ultra_compatible_fallback, + autospec=True, + ) as fallback: + result = processor.extract_frames_by_interval_with_fallback(output_dir, interval_seconds=3.0) + + self.assertEqual([expected_frame_path], result) + fast_path.assert_called_once_with(output_dir, interval_seconds=3.0) + fallback.assert_called_once_with(processor, output_dir, interval_seconds=3.0) + + def test_extract_frames_by_interval_rejects_non_positive_interval(self): + processor = VideoProcessor.__new__(VideoProcessor) + processor.video_path = "demo.mp4" + processor.duration = 6.0 + processor.fps = 25.0 + + with patch.object(VideoProcessor, "extract_frames_by_interval_ultra_compatible", autospec=True) as fallback: + with self.assertRaises(ValueError): + processor.extract_frames_by_interval_with_fallback("/tmp/out", interval_seconds=0) + + fallback.assert_not_called() + + def test_extract_frames_by_interval_fallback_cleans_partial_fast_path_artifacts(self): + processor = VideoProcessor.__new__(VideoProcessor) + processor.video_path = "demo.mp4" + processor.duration = 6.0 + processor.fps = 25.0 + + with TemporaryDirectory() as output_dir: + stale_fastframe = os.path.join(output_dir, "fastframe_000000.jpg") + expected_keyframe = os.path.join(output_dir, "keyframe_000000_000000000.jpg") + + def fast_path_with_partial_output(_output_dir, interval_seconds=5.0): + with open(stale_fastframe, "wb") as frame_file: + frame_file.write(b"stale") + raise RuntimeError("simulated fast-path failure") + + def ultra_compatible_fallback(self, output_dir_arg, interval_seconds=5.0): + with open(expected_keyframe, "wb") as frame_file: + frame_file.write(b"frame") + return [0] + + with patch.object(VideoProcessor, "_extract_frames_fast_path", side_effect=fast_path_with_partial_output) as fast_path, patch.object( + VideoProcessor, + "extract_frames_by_interval_ultra_compatible", + side_effect=ultra_compatible_fallback, + autospec=True, + ) as fallback: + result = processor.extract_frames_by_interval_with_fallback(output_dir, interval_seconds=3.0) + + self.assertEqual([expected_keyframe], result) + self.assertFalse(os.path.exists(stale_fastframe)) + fast_path.assert_called_once_with(output_dir, interval_seconds=3.0) + fallback.assert_called_once_with(processor, output_dir, interval_seconds=3.0) diff --git a/webui/tools/generate_script_docu.py b/webui/tools/generate_script_docu.py index 9f51c01..b366156 100644 --- a/webui/tools/generate_script_docu.py +++ b/webui/tools/generate_script_docu.py @@ -1,21 +1,32 @@ # 纪录片脚本生成 -import os +import asyncio import json import time -import asyncio import traceback + import streamlit as st from loguru import logger -from datetime import datetime from app.config import config -from app.utils import utils, video_processor -from webui.tools.base import create_vision_analyzer, get_batch_files, get_batch_timestamps +from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService + + +def _normalize_progress_value(progress: float | int) -> int: + """Normalize mixed progress inputs to Streamlit's 0-100 integer range.""" + try: + value = float(progress) + except (TypeError, ValueError): + return 0 + + if 0.0 <= value <= 1.0: + value *= 100 + + return max(0, min(100, int(round(value)))) def generate_script_docu(params): """ - 生成 纪录片 视频脚本 + 生成纪录片视频脚本。 要求: 原视频无字幕无配音 适合场景: 纪录片、动物搞笑解说、荒野建造等 """ @@ -23,419 +34,72 @@ def generate_script_docu(params): status_text = st.empty() def update_progress(progress: float, message: str = ""): - progress_bar.progress(progress) + normalized_progress = _normalize_progress_value(progress) + progress_bar.progress(normalized_progress) if message: status_text.text(f"🎬 {message}") else: - status_text.text(f"📊 进度: {progress}%") + status_text.text(f"📊 进度: {normalized_progress}%") try: with st.spinner("正在生成脚本..."): if not params.video_origin_path: st.error("请先选择视频文件") return - """ - 1. 提取键帧 - """ - update_progress(10, "正在提取关键帧...") - # 创建临时目录用于存储关键帧 - keyframes_dir = os.path.join(utils.temp_dir(), "keyframes") - video_hash = utils.md5(params.video_origin_path + str(os.path.getmtime(params.video_origin_path))) - video_keyframes_dir = os.path.join(keyframes_dir, video_hash) - - # 检查是否已经提取过关键帧 - keyframe_files = [] - if os.path.exists(video_keyframes_dir): - # 取已有的关键帧文件 - for filename in sorted(os.listdir(video_keyframes_dir)): - if filename.endswith('.jpg'): - keyframe_files.append(os.path.join(video_keyframes_dir, filename)) - - if keyframe_files: - logger.info(f"使用已缓存的关键帧: {video_keyframes_dir}") - st.info(f"✅ 使用已缓存关键帧,共 {len(keyframe_files)} 帧") - update_progress(20, f"使用已缓存关键帧,共 {len(keyframe_files)} 帧") - - # 如果没有缓存的关键帧,则进行提取 - if not keyframe_files: - try: - # 确保目录存在 - os.makedirs(video_keyframes_dir, exist_ok=True) - - # 初始化视频处理器 - processor = video_processor.VideoProcessor(params.video_origin_path) - - # 显示视频信息 - st.info(f"📹 视频信息: {processor.width}x{processor.height}, {processor.fps:.1f}fps, {processor.duration:.1f}秒") - - # 处理视频并提取关键帧 - 直接使用超级兼容性方案 - update_progress(15, "正在提取关键帧(使用超级兼容性方案)...") - - try: - # 使用优化的关键帧提取方法 - processor.extract_frames_by_interval_ultra_compatible( - output_dir=video_keyframes_dir, - interval_seconds=st.session_state.get('frame_interval_input'), - ) - except Exception as extract_error: - logger.error(f"关键帧提取失败: {extract_error}") - - # 提供详细的错误信息和解决建议 - error_msg = str(extract_error) - if "权限" in error_msg or "permission" in error_msg.lower(): - suggestion = "建议:检查输出目录权限,或更换输出位置" - elif "空间" in error_msg or "space" in error_msg.lower(): - suggestion = "建议:检查磁盘空间是否足够" - else: - suggestion = "建议:检查视频文件是否损坏,或尝试转换为标准格式" - - raise Exception(f"关键帧提取失败: {error_msg}\n{suggestion}") - - # 获取所有关键文件路径 - for filename in sorted(os.listdir(video_keyframes_dir)): - if filename.endswith('.jpg'): - keyframe_files.append(os.path.join(video_keyframes_dir, filename)) - - if not keyframe_files: - # 检查目录中是否有其他文件 - all_files = os.listdir(video_keyframes_dir) - logger.error(f"关键帧目录内容: {all_files}") - raise Exception("未提取到任何关键帧文件,请检查视频文件格式") - - update_progress(20, f"关键帧提取完成,共 {len(keyframe_files)} 帧") - st.success(f"✅ 成功提取 {len(keyframe_files)} 个关键帧") - - except Exception as e: - # 如果提取失败,清理创建的目录 - try: - if os.path.exists(video_keyframes_dir): - import shutil - shutil.rmtree(video_keyframes_dir) - except Exception as cleanup_err: - logger.error(f"清理失败的关键帧目录时出错: {cleanup_err}") - - raise Exception(f"关键帧提取失败: {str(e)}") - - """ - 2. 视觉分析(批量分析每一帧) - """ - # 最佳实践:使用 get() 的默认值参数 + 从 config 获取备用值 vision_llm_provider = ( - st.session_state.get('vision_llm_provider') or - config.app.get('vision_llm_provider', 'openai') + st.session_state.get("vision_llm_provider") or config.app.get("vision_llm_provider", "openai") ).lower() - - logger.info(f"使用 {vision_llm_provider.upper()} 进行视觉分析") - - try: - # ===================初始化视觉分析器=================== - update_progress(30, "正在初始化视觉分析器...") - - # 使用统一的配置键格式获取配置(支持所有 provider) - vision_api_key = ( - st.session_state.get(f'vision_{vision_llm_provider}_api_key') or - config.app.get(f'vision_{vision_llm_provider}_api_key') - ) - vision_model = ( - st.session_state.get(f'vision_{vision_llm_provider}_model_name') or - config.app.get(f'vision_{vision_llm_provider}_model_name') - ) - vision_base_url = ( - st.session_state.get(f'vision_{vision_llm_provider}_base_url') or - config.app.get(f'vision_{vision_llm_provider}_base_url', '') + vision_api_key = ( + st.session_state.get(f"vision_{vision_llm_provider}_api_key") + or config.app.get(f"vision_{vision_llm_provider}_api_key") + ) + vision_model = ( + st.session_state.get(f"vision_{vision_llm_provider}_model_name") + or config.app.get(f"vision_{vision_llm_provider}_model_name") + ) + vision_base_url = ( + st.session_state.get(f"vision_{vision_llm_provider}_base_url") + or config.app.get(f"vision_{vision_llm_provider}_base_url", "") + ) + if not vision_api_key or not vision_model: + raise ValueError( + f"未配置 {vision_llm_provider} 的 API Key 或模型名称。" + f"请在设置页面配置 vision_{vision_llm_provider}_api_key 和 vision_{vision_llm_provider}_model_name" ) - # 验证必需配置 - if not vision_api_key or not vision_model: - raise ValueError( - f"未配置 {vision_llm_provider} 的 API Key 或模型名称。" - f"请在设置页面配置 vision_{vision_llm_provider}_api_key 和 vision_{vision_llm_provider}_model_name" - ) + frame_interval_input = st.session_state.get("frame_interval_input") or config.frames.get( + "frame_interval_input", 3 + ) + vision_batch_size = st.session_state.get("vision_batch_size") or config.frames.get("vision_batch_size", 10) + vision_max_concurrency = st.session_state.get("vision_max_concurrency") or config.frames.get( + "vision_max_concurrency", 2 + ) - # 创建视觉分析器实例(使用统一接口) - llm_params = { - "vision_provider": vision_llm_provider, - "vision_api_key": vision_api_key, - "vision_model_name": vision_model, - "vision_base_url": vision_base_url, - } - - logger.debug(f"视觉分析器配置: provider={vision_llm_provider}, model={vision_model}") - - analyzer = create_vision_analyzer( - provider=vision_llm_provider, - api_key=vision_api_key, - model=vision_model, - base_url=vision_base_url + update_progress(10, "正在提取关键帧...") + service = DocumentaryFrameAnalysisService() + script_items = asyncio.run( + service.generate_documentary_script( + video_path=params.video_origin_path, + video_theme=st.session_state.get("video_theme", ""), + custom_prompt=st.session_state.get("custom_prompt", ""), + frame_interval_input=frame_interval_input, + vision_batch_size=vision_batch_size, + vision_llm_provider=vision_llm_provider, + progress_callback=update_progress, + vision_api_key=vision_api_key, + vision_model_name=vision_model, + vision_base_url=vision_base_url, + max_concurrency=vision_max_concurrency, ) + ) - update_progress(40, "正在分析关键帧...") - - # ===================创建异步事件循环=================== - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - # 执行异步分析 - vision_batch_size = st.session_state.get('vision_batch_size') or config.frames.get("vision_batch_size") - vision_analysis_prompt = """ -我提供了 %s 张视频帧,它们按时间顺序排列,代表一个连续的视频片段。请仔细分析每一帧的内容,并关注帧与帧之间的变化,以理解整个片段的活动。 - -首先,请详细描述每一帧的关键视觉信息(包含:主要内容、人物、动作和场景)。 -然后,基于所有帧的分析,请用**简洁的语言**总结整个视频片段中发生的主要活动或事件流程。 - -请务必使用 JSON 格式输出你的结果。JSON 结构应如下: -{ - "frame_observations": [ - { - "frame_number": 1, // 或其他标识帧的方式 - "observation": "描述每张视频帧中的主要内容、人物、动作和场景。" - }, - // ... 更多帧的观察 ... - ], - "overall_activity_summary": "在这里填写你总结的整个片段的主要活动,保持简洁。" -} - -请务必不要遗漏视频帧,我提供了 %s 张视频帧,frame_observations 必须包含 %s 个元素 - -请只返回 JSON 字符串,不要包含任何其他解释性文字。 - """ - results = loop.run_until_complete( - analyzer.analyze_images( - images=keyframe_files, - prompt=vision_analysis_prompt, - batch_size=vision_batch_size - ) - ) - loop.close() - - """ - 3. 处理分析结果(格式化为 json 数据) - """ - # ===================处理分析结果=================== - update_progress(60, "正在整理分析结果...") - - # 合并所有批次的分析结果 - frame_analysis = "" - merged_frame_observations = [] # 合并所有批次的帧观察 - overall_activity_summaries = [] # 合并所有批次的整体总结 - prev_batch_files = None - frame_counter = 1 # 初始化帧计数器,用于给所有帧分配连续的序号 - - # 确保分析目录存在 - analysis_dir = os.path.join(utils.storage_dir(), "temp", "analysis") - os.makedirs(analysis_dir, exist_ok=True) - origin_res = os.path.join(analysis_dir, "frame_analysis.json") - with open(origin_res, 'w', encoding='utf-8') as f: - json.dump(results, f, ensure_ascii=False, indent=2) - - # 开始处理 - for result in results: - if 'error' in result: - logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}") - continue - - # 获取当前批次的文件列表 - batch_files = get_batch_files(keyframe_files, result, vision_batch_size) - - # 获取批次的时间戳范围 - first_timestamp, last_timestamp, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files) - - # 解析响应中的JSON数据 - response_text = result['response'] - try: - # 处理可能包含```json```格式的响应 - if "```json" in response_text: - json_content = response_text.split("```json")[1].split("```")[0].strip() - elif "```" in response_text: - json_content = response_text.split("```")[1].split("```")[0].strip() - else: - json_content = response_text.strip() - - response_data = json.loads(json_content) - - # 提取frame_observations和overall_activity_summary - if "frame_observations" in response_data: - frame_obs = response_data["frame_observations"] - overall_summary = response_data.get("overall_activity_summary", "") - - # 添加时间戳信息到每个帧观察 - for i, obs in enumerate(frame_obs): - if i < len(batch_files): - # 从文件名中提取时间戳 - file_path = batch_files[i] - file_name = os.path.basename(file_path) - # 提取时间戳字符串 (格式如: keyframe_000675_000027000.jpg) - # 格式解析: keyframe_帧序号_毫秒时间戳.jpg - timestamp_parts = file_name.split('_') - if len(timestamp_parts) >= 3: - timestamp_str = timestamp_parts[-1].split('.')[0] - try: - # 修正时间戳解析逻辑 - # 格式为000100000,表示00:01:00,000,即1分钟 - # 需要按照对应位数进行解析: - # 前两位是小时,中间两位是分钟,后面是秒和毫秒 - if len(timestamp_str) >= 9: # 确保格式正确 - hours = int(timestamp_str[0:2]) - minutes = int(timestamp_str[2:4]) - seconds = int(timestamp_str[4:6]) - milliseconds = int(timestamp_str[6:9]) - - # 计算总秒数 - timestamp_seconds = hours * 3600 + minutes * 60 + seconds + milliseconds / 1000 - formatted_time = utils.format_time(timestamp_seconds) # 格式化时间戳 - else: - # 兼容旧的解析方式 - timestamp_seconds = int(timestamp_str) / 1000 # 转换为秒 - formatted_time = utils.format_time(timestamp_seconds) # 格式化时间戳 - except ValueError: - logger.warning(f"无法解析时间戳: {timestamp_str}") - timestamp_seconds = 0 - formatted_time = "00:00:00,000" - else: - logger.warning(f"文件名格式不符合预期: {file_name}") - timestamp_seconds = 0 - formatted_time = "00:00:00,000" - - # 添加额外信息到帧观察 - obs["frame_path"] = file_path - obs["timestamp"] = formatted_time - obs["timestamp_seconds"] = timestamp_seconds - obs["batch_index"] = result['batch_index'] - - # 使用全局递增的帧计数器替换原始的frame_number - if "frame_number" in obs: - obs["original_frame_number"] = obs["frame_number"] # 保留原始编号作为参考 - obs["frame_number"] = frame_counter # 赋值连续的帧编号 - frame_counter += 1 # 增加帧计数器 - - # 添加到合并列表 - merged_frame_observations.append(obs) - - # 添加批次整体总结信息 - if overall_summary: - # 从文件名中提取时间戳数值 - first_time_str = first_timestamp.split('_')[-1].split('.')[0] - last_time_str = last_timestamp.split('_')[-1].split('.')[0] - - # 转换为毫秒并计算持续时间(秒) - try: - # 修正解析逻辑,与上面相同的方式解析时间戳 - if len(first_time_str) >= 9 and len(last_time_str) >= 9: - # 解析第一个时间戳 - first_hours = int(first_time_str[0:2]) - first_minutes = int(first_time_str[2:4]) - first_seconds = int(first_time_str[4:6]) - first_ms = int(first_time_str[6:9]) - first_time_seconds = first_hours * 3600 + first_minutes * 60 + first_seconds + first_ms / 1000 - - # 解析第二个时间戳 - last_hours = int(last_time_str[0:2]) - last_minutes = int(last_time_str[2:4]) - last_seconds = int(last_time_str[4:6]) - last_ms = int(last_time_str[6:9]) - last_time_seconds = last_hours * 3600 + last_minutes * 60 + last_seconds + last_ms / 1000 - - batch_duration = last_time_seconds - first_time_seconds - else: - # 兼容旧的解析方式 - first_time_ms = int(first_time_str) - last_time_ms = int(last_time_str) - batch_duration = (last_time_ms - first_time_ms) / 1000 - except ValueError: - # 使用 utils.time_to_seconds 函数处理格式化的时间戳 - first_time_seconds = utils.time_to_seconds(first_time_str.replace('_', ':').replace('-', ',')) - last_time_seconds = utils.time_to_seconds(last_time_str.replace('_', ':').replace('-', ',')) - batch_duration = last_time_seconds - first_time_seconds - - overall_activity_summaries.append({ - "batch_index": result['batch_index'], - "time_range": f"{first_timestamp}-{last_timestamp}", - "duration_seconds": batch_duration, - "summary": overall_summary - }) - except Exception as e: - logger.error(f"解析批次 {result['batch_index']} 的响应数据失败: {str(e)}") - # 添加原始响应作为回退 - frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n" - frame_analysis += response_text - frame_analysis += "\n" - - # 更新上一个批次的文件 - prev_batch_files = batch_files - - # 将合并后的结果转为JSON字符串 - merged_results = { - "frame_observations": merged_frame_observations, - "overall_activity_summaries": overall_activity_summaries - } - - # 使用当前时间创建文件名 - now = datetime.now() - timestamp_str = now.strftime("%Y%m%d_%H%M") - - # 保存完整的分析结果为JSON - analysis_filename = f"frame_analysis_{timestamp_str}.json" - analysis_json_path = os.path.join(analysis_dir, analysis_filename) - with open(analysis_json_path, 'w', encoding='utf-8') as f: - json.dump(merged_results, f, ensure_ascii=False, indent=2) - logger.info(f"分析结果已保存到: {analysis_json_path}") - - """ - 4. 生成文案 - """ - logger.info("开始生成解说文案") - update_progress(80, "正在生成解说文案...") - from app.services.generate_narration_script import parse_frame_analysis_to_markdown, generate_narration - # 从配置中获取文本生成相关配置 - text_provider = config.app.get('text_llm_provider', 'gemini').lower() - text_api_key = config.app.get(f'text_{text_provider}_api_key') - text_model = config.app.get(f'text_{text_provider}_model_name') - text_base_url = config.app.get(f'text_{text_provider}_base_url') - llm_params.update({ - "text_provider": text_provider, - "text_api_key": text_api_key, - "text_model_name": text_model, - "text_base_url": text_base_url - }) - # 整理帧分析数据 - markdown_output = parse_frame_analysis_to_markdown(analysis_json_path) - - # 生成解说文案 - narration = generate_narration( - markdown_output, - text_api_key, - base_url=text_base_url, - model=text_model - ) - - # 使用增强的JSON解析器 - from webui.tools.generate_short_summary import parse_and_fix_json - narration_data = parse_and_fix_json(narration) - - if not narration_data or 'items' not in narration_data: - logger.error(f"解说文案JSON解析失败,原始内容: {narration[:200]}...") - raise Exception("解说文案格式错误,无法解析JSON或缺少items字段") - - narration_dict = narration_data['items'] - # 为 narration_dict 中每个 item 新增一个 OST: 2 的字段, 代表保留原声和配音 - narration_dict = [{**item, "OST": 2} for item in narration_dict] - logger.info(f"解说文案生成完成,共 {len(narration_dict)} 个片段") - # 结果转换为JSON字符串 - script = json.dumps(narration_dict, ensure_ascii=False, indent=2) - - except Exception as e: - logger.exception(f"大模型处理过程中发生错误\n{traceback.format_exc()}") - raise Exception(f"分析失败: {str(e)}") - - if script is None: - st.error("生成脚本失败,请检查日志") - st.stop() - logger.info(f"纪录片解说脚本生成完成") + logger.info(f"纪录片解说脚本生成完成,共 {len(script_items)} 个片段") + script = json.dumps(script_items, ensure_ascii=False, indent=2) if isinstance(script, list): - st.session_state['video_clip_json'] = script + st.session_state["video_clip_json"] = script elif isinstance(script, str): - st.session_state['video_clip_json'] = json.loads(script) + st.session_state["video_clip_json"] = json.loads(script) update_progress(100, "脚本生成完成") time.sleep(0.1)