From ac63fea95398eb2a9017aba7f6f62d4f37715f43 Mon Sep 17 00:00:00 2001 From: linyq Date: Fri, 3 Apr 2026 02:24:30 +0800 Subject: [PATCH] refactor(documentary): route adapters through shared analysis service --- .gitignore | 2 + .../documentary/frame_analysis_service.py | 489 +++++++++++++++++- app/services/generate_narration_script.py | 78 ++- app/services/script_service.py | 328 +----------- ...e_narration_script_documentary_unittest.py | 58 +++ ...est_script_service_documentary_unittest.py | 51 ++ webui/tools/generate_script_docu.py | 461 +++-------------- 7 files changed, 747 insertions(+), 720 deletions(-) create mode 100644 tests/test_generate_narration_script_documentary_unittest.py create mode 100644 tests/test_script_service_documentary_unittest.py diff --git a/.gitignore b/.gitignore index 9cf0620..4032202 100644 --- a/.gitignore +++ b/.gitignore @@ -48,3 +48,5 @@ CLAUDE.md tests/* !tests/test_documentary_frame_analysis_service.py !tests/test_video_processor_documentary_unittest.py +!tests/test_script_service_documentary_unittest.py +!tests/test_generate_narration_script_documentary_unittest.py diff --git a/app/services/documentary/frame_analysis_service.py b/app/services/documentary/frame_analysis_service.py index 2813e9b..3cc4cc3 100644 --- a/app/services/documentary/frame_analysis_service.py +++ b/app/services/documentary/frame_analysis_service.py @@ -1,9 +1,16 @@ +import asyncio import json import os import re +from datetime import datetime +from typing import Any, Callable -from app.utils import utils +from loguru import logger + +from app.config import config from app.services.documentary.frame_analysis_models import FrameBatchResult +from app.services.llm.migration_adapter import create_vision_analyzer +from app.utils import utils, video_processor class DocumentaryFrameAnalysisService: @@ -23,8 +30,488 @@ JSON 必须包含以下键: "overall_activity_summary": "本批次主要活动总结" }} 请务必不要遗漏视频帧,我提供了 {frame_count} 张视频帧,frame_observations 必须包含 {frame_count} 个元素 +请只返回 JSON 字符串,不要附加解释文字。 """.strip() + async def generate_documentary_script( + self, + *, + video_path: str, + video_theme: str = "", + custom_prompt: str = "", + frame_interval_input: int | float = 3, + vision_batch_size: int = 10, + vision_llm_provider: str = "openai", + progress_callback: Callable[[float, str], None] | None = None, + vision_api_key: str | None = None, + vision_model_name: str | None = None, + vision_base_url: str | None = None, + max_concurrency: int | None = None, + ) -> list[dict]: + analysis_result = await self.analyze_video( + video_path=video_path, + video_theme=video_theme, + custom_prompt=custom_prompt, + frame_interval_input=frame_interval_input, + vision_batch_size=vision_batch_size, + vision_llm_provider=vision_llm_provider, + progress_callback=progress_callback, + vision_api_key=vision_api_key, + vision_model_name=vision_model_name, + vision_base_url=vision_base_url, + max_concurrency=max_concurrency, + ) + return analysis_result["video_clip_json"] + + async def analyze_video( + self, + *, + video_path: str, + video_theme: str = "", + custom_prompt: str = "", + frame_interval_input: int | float = 3, + vision_batch_size: int = 10, + vision_llm_provider: str = "openai", + progress_callback: Callable[[float, str], None] | None = None, + vision_api_key: str | None = None, + vision_model_name: str | None = None, + vision_base_url: str | None = None, + max_concurrency: int | None = None, + ) -> dict[str, Any]: + progress = progress_callback or (lambda _p, _m: None) + + if not video_path or not os.path.exists(video_path): + raise FileNotFoundError(f"视频文件不存在: {video_path}") + + frame_interval_seconds = self._resolve_frame_interval(frame_interval_input) + batch_size = self._resolve_batch_size(vision_batch_size) + concurrency = self._resolve_max_concurrency(max_concurrency) + provider = (vision_llm_provider or config.app.get("vision_llm_provider", "openai")).lower() + + api_key = vision_api_key or config.app.get(f"vision_{provider}_api_key") + model_name = vision_model_name or config.app.get(f"vision_{provider}_model_name") + base_url = vision_base_url or config.app.get(f"vision_{provider}_base_url", "") + if not api_key or not model_name: + raise ValueError( + f"未配置 {provider} 的 API Key 或模型名称。" + f"请在设置中配置 vision_{provider}_api_key 和 vision_{provider}_model_name" + ) + + progress(10, "正在提取关键帧...") + keyframe_files = self._load_or_extract_keyframes(video_path, frame_interval_seconds) + progress(25, f"关键帧准备完成,共 {len(keyframe_files)} 帧") + + progress(30, "正在初始化视觉分析器...") + analyzer = create_vision_analyzer( + provider=provider, + api_key=api_key, + model=model_name, + base_url=base_url, + ) + + batches = self._chunk_keyframes(keyframe_files, batch_size=batch_size) + if not batches: + raise RuntimeError("未能构建任何关键帧批次") + + progress(40, f"正在分析关键帧,共 {len(batches)} 个批次...") + batch_results = await self._analyze_batches( + analyzer=analyzer, + batches=batches, + custom_prompt=custom_prompt, + video_theme=video_theme, + max_concurrency=concurrency, + progress_callback=progress, + ) + + progress(65, "正在整理分析结果...") + sorted_batches = self._sort_batch_results(batch_results) + artifact = self._build_analysis_artifact( + sorted_batches, + video_path=video_path, + frame_interval_seconds=frame_interval_seconds, + vision_batch_size=batch_size, + vision_llm_provider=provider, + vision_model_name=model_name, + max_concurrency=concurrency, + ) + analysis_json_path = self._save_analysis_artifact(artifact) + video_clip_json = self._build_video_clip_json(sorted_batches) + + progress(75, "逐帧分析完成") + return { + "analysis_json_path": analysis_json_path, + "analysis_artifact": artifact, + "video_clip_json": video_clip_json, + "keyframe_files": keyframe_files, + } + + def _resolve_frame_interval(self, frame_interval_input: int | float | None) -> float: + interval = frame_interval_input + if interval in (None, ""): + interval = config.frames.get("frame_interval_input", 3) + try: + value = float(interval) + except (TypeError, ValueError): + value = 3.0 + if value <= 0: + raise ValueError("frame_interval_input must be > 0") + return value + + def _resolve_batch_size(self, vision_batch_size: int | None) -> int: + size = vision_batch_size or config.frames.get("vision_batch_size", 10) + try: + value = int(size) + except (TypeError, ValueError): + value = 10 + if value <= 0: + raise ValueError("vision_batch_size must be > 0") + return value + + def _resolve_max_concurrency(self, max_concurrency: int | None) -> int: + value = max_concurrency if max_concurrency is not None else config.frames.get("vision_max_concurrency", 2) + try: + parsed = int(value) + except (TypeError, ValueError): + parsed = 1 + return max(1, parsed) + + def _load_or_extract_keyframes(self, video_path: str, frame_interval_seconds: float) -> list[str]: + keyframes_root = os.path.join(utils.temp_dir(), "keyframes") + os.makedirs(keyframes_root, exist_ok=True) + cache_key = self._build_keyframe_cache_key(video_path, frame_interval_seconds) + cache_dir = os.path.join(keyframes_root, cache_key) + os.makedirs(cache_dir, exist_ok=True) + + cached_files = self._collect_keyframe_paths(cache_dir) + if cached_files: + logger.info(f"使用已缓存关键帧: {cache_dir}, 共 {len(cached_files)} 帧") + return cached_files + + processor = video_processor.VideoProcessor(video_path) + extracted = processor.extract_frames_by_interval_with_fallback( + output_dir=cache_dir, + interval_seconds=frame_interval_seconds, + ) + keyframe_files = sorted(str(path) for path in extracted if str(path).endswith(".jpg")) + if not keyframe_files: + keyframe_files = self._collect_keyframe_paths(cache_dir) + if not keyframe_files: + raise RuntimeError("未提取到任何关键帧") + + logger.info(f"关键帧提取完成: {cache_dir}, 共 {len(keyframe_files)} 帧") + return keyframe_files + + def _build_keyframe_cache_key(self, video_path: str, frame_interval_seconds: float) -> str: + try: + video_mtime = os.path.getmtime(video_path) + except OSError: + video_mtime = 0 + + legacy_prefix = utils.md5(f"{video_path}{video_mtime}") + payload = "|".join( + [ + str(video_path), + str(video_mtime), + str(frame_interval_seconds), + "documentary-keyframes-v2", + ] + ) + return f"{legacy_prefix}_{utils.md5(payload)}" + + @staticmethod + def _collect_keyframe_paths(cache_dir: str) -> list[str]: + if not os.path.exists(cache_dir): + return [] + return sorted( + os.path.join(cache_dir, name) + for name in os.listdir(cache_dir) + if re.fullmatch(r"keyframe_\d{6}_\d{9}\.jpg", name) + ) + + @staticmethod + def _chunk_keyframes(keyframe_files: list[str], batch_size: int) -> list[list[str]]: + return [keyframe_files[index : index + batch_size] for index in range(0, len(keyframe_files), batch_size)] + + async def _analyze_batches( + self, + *, + analyzer: Any, + batches: list[list[str]], + custom_prompt: str, + video_theme: str, + max_concurrency: int, + progress_callback: Callable[[float, str], None], + ) -> list[FrameBatchResult]: + semaphore = asyncio.Semaphore(max(1, max_concurrency)) + total = len(batches) + done = 0 + done_lock = asyncio.Lock() + + batch_time_ranges: list[str] = [] + previous_batch_files: list[str] | None = None + for batch_files in batches: + _, _, time_range = self._get_batch_timestamps(batch_files, previous_batch_files) + batch_time_ranges.append(time_range) + previous_batch_files = batch_files + + async def run_single(batch_index: int, frame_paths: list[str], time_range: str) -> FrameBatchResult: + nonlocal done + prompt = self._build_batch_prompt( + frame_count=len(frame_paths), + video_theme=video_theme, + custom_prompt=custom_prompt, + ) + try: + async with semaphore: + raw_results = await analyzer.analyze_images( + images=frame_paths, + prompt=prompt, + batch_size=max(1, len(frame_paths)), + max_concurrency=1, + ) + raw_response, error_message = self._extract_batch_response(raw_results) + if error_message: + return self._build_failed_batch_result( + batch_index=batch_index, + raw_response=raw_response, + error_message=error_message, + frame_paths=frame_paths, + time_range=time_range, + ) + return self._parse_batch_response( + batch_index=batch_index, + raw_response=raw_response, + frame_paths=frame_paths, + time_range=time_range, + ) + except Exception as exc: + return self._build_failed_batch_result( + batch_index=batch_index, + raw_response="", + error_message=str(exc), + frame_paths=frame_paths, + time_range=time_range, + ) + finally: + async with done_lock: + done += 1 + progress = 40 + (done / max(1, total)) * 25 + progress_callback(progress, f"正在分析关键帧批次 ({done}/{total})...") + + tasks = [ + run_single(batch_index=index, frame_paths=batch_files, time_range=batch_time_ranges[index]) + for index, batch_files in enumerate(batches) + ] + return await asyncio.gather(*tasks) + + def _build_batch_prompt(self, *, frame_count: int, video_theme: str, custom_prompt: str) -> str: + prompt = self._build_analysis_prompt(frame_count=frame_count) + extra_lines: list[str] = [] + if (video_theme or "").strip(): + extra_lines.append(f"视频主题:{video_theme.strip()}") + if (custom_prompt or "").strip(): + extra_lines.append(custom_prompt.strip()) + if not extra_lines: + return prompt + + extras = "\n".join(f"- {line}" for line in extra_lines) + return f"{prompt}\n\n补充分析要求:\n{extras}" + + def _extract_batch_response(self, raw_results: Any) -> tuple[str, str]: + if not raw_results: + return "", "Batch response is empty" + + first_result = raw_results[0] if isinstance(raw_results, list) else raw_results + if isinstance(first_result, dict): + raw_response = str(first_result.get("response", "") or "") + error_message = str(first_result.get("error", "") or "") + if error_message: + if not raw_response: + raw_response = error_message + return raw_response, error_message + if not raw_response.strip(): + return raw_response, "Batch response is empty" + return raw_response, "" + + raw_response = str(first_result or "") + if not raw_response.strip(): + return raw_response, "Batch response is empty" + return raw_response, "" + + def _sort_batch_results(self, batch_results: list[FrameBatchResult]) -> list[FrameBatchResult]: + return sorted(batch_results, key=lambda item: (self._time_range_sort_key(item.time_range), item.batch_index)) + + def _build_analysis_artifact( + self, + batch_results: list[FrameBatchResult], + *, + video_path: str, + frame_interval_seconds: float, + vision_batch_size: int, + vision_llm_provider: str, + vision_model_name: str, + max_concurrency: int, + ) -> dict[str, Any]: + sorted_batches = self._sort_batch_results(batch_results) + + batch_dicts: list[dict[str, Any]] = [] + frame_observations: list[dict[str, Any]] = [] + overall_activity_summaries: list[dict[str, Any]] = [] + + for batch in sorted_batches: + batch_payload = { + "batch_index": batch.batch_index, + "status": batch.status, + "time_range": batch.time_range, + "raw_response": batch.raw_response, + "frame_paths": list(batch.frame_paths), + "frame_observations": list(batch.frame_observations), + "overall_activity_summary": batch.overall_activity_summary, + "fallback_summary": batch.fallback_summary, + "error_message": batch.error_message, + } + batch_dicts.append(batch_payload) + + for observation in batch.frame_observations: + observation_payload = dict(observation) + observation_payload["batch_index"] = batch.batch_index + observation_payload["time_range"] = batch.time_range + frame_observations.append(observation_payload) + + summary_text = (batch.overall_activity_summary or batch.fallback_summary or "").strip() + if summary_text: + overall_activity_summaries.append( + { + "batch_index": batch.batch_index, + "time_range": batch.time_range, + "summary": summary_text, + } + ) + + return { + "artifact_version": "documentary-frame-analysis-v2", + "generated_at": datetime.now().isoformat(), + "video_path": video_path, + "frame_interval_seconds": frame_interval_seconds, + "vision_batch_size": vision_batch_size, + "vision_llm_provider": vision_llm_provider, + "vision_model_name": vision_model_name, + "vision_max_concurrency": max_concurrency, + "batches": batch_dicts, + # 向后兼容旧解析器结构 + "frame_observations": frame_observations, + "overall_activity_summaries": overall_activity_summaries, + } + + def _save_analysis_artifact(self, artifact: dict[str, Any]) -> str: + analysis_dir = os.path.join(utils.storage_dir(), "temp", "analysis") + os.makedirs(analysis_dir, exist_ok=True) + + filename = f"frame_analysis_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" + file_path = os.path.join(analysis_dir, filename) + suffix = 1 + while os.path.exists(file_path): + filename = f"frame_analysis_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{suffix:02d}.json" + file_path = os.path.join(analysis_dir, filename) + suffix += 1 + + with open(file_path, "w", encoding="utf-8") as fp: + json.dump(artifact, fp, ensure_ascii=False, indent=2) + logger.info(f"分析结果已保存到: {file_path}") + return file_path + + def _build_video_clip_json(self, batch_results: list[FrameBatchResult]) -> list[dict]: + clips: list[dict] = [] + for batch in self._sort_batch_results(batch_results): + picture = self._build_batch_picture(batch) + clips.append( + { + "timestamp": batch.time_range, + "picture": picture, + "narration": "", + "OST": 2, + } + ) + return clips + + def _build_batch_picture(self, batch: FrameBatchResult) -> str: + summary = (batch.overall_activity_summary or "").strip() + if summary: + return summary + + fallback = (batch.fallback_summary or "").strip() + if fallback: + return fallback + + observation_lines = [] + for frame in batch.frame_observations: + timestamp = str(frame.get("timestamp", "") or "").strip() + observation = str(frame.get("observation", "") or "").strip() + if timestamp and observation: + observation_lines.append(f"{timestamp}: {observation}") + elif observation: + observation_lines.append(observation) + if observation_lines: + return " ".join(observation_lines) + + raw_response = (batch.raw_response or "").strip() + if raw_response: + return raw_response[:200] + return "该批次分析失败,未返回可用描述。" + + def _time_range_sort_key(self, time_range: str) -> tuple[int, str]: + start = (time_range or "").split("-", 1)[0].strip() + return self._timestamp_to_milliseconds(start), time_range + + @staticmethod + def _timestamp_to_milliseconds(timestamp: str) -> int: + text = (timestamp or "").strip() + try: + if "," in text: + time_part, ms_part = text.split(",", 1) + milliseconds = int(ms_part) + else: + time_part = text + milliseconds = 0 + + parts = [int(part) for part in time_part.split(":") if part] + while len(parts) < 3: + parts.insert(0, 0) + hours, minutes, seconds = parts[-3], parts[-2], parts[-1] + return ((hours * 3600 + minutes * 60 + seconds) * 1000) + milliseconds + except Exception: + return 0 + + def _get_batch_timestamps( + self, + batch_files: list[str], + prev_batch_files: list[str] | None = None, + ) -> tuple[str, str, str]: + if not batch_files: + return "00:00:00,000", "00:00:00,000", "00:00:00,000-00:00:00,000" + + if len(batch_files) == 1 and prev_batch_files: + first_frame = os.path.basename(prev_batch_files[-1]) + last_frame = os.path.basename(batch_files[0]) + else: + first_frame = os.path.basename(batch_files[0]) + last_frame = os.path.basename(batch_files[-1]) + + first_timestamp = self._timestamp_from_keyframe_name(first_frame) + last_timestamp = self._timestamp_from_keyframe_name(last_frame) + return first_timestamp, last_timestamp, f"{first_timestamp}-{last_timestamp}" + + def _timestamp_from_keyframe_name(self, filename: str) -> str: + match = re.search(r"keyframe_\d{6}_(\d{9})\.jpg$", filename) + if not match: + return "00:00:00,000" + token = match.group(1) + hours = int(token[0:2]) + minutes = int(token[2:4]) + seconds = int(token[4:6]) + milliseconds = int(token[6:9]) + return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}" + def _build_analysis_prompt(self, frame_count: int) -> str: return self.PROMPT_TEMPLATE.format(frame_count=frame_count) diff --git a/app/services/generate_narration_script.py b/app/services/generate_narration_script.py index 80fcf1a..e18622b 100644 --- a/app/services/generate_narration_script.py +++ b/app/services/generate_narration_script.py @@ -38,46 +38,90 @@ def parse_frame_analysis_to_markdown(json_file_path): with open(json_file_path, 'r', encoding='utf-8') as file: data = json.load(file) - # 初始化Markdown字符串 + def time_to_milliseconds(time_text): + time_text = (time_text or "").strip() + if not time_text: + return 0 + try: + if "," in time_text: + hhmmss, ms = time_text.split(",", 1) + milliseconds = int(ms) + else: + hhmmss = time_text + milliseconds = 0 + + parts = [int(part) for part in hhmmss.split(":") if part] + while len(parts) < 3: + parts.insert(0, 0) + hours, minutes, seconds = parts[-3], parts[-2], parts[-1] + return ((hours * 3600 + minutes * 60 + seconds) * 1000) + milliseconds + except Exception: + return 0 + + def batch_sort_key(batch): + time_range = batch.get("time_range", "") + start = time_range.split("-", 1)[0].strip() + return time_to_milliseconds(start), batch.get("batch_index", 0) + markdown = "" - - # 获取总结和帧观察数据 + + # 新结构:按批次保存完整分析产物 + if isinstance(data.get("batches"), list): + ordered_batches = sorted(data.get("batches", []), key=batch_sort_key) + + for i, batch in enumerate(ordered_batches, 1): + time_range = batch.get("time_range", "") + summary = ( + batch.get("overall_activity_summary") + or batch.get("summary") + or batch.get("fallback_summary") + or "" + ) + observations = batch.get("frame_observations") or batch.get("observations") or [] + + markdown += f"## 片段 {i}\n" + markdown += f"- 时间范围:{time_range}\n" + markdown += f"- 片段描述:{summary}\n" if summary else "- 片段描述:\n" + markdown += "- 详细描述:\n" + + for frame in observations: + timestamp = frame.get("timestamp", "") + observation = frame.get("observation", "") + markdown += f" - {timestamp}: {observation}\n" if observation else f" - {timestamp}: \n" + + markdown += "\n" + + return markdown + + # 兼容旧结构 summaries = data.get('overall_activity_summaries', []) frame_observations = data.get('frame_observations', []) - - # 按批次组织数据 + batch_frames = {} for frame in frame_observations: batch_index = frame.get('batch_index') if batch_index not in batch_frames: batch_frames[batch_index] = [] batch_frames[batch_index].append(frame) - - # 生成Markdown内容 + for i, summary in enumerate(summaries, 1): batch_index = summary.get('batch_index') time_range = summary.get('time_range', '') batch_summary = summary.get('summary', '') - + markdown += f"## 片段 {i}\n" markdown += f"- 时间范围:{time_range}\n" - - # 添加片段描述 markdown += f"- 片段描述:{batch_summary}\n" if batch_summary else f"- 片段描述:\n" - markdown += "- 详细描述:\n" - - # 添加该批次的帧观察详情 + frames = batch_frames.get(batch_index, []) for frame in frames: timestamp = frame.get('timestamp', '') observation = frame.get('observation', '') - - # 直接使用原始文本,不进行分割 markdown += f" - {timestamp}: {observation}\n" if observation else f" - {timestamp}: \n" - + markdown += "\n" - + return markdown except Exception as e: diff --git a/app/services/script_service.py b/app/services/script_service.py index 34a17a6..61c36a7 100644 --- a/app/services/script_service.py +++ b/app/services/script_service.py @@ -1,22 +1,12 @@ -import os -import json -import time -import asyncio -import requests -from app.utils import video_processor -from loguru import logger -from typing import List, Dict, Any, Callable +from typing import Any, Callable -from app.utils import utils, gemini_analyzer, video_processor -from app.utils.script_generator import ScriptProcessor -from app.config import config +from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService class ScriptGenerator: - def __init__(self): - self.temp_dir = utils.temp_dir() - self.keyframes_dir = os.path.join(self.temp_dir, "keyframes") - + def __init__(self, documentary_service: DocumentaryFrameAnalysisService | None = None): + self.documentary_service = documentary_service or DocumentaryFrameAnalysisService() + async def generate_script( self, video_path: str, @@ -27,298 +17,18 @@ class ScriptGenerator: threshold: int = 30, vision_batch_size: int = 5, vision_llm_provider: str = "gemini", - progress_callback: Callable[[float, str], None] = None - ) -> List[Dict[Any, Any]]: - """ - 生成视频脚本的核心逻辑 - - Args: - video_path: 视频文件路径 - video_theme: 视频主题 - custom_prompt: 自定义提示词 - skip_seconds: 跳过开始的秒数 - threshold: 差异���值 - vision_batch_size: 视觉处理批次大小 - vision_llm_provider: 视觉模型提供商 - progress_callback: 进度回调函数 - - Returns: - List[Dict]: 生成的视频脚本 - """ - if progress_callback is None: - progress_callback = lambda p, m: None - - try: - # 提取关键帧 - progress_callback(10, "正在提取关键帧...") - keyframe_files = await self._extract_keyframes( - video_path, - skip_seconds, - threshold - ) - - # 使用统一的 LLM 接口(支持所有 provider) - script = await self._process_with_llm( - keyframe_files, - video_theme, - custom_prompt, - vision_batch_size, - vision_llm_provider, - progress_callback - ) - - return json.loads(script) if isinstance(script, str) else script - - except Exception as e: - logger.exception("Generate script failed") - raise - - async def _extract_keyframes( - self, - video_path: str, - skip_seconds: int, - threshold: int - ) -> List[str]: - """提取视频关键帧""" - video_hash = utils.md5(video_path + str(os.path.getmtime(video_path))) - video_keyframes_dir = os.path.join(self.keyframes_dir, video_hash) - - # 检查缓存 - keyframe_files = [] - if os.path.exists(video_keyframes_dir): - for filename in sorted(os.listdir(video_keyframes_dir)): - if filename.endswith('.jpg'): - keyframe_files.append(os.path.join(video_keyframes_dir, filename)) - - if keyframe_files: - logger.info(f"Using cached keyframes: {video_keyframes_dir}") - return keyframe_files - - # 提取新的关键帧 - os.makedirs(video_keyframes_dir, exist_ok=True) - - try: - processor = video_processor.VideoProcessor(video_path) - processor.process_video_pipeline( - output_dir=video_keyframes_dir, - skip_seconds=skip_seconds, - threshold=threshold - ) - - for filename in sorted(os.listdir(video_keyframes_dir)): - if filename.endswith('.jpg'): - keyframe_files.append(os.path.join(video_keyframes_dir, filename)) - - return keyframe_files - - except Exception as e: - if os.path.exists(video_keyframes_dir): - import shutil - shutil.rmtree(video_keyframes_dir) - raise - - async def _process_with_llm( - self, - keyframe_files: List[str], - video_theme: str, - custom_prompt: str, - vision_batch_size: int, - vision_llm_provider: str, - progress_callback: Callable[[float, str], None] - ) -> str: - """使用统一 LLM 接口处理视频帧""" - progress_callback(30, "正在初始化视觉分析器...") - - # 使用新的 LLM 迁移适配器(支持所有 provider) - from app.services.llm.migration_adapter import create_vision_analyzer - - # 获取配置 - text_provider = config.app.get('text_llm_provider', 'openai').lower() - vision_api_key = config.app.get(f'vision_{vision_llm_provider}_api_key') - vision_model = config.app.get(f'vision_{vision_llm_provider}_model_name') - vision_base_url = config.app.get(f'vision_{vision_llm_provider}_base_url') - - if not vision_api_key or not vision_model: - raise ValueError(f"未配置 {vision_llm_provider} API Key 或者模型") - - # 创建统一的视觉分析器 - analyzer = create_vision_analyzer( - provider=vision_llm_provider, - api_key=vision_api_key, - model=vision_model, - base_url=vision_base_url + progress_callback: Callable[[float, str], None] | None = None, + ) -> list[dict[Any, Any]]: + callback = progress_callback or (lambda _p, _m: None) + return await self.documentary_service.generate_documentary_script( + video_path=video_path, + video_theme=video_theme, + custom_prompt=custom_prompt, + frame_interval_input=frame_interval_input, + vision_batch_size=vision_batch_size, + vision_llm_provider=vision_llm_provider, + progress_callback=callback, + # 历史参数保留在签名中以兼容调用方;共享逐帧分析当前不使用这两个参数。 + # skip_seconds=skip_seconds, + # threshold=threshold, ) - - progress_callback(40, "正在分析关键帧...") - - # 执行异步分析 - results = await analyzer.analyze_images( - images=keyframe_files, - prompt=config.app.get('vision_analysis_prompt'), - batch_size=vision_batch_size - ) - - progress_callback(60, "正在整理分析结果...") - - # 合并所有批次的分析结果 - frame_analysis = "" - prev_batch_files = None - - for result in results: - if 'error' in result: - logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}") - continue - - batch_files = self._get_batch_files(keyframe_files, result, vision_batch_size) - first_timestamp, last_timestamp, _ = self._get_batch_timestamps(batch_files, prev_batch_files) - - # 添加带时间戳的分��结果 - frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n" - frame_analysis += result['response'] - frame_analysis += "\n" - - prev_batch_files = batch_files - - if not frame_analysis.strip(): - raise Exception("未能生成有效的帧分析结果") - - progress_callback(70, "正在生成脚本...") - - # 构建帧内容列表 - frame_content_list = [] - prev_batch_files = None - - for result in results: - if 'error' in result: - continue - - batch_files = self._get_batch_files(keyframe_files, result, vision_batch_size) - _, _, timestamp_range = self._get_batch_timestamps(batch_files, prev_batch_files) - - frame_content = { - "timestamp": timestamp_range, - "picture": result['response'], - "narration": "", - "OST": 2 - } - frame_content_list.append(frame_content) - prev_batch_files = batch_files - - if not frame_content_list: - raise Exception("没有有效的帧内容可以处理") - - progress_callback(90, "正在生成文案...") - - # 获取文本生��配置 - text_provider = config.app.get('text_llm_provider', 'gemini').lower() - text_api_key = config.app.get(f'text_{text_provider}_api_key') - text_model = config.app.get(f'text_{text_provider}_model_name') - text_base_url = config.app.get(f'text_{text_provider}_base_url') - - # 根据提供商类型选择合适的处理器 - if text_provider == 'gemini(openai)': - # 使用OpenAI兼容的Gemini代理 - from app.utils.script_generator import GeminiOpenAIGenerator - generator = GeminiOpenAIGenerator( - model_name=text_model, - api_key=text_api_key, - prompt=custom_prompt, - base_url=text_base_url - ) - processor = ScriptProcessor( - model_name=text_model, - api_key=text_api_key, - base_url=text_base_url, - prompt=custom_prompt, - video_theme=video_theme - ) - processor.generator = generator - else: - # 使用标准处理器(包括原生Gemini) - processor = ScriptProcessor( - model_name=text_model, - api_key=text_api_key, - base_url=text_base_url, - prompt=custom_prompt, - video_theme=video_theme - ) - - return processor.process_frames(frame_content_list) - - def _get_batch_files( - self, - keyframe_files: List[str], - result: Dict[str, Any], - batch_size: int - ) -> List[str]: - """获取当前批次的图片文件""" - batch_start = result['batch_index'] * batch_size - batch_end = min(batch_start + batch_size, len(keyframe_files)) - return keyframe_files[batch_start:batch_end] - - def _get_batch_timestamps( - self, - batch_files: List[str], - prev_batch_files: List[str] = None - ) -> tuple[str, str, str]: - """获取一批文件的时间戳范围,支持毫秒级精度""" - if not batch_files: - logger.warning("Empty batch files") - return "00:00:00,000", "00:00:00,000", "00:00:00,000-00:00:00,000" - - if len(batch_files) == 1 and prev_batch_files and len(prev_batch_files) > 0: - first_frame = os.path.basename(prev_batch_files[-1]) - last_frame = os.path.basename(batch_files[0]) - else: - first_frame = os.path.basename(batch_files[0]) - last_frame = os.path.basename(batch_files[-1]) - - first_time = first_frame.split('_')[2].replace('.jpg', '') - last_time = last_frame.split('_')[2].replace('.jpg', '') - - def format_timestamp(time_str: str) -> str: - """将时间字符串转换为 HH:MM:SS,mmm 格式""" - try: - if len(time_str) < 4: - logger.warning(f"Invalid timestamp format: {time_str}") - return "00:00:00,000" - - # 处理毫秒部分 - if ',' in time_str: - time_part, ms_part = time_str.split(',') - ms = int(ms_part) - else: - time_part = time_str - ms = 0 - - # 处理时分秒 - parts = time_part.split(':') - if len(parts) == 3: # HH:MM:SS - h, m, s = map(int, parts) - elif len(parts) == 2: # MM:SS - h = 0 - m, s = map(int, parts) - else: # SS - h = 0 - m = 0 - s = int(parts[0]) - - # 处理进位 - if s >= 60: - m += s // 60 - s = s % 60 - if m >= 60: - h += m // 60 - m = m % 60 - - return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}" - - except Exception as e: - logger.error(f"时间戳格式转换错误 {time_str}: {str(e)}") - return "00:00:00,000" - - first_timestamp = format_timestamp(first_time) - last_timestamp = format_timestamp(last_time) - timestamp_range = f"{first_timestamp}-{last_timestamp}" - - return first_timestamp, last_timestamp, timestamp_range diff --git a/tests/test_generate_narration_script_documentary_unittest.py b/tests/test_generate_narration_script_documentary_unittest.py new file mode 100644 index 0000000..edb6963 --- /dev/null +++ b/tests/test_generate_narration_script_documentary_unittest.py @@ -0,0 +1,58 @@ +import json +import unittest +from pathlib import Path +from tempfile import TemporaryDirectory + +from app.services.generate_narration_script import parse_frame_analysis_to_markdown + + +class GenerateNarrationMarkdownTests(unittest.TestCase): + def test_markdown_keeps_batches_without_summary_and_sorts_by_time(self): + artifact = { + "batches": [ + { + "batch_index": 1, + "time_range": "00:00:03,000-00:00:06,000", + "overall_activity_summary": "人物转身跑向远处", + "fallback_summary": "", + "frame_observations": [ + { + "timestamp": "00:00:03,000", + "observation": "人物突然回头", + } + ], + }, + { + "batch_index": 0, + "time_range": "00:00:00,000-00:00:03,000", + "overall_activity_summary": "", + "fallback_summary": "原始响应回退摘要", + "frame_observations": [ + { + "timestamp": "00:00:00,000", + "observation": "镜头里有一只猫", + } + ], + }, + ] + } + + with TemporaryDirectory() as temp_dir: + analysis_path = Path(temp_dir) / "frame-analysis.json" + analysis_path.write_text(json.dumps(artifact, ensure_ascii=False), encoding="utf-8") + markdown = parse_frame_analysis_to_markdown(str(analysis_path)) + + first_range_index = markdown.find("00:00:00,000-00:00:03,000") + second_range_index = markdown.find("00:00:03,000-00:00:06,000") + + self.assertIn("原始响应回退摘要", markdown) + self.assertIn("镜头里有一只猫", markdown) + self.assertIn("人物转身跑向远处", markdown) + self.assertIn("人物突然回头", markdown) + self.assertNotEqual(-1, first_range_index) + self.assertNotEqual(-1, second_range_index) + self.assertLess(first_range_index, second_range_index) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_script_service_documentary_unittest.py b/tests/test_script_service_documentary_unittest.py new file mode 100644 index 0000000..d1fcd70 --- /dev/null +++ b/tests/test_script_service_documentary_unittest.py @@ -0,0 +1,51 @@ +import unittest +from unittest.mock import AsyncMock, patch + +from app.services.script_service import ScriptGenerator + + +class ScriptGeneratorDocumentaryTests(unittest.IsolatedAsyncioTestCase): + async def test_generate_script_passes_frame_interval_to_shared_service(self): + expected_script = [ + { + "timestamp": "00:00:00,000-00:00:03,000", + "picture": "批次描述", + "narration": "", + "OST": 2, + } + ] + progress = [] + + def progress_callback(percent, message): + progress.append((percent, message)) + + with patch("app.services.script_service.DocumentaryFrameAnalysisService") as service_cls: + service = service_cls.return_value + service.generate_documentary_script = AsyncMock(return_value=expected_script) + generator = ScriptGenerator() + + result = await generator.generate_script( + video_path="demo.mp4", + video_theme="荒野生存", + custom_prompt="请聚焦生存动作", + frame_interval_input=3, + vision_batch_size=6, + vision_llm_provider="openai", + progress_callback=progress_callback, + ) + + self.assertEqual(expected_script, result) + service.generate_documentary_script.assert_awaited_once() + called_kwargs = service.generate_documentary_script.await_args.kwargs + self.assertEqual("demo.mp4", called_kwargs["video_path"]) + self.assertEqual(3, called_kwargs["frame_interval_input"]) + self.assertEqual(6, called_kwargs["vision_batch_size"]) + self.assertEqual("openai", called_kwargs["vision_llm_provider"]) + self.assertEqual("荒野生存", called_kwargs["video_theme"]) + self.assertEqual("请聚焦生存动作", called_kwargs["custom_prompt"]) + self.assertIs(called_kwargs["progress_callback"], progress_callback) + self.assertEqual([], progress) + + +if __name__ == "__main__": + unittest.main() diff --git a/webui/tools/generate_script_docu.py b/webui/tools/generate_script_docu.py index 9f51c01..18fba78 100644 --- a/webui/tools/generate_script_docu.py +++ b/webui/tools/generate_script_docu.py @@ -1,21 +1,21 @@ # 纪录片脚本生成 -import os +import asyncio import json import time -import asyncio import traceback + import streamlit as st from loguru import logger -from datetime import datetime from app.config import config -from app.utils import utils, video_processor -from webui.tools.base import create_vision_analyzer, get_batch_files, get_batch_timestamps +from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService +from app.services.generate_narration_script import generate_narration, parse_frame_analysis_to_markdown +from webui.tools.generate_short_summary import parse_and_fix_json def generate_script_docu(params): """ - 生成 纪录片 视频脚本 + 生成纪录片视频脚本。 要求: 原视频无字幕无配音 适合场景: 纪录片、动物搞笑解说、荒野建造等 """ @@ -34,408 +34,83 @@ def generate_script_docu(params): if not params.video_origin_path: st.error("请先选择视频文件") return - """ - 1. 提取键帧 - """ - update_progress(10, "正在提取关键帧...") - # 创建临时目录用于存储关键帧 - keyframes_dir = os.path.join(utils.temp_dir(), "keyframes") - video_hash = utils.md5(params.video_origin_path + str(os.path.getmtime(params.video_origin_path))) - video_keyframes_dir = os.path.join(keyframes_dir, video_hash) - - # 检查是否已经提取过关键帧 - keyframe_files = [] - if os.path.exists(video_keyframes_dir): - # 取已有的关键帧文件 - for filename in sorted(os.listdir(video_keyframes_dir)): - if filename.endswith('.jpg'): - keyframe_files.append(os.path.join(video_keyframes_dir, filename)) - - if keyframe_files: - logger.info(f"使用已缓存的关键帧: {video_keyframes_dir}") - st.info(f"✅ 使用已缓存关键帧,共 {len(keyframe_files)} 帧") - update_progress(20, f"使用已缓存关键帧,共 {len(keyframe_files)} 帧") - - # 如果没有缓存的关键帧,则进行提取 - if not keyframe_files: - try: - # 确保目录存在 - os.makedirs(video_keyframes_dir, exist_ok=True) - - # 初始化视频处理器 - processor = video_processor.VideoProcessor(params.video_origin_path) - - # 显示视频信息 - st.info(f"📹 视频信息: {processor.width}x{processor.height}, {processor.fps:.1f}fps, {processor.duration:.1f}秒") - - # 处理视频并提取关键帧 - 直接使用超级兼容性方案 - update_progress(15, "正在提取关键帧(使用超级兼容性方案)...") - - try: - # 使用优化的关键帧提取方法 - processor.extract_frames_by_interval_ultra_compatible( - output_dir=video_keyframes_dir, - interval_seconds=st.session_state.get('frame_interval_input'), - ) - except Exception as extract_error: - logger.error(f"关键帧提取失败: {extract_error}") - - # 提供详细的错误信息和解决建议 - error_msg = str(extract_error) - if "权限" in error_msg or "permission" in error_msg.lower(): - suggestion = "建议:检查输出目录权限,或更换输出位置" - elif "空间" in error_msg or "space" in error_msg.lower(): - suggestion = "建议:检查磁盘空间是否足够" - else: - suggestion = "建议:检查视频文件是否损坏,或尝试转换为标准格式" - - raise Exception(f"关键帧提取失败: {error_msg}\n{suggestion}") - - # 获取所有关键文件路径 - for filename in sorted(os.listdir(video_keyframes_dir)): - if filename.endswith('.jpg'): - keyframe_files.append(os.path.join(video_keyframes_dir, filename)) - - if not keyframe_files: - # 检查目录中是否有其他文件 - all_files = os.listdir(video_keyframes_dir) - logger.error(f"关键帧目录内容: {all_files}") - raise Exception("未提取到任何关键帧文件,请检查视频文件格式") - - update_progress(20, f"关键帧提取完成,共 {len(keyframe_files)} 帧") - st.success(f"✅ 成功提取 {len(keyframe_files)} 个关键帧") - - except Exception as e: - # 如果提取失败,清理创建的目录 - try: - if os.path.exists(video_keyframes_dir): - import shutil - shutil.rmtree(video_keyframes_dir) - except Exception as cleanup_err: - logger.error(f"清理失败的关键帧目录时出错: {cleanup_err}") - - raise Exception(f"关键帧提取失败: {str(e)}") - - """ - 2. 视觉分析(批量分析每一帧) - """ - # 最佳实践:使用 get() 的默认值参数 + 从 config 获取备用值 vision_llm_provider = ( - st.session_state.get('vision_llm_provider') or - config.app.get('vision_llm_provider', 'openai') + st.session_state.get("vision_llm_provider") or config.app.get("vision_llm_provider", "openai") ).lower() - - logger.info(f"使用 {vision_llm_provider.upper()} 进行视觉分析") - - try: - # ===================初始化视觉分析器=================== - update_progress(30, "正在初始化视觉分析器...") - - # 使用统一的配置键格式获取配置(支持所有 provider) - vision_api_key = ( - st.session_state.get(f'vision_{vision_llm_provider}_api_key') or - config.app.get(f'vision_{vision_llm_provider}_api_key') - ) - vision_model = ( - st.session_state.get(f'vision_{vision_llm_provider}_model_name') or - config.app.get(f'vision_{vision_llm_provider}_model_name') - ) - vision_base_url = ( - st.session_state.get(f'vision_{vision_llm_provider}_base_url') or - config.app.get(f'vision_{vision_llm_provider}_base_url', '') + vision_api_key = ( + st.session_state.get(f"vision_{vision_llm_provider}_api_key") + or config.app.get(f"vision_{vision_llm_provider}_api_key") + ) + vision_model = ( + st.session_state.get(f"vision_{vision_llm_provider}_model_name") + or config.app.get(f"vision_{vision_llm_provider}_model_name") + ) + vision_base_url = ( + st.session_state.get(f"vision_{vision_llm_provider}_base_url") + or config.app.get(f"vision_{vision_llm_provider}_base_url", "") + ) + if not vision_api_key or not vision_model: + raise ValueError( + f"未配置 {vision_llm_provider} 的 API Key 或模型名称。" + f"请在设置页面配置 vision_{vision_llm_provider}_api_key 和 vision_{vision_llm_provider}_model_name" ) - # 验证必需配置 - if not vision_api_key or not vision_model: - raise ValueError( - f"未配置 {vision_llm_provider} 的 API Key 或模型名称。" - f"请在设置页面配置 vision_{vision_llm_provider}_api_key 和 vision_{vision_llm_provider}_model_name" - ) + frame_interval_input = st.session_state.get("frame_interval_input") or config.frames.get( + "frame_interval_input", 3 + ) + vision_batch_size = st.session_state.get("vision_batch_size") or config.frames.get("vision_batch_size", 10) + vision_max_concurrency = st.session_state.get("vision_max_concurrency") or config.frames.get( + "vision_max_concurrency", 2 + ) - # 创建视觉分析器实例(使用统一接口) - llm_params = { - "vision_provider": vision_llm_provider, - "vision_api_key": vision_api_key, - "vision_model_name": vision_model, - "vision_base_url": vision_base_url, - } - - logger.debug(f"视觉分析器配置: provider={vision_llm_provider}, model={vision_model}") - - analyzer = create_vision_analyzer( - provider=vision_llm_provider, - api_key=vision_api_key, - model=vision_model, - base_url=vision_base_url + update_progress(10, "正在提取关键帧...") + service = DocumentaryFrameAnalysisService() + analysis_result = asyncio.run( + service.analyze_video( + video_path=params.video_origin_path, + video_theme=st.session_state.get("video_theme", ""), + custom_prompt=st.session_state.get("custom_prompt", ""), + frame_interval_input=frame_interval_input, + vision_batch_size=vision_batch_size, + vision_llm_provider=vision_llm_provider, + progress_callback=update_progress, + vision_api_key=vision_api_key, + vision_model_name=vision_model, + vision_base_url=vision_base_url, + max_concurrency=vision_max_concurrency, ) + ) - update_progress(40, "正在分析关键帧...") + analysis_json_path = analysis_result["analysis_json_path"] + update_progress(80, "正在生成解说文案...") - # ===================创建异步事件循环=================== - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + text_provider = config.app.get("text_llm_provider", "gemini").lower() + text_api_key = config.app.get(f"text_{text_provider}_api_key") + text_model = config.app.get(f"text_{text_provider}_model_name") + text_base_url = config.app.get(f"text_{text_provider}_base_url") - # 执行异步分析 - vision_batch_size = st.session_state.get('vision_batch_size') or config.frames.get("vision_batch_size") - vision_analysis_prompt = """ -我提供了 %s 张视频帧,它们按时间顺序排列,代表一个连续的视频片段。请仔细分析每一帧的内容,并关注帧与帧之间的变化,以理解整个片段的活动。 + markdown_output = parse_frame_analysis_to_markdown(analysis_json_path) + narration = generate_narration( + markdown_output, + text_api_key, + base_url=text_base_url, + model=text_model, + ) + narration_data = parse_and_fix_json(narration) -首先,请详细描述每一帧的关键视觉信息(包含:主要内容、人物、动作和场景)。 -然后,基于所有帧的分析,请用**简洁的语言**总结整个视频片段中发生的主要活动或事件流程。 + if not narration_data or "items" not in narration_data: + logger.error(f"解说文案JSON解析失败,原始内容: {narration[:200]}...") + raise Exception("解说文案格式错误,无法解析JSON或缺少items字段") -请务必使用 JSON 格式输出你的结果。JSON 结构应如下: -{ - "frame_observations": [ - { - "frame_number": 1, // 或其他标识帧的方式 - "observation": "描述每张视频帧中的主要内容、人物、动作和场景。" - }, - // ... 更多帧的观察 ... - ], - "overall_activity_summary": "在这里填写你总结的整个片段的主要活动,保持简洁。" -} + narration_dict = [{**item, "OST": 2} for item in narration_data["items"]] + script = json.dumps(narration_dict, ensure_ascii=False, indent=2) -请务必不要遗漏视频帧,我提供了 %s 张视频帧,frame_observations 必须包含 %s 个元素 - -请只返回 JSON 字符串,不要包含任何其他解释性文字。 - """ - results = loop.run_until_complete( - analyzer.analyze_images( - images=keyframe_files, - prompt=vision_analysis_prompt, - batch_size=vision_batch_size - ) - ) - loop.close() - - """ - 3. 处理分析结果(格式化为 json 数据) - """ - # ===================处理分析结果=================== - update_progress(60, "正在整理分析结果...") - - # 合并所有批次的分析结果 - frame_analysis = "" - merged_frame_observations = [] # 合并所有批次的帧观察 - overall_activity_summaries = [] # 合并所有批次的整体总结 - prev_batch_files = None - frame_counter = 1 # 初始化帧计数器,用于给所有帧分配连续的序号 - - # 确保分析目录存在 - analysis_dir = os.path.join(utils.storage_dir(), "temp", "analysis") - os.makedirs(analysis_dir, exist_ok=True) - origin_res = os.path.join(analysis_dir, "frame_analysis.json") - with open(origin_res, 'w', encoding='utf-8') as f: - json.dump(results, f, ensure_ascii=False, indent=2) - - # 开始处理 - for result in results: - if 'error' in result: - logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}") - continue - - # 获取当前批次的文件列表 - batch_files = get_batch_files(keyframe_files, result, vision_batch_size) - - # 获取批次的时间戳范围 - first_timestamp, last_timestamp, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files) - - # 解析响应中的JSON数据 - response_text = result['response'] - try: - # 处理可能包含```json```格式的响应 - if "```json" in response_text: - json_content = response_text.split("```json")[1].split("```")[0].strip() - elif "```" in response_text: - json_content = response_text.split("```")[1].split("```")[0].strip() - else: - json_content = response_text.strip() - - response_data = json.loads(json_content) - - # 提取frame_observations和overall_activity_summary - if "frame_observations" in response_data: - frame_obs = response_data["frame_observations"] - overall_summary = response_data.get("overall_activity_summary", "") - - # 添加时间戳信息到每个帧观察 - for i, obs in enumerate(frame_obs): - if i < len(batch_files): - # 从文件名中提取时间戳 - file_path = batch_files[i] - file_name = os.path.basename(file_path) - # 提取时间戳字符串 (格式如: keyframe_000675_000027000.jpg) - # 格式解析: keyframe_帧序号_毫秒时间戳.jpg - timestamp_parts = file_name.split('_') - if len(timestamp_parts) >= 3: - timestamp_str = timestamp_parts[-1].split('.')[0] - try: - # 修正时间戳解析逻辑 - # 格式为000100000,表示00:01:00,000,即1分钟 - # 需要按照对应位数进行解析: - # 前两位是小时,中间两位是分钟,后面是秒和毫秒 - if len(timestamp_str) >= 9: # 确保格式正确 - hours = int(timestamp_str[0:2]) - minutes = int(timestamp_str[2:4]) - seconds = int(timestamp_str[4:6]) - milliseconds = int(timestamp_str[6:9]) - - # 计算总秒数 - timestamp_seconds = hours * 3600 + minutes * 60 + seconds + milliseconds / 1000 - formatted_time = utils.format_time(timestamp_seconds) # 格式化时间戳 - else: - # 兼容旧的解析方式 - timestamp_seconds = int(timestamp_str) / 1000 # 转换为秒 - formatted_time = utils.format_time(timestamp_seconds) # 格式化时间戳 - except ValueError: - logger.warning(f"无法解析时间戳: {timestamp_str}") - timestamp_seconds = 0 - formatted_time = "00:00:00,000" - else: - logger.warning(f"文件名格式不符合预期: {file_name}") - timestamp_seconds = 0 - formatted_time = "00:00:00,000" - - # 添加额外信息到帧观察 - obs["frame_path"] = file_path - obs["timestamp"] = formatted_time - obs["timestamp_seconds"] = timestamp_seconds - obs["batch_index"] = result['batch_index'] - - # 使用全局递增的帧计数器替换原始的frame_number - if "frame_number" in obs: - obs["original_frame_number"] = obs["frame_number"] # 保留原始编号作为参考 - obs["frame_number"] = frame_counter # 赋值连续的帧编号 - frame_counter += 1 # 增加帧计数器 - - # 添加到合并列表 - merged_frame_observations.append(obs) - - # 添加批次整体总结信息 - if overall_summary: - # 从文件名中提取时间戳数值 - first_time_str = first_timestamp.split('_')[-1].split('.')[0] - last_time_str = last_timestamp.split('_')[-1].split('.')[0] - - # 转换为毫秒并计算持续时间(秒) - try: - # 修正解析逻辑,与上面相同的方式解析时间戳 - if len(first_time_str) >= 9 and len(last_time_str) >= 9: - # 解析第一个时间戳 - first_hours = int(first_time_str[0:2]) - first_minutes = int(first_time_str[2:4]) - first_seconds = int(first_time_str[4:6]) - first_ms = int(first_time_str[6:9]) - first_time_seconds = first_hours * 3600 + first_minutes * 60 + first_seconds + first_ms / 1000 - - # 解析第二个时间戳 - last_hours = int(last_time_str[0:2]) - last_minutes = int(last_time_str[2:4]) - last_seconds = int(last_time_str[4:6]) - last_ms = int(last_time_str[6:9]) - last_time_seconds = last_hours * 3600 + last_minutes * 60 + last_seconds + last_ms / 1000 - - batch_duration = last_time_seconds - first_time_seconds - else: - # 兼容旧的解析方式 - first_time_ms = int(first_time_str) - last_time_ms = int(last_time_str) - batch_duration = (last_time_ms - first_time_ms) / 1000 - except ValueError: - # 使用 utils.time_to_seconds 函数处理格式化的时间戳 - first_time_seconds = utils.time_to_seconds(first_time_str.replace('_', ':').replace('-', ',')) - last_time_seconds = utils.time_to_seconds(last_time_str.replace('_', ':').replace('-', ',')) - batch_duration = last_time_seconds - first_time_seconds - - overall_activity_summaries.append({ - "batch_index": result['batch_index'], - "time_range": f"{first_timestamp}-{last_timestamp}", - "duration_seconds": batch_duration, - "summary": overall_summary - }) - except Exception as e: - logger.error(f"解析批次 {result['batch_index']} 的响应数据失败: {str(e)}") - # 添加原始响应作为回退 - frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n" - frame_analysis += response_text - frame_analysis += "\n" - - # 更新上一个批次的文件 - prev_batch_files = batch_files - - # 将合并后的结果转为JSON字符串 - merged_results = { - "frame_observations": merged_frame_observations, - "overall_activity_summaries": overall_activity_summaries - } - - # 使用当前时间创建文件名 - now = datetime.now() - timestamp_str = now.strftime("%Y%m%d_%H%M") - - # 保存完整的分析结果为JSON - analysis_filename = f"frame_analysis_{timestamp_str}.json" - analysis_json_path = os.path.join(analysis_dir, analysis_filename) - with open(analysis_json_path, 'w', encoding='utf-8') as f: - json.dump(merged_results, f, ensure_ascii=False, indent=2) - logger.info(f"分析结果已保存到: {analysis_json_path}") - - """ - 4. 生成文案 - """ - logger.info("开始生成解说文案") - update_progress(80, "正在生成解说文案...") - from app.services.generate_narration_script import parse_frame_analysis_to_markdown, generate_narration - # 从配置中获取文本生成相关配置 - text_provider = config.app.get('text_llm_provider', 'gemini').lower() - text_api_key = config.app.get(f'text_{text_provider}_api_key') - text_model = config.app.get(f'text_{text_provider}_model_name') - text_base_url = config.app.get(f'text_{text_provider}_base_url') - llm_params.update({ - "text_provider": text_provider, - "text_api_key": text_api_key, - "text_model_name": text_model, - "text_base_url": text_base_url - }) - # 整理帧分析数据 - markdown_output = parse_frame_analysis_to_markdown(analysis_json_path) - - # 生成解说文案 - narration = generate_narration( - markdown_output, - text_api_key, - base_url=text_base_url, - model=text_model - ) - - # 使用增强的JSON解析器 - from webui.tools.generate_short_summary import parse_and_fix_json - narration_data = parse_and_fix_json(narration) - - if not narration_data or 'items' not in narration_data: - logger.error(f"解说文案JSON解析失败,原始内容: {narration[:200]}...") - raise Exception("解说文案格式错误,无法解析JSON或缺少items字段") - - narration_dict = narration_data['items'] - # 为 narration_dict 中每个 item 新增一个 OST: 2 的字段, 代表保留原声和配音 - narration_dict = [{**item, "OST": 2} for item in narration_dict] - logger.info(f"解说文案生成完成,共 {len(narration_dict)} 个片段") - # 结果转换为JSON字符串 - script = json.dumps(narration_dict, ensure_ascii=False, indent=2) - - except Exception as e: - logger.exception(f"大模型处理过程中发生错误\n{traceback.format_exc()}") - raise Exception(f"分析失败: {str(e)}") - - if script is None: - st.error("生成脚本失败,请检查日志") - st.stop() - logger.info(f"纪录片解说脚本生成完成") + logger.info(f"纪录片解说脚本生成完成,共 {len(narration_dict)} 个片段") if isinstance(script, list): - st.session_state['video_clip_json'] = script + st.session_state["video_clip_json"] = script elif isinstance(script, str): - st.session_state['video_clip_json'] = json.loads(script) + st.session_state["video_clip_json"] = json.loads(script) update_progress(100, "脚本生成完成") time.sleep(0.1)