import asyncio import json import os import re from datetime import datetime from typing import Any, Callable from loguru import logger from app.config import config from app.services.documentary.frame_analysis_models import FrameBatchResult from app.services.generate_narration_script import generate_narration, parse_frame_analysis_to_markdown from app.services.llm.migration_adapter import create_vision_analyzer from app.utils import utils, video_processor class DocumentaryFrameAnalysisService: PROMPT_TEMPLATE = """ 我提供了 {frame_count} 张视频帧,它们按时间顺序排列,代表一个连续的视频片段。 首先,请详细描述每一帧的关键视觉信息(包含:主要内容、人物、动作和场景)。 然后,基于所有帧的分析,请用简洁的语言总结整个视频片段中发生的主要活动或事件流程。 请务必使用 JSON 格式输出。 JSON 必须包含以下键: - frame_observations: 数组,且长度必须为 {frame_count} - overall_activity_summary: 字符串,描述整个批次主要活动 示例结构: {{ "frame_observations": [ {{"timestamp": "00:00:00,000", "observation": "画面描述"}} ], "overall_activity_summary": "本批次主要活动总结" }} 请务必不要遗漏视频帧,我提供了 {frame_count} 张视频帧,frame_observations 必须包含 {frame_count} 个元素 请只返回 JSON 字符串,不要附加解释文字。 """.strip() async def generate_documentary_script( self, *, video_path: str, video_theme: str = "", custom_prompt: str = "", frame_interval_input: int | float | None = None, vision_batch_size: int | None = None, vision_llm_provider: str | None = None, progress_callback: Callable[[float, str], None] | None = None, vision_api_key: str | None = None, vision_model_name: str | None = None, vision_base_url: str | None = None, max_concurrency: int | None = None, ) -> list[dict]: progress = progress_callback or (lambda _p, _m: None) analysis_result = await self.analyze_video( video_path=video_path, video_theme=video_theme, custom_prompt=custom_prompt, frame_interval_input=frame_interval_input, vision_batch_size=vision_batch_size, vision_llm_provider=vision_llm_provider, progress_callback=progress_callback, vision_api_key=vision_api_key, vision_model_name=vision_model_name, vision_base_url=vision_base_url, max_concurrency=max_concurrency, ) analysis_json_path = analysis_result["analysis_json_path"] progress(80, "正在生成解说文案...") text_provider = config.app.get("text_llm_provider", "openai").lower() text_api_key = config.app.get(f"text_{text_provider}_api_key") text_model = config.app.get(f"text_{text_provider}_model_name") text_base_url = config.app.get(f"text_{text_provider}_base_url") if not text_api_key or not text_model: raise ValueError( f"未配置 {text_provider} 的文本模型参数。" f"请在设置中配置 text_{text_provider}_api_key 和 text_{text_provider}_model_name" ) markdown_output = parse_frame_analysis_to_markdown(analysis_json_path) narration_input = self._build_narration_input( markdown_output=markdown_output, video_theme=video_theme, custom_prompt=custom_prompt, ) narration_raw = generate_narration( narration_input, text_api_key, base_url=text_base_url, model=text_model, ) narration_items = self._parse_narration_items(narration_raw) final_script = [{**item, "OST": 2} for item in narration_items] progress(100, "脚本生成完成") return final_script async def analyze_video( self, *, video_path: str, video_theme: str = "", custom_prompt: str = "", frame_interval_input: int | float | None = None, vision_batch_size: int | None = None, vision_llm_provider: str | None = None, progress_callback: Callable[[float, str], None] | None = None, vision_api_key: str | None = None, vision_model_name: str | None = None, vision_base_url: str | None = None, max_concurrency: int | None = None, ) -> dict[str, Any]: progress = progress_callback or (lambda _p, _m: None) if not video_path or not os.path.exists(video_path): raise FileNotFoundError(f"视频文件不存在: {video_path}") frame_interval_seconds = self._resolve_frame_interval(frame_interval_input) batch_size = self._resolve_batch_size(vision_batch_size) concurrency = self._resolve_max_concurrency(max_concurrency) provider = (vision_llm_provider or config.app.get("vision_llm_provider", "openai")).lower() api_key = vision_api_key if vision_api_key is not None else config.app.get(f"vision_{provider}_api_key") model_name = ( vision_model_name if vision_model_name is not None else config.app.get(f"vision_{provider}_model_name") ) base_url = vision_base_url if vision_base_url is not None else config.app.get(f"vision_{provider}_base_url", "") if not api_key or not model_name: raise ValueError( f"未配置 {provider} 的 API Key 或模型名称。" f"请在设置中配置 vision_{provider}_api_key 和 vision_{provider}_model_name" ) progress(10, "正在提取关键帧...") keyframe_files = self._load_or_extract_keyframes(video_path, frame_interval_seconds) progress(25, f"关键帧准备完成,共 {len(keyframe_files)} 帧") progress(30, "正在初始化视觉分析器...") analyzer = create_vision_analyzer( provider=provider, api_key=api_key, model=model_name, base_url=base_url, ) batches = self._chunk_keyframes(keyframe_files, batch_size=batch_size) if not batches: raise RuntimeError("未能构建任何关键帧批次") progress(40, f"正在分析关键帧,共 {len(batches)} 个批次...") batch_results = await self._analyze_batches( analyzer=analyzer, batches=batches, custom_prompt=custom_prompt, video_theme=video_theme, max_concurrency=concurrency, progress_callback=progress, ) progress(65, "正在整理分析结果...") sorted_batches = self._sort_batch_results(batch_results) artifact = self._build_analysis_artifact( sorted_batches, video_path=video_path, frame_interval_seconds=frame_interval_seconds, vision_batch_size=batch_size, vision_llm_provider=provider, vision_model_name=model_name, max_concurrency=concurrency, ) analysis_json_path = self._save_analysis_artifact(artifact) video_clip_json = self._build_video_clip_json(sorted_batches) progress(75, "逐帧分析完成") return { "analysis_json_path": analysis_json_path, "analysis_artifact": artifact, "video_clip_json": video_clip_json, "keyframe_files": keyframe_files, } def _parse_narration_items(self, narration_raw: str) -> list[dict[str, Any]]: parsed = self._repair_narration_payload(narration_raw) items: list[dict[str, Any]] = [] if isinstance(parsed, dict): raw_items = parsed.get("items") if isinstance(raw_items, list): items = [item for item in raw_items if isinstance(item, dict)] if not items: raise ValueError("解说文案格式错误,无法解析JSON或缺少items字段") return items def _build_narration_input(self, *, markdown_output: str, video_theme: str, custom_prompt: str) -> str: context_lines: list[str] = [] if (video_theme or "").strip(): context_lines.append(f"视频主题:{video_theme.strip()}") if (custom_prompt or "").strip(): context_lines.append(f"补充创作要求:{custom_prompt.strip()}") if not context_lines: return markdown_output context_block = "\n".join(f"- {line}" for line in context_lines) return f"{markdown_output.rstrip()}\n\n## 创作上下文\n{context_block}\n" def _repair_narration_payload(self, narration_raw: str) -> dict[str, Any] | None: def load_json_candidate(payload: str) -> dict[str, Any] | None: try: parsed = json.loads(payload) return parsed if isinstance(parsed, dict) else None except Exception: return None cleaned = (narration_raw or "").strip() if not cleaned: return None candidates: list[str] = [cleaned] candidates.append(cleaned.replace("{{", "{").replace("}}", "}")) json_block = re.search(r"```json\s*(.*?)\s*```", cleaned, re.DOTALL) if json_block: candidates.append(json_block.group(1).strip()) start = cleaned.find("{") end = cleaned.rfind("}") if start >= 0 and end > start: candidates.append(cleaned[start : end + 1]) for candidate in candidates: parsed = load_json_candidate(candidate) if parsed is not None: return parsed fixed = cleaned.replace("{{", "{").replace("}}", "}") fixed_start = fixed.find("{") fixed_end = fixed.rfind("}") if fixed_start >= 0 and fixed_end > fixed_start: fixed = fixed[fixed_start : fixed_end + 1] fixed = re.sub(r"^\s*#.*$", "", fixed, flags=re.MULTILINE) fixed = re.sub(r"^\s*//.*$", "", fixed, flags=re.MULTILINE) fixed = re.sub(r",\s*}", "}", fixed) fixed = re.sub(r",\s*]", "]", fixed) fixed = re.sub(r"'([^']*)'\s*:", r'"\1":', fixed) fixed = re.sub(r'([{\[,]\s*)([A-Za-z_][\w\u4e00-\u9fff]*)(\s*:)', r'\1"\2"\3', fixed) fixed = re.sub(r'""([^"]*?)""', r'"\1"', fixed) return load_json_candidate(fixed) def _resolve_frame_interval(self, frame_interval_input: int | float | None) -> float: interval = frame_interval_input if interval in (None, ""): interval = config.frames.get("frame_interval_input", 3) try: value = float(interval) except (TypeError, ValueError): value = 3.0 if value <= 0: raise ValueError("frame_interval_input must be > 0") return value def _resolve_batch_size(self, vision_batch_size: int | None) -> int: size = vision_batch_size or config.frames.get("vision_batch_size", 10) try: value = int(size) except (TypeError, ValueError): value = 10 if value <= 0: raise ValueError("vision_batch_size must be > 0") return value def _resolve_max_concurrency(self, max_concurrency: int | None) -> int: value = max_concurrency if max_concurrency is not None else config.frames.get("vision_max_concurrency", 2) try: parsed = int(value) except (TypeError, ValueError): parsed = 1 return max(1, parsed) def _load_or_extract_keyframes(self, video_path: str, frame_interval_seconds: float) -> list[str]: keyframes_root = os.path.join(utils.temp_dir(), "keyframes") os.makedirs(keyframes_root, exist_ok=True) cache_key = self._build_keyframe_cache_key(video_path, frame_interval_seconds) cache_dir = os.path.join(keyframes_root, cache_key) os.makedirs(cache_dir, exist_ok=True) cached_files = self._collect_keyframe_paths(cache_dir) if cached_files: logger.info(f"使用已缓存关键帧: {cache_dir}, 共 {len(cached_files)} 帧") return cached_files processor = video_processor.VideoProcessor(video_path) extracted = processor.extract_frames_by_interval_with_fallback( output_dir=cache_dir, interval_seconds=frame_interval_seconds, ) keyframe_files = sorted(str(path) for path in extracted if str(path).endswith(".jpg")) if not keyframe_files: keyframe_files = self._collect_keyframe_paths(cache_dir) if not keyframe_files: raise RuntimeError("未提取到任何关键帧") logger.info(f"关键帧提取完成: {cache_dir}, 共 {len(keyframe_files)} 帧") return keyframe_files def _build_keyframe_cache_key(self, video_path: str, frame_interval_seconds: float) -> str: try: video_mtime = os.path.getmtime(video_path) except OSError: video_mtime = 0 legacy_prefix = utils.md5(f"{video_path}{video_mtime}") payload = "|".join( [ str(video_path), str(video_mtime), str(frame_interval_seconds), "documentary-keyframes-v2", ] ) return f"{legacy_prefix}_{utils.md5(payload)}" @staticmethod def _collect_keyframe_paths(cache_dir: str) -> list[str]: if not os.path.exists(cache_dir): return [] return sorted( os.path.join(cache_dir, name) for name in os.listdir(cache_dir) if re.fullmatch(r"keyframe_\d{6}_\d{9}\.jpg", name) ) @staticmethod def _chunk_keyframes(keyframe_files: list[str], batch_size: int) -> list[list[str]]: return [keyframe_files[index : index + batch_size] for index in range(0, len(keyframe_files), batch_size)] async def _analyze_batches( self, *, analyzer: Any, batches: list[list[str]], custom_prompt: str, video_theme: str, max_concurrency: int, progress_callback: Callable[[float, str], None], ) -> list[FrameBatchResult]: semaphore = asyncio.Semaphore(max(1, max_concurrency)) total = len(batches) done = 0 done_lock = asyncio.Lock() batch_time_ranges: list[str] = [] previous_batch_files: list[str] | None = None for batch_files in batches: _, _, time_range = self._get_batch_timestamps(batch_files, previous_batch_files) batch_time_ranges.append(time_range) previous_batch_files = batch_files async def run_single(batch_index: int, frame_paths: list[str], time_range: str) -> FrameBatchResult: nonlocal done prompt = self._build_batch_prompt( frame_count=len(frame_paths), video_theme=video_theme, custom_prompt=custom_prompt, ) try: async with semaphore: raw_results = await analyzer.analyze_images( images=frame_paths, prompt=prompt, batch_size=max(1, len(frame_paths)), max_concurrency=1, ) raw_response, error_message = self._extract_batch_response(raw_results) if error_message: return self._build_failed_batch_result( batch_index=batch_index, raw_response=raw_response, error_message=error_message, frame_paths=frame_paths, time_range=time_range, ) return self._parse_batch_response( batch_index=batch_index, raw_response=raw_response, frame_paths=frame_paths, time_range=time_range, ) except Exception as exc: return self._build_failed_batch_result( batch_index=batch_index, raw_response="", error_message=str(exc), frame_paths=frame_paths, time_range=time_range, ) finally: async with done_lock: done += 1 progress = 40 + (done / max(1, total)) * 25 progress_callback(progress, f"正在分析关键帧批次 ({done}/{total})...") tasks = [ run_single(batch_index=index, frame_paths=batch_files, time_range=batch_time_ranges[index]) for index, batch_files in enumerate(batches) ] return await asyncio.gather(*tasks) def _build_batch_prompt(self, *, frame_count: int, video_theme: str, custom_prompt: str) -> str: prompt = self._build_analysis_prompt(frame_count=frame_count) extra_lines: list[str] = [] if (video_theme or "").strip(): extra_lines.append(f"视频主题:{video_theme.strip()}") if (custom_prompt or "").strip(): extra_lines.append(custom_prompt.strip()) if not extra_lines: return prompt extras = "\n".join(f"- {line}" for line in extra_lines) return f"{prompt}\n\n补充分析要求:\n{extras}" def _extract_batch_response(self, raw_results: Any) -> tuple[str, str]: if not raw_results: return "", "Batch response is empty" first_result = raw_results[0] if isinstance(raw_results, list) else raw_results if isinstance(first_result, dict): raw_response = str(first_result.get("response", "") or "") error_message = str(first_result.get("error", "") or "") if error_message: if not raw_response: raw_response = error_message return raw_response, error_message if not raw_response.strip(): return raw_response, "Batch response is empty" return raw_response, "" raw_response = str(first_result or "") if not raw_response.strip(): return raw_response, "Batch response is empty" return raw_response, "" def _sort_batch_results(self, batch_results: list[FrameBatchResult]) -> list[FrameBatchResult]: return sorted(batch_results, key=lambda item: (self._time_range_sort_key(item.time_range), item.batch_index)) def _build_analysis_artifact( self, batch_results: list[FrameBatchResult], *, video_path: str, frame_interval_seconds: float, vision_batch_size: int, vision_llm_provider: str, vision_model_name: str, max_concurrency: int, ) -> dict[str, Any]: sorted_batches = self._sort_batch_results(batch_results) batch_dicts: list[dict[str, Any]] = [] frame_observations: list[dict[str, Any]] = [] overall_activity_summaries: list[dict[str, Any]] = [] for batch in sorted_batches: batch_payload = { "batch_index": batch.batch_index, "status": batch.status, "time_range": batch.time_range, "raw_response": batch.raw_response, "frame_paths": list(batch.frame_paths), "frame_observations": list(batch.frame_observations), "overall_activity_summary": batch.overall_activity_summary, "fallback_summary": batch.fallback_summary, "error_message": batch.error_message, } batch_dicts.append(batch_payload) for observation in batch.frame_observations: observation_payload = dict(observation) observation_payload["batch_index"] = batch.batch_index observation_payload["time_range"] = batch.time_range frame_observations.append(observation_payload) summary_text = (batch.overall_activity_summary or batch.fallback_summary or "").strip() if summary_text: overall_activity_summaries.append( { "batch_index": batch.batch_index, "time_range": batch.time_range, "summary": summary_text, } ) return { "artifact_version": "documentary-frame-analysis-v2", "generated_at": datetime.now().isoformat(), "video_path": video_path, "frame_interval_seconds": frame_interval_seconds, "vision_batch_size": vision_batch_size, "vision_llm_provider": vision_llm_provider, "vision_model_name": vision_model_name, "vision_max_concurrency": max_concurrency, "batches": batch_dicts, # 向后兼容旧解析器结构 "frame_observations": frame_observations, "overall_activity_summaries": overall_activity_summaries, } def _save_analysis_artifact(self, artifact: dict[str, Any]) -> str: analysis_dir = os.path.join(utils.storage_dir(), "temp", "analysis") os.makedirs(analysis_dir, exist_ok=True) filename = f"frame_analysis_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" file_path = os.path.join(analysis_dir, filename) suffix = 1 while os.path.exists(file_path): filename = f"frame_analysis_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{suffix:02d}.json" file_path = os.path.join(analysis_dir, filename) suffix += 1 with open(file_path, "w", encoding="utf-8") as fp: json.dump(artifact, fp, ensure_ascii=False, indent=2) logger.info(f"分析结果已保存到: {file_path}") return file_path def _build_video_clip_json(self, batch_results: list[FrameBatchResult]) -> list[dict]: clips: list[dict] = [] for batch in self._sort_batch_results(batch_results): picture = self._build_batch_picture(batch) clips.append( { "timestamp": batch.time_range, "picture": picture, "narration": "", "OST": 2, } ) return clips def _build_batch_picture(self, batch: FrameBatchResult) -> str: summary = (batch.overall_activity_summary or "").strip() if summary: return summary fallback = (batch.fallback_summary or "").strip() if fallback: return fallback observation_lines = [] for frame in batch.frame_observations: timestamp = str(frame.get("timestamp", "") or "").strip() observation = str(frame.get("observation", "") or "").strip() if timestamp and observation: observation_lines.append(f"{timestamp}: {observation}") elif observation: observation_lines.append(observation) if observation_lines: return " ".join(observation_lines) raw_response = (batch.raw_response or "").strip() if raw_response: return raw_response[:200] return "该批次分析失败,未返回可用描述。" def _time_range_sort_key(self, time_range: str) -> tuple[int, str]: start = (time_range or "").split("-", 1)[0].strip() return self._timestamp_to_milliseconds(start), time_range @staticmethod def _timestamp_to_milliseconds(timestamp: str) -> int: text = (timestamp or "").strip() try: if "," in text: time_part, ms_part = text.split(",", 1) milliseconds = int(ms_part) else: time_part = text milliseconds = 0 parts = [int(part) for part in time_part.split(":") if part] while len(parts) < 3: parts.insert(0, 0) hours, minutes, seconds = parts[-3], parts[-2], parts[-1] return ((hours * 3600 + minutes * 60 + seconds) * 1000) + milliseconds except Exception: return 0 def _get_batch_timestamps( self, batch_files: list[str], prev_batch_files: list[str] | None = None, ) -> tuple[str, str, str]: if not batch_files: return "00:00:00,000", "00:00:00,000", "00:00:00,000-00:00:00,000" if len(batch_files) == 1 and prev_batch_files: first_frame = os.path.basename(prev_batch_files[-1]) last_frame = os.path.basename(batch_files[0]) else: first_frame = os.path.basename(batch_files[0]) last_frame = os.path.basename(batch_files[-1]) first_timestamp = self._timestamp_from_keyframe_name(first_frame) last_timestamp = self._timestamp_from_keyframe_name(last_frame) return first_timestamp, last_timestamp, f"{first_timestamp}-{last_timestamp}" def _timestamp_from_keyframe_name(self, filename: str) -> str: match = re.search(r"keyframe_\d{6}_(\d{9})\.jpg$", filename) if not match: return "00:00:00,000" token = match.group(1) hours = int(token[0:2]) minutes = int(token[2:4]) seconds = int(token[4:6]) milliseconds = int(token[6:9]) return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}" def _build_analysis_prompt(self, frame_count: int) -> str: return self.PROMPT_TEMPLATE.format(frame_count=frame_count) def _build_failed_batch_result( self, *, batch_index: int, raw_response: str, error_message: str, frame_paths: list[str], time_range: str, ) -> FrameBatchResult: fallback_summary = (raw_response or "").strip()[:200] if not fallback_summary: fallback_summary = f"Batch {batch_index} analysis failed: {error_message or 'unknown error'}" return FrameBatchResult( batch_index=batch_index, status="failed", time_range=time_range, raw_response=raw_response, frame_paths=list(frame_paths), fallback_summary=fallback_summary, error_message=error_message, ) def _build_cache_key( self, video_path: str, interval_seconds: float, prompt_version: str, model_name: str, batch_size: int, max_concurrency: int, ) -> str: try: video_mtime = os.path.getmtime(video_path) except OSError: video_mtime = 0 legacy_prefix = utils.md5(f"{video_path}{video_mtime}") payload = "|".join( [ str(video_path), str(video_mtime), str(interval_seconds), str(prompt_version), str(model_name), str(batch_size), str(max_concurrency), "documentary-frame-analysis-v2", ] ) return f"{legacy_prefix}_{utils.md5(payload)}" def _strip_code_fence(self, response_text: str) -> str: cleaned = (response_text or "").strip() cleaned = re.sub(r"^```[a-zA-Z0-9_-]*\s*", "", cleaned) cleaned = re.sub(r"\s*```$", "", cleaned) return cleaned.strip() def _parse_batch_response( self, *, batch_index: int, raw_response: str, frame_paths: list[str], time_range: str, ) -> FrameBatchResult: try: payload = json.loads(self._strip_code_fence(raw_response)) except Exception as exc: return self._build_failed_batch_result( batch_index=batch_index, raw_response=raw_response, error_message=str(exc), frame_paths=frame_paths, time_range=time_range, ) validation_error = self._validate_batch_payload_contract(payload, expected_frame_count=len(frame_paths)) if validation_error: return self._build_failed_batch_result( batch_index=batch_index, raw_response=raw_response, error_message=validation_error, frame_paths=frame_paths, time_range=time_range, ) raw_observations = payload["frame_observations"] frame_observations: list[dict] = [] for index, frame_path in enumerate(frame_paths): entry = raw_observations[index] if index < len(raw_observations) else {} if isinstance(entry, dict): observation = str(entry.get("observation", "") or "") timestamp = str(entry.get("timestamp", "") or "") else: observation = str(entry or "") timestamp = "" frame_observations.append( { "frame_path": frame_path, "timestamp": timestamp, "observation": observation, } ) raw_summary = payload.get("overall_activity_summary", "") if isinstance(raw_summary, str): summary = raw_summary elif raw_summary is None: summary = "" else: summary = str(raw_summary) return FrameBatchResult( batch_index=batch_index, status="success", time_range=time_range, raw_response=raw_response, frame_paths=list(frame_paths), frame_observations=frame_observations, overall_activity_summary=summary, ) def _validate_batch_payload_contract(self, payload: object, *, expected_frame_count: int) -> str: if not isinstance(payload, dict): return "Batch response JSON payload must be an object" if "frame_observations" not in payload or not isinstance(payload["frame_observations"], list): return "Batch response must include frame_observations as a list" if len(payload["frame_observations"]) < expected_frame_count: return ( "Batch response frame_observations length is shorter than provided frame_paths: " f"{len(payload['frame_observations'])} < {expected_frame_count}" ) return ""