Merge pull request #236 from linyqh/codex/refactor-documentary-frame-analysis-pipeline

refactor(documentary): centralize frame analysis pipeline
This commit is contained in:
viccy 2026-04-03 13:16:02 +08:00 committed by GitHub
commit be653c5748
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 2034 additions and 791 deletions

8
.gitignore vendored
View File

@ -39,9 +39,15 @@ bug清单.md
task.md
.claude/*
.serena/*
.worktrees/
# OpenSpec: 忽略活动的变更提案,但保留归档和规范
openspec/*
AGENTS.md
CLAUDE.md
tests/*
tests/*
!tests/test_documentary_frame_analysis_service.py
!tests/test_video_processor_documentary_unittest.py
!tests/test_script_service_documentary_unittest.py
!tests/test_generate_narration_script_documentary_unittest.py
!tests/test_generate_script_docu_unittest.py

View File

@ -33,6 +33,7 @@ NarratoAI is an automated video narration tool that provides an all-in-one solut
</div>
## Latest News
- 2026.04.03 Released version 0.7.8, refactored the documentary frame-analysis pipeline with a shared service and improved extraction, caching, vision batching, and narration generation
- 2025.05.11 Released new version 0.6.0, supports **short drama commentary** and optimized editing process
- 2025.03.06 Released new version 0.5.2, supports DeepSeek R1 and DeepSeek V3 models for short drama mixing
- 2024.12.16 Released new version 0.3.9, supports Alibaba Qwen2-VL model for video understanding; supports short drama mixing

View File

@ -33,6 +33,7 @@ NarratoAI 是一个自动化影视解说工具基于LLM实现文案撰写、
本项目仅供学习和研究使用,不得商用。如需商业授权,请联系作者。
## 最新资讯
- 2026.04.03 发布新版本 0.7.8,重构纪录片逐帧分析链路,统一共享服务并优化抽帧、缓存、视觉并发与文案生成流程
- 2026.03.27 发布新版本 0.7.7,出于安全考虑,已移除 LiteLLM 依赖,统一使用 OpenAI 兼容请求链路
- 2025.11.20 发布新版本 0.7.5,新增 [IndexTTS2](https://github.com/index-tts/index-tts) 语音克隆支持
- 2025.10.15 发布新版本 0.7.3,升级大模型供应商管理能力

View File

@ -0,0 +1,13 @@
from app.services.documentary.frame_analysis_models import (
DocumentaryAnalysisConfig,
FrameBatchResult,
)
from app.services.documentary.frame_analysis_service import (
DocumentaryFrameAnalysisService,
)
__all__ = [
"DocumentaryAnalysisConfig",
"FrameBatchResult",
"DocumentaryFrameAnalysisService",
]

View File

@ -0,0 +1,33 @@
from dataclasses import dataclass, field
@dataclass(slots=True)
class DocumentaryAnalysisConfig:
video_path: str
frame_interval_seconds: float
vision_batch_size: int
vision_llm_provider: str
vision_model_name: str
custom_prompt: str = ""
max_concurrency: int = 2
def __post_init__(self) -> None:
if self.frame_interval_seconds <= 0:
raise ValueError("frame_interval_seconds must be > 0")
if self.vision_batch_size <= 0:
raise ValueError("vision_batch_size must be > 0")
if self.max_concurrency <= 0:
raise ValueError("max_concurrency must be > 0")
@dataclass(slots=True)
class FrameBatchResult:
batch_index: int
status: str
time_range: str
raw_response: str
frame_paths: list[str] = field(default_factory=list)
frame_observations: list[dict] = field(default_factory=list)
overall_activity_summary: str = ""
fallback_summary: str = ""
error_message: str = ""

View File

@ -0,0 +1,761 @@
import asyncio
import json
import os
import re
from datetime import datetime
from typing import Any, Callable
from loguru import logger
from app.config import config
from app.services.documentary.frame_analysis_models import FrameBatchResult
from app.services.generate_narration_script import generate_narration, parse_frame_analysis_to_markdown
from app.services.llm.migration_adapter import create_vision_analyzer
from app.utils import utils, video_processor
class DocumentaryFrameAnalysisService:
PROMPT_TEMPLATE = """
我提供了 {frame_count} 张视频帧它们按时间顺序排列代表一个连续的视频片段
首先请详细描述每一帧的关键视觉信息包含主要内容人物动作和场景
然后基于所有帧的分析请用简洁的语言总结整个视频片段中发生的主要活动或事件流程
请务必使用 JSON 格式输出
JSON 必须包含以下键
- frame_observations: 数组且长度必须为 {frame_count}
- overall_activity_summary: 字符串描述整个批次主要活动
示例结构
{{
"frame_observations": [
{{"timestamp": "00:00:00,000", "observation": "画面描述"}}
],
"overall_activity_summary": "本批次主要活动总结"
}}
请务必不要遗漏视频帧我提供了 {frame_count} 张视频帧frame_observations 必须包含 {frame_count} 个元素
请只返回 JSON 字符串不要附加解释文字
""".strip()
async def generate_documentary_script(
self,
*,
video_path: str,
video_theme: str = "",
custom_prompt: str = "",
frame_interval_input: int | float | None = None,
vision_batch_size: int | None = None,
vision_llm_provider: str | None = None,
progress_callback: Callable[[float, str], None] | None = None,
vision_api_key: str | None = None,
vision_model_name: str | None = None,
vision_base_url: str | None = None,
max_concurrency: int | None = None,
) -> list[dict]:
progress = progress_callback or (lambda _p, _m: None)
analysis_result = await self.analyze_video(
video_path=video_path,
video_theme=video_theme,
custom_prompt=custom_prompt,
frame_interval_input=frame_interval_input,
vision_batch_size=vision_batch_size,
vision_llm_provider=vision_llm_provider,
progress_callback=progress_callback,
vision_api_key=vision_api_key,
vision_model_name=vision_model_name,
vision_base_url=vision_base_url,
max_concurrency=max_concurrency,
)
analysis_json_path = analysis_result["analysis_json_path"]
progress(80, "正在生成解说文案...")
text_provider = config.app.get("text_llm_provider", "openai").lower()
text_api_key = config.app.get(f"text_{text_provider}_api_key")
text_model = config.app.get(f"text_{text_provider}_model_name")
text_base_url = config.app.get(f"text_{text_provider}_base_url")
if not text_api_key or not text_model:
raise ValueError(
f"未配置 {text_provider} 的文本模型参数。"
f"请在设置中配置 text_{text_provider}_api_key 和 text_{text_provider}_model_name"
)
markdown_output = parse_frame_analysis_to_markdown(analysis_json_path)
narration_input = self._build_narration_input(
markdown_output=markdown_output,
video_theme=video_theme,
custom_prompt=custom_prompt,
)
narration_raw = generate_narration(
narration_input,
text_api_key,
base_url=text_base_url,
model=text_model,
)
narration_items = self._parse_narration_items(narration_raw)
final_script = [{**item, "OST": 2} for item in narration_items]
progress(100, "脚本生成完成")
return final_script
async def analyze_video(
self,
*,
video_path: str,
video_theme: str = "",
custom_prompt: str = "",
frame_interval_input: int | float | None = None,
vision_batch_size: int | None = None,
vision_llm_provider: str | None = None,
progress_callback: Callable[[float, str], None] | None = None,
vision_api_key: str | None = None,
vision_model_name: str | None = None,
vision_base_url: str | None = None,
max_concurrency: int | None = None,
) -> dict[str, Any]:
progress = progress_callback or (lambda _p, _m: None)
if not video_path or not os.path.exists(video_path):
raise FileNotFoundError(f"视频文件不存在: {video_path}")
frame_interval_seconds = self._resolve_frame_interval(frame_interval_input)
batch_size = self._resolve_batch_size(vision_batch_size)
concurrency = self._resolve_max_concurrency(max_concurrency)
provider = (vision_llm_provider or config.app.get("vision_llm_provider", "openai")).lower()
api_key = vision_api_key if vision_api_key is not None else config.app.get(f"vision_{provider}_api_key")
model_name = (
vision_model_name if vision_model_name is not None else config.app.get(f"vision_{provider}_model_name")
)
base_url = vision_base_url if vision_base_url is not None else config.app.get(f"vision_{provider}_base_url", "")
if not api_key or not model_name:
raise ValueError(
f"未配置 {provider} 的 API Key 或模型名称。"
f"请在设置中配置 vision_{provider}_api_key 和 vision_{provider}_model_name"
)
progress(10, "正在提取关键帧...")
keyframe_files = self._load_or_extract_keyframes(video_path, frame_interval_seconds)
progress(25, f"关键帧准备完成,共 {len(keyframe_files)}")
progress(30, "正在初始化视觉分析器...")
analyzer = create_vision_analyzer(
provider=provider,
api_key=api_key,
model=model_name,
base_url=base_url,
)
batches = self._chunk_keyframes(keyframe_files, batch_size=batch_size)
if not batches:
raise RuntimeError("未能构建任何关键帧批次")
progress(40, f"正在分析关键帧,共 {len(batches)} 个批次...")
batch_results = await self._analyze_batches(
analyzer=analyzer,
batches=batches,
custom_prompt=custom_prompt,
video_theme=video_theme,
max_concurrency=concurrency,
progress_callback=progress,
)
progress(65, "正在整理分析结果...")
sorted_batches = self._sort_batch_results(batch_results)
artifact = self._build_analysis_artifact(
sorted_batches,
video_path=video_path,
frame_interval_seconds=frame_interval_seconds,
vision_batch_size=batch_size,
vision_llm_provider=provider,
vision_model_name=model_name,
max_concurrency=concurrency,
)
analysis_json_path = self._save_analysis_artifact(artifact)
video_clip_json = self._build_video_clip_json(sorted_batches)
progress(75, "逐帧分析完成")
return {
"analysis_json_path": analysis_json_path,
"analysis_artifact": artifact,
"video_clip_json": video_clip_json,
"keyframe_files": keyframe_files,
}
def _parse_narration_items(self, narration_raw: str) -> list[dict[str, Any]]:
parsed = self._repair_narration_payload(narration_raw)
items: list[dict[str, Any]] = []
if isinstance(parsed, dict):
raw_items = parsed.get("items")
if isinstance(raw_items, list):
items = [item for item in raw_items if isinstance(item, dict)]
if not items:
raise ValueError("解说文案格式错误无法解析JSON或缺少items字段")
return items
def _build_narration_input(self, *, markdown_output: str, video_theme: str, custom_prompt: str) -> str:
context_lines: list[str] = []
if (video_theme or "").strip():
context_lines.append(f"视频主题:{video_theme.strip()}")
if (custom_prompt or "").strip():
context_lines.append(f"补充创作要求:{custom_prompt.strip()}")
if not context_lines:
return markdown_output
context_block = "\n".join(f"- {line}" for line in context_lines)
return f"{markdown_output.rstrip()}\n\n## 创作上下文\n{context_block}\n"
def _repair_narration_payload(self, narration_raw: str) -> dict[str, Any] | None:
def load_json_candidate(payload: str) -> dict[str, Any] | None:
try:
parsed = json.loads(payload)
return parsed if isinstance(parsed, dict) else None
except Exception:
return None
cleaned = (narration_raw or "").strip()
if not cleaned:
return None
candidates: list[str] = [cleaned]
candidates.append(cleaned.replace("{{", "{").replace("}}", "}"))
json_block = re.search(r"```json\s*(.*?)\s*```", cleaned, re.DOTALL)
if json_block:
candidates.append(json_block.group(1).strip())
start = cleaned.find("{")
end = cleaned.rfind("}")
if start >= 0 and end > start:
candidates.append(cleaned[start : end + 1])
for candidate in candidates:
parsed = load_json_candidate(candidate)
if parsed is not None:
return parsed
fixed = cleaned.replace("{{", "{").replace("}}", "}")
fixed_start = fixed.find("{")
fixed_end = fixed.rfind("}")
if fixed_start >= 0 and fixed_end > fixed_start:
fixed = fixed[fixed_start : fixed_end + 1]
fixed = re.sub(r"^\s*#.*$", "", fixed, flags=re.MULTILINE)
fixed = re.sub(r"^\s*//.*$", "", fixed, flags=re.MULTILINE)
fixed = re.sub(r",\s*}", "}", fixed)
fixed = re.sub(r",\s*]", "]", fixed)
fixed = re.sub(r"'([^']*)'\s*:", r'"\1":', fixed)
fixed = re.sub(r'([{\[,]\s*)([A-Za-z_][\w\u4e00-\u9fff]*)(\s*:)', r'\1"\2"\3', fixed)
fixed = re.sub(r'""([^"]*?)""', r'"\1"', fixed)
return load_json_candidate(fixed)
def _resolve_frame_interval(self, frame_interval_input: int | float | None) -> float:
interval = frame_interval_input
if interval in (None, ""):
interval = config.frames.get("frame_interval_input", 3)
try:
value = float(interval)
except (TypeError, ValueError):
value = 3.0
if value <= 0:
raise ValueError("frame_interval_input must be > 0")
return value
def _resolve_batch_size(self, vision_batch_size: int | None) -> int:
size = vision_batch_size or config.frames.get("vision_batch_size", 10)
try:
value = int(size)
except (TypeError, ValueError):
value = 10
if value <= 0:
raise ValueError("vision_batch_size must be > 0")
return value
def _resolve_max_concurrency(self, max_concurrency: int | None) -> int:
value = max_concurrency if max_concurrency is not None else config.frames.get("vision_max_concurrency", 2)
try:
parsed = int(value)
except (TypeError, ValueError):
parsed = 1
return max(1, parsed)
def _load_or_extract_keyframes(self, video_path: str, frame_interval_seconds: float) -> list[str]:
keyframes_root = os.path.join(utils.temp_dir(), "keyframes")
os.makedirs(keyframes_root, exist_ok=True)
cache_key = self._build_keyframe_cache_key(video_path, frame_interval_seconds)
cache_dir = os.path.join(keyframes_root, cache_key)
os.makedirs(cache_dir, exist_ok=True)
cached_files = self._collect_keyframe_paths(cache_dir)
if cached_files:
logger.info(f"使用已缓存关键帧: {cache_dir}, 共 {len(cached_files)}")
return cached_files
processor = video_processor.VideoProcessor(video_path)
extracted = processor.extract_frames_by_interval_with_fallback(
output_dir=cache_dir,
interval_seconds=frame_interval_seconds,
)
keyframe_files = sorted(str(path) for path in extracted if str(path).endswith(".jpg"))
if not keyframe_files:
keyframe_files = self._collect_keyframe_paths(cache_dir)
if not keyframe_files:
raise RuntimeError("未提取到任何关键帧")
logger.info(f"关键帧提取完成: {cache_dir}, 共 {len(keyframe_files)}")
return keyframe_files
def _build_keyframe_cache_key(self, video_path: str, frame_interval_seconds: float) -> str:
try:
video_mtime = os.path.getmtime(video_path)
except OSError:
video_mtime = 0
legacy_prefix = utils.md5(f"{video_path}{video_mtime}")
payload = "|".join(
[
str(video_path),
str(video_mtime),
str(frame_interval_seconds),
"documentary-keyframes-v2",
]
)
return f"{legacy_prefix}_{utils.md5(payload)}"
@staticmethod
def _collect_keyframe_paths(cache_dir: str) -> list[str]:
if not os.path.exists(cache_dir):
return []
return sorted(
os.path.join(cache_dir, name)
for name in os.listdir(cache_dir)
if re.fullmatch(r"keyframe_\d{6}_\d{9}\.jpg", name)
)
@staticmethod
def _chunk_keyframes(keyframe_files: list[str], batch_size: int) -> list[list[str]]:
return [keyframe_files[index : index + batch_size] for index in range(0, len(keyframe_files), batch_size)]
async def _analyze_batches(
self,
*,
analyzer: Any,
batches: list[list[str]],
custom_prompt: str,
video_theme: str,
max_concurrency: int,
progress_callback: Callable[[float, str], None],
) -> list[FrameBatchResult]:
semaphore = asyncio.Semaphore(max(1, max_concurrency))
total = len(batches)
done = 0
done_lock = asyncio.Lock()
batch_time_ranges: list[str] = []
previous_batch_files: list[str] | None = None
for batch_files in batches:
_, _, time_range = self._get_batch_timestamps(batch_files, previous_batch_files)
batch_time_ranges.append(time_range)
previous_batch_files = batch_files
async def run_single(batch_index: int, frame_paths: list[str], time_range: str) -> FrameBatchResult:
nonlocal done
prompt = self._build_batch_prompt(
frame_count=len(frame_paths),
video_theme=video_theme,
custom_prompt=custom_prompt,
)
try:
async with semaphore:
raw_results = await analyzer.analyze_images(
images=frame_paths,
prompt=prompt,
batch_size=max(1, len(frame_paths)),
max_concurrency=1,
)
raw_response, error_message = self._extract_batch_response(raw_results)
if error_message:
return self._build_failed_batch_result(
batch_index=batch_index,
raw_response=raw_response,
error_message=error_message,
frame_paths=frame_paths,
time_range=time_range,
)
return self._parse_batch_response(
batch_index=batch_index,
raw_response=raw_response,
frame_paths=frame_paths,
time_range=time_range,
)
except Exception as exc:
return self._build_failed_batch_result(
batch_index=batch_index,
raw_response="",
error_message=str(exc),
frame_paths=frame_paths,
time_range=time_range,
)
finally:
async with done_lock:
done += 1
progress = 40 + (done / max(1, total)) * 25
progress_callback(progress, f"正在分析关键帧批次 ({done}/{total})...")
tasks = [
run_single(batch_index=index, frame_paths=batch_files, time_range=batch_time_ranges[index])
for index, batch_files in enumerate(batches)
]
return await asyncio.gather(*tasks)
def _build_batch_prompt(self, *, frame_count: int, video_theme: str, custom_prompt: str) -> str:
prompt = self._build_analysis_prompt(frame_count=frame_count)
extra_lines: list[str] = []
if (video_theme or "").strip():
extra_lines.append(f"视频主题:{video_theme.strip()}")
if (custom_prompt or "").strip():
extra_lines.append(custom_prompt.strip())
if not extra_lines:
return prompt
extras = "\n".join(f"- {line}" for line in extra_lines)
return f"{prompt}\n\n补充分析要求:\n{extras}"
def _extract_batch_response(self, raw_results: Any) -> tuple[str, str]:
if not raw_results:
return "", "Batch response is empty"
first_result = raw_results[0] if isinstance(raw_results, list) else raw_results
if isinstance(first_result, dict):
raw_response = str(first_result.get("response", "") or "")
error_message = str(first_result.get("error", "") or "")
if error_message:
if not raw_response:
raw_response = error_message
return raw_response, error_message
if not raw_response.strip():
return raw_response, "Batch response is empty"
return raw_response, ""
raw_response = str(first_result or "")
if not raw_response.strip():
return raw_response, "Batch response is empty"
return raw_response, ""
def _sort_batch_results(self, batch_results: list[FrameBatchResult]) -> list[FrameBatchResult]:
return sorted(batch_results, key=lambda item: (self._time_range_sort_key(item.time_range), item.batch_index))
def _build_analysis_artifact(
self,
batch_results: list[FrameBatchResult],
*,
video_path: str,
frame_interval_seconds: float,
vision_batch_size: int,
vision_llm_provider: str,
vision_model_name: str,
max_concurrency: int,
) -> dict[str, Any]:
sorted_batches = self._sort_batch_results(batch_results)
batch_dicts: list[dict[str, Any]] = []
frame_observations: list[dict[str, Any]] = []
overall_activity_summaries: list[dict[str, Any]] = []
for batch in sorted_batches:
batch_payload = {
"batch_index": batch.batch_index,
"status": batch.status,
"time_range": batch.time_range,
"raw_response": batch.raw_response,
"frame_paths": list(batch.frame_paths),
"frame_observations": list(batch.frame_observations),
"overall_activity_summary": batch.overall_activity_summary,
"fallback_summary": batch.fallback_summary,
"error_message": batch.error_message,
}
batch_dicts.append(batch_payload)
for observation in batch.frame_observations:
observation_payload = dict(observation)
observation_payload["batch_index"] = batch.batch_index
observation_payload["time_range"] = batch.time_range
frame_observations.append(observation_payload)
summary_text = (batch.overall_activity_summary or batch.fallback_summary or "").strip()
if summary_text:
overall_activity_summaries.append(
{
"batch_index": batch.batch_index,
"time_range": batch.time_range,
"summary": summary_text,
}
)
return {
"artifact_version": "documentary-frame-analysis-v2",
"generated_at": datetime.now().isoformat(),
"video_path": video_path,
"frame_interval_seconds": frame_interval_seconds,
"vision_batch_size": vision_batch_size,
"vision_llm_provider": vision_llm_provider,
"vision_model_name": vision_model_name,
"vision_max_concurrency": max_concurrency,
"batches": batch_dicts,
# 向后兼容旧解析器结构
"frame_observations": frame_observations,
"overall_activity_summaries": overall_activity_summaries,
}
def _save_analysis_artifact(self, artifact: dict[str, Any]) -> str:
analysis_dir = os.path.join(utils.storage_dir(), "temp", "analysis")
os.makedirs(analysis_dir, exist_ok=True)
filename = f"frame_analysis_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
file_path = os.path.join(analysis_dir, filename)
suffix = 1
while os.path.exists(file_path):
filename = f"frame_analysis_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{suffix:02d}.json"
file_path = os.path.join(analysis_dir, filename)
suffix += 1
with open(file_path, "w", encoding="utf-8") as fp:
json.dump(artifact, fp, ensure_ascii=False, indent=2)
logger.info(f"分析结果已保存到: {file_path}")
return file_path
def _build_video_clip_json(self, batch_results: list[FrameBatchResult]) -> list[dict]:
clips: list[dict] = []
for batch in self._sort_batch_results(batch_results):
picture = self._build_batch_picture(batch)
clips.append(
{
"timestamp": batch.time_range,
"picture": picture,
"narration": "",
"OST": 2,
}
)
return clips
def _build_batch_picture(self, batch: FrameBatchResult) -> str:
summary = (batch.overall_activity_summary or "").strip()
if summary:
return summary
fallback = (batch.fallback_summary or "").strip()
if fallback:
return fallback
observation_lines = []
for frame in batch.frame_observations:
timestamp = str(frame.get("timestamp", "") or "").strip()
observation = str(frame.get("observation", "") or "").strip()
if timestamp and observation:
observation_lines.append(f"{timestamp}: {observation}")
elif observation:
observation_lines.append(observation)
if observation_lines:
return " ".join(observation_lines)
raw_response = (batch.raw_response or "").strip()
if raw_response:
return raw_response[:200]
return "该批次分析失败,未返回可用描述。"
def _time_range_sort_key(self, time_range: str) -> tuple[int, str]:
start = (time_range or "").split("-", 1)[0].strip()
return self._timestamp_to_milliseconds(start), time_range
@staticmethod
def _timestamp_to_milliseconds(timestamp: str) -> int:
text = (timestamp or "").strip()
try:
if "," in text:
time_part, ms_part = text.split(",", 1)
milliseconds = int(ms_part)
else:
time_part = text
milliseconds = 0
parts = [int(part) for part in time_part.split(":") if part]
while len(parts) < 3:
parts.insert(0, 0)
hours, minutes, seconds = parts[-3], parts[-2], parts[-1]
return ((hours * 3600 + minutes * 60 + seconds) * 1000) + milliseconds
except Exception:
return 0
def _get_batch_timestamps(
self,
batch_files: list[str],
prev_batch_files: list[str] | None = None,
) -> tuple[str, str, str]:
if not batch_files:
return "00:00:00,000", "00:00:00,000", "00:00:00,000-00:00:00,000"
if len(batch_files) == 1 and prev_batch_files:
first_frame = os.path.basename(prev_batch_files[-1])
last_frame = os.path.basename(batch_files[0])
else:
first_frame = os.path.basename(batch_files[0])
last_frame = os.path.basename(batch_files[-1])
first_timestamp = self._timestamp_from_keyframe_name(first_frame)
last_timestamp = self._timestamp_from_keyframe_name(last_frame)
return first_timestamp, last_timestamp, f"{first_timestamp}-{last_timestamp}"
def _timestamp_from_keyframe_name(self, filename: str) -> str:
match = re.search(r"keyframe_\d{6}_(\d{9})\.jpg$", filename)
if not match:
return "00:00:00,000"
token = match.group(1)
hours = int(token[0:2])
minutes = int(token[2:4])
seconds = int(token[4:6])
milliseconds = int(token[6:9])
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
def _build_analysis_prompt(self, frame_count: int) -> str:
return self.PROMPT_TEMPLATE.format(frame_count=frame_count)
def _build_failed_batch_result(
self,
*,
batch_index: int,
raw_response: str,
error_message: str,
frame_paths: list[str],
time_range: str,
) -> FrameBatchResult:
fallback_summary = (raw_response or "").strip()[:200]
if not fallback_summary:
fallback_summary = f"Batch {batch_index} analysis failed: {error_message or 'unknown error'}"
return FrameBatchResult(
batch_index=batch_index,
status="failed",
time_range=time_range,
raw_response=raw_response,
frame_paths=list(frame_paths),
fallback_summary=fallback_summary,
error_message=error_message,
)
def _build_cache_key(
self,
video_path: str,
interval_seconds: float,
prompt_version: str,
model_name: str,
batch_size: int,
max_concurrency: int,
) -> str:
try:
video_mtime = os.path.getmtime(video_path)
except OSError:
video_mtime = 0
legacy_prefix = utils.md5(f"{video_path}{video_mtime}")
payload = "|".join(
[
str(video_path),
str(video_mtime),
str(interval_seconds),
str(prompt_version),
str(model_name),
str(batch_size),
str(max_concurrency),
"documentary-frame-analysis-v2",
]
)
return f"{legacy_prefix}_{utils.md5(payload)}"
def _strip_code_fence(self, response_text: str) -> str:
cleaned = (response_text or "").strip()
cleaned = re.sub(r"^```[a-zA-Z0-9_-]*\s*", "", cleaned)
cleaned = re.sub(r"\s*```$", "", cleaned)
return cleaned.strip()
def _parse_batch_response(
self,
*,
batch_index: int,
raw_response: str,
frame_paths: list[str],
time_range: str,
) -> FrameBatchResult:
try:
payload = json.loads(self._strip_code_fence(raw_response))
except Exception as exc:
return self._build_failed_batch_result(
batch_index=batch_index,
raw_response=raw_response,
error_message=str(exc),
frame_paths=frame_paths,
time_range=time_range,
)
validation_error = self._validate_batch_payload_contract(payload, expected_frame_count=len(frame_paths))
if validation_error:
return self._build_failed_batch_result(
batch_index=batch_index,
raw_response=raw_response,
error_message=validation_error,
frame_paths=frame_paths,
time_range=time_range,
)
raw_observations = payload["frame_observations"]
frame_observations: list[dict] = []
for index, frame_path in enumerate(frame_paths):
entry = raw_observations[index] if index < len(raw_observations) else {}
if isinstance(entry, dict):
observation = str(entry.get("observation", "") or "")
timestamp = str(entry.get("timestamp", "") or "")
else:
observation = str(entry or "")
timestamp = ""
frame_observations.append(
{
"frame_path": frame_path,
"timestamp": timestamp,
"observation": observation,
}
)
raw_summary = payload.get("overall_activity_summary", "")
if isinstance(raw_summary, str):
summary = raw_summary
elif raw_summary is None:
summary = ""
else:
summary = str(raw_summary)
return FrameBatchResult(
batch_index=batch_index,
status="success",
time_range=time_range,
raw_response=raw_response,
frame_paths=list(frame_paths),
frame_observations=frame_observations,
overall_activity_summary=summary,
)
def _validate_batch_payload_contract(self, payload: object, *, expected_frame_count: int) -> str:
if not isinstance(payload, dict):
return "Batch response JSON payload must be an object"
if "frame_observations" not in payload or not isinstance(payload["frame_observations"], list):
return "Batch response must include frame_observations as a list"
if len(payload["frame_observations"]) < expected_frame_count:
return (
"Batch response frame_observations length is shorter than provided frame_paths: "
f"{len(payload['frame_observations'])} < {expected_frame_count}"
)
return ""

View File

@ -38,46 +38,90 @@ def parse_frame_analysis_to_markdown(json_file_path):
with open(json_file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
# 初始化Markdown字符串
def time_to_milliseconds(time_text):
time_text = (time_text or "").strip()
if not time_text:
return 0
try:
if "," in time_text:
hhmmss, ms = time_text.split(",", 1)
milliseconds = int(ms)
else:
hhmmss = time_text
milliseconds = 0
parts = [int(part) for part in hhmmss.split(":") if part]
while len(parts) < 3:
parts.insert(0, 0)
hours, minutes, seconds = parts[-3], parts[-2], parts[-1]
return ((hours * 3600 + minutes * 60 + seconds) * 1000) + milliseconds
except Exception:
return 0
def batch_sort_key(batch):
time_range = batch.get("time_range", "")
start = time_range.split("-", 1)[0].strip()
return time_to_milliseconds(start), batch.get("batch_index", 0)
markdown = ""
# 获取总结和帧观察数据
# 新结构:按批次保存完整分析产物
if isinstance(data.get("batches"), list):
ordered_batches = sorted(data.get("batches", []), key=batch_sort_key)
for i, batch in enumerate(ordered_batches, 1):
time_range = batch.get("time_range", "")
summary = (
batch.get("overall_activity_summary")
or batch.get("summary")
or batch.get("fallback_summary")
or ""
)
observations = batch.get("frame_observations") or batch.get("observations") or []
markdown += f"## 片段 {i}\n"
markdown += f"- 时间范围:{time_range}\n"
markdown += f"- 片段描述:{summary}\n" if summary else "- 片段描述:\n"
markdown += "- 详细描述:\n"
for frame in observations:
timestamp = frame.get("timestamp", "")
observation = frame.get("observation", "")
markdown += f" - {timestamp}: {observation}\n" if observation else f" - {timestamp}: \n"
markdown += "\n"
return markdown
# 兼容旧结构
summaries = data.get('overall_activity_summaries', [])
frame_observations = data.get('frame_observations', [])
# 按批次组织数据
batch_frames = {}
for frame in frame_observations:
batch_index = frame.get('batch_index')
if batch_index not in batch_frames:
batch_frames[batch_index] = []
batch_frames[batch_index].append(frame)
# 生成Markdown内容
for i, summary in enumerate(summaries, 1):
batch_index = summary.get('batch_index')
time_range = summary.get('time_range', '')
batch_summary = summary.get('summary', '')
markdown += f"## 片段 {i}\n"
markdown += f"- 时间范围:{time_range}\n"
# 添加片段描述
markdown += f"- 片段描述:{batch_summary}\n" if batch_summary else f"- 片段描述:\n"
markdown += "- 详细描述:\n"
# 添加该批次的帧观察详情
frames = batch_frames.get(batch_index, [])
for frame in frames:
timestamp = frame.get('timestamp', '')
observation = frame.get('observation', '')
# 直接使用原始文本,不进行分割
markdown += f" - {timestamp}: {observation}\n" if observation else f" - {timestamp}: \n"
markdown += "\n"
return markdown
except Exception as e:

View File

@ -108,6 +108,7 @@ class VisionModelProvider(BaseLLMProvider):
images: List[Union[str, Path, PIL.Image.Image]],
prompt: str,
batch_size: int = 10,
max_concurrency: int = 1,
**kwargs) -> List[str]:
"""
分析图片并返回结果
@ -116,6 +117,7 @@ class VisionModelProvider(BaseLLMProvider):
images: 图片路径列表或PIL图片对象列表
prompt: 分析提示词
batch_size: 批处理大小
max_concurrency: 最大并发批次数实现支持时生效
**kwargs: 其他参数
Returns:

View File

@ -5,7 +5,6 @@
"""
import asyncio
import json
from typing import List, Dict, Any, Optional, Union
from pathlib import Path
import PIL.Image
@ -13,6 +12,7 @@ from loguru import logger
from .unified_service import UnifiedLLMService
from .exceptions import LLMServiceError
from .manager import LLMServiceManager
# 导入新的提示词管理系统
from app.services.prompts import PromptManager
@ -110,41 +110,11 @@ class LegacyLLMAdapter:
temperature=1.5,
response_format="json"
)
# 使用增强的JSON解析器
from webui.tools.generate_short_summary import parse_and_fix_json
parsed_result = parse_and_fix_json(result)
if not parsed_result:
logger.error("无法解析LLM返回的JSON数据")
# 返回一个基本的JSON结构而不是错误字符串
return json.dumps({
"items": [
{
"_id": 1,
"timestamp": "00:00:00-00:00:10",
"picture": "解析失败请检查LLM输出",
"narration": "解说文案生成失败,请重试"
}
]
}, ensure_ascii=False)
# 确保返回的是JSON字符串
return json.dumps(parsed_result, ensure_ascii=False)
return result if isinstance(result, str) else str(result)
except Exception as e:
logger.error(f"生成解说文案失败: {str(e)}")
# 返回一个基本的JSON结构而不是错误字符串
return json.dumps({
"items": [
{
"_id": 1,
"timestamp": "00:00:00-00:00:10",
"picture": "生成失败",
"narration": f"解说文案生成失败: {str(e)}"
}
]
}, ensure_ascii=False)
raise
class VisionAnalyzerAdapter:
@ -155,11 +125,29 @@ class VisionAnalyzerAdapter:
self.api_key = api_key
self.model = model
self.base_url = base_url
def _build_provider_with_explicit_settings(self):
provider_name = (self.provider or "").lower()
if not LLMServiceManager.is_registered():
from .providers import register_all_providers
register_all_providers()
provider_class = LLMServiceManager._vision_providers.get(provider_name)
if provider_class is None:
raise LLMServiceError(f"视觉模型提供商未注册: {provider_name}")
return provider_class(
api_key=self.api_key,
model_name=self.model,
base_url=self.base_url,
)
async def analyze_images(self,
images: List[Union[str, Path, PIL.Image.Image]],
prompt: str,
batch_size: int = 10) -> List[Dict[str, Any]]:
batch_size: int = 10,
max_concurrency: int = 1) -> List[Dict[str, Any]]:
"""
分析图片 - 兼容原有接口
@ -167,17 +155,20 @@ class VisionAnalyzerAdapter:
images: 图片列表
prompt: 分析提示词
batch_size: 批处理大小
max_concurrency: 最大并发批次数
Returns:
分析结果列表格式与旧实现兼容
"""
try:
# 使用统一服务分析图片
results = await UnifiedLLMService.analyze_images(
provider = self._build_provider_with_explicit_settings()
results = await provider.analyze_images(
images=images,
prompt=prompt,
provider=self.provider,
batch_size=batch_size
batch_size=batch_size,
max_concurrency=max_concurrency,
api_key=self.api_key,
api_base=self.base_url,
)
# 转换为旧格式以保持向后兼容性

View File

@ -4,6 +4,7 @@ OpenAI 兼容提供商实现
使用 OpenAI 官方 SDK 调用 OpenAI 兼容接口支持文本和视觉模型
"""
import asyncio
import io
import base64
import re
@ -96,24 +97,35 @@ class OpenAICompatibleVisionProvider(_OpenAICompatibleBase, VisionModelProvider)
images: List[Union[str, Path, PIL.Image.Image]],
prompt: str,
batch_size: int = 10,
max_concurrency: int = 1,
**kwargs,
) -> List[str]:
logger.info(f"开始使用 OpenAI 兼容接口 ({self.model_name}) 分析 {len(images)} 张图片")
processed_images = self._prepare_images(images)
results: List[str] = []
if not processed_images:
return []
for i in range(0, len(processed_images), batch_size):
batch = processed_images[i : i + batch_size]
logger.info(f"处理第 {i // batch_size + 1} 批,共 {len(batch)} 张图片")
try:
result = await self._analyze_batch(batch, prompt, **kwargs)
results.append(result)
except Exception as exc:
logger.error(f"批次 {i // batch_size + 1} 处理失败: {exc}")
results.append(f"批次处理失败: {exc}")
bounded_concurrency = max(1, int(max_concurrency))
semaphore = asyncio.Semaphore(bounded_concurrency)
batches = [
(index // batch_size, processed_images[index : index + batch_size])
for index in range(0, len(processed_images), batch_size)
]
return results
async def run_batch(batch_index: int, batch: List[PIL.Image.Image]) -> tuple[int, str]:
logger.info(f"处理第 {batch_index + 1} 批,共 {len(batch)} 张图片")
async with semaphore:
try:
result = await self._analyze_batch(batch, prompt, **kwargs)
return batch_index, result
except Exception as exc:
logger.error(f"批次 {batch_index + 1} 处理失败: {exc}")
return batch_index, f"批次处理失败: {exc}"
completed = await asyncio.gather(*(run_batch(index, batch) for index, batch in batches))
completed.sort(key=lambda item: item[0])
return [result for _, result in completed]
async def _analyze_batch(self, batch: List[PIL.Image.Image], prompt: str, **kwargs) -> str:
content = [{"type": "text", "text": prompt}]

View File

@ -1,10 +1,14 @@
"""OpenAI 兼容 provider 的最小回归测试。"""
import asyncio
import unittest
from unittest.mock import patch
from app.config import config
from app.services.llm.base import TextModelProvider
from app.services.llm.manager import LLMServiceManager
from app.services.llm.migration_adapter import LegacyLLMAdapter, VisionAnalyzerAdapter
from app.services.llm.openai_compatible_provider import OpenAICompatibleVisionProvider
from app.services.llm.providers import register_all_providers
@ -63,5 +67,128 @@ class OpenAICompatManagerTests(unittest.TestCase):
self.assertEqual("https://new.example/v1", provider.base_url)
class OpenAICompatVisionConcurrencyTests(unittest.IsolatedAsyncioTestCase):
async def test_analyze_images_keeps_batch_order_when_running_concurrently(self):
provider = OpenAICompatibleVisionProvider(api_key="k", model_name="m")
provider._prepare_images = lambda images: list(images)
async def fake_analyze_batch(batch, prompt, **kwargs):
delays = {"a": 0.03, "c": 0.01, "e": 0.0}
await asyncio.sleep(delays[batch[0]])
return f"batch-{batch[0]}"
provider._analyze_batch = fake_analyze_batch
result = await provider.analyze_images(
images=["a", "b", "c", "d", "e", "f"],
prompt="prompt",
batch_size=2,
max_concurrency=2,
)
self.assertEqual(["batch-a", "batch-c", "batch-e"], result)
async def test_analyze_images_respects_max_concurrency_limit(self):
provider = OpenAICompatibleVisionProvider(api_key="k", model_name="m")
provider._prepare_images = lambda images: list(images)
in_flight = 0
max_in_flight = 0
async def fake_analyze_batch(batch, prompt, **kwargs):
nonlocal in_flight, max_in_flight
in_flight += 1
max_in_flight = max(max_in_flight, in_flight)
await asyncio.sleep(0.02)
in_flight -= 1
return f"batch-{batch[0]}"
provider._analyze_batch = fake_analyze_batch
result = await provider.analyze_images(
images=["a", "b", "c", "d", "e", "f"],
prompt="prompt",
batch_size=1,
max_concurrency=2,
)
self.assertEqual(6, len(result))
self.assertEqual(2, max_in_flight)
class ExplicitVisionAdapterSettingsTests(unittest.IsolatedAsyncioTestCase):
class _CapturingVisionProvider:
last_init: tuple[str, str, str | None] | None = None
last_call_kwargs: dict | None = None
def __init__(self, api_key: str, model_name: str, base_url: str | None = None):
self.api_key = api_key
self.model_name = model_name
self.base_url = base_url
ExplicitVisionAdapterSettingsTests._CapturingVisionProvider.last_init = (api_key, model_name, base_url)
async def analyze_images(self, images, prompt, batch_size=10, max_concurrency=1, **kwargs):
ExplicitVisionAdapterSettingsTests._CapturingVisionProvider.last_call_kwargs = dict(kwargs)
return [f"{self.model_name}|{self.api_key}|{self.base_url}"]
def setUp(self):
_reset_manager_state()
self._original_app = dict(config.app)
def tearDown(self):
_reset_manager_state()
config.app.clear()
config.app.update(self._original_app)
async def test_adapter_uses_explicit_settings_instead_of_global_config(self):
LLMServiceManager.register_vision_provider("openai", self._CapturingVisionProvider)
config.app["vision_openai_api_key"] = "config-key"
config.app["vision_openai_model_name"] = "config-model"
config.app["vision_openai_base_url"] = "https://config.example/v1"
adapter = VisionAnalyzerAdapter(
provider="openai",
api_key="explicit-key",
model="explicit-model",
base_url="https://explicit.example/v1",
)
result = await adapter.analyze_images(
images=["/tmp/keyframe_000001_000000100.jpg"],
prompt="描述画面",
batch_size=1,
max_concurrency=1,
)
self.assertEqual(
("explicit-key", "explicit-model", "https://explicit.example/v1"),
self._CapturingVisionProvider.last_init,
)
self.assertEqual("explicit-key", self._CapturingVisionProvider.last_call_kwargs["api_key"])
self.assertEqual("https://explicit.example/v1", self._CapturingVisionProvider.last_call_kwargs["api_base"])
self.assertEqual("explicit-model|explicit-key|https://explicit.example/v1", result[0]["response"])
class LegacyNarrationAdapterBehaviorTests(unittest.TestCase):
def test_generate_narration_returns_raw_unrecoverable_payload_without_fabrication(self):
raw_payload = "not-json-at-all ::: ???"
with patch(
"app.services.llm.migration_adapter.PromptManager.get_prompt",
return_value="prompt",
), patch(
"app.services.llm.migration_adapter._run_async_safely",
return_value=raw_payload,
):
result = LegacyLLMAdapter.generate_narration(
markdown_content="markdown",
api_key="test-key",
base_url="https://example.com/v1",
model="test-model",
)
self.assertEqual(raw_payload, result)
self.assertNotIn('"items"', result)
if __name__ == "__main__":
unittest.main()

View File

@ -1,324 +1,40 @@
import os
import json
import time
import asyncio
import requests
from app.utils import video_processor
from loguru import logger
from typing import List, Dict, Any, Callable
from typing import Any, Callable
from app.utils import utils, gemini_analyzer, video_processor
from app.utils.script_generator import ScriptProcessor
from app.config import config
from loguru import logger
from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService
class ScriptGenerator:
def __init__(self):
self.temp_dir = utils.temp_dir()
self.keyframes_dir = os.path.join(self.temp_dir, "keyframes")
def __init__(self, documentary_service: DocumentaryFrameAnalysisService | None = None):
self.documentary_service = documentary_service or DocumentaryFrameAnalysisService()
async def generate_script(
self,
video_path: str,
video_theme: str = "",
custom_prompt: str = "",
frame_interval_input: int = 5,
frame_interval_input: int | None = None,
skip_seconds: int = 0,
threshold: int = 30,
vision_batch_size: int = 5,
vision_llm_provider: str = "gemini",
progress_callback: Callable[[float, str], None] = None
) -> List[Dict[Any, Any]]:
"""
生成视频脚本的核心逻辑
Args:
video_path: 视频文件路径
video_theme: 视频主题
custom_prompt: 自定义提示词
skip_seconds: 跳过开始的秒数
threshold: 差异<EFBFBD><EFBFBD><EFBFBD>
vision_batch_size: 视觉处理批次大小
vision_llm_provider: 视觉模型提供商
progress_callback: 进度回调函数
Returns:
List[Dict]: 生成的视频脚本
"""
if progress_callback is None:
progress_callback = lambda p, m: None
try:
# 提取关键帧
progress_callback(10, "正在提取关键帧...")
keyframe_files = await self._extract_keyframes(
video_path,
skip_seconds,
threshold
)
# 使用统一的 LLM 接口(支持所有 provider
script = await self._process_with_llm(
keyframe_files,
video_theme,
custom_prompt,
vision_batch_size,
vision_llm_provider,
progress_callback
)
return json.loads(script) if isinstance(script, str) else script
except Exception as e:
logger.exception("Generate script failed")
raise
async def _extract_keyframes(
self,
video_path: str,
skip_seconds: int,
threshold: int
) -> List[str]:
"""提取视频关键帧"""
video_hash = utils.md5(video_path + str(os.path.getmtime(video_path)))
video_keyframes_dir = os.path.join(self.keyframes_dir, video_hash)
# 检查缓存
keyframe_files = []
if os.path.exists(video_keyframes_dir):
for filename in sorted(os.listdir(video_keyframes_dir)):
if filename.endswith('.jpg'):
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
if keyframe_files:
logger.info(f"Using cached keyframes: {video_keyframes_dir}")
return keyframe_files
# 提取新的关键帧
os.makedirs(video_keyframes_dir, exist_ok=True)
try:
processor = video_processor.VideoProcessor(video_path)
processor.process_video_pipeline(
output_dir=video_keyframes_dir,
skip_seconds=skip_seconds,
threshold=threshold
vision_batch_size: int | None = None,
vision_llm_provider: str | None = None,
progress_callback: Callable[[float, str], None] | None = None,
) -> list[dict[Any, Any]]:
callback = progress_callback or (lambda _p, _m: None)
if skip_seconds != 0 or threshold != 30:
logger.warning(
"ScriptGenerator documentary path received "
f"skip_seconds={skip_seconds} threshold={threshold}; "
"the shared documentary frame pipeline does not currently apply these parameters."
)
for filename in sorted(os.listdir(video_keyframes_dir)):
if filename.endswith('.jpg'):
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
return keyframe_files
except Exception as e:
if os.path.exists(video_keyframes_dir):
import shutil
shutil.rmtree(video_keyframes_dir)
raise
async def _process_with_llm(
self,
keyframe_files: List[str],
video_theme: str,
custom_prompt: str,
vision_batch_size: int,
vision_llm_provider: str,
progress_callback: Callable[[float, str], None]
) -> str:
"""使用统一 LLM 接口处理视频帧"""
progress_callback(30, "正在初始化视觉分析器...")
# 使用新的 LLM 迁移适配器(支持所有 provider
from app.services.llm.migration_adapter import create_vision_analyzer
# 获取配置
text_provider = config.app.get('text_llm_provider', 'openai').lower()
vision_api_key = config.app.get(f'vision_{vision_llm_provider}_api_key')
vision_model = config.app.get(f'vision_{vision_llm_provider}_model_name')
vision_base_url = config.app.get(f'vision_{vision_llm_provider}_base_url')
if not vision_api_key or not vision_model:
raise ValueError(f"未配置 {vision_llm_provider} API Key 或者模型")
# 创建统一的视觉分析器
analyzer = create_vision_analyzer(
provider=vision_llm_provider,
api_key=vision_api_key,
model=vision_model,
base_url=vision_base_url
return await self.documentary_service.generate_documentary_script(
video_path=video_path,
video_theme=video_theme,
custom_prompt=custom_prompt,
frame_interval_input=frame_interval_input,
vision_batch_size=vision_batch_size,
vision_llm_provider=vision_llm_provider,
progress_callback=callback,
)
progress_callback(40, "正在分析关键帧...")
# 执行异步分析
results = await analyzer.analyze_images(
images=keyframe_files,
prompt=config.app.get('vision_analysis_prompt'),
batch_size=vision_batch_size
)
progress_callback(60, "正在整理分析结果...")
# 合并所有批次的分析结果
frame_analysis = ""
prev_batch_files = None
for result in results:
if 'error' in result:
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
continue
batch_files = self._get_batch_files(keyframe_files, result, vision_batch_size)
first_timestamp, last_timestamp, _ = self._get_batch_timestamps(batch_files, prev_batch_files)
# 添加带时间戳的分<E79A84><E58886>结果
frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n"
frame_analysis += result['response']
frame_analysis += "\n"
prev_batch_files = batch_files
if not frame_analysis.strip():
raise Exception("未能生成有效的帧分析结果")
progress_callback(70, "正在生成脚本...")
# 构建帧内容列表
frame_content_list = []
prev_batch_files = None
for result in results:
if 'error' in result:
continue
batch_files = self._get_batch_files(keyframe_files, result, vision_batch_size)
_, _, timestamp_range = self._get_batch_timestamps(batch_files, prev_batch_files)
frame_content = {
"timestamp": timestamp_range,
"picture": result['response'],
"narration": "",
"OST": 2
}
frame_content_list.append(frame_content)
prev_batch_files = batch_files
if not frame_content_list:
raise Exception("没有有效的帧内容可以处理")
progress_callback(90, "正在生成文案...")
# 获取文本生<E69CAC><E7949F>配置
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
text_api_key = config.app.get(f'text_{text_provider}_api_key')
text_model = config.app.get(f'text_{text_provider}_model_name')
text_base_url = config.app.get(f'text_{text_provider}_base_url')
# 根据提供商类型选择合适的处理器
if text_provider == 'gemini(openai)':
# 使用OpenAI兼容的Gemini代理
from app.utils.script_generator import GeminiOpenAIGenerator
generator = GeminiOpenAIGenerator(
model_name=text_model,
api_key=text_api_key,
prompt=custom_prompt,
base_url=text_base_url
)
processor = ScriptProcessor(
model_name=text_model,
api_key=text_api_key,
base_url=text_base_url,
prompt=custom_prompt,
video_theme=video_theme
)
processor.generator = generator
else:
# 使用标准处理器包括原生Gemini
processor = ScriptProcessor(
model_name=text_model,
api_key=text_api_key,
base_url=text_base_url,
prompt=custom_prompt,
video_theme=video_theme
)
return processor.process_frames(frame_content_list)
def _get_batch_files(
self,
keyframe_files: List[str],
result: Dict[str, Any],
batch_size: int
) -> List[str]:
"""获取当前批次的图片文件"""
batch_start = result['batch_index'] * batch_size
batch_end = min(batch_start + batch_size, len(keyframe_files))
return keyframe_files[batch_start:batch_end]
def _get_batch_timestamps(
self,
batch_files: List[str],
prev_batch_files: List[str] = None
) -> tuple[str, str, str]:
"""获取一批文件的时间戳范围,支持毫秒级精度"""
if not batch_files:
logger.warning("Empty batch files")
return "00:00:00,000", "00:00:00,000", "00:00:00,000-00:00:00,000"
if len(batch_files) == 1 and prev_batch_files and len(prev_batch_files) > 0:
first_frame = os.path.basename(prev_batch_files[-1])
last_frame = os.path.basename(batch_files[0])
else:
first_frame = os.path.basename(batch_files[0])
last_frame = os.path.basename(batch_files[-1])
first_time = first_frame.split('_')[2].replace('.jpg', '')
last_time = last_frame.split('_')[2].replace('.jpg', '')
def format_timestamp(time_str: str) -> str:
"""将时间字符串转换为 HH:MM:SS,mmm 格式"""
try:
if len(time_str) < 4:
logger.warning(f"Invalid timestamp format: {time_str}")
return "00:00:00,000"
# 处理毫秒部分
if ',' in time_str:
time_part, ms_part = time_str.split(',')
ms = int(ms_part)
else:
time_part = time_str
ms = 0
# 处理时分秒
parts = time_part.split(':')
if len(parts) == 3: # HH:MM:SS
h, m, s = map(int, parts)
elif len(parts) == 2: # MM:SS
h = 0
m, s = map(int, parts)
else: # SS
h = 0
m = 0
s = int(parts[0])
# 处理进位
if s >= 60:
m += s // 60
s = s % 60
if m >= 60:
h += m // 60
m = m % 60
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
except Exception as e:
logger.error(f"时间戳格式转换错误 {time_str}: {str(e)}")
return "00:00:00,000"
first_timestamp = format_timestamp(first_time)
last_timestamp = format_timestamp(last_time)
timestamp_range = f"{first_timestamp}-{last_timestamp}"
return first_timestamp, last_timestamp, timestamp_range

View File

@ -570,29 +570,39 @@ def temp_dir(sub_dir: str = ""):
return d
def clear_keyframes_cache(video_path: str = None):
def clear_keyframes_cache(video_path: str = None, cache_scope: str = "keyframes"):
"""
清理关键帧缓存
Args:
video_path: 视频文件路径如果指定则只清理该视频的缓存
cache_scope: 缓存作用域目录默认 keyframes
"""
try:
keyframes_dir = os.path.join(temp_dir(), "keyframes")
if not os.path.exists(keyframes_dir):
cache_dir = os.path.join(temp_dir(), cache_scope)
if not os.path.exists(cache_dir):
return
import shutil
if video_path:
# 理指定视频的缓存
video_hash = md5(video_path + str(os.path.getmtime(video_path)))
video_keyframes_dir = os.path.join(keyframes_dir, video_hash)
if os.path.exists(video_keyframes_dir):
import shutil
shutil.rmtree(video_keyframes_dir)
logger.info(f"已清理视频关键帧缓存: {video_path}")
# 清理指定视频的缓存(兼容前缀扩展键)
try:
video_mtime = os.path.getmtime(video_path)
except OSError:
video_mtime = 0
video_hash = md5(video_path + str(video_mtime))
for entry in os.listdir(cache_dir):
if not entry.startswith(video_hash):
continue
target_path = os.path.join(cache_dir, entry)
if os.path.isdir(target_path):
shutil.rmtree(target_path)
else:
os.remove(target_path)
logger.info(f"已清理视频关键帧缓存: {video_path}")
else:
# 清理所有缓存
import shutil
shutil.rmtree(keyframes_dir)
shutil.rmtree(cache_dir)
logger.info("已清理所有关键帧缓存")
except Exception as e:

View File

@ -185,6 +185,95 @@ class VideoProcessor:
return frame_numbers
def extract_frames_by_interval_with_fallback(self, output_dir: str, interval_seconds: float = 5.0) -> List[str]:
"""
先尝试单次 ffmpeg 快路径抽帧失败时回退到高兼容方案
"""
if interval_seconds <= 0:
raise ValueError("interval_seconds must be > 0")
os.makedirs(output_dir, exist_ok=True)
try:
return self._extract_frames_fast_path(output_dir, interval_seconds=interval_seconds)
except Exception as exc:
logger.warning(f"快路径抽帧失败,回退到兼容模式: {exc}")
self._cleanup_fast_path_artifacts(output_dir)
self.extract_frames_by_interval_ultra_compatible(output_dir, interval_seconds=interval_seconds)
return self._collect_extracted_frame_paths(output_dir)
def _extract_frames_fast_path(self, output_dir: str, interval_seconds: float = 5.0) -> List[str]:
"""
使用单次 ffmpeg 命令按固定间隔抽帧随后重命名为既有 keyframe 约定格式
"""
if interval_seconds <= 0:
raise ValueError("interval_seconds must be > 0")
os.makedirs(output_dir, exist_ok=True)
raw_pattern = os.path.join(output_dir, "fastframe_%06d.jpg")
cmd = [
"ffmpeg",
"-hide_banner",
"-loglevel",
"error",
"-i",
self.video_path,
"-vf",
f"fps=1/{interval_seconds}",
"-q:v",
"2",
"-start_number",
"0",
"-y",
raw_pattern,
]
subprocess.run(cmd, capture_output=True, text=True, check=True, timeout=120)
raw_files = sorted(
filename
for filename in os.listdir(output_dir)
if re.fullmatch(r"fastframe_\d{6}\.jpg", filename)
)
if not raw_files:
raise RuntimeError("Fast-path extraction produced no frames")
renamed_files: List[str] = []
for index, filename in enumerate(raw_files):
timestamp = index * interval_seconds
frame_number = int(timestamp * self.fps)
token = self._format_timestamp_token(timestamp)
source_path = os.path.join(output_dir, filename)
target_path = os.path.join(output_dir, f"keyframe_{frame_number:06d}_{token}.jpg")
os.replace(source_path, target_path)
renamed_files.append(target_path)
return renamed_files
@staticmethod
def _format_timestamp_token(timestamp: float) -> str:
hours = int(timestamp // 3600)
minutes = int((timestamp % 3600) // 60)
seconds = int(timestamp % 60)
milliseconds = int((timestamp % 1) * 1000)
return f"{hours:02d}{minutes:02d}{seconds:02d}{milliseconds:03d}"
@staticmethod
def _collect_extracted_frame_paths(output_dir: str) -> List[str]:
return sorted(
os.path.join(output_dir, name)
for name in os.listdir(output_dir)
if re.fullmatch(r"keyframe_\d{6}_\d{9}\.jpg", name)
)
@staticmethod
def _cleanup_fast_path_artifacts(output_dir: str) -> None:
for name in os.listdir(output_dir):
if not re.fullmatch(r"fastframe_\d{6}\.jpg", name):
continue
artifact_path = os.path.join(output_dir, name)
if os.path.isfile(artifact_path):
os.remove(artifact_path)
def _extract_single_frame_optimized(self, timestamp: float, output_path: str,
use_hw_accel: bool, hwaccel_type: str) -> bool:
"""

View File

@ -1,5 +1,5 @@
[app]
project_version="0.7.6"
project_version="0.7.8"
# LLM API 超时配置(秒)
llm_vision_timeout = 120 # 视觉模型基础超时时间
@ -152,3 +152,6 @@
# 大模型单次处理的关键帧数量
vision_batch_size = 10
# 视觉批处理最大并发批次数OpenAI 兼容 provider
vision_max_concurrency = 2

11
conftest.py Normal file
View File

@ -0,0 +1,11 @@
"""Pytest collection rules for the repository.
These files are executable smoke-check scripts that live next to the LLM
implementation for convenience. They require live credentials or manual
execution semantics, so keep them out of the default automated test suite.
"""
collect_ignore = [
"app/services/llm/test_llm_service.py",
"app/services/llm/test_openai_compatible_integration.py",
]

View File

@ -1 +1 @@
0.7.7
0.7.8

View File

@ -0,0 +1,275 @@
import unittest
import os
from tempfile import TemporaryDirectory
from unittest.mock import patch
from app.services.documentary.frame_analysis_models import DocumentaryAnalysisConfig
from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService
from app.utils import utils
class DocumentaryFrameAnalysisServiceTests(unittest.TestCase):
def test_build_analysis_prompt_formats_real_frame_count(self):
service = DocumentaryFrameAnalysisService()
prompt = service._build_analysis_prompt(frame_count=3)
self.assertIn("我提供了 3 张视频帧", prompt)
self.assertNotIn("%s", prompt)
self.assertIn("frame_observations", prompt)
self.assertIn("overall_activity_summary", prompt)
def test_parse_failed_batch_keeps_raw_response_and_time_range(self):
service = DocumentaryFrameAnalysisService()
batch = service._build_failed_batch_result(
batch_index=2,
raw_response="not-json",
error_message="JSON decode failed",
frame_paths=["/tmp/keyframe_000000_000000000.jpg"],
time_range="00:00:00,000-00:00:03,000",
)
self.assertEqual("failed", batch.status)
self.assertEqual("not-json", batch.raw_response)
self.assertEqual("00:00:00,000-00:00:03,000", batch.time_range)
self.assertTrue(batch.fallback_summary)
def test_parse_failed_batch_uses_non_empty_fallback_when_raw_response_is_empty(self):
service = DocumentaryFrameAnalysisService()
batch = service._build_failed_batch_result(
batch_index=3,
raw_response="",
error_message="Empty model response",
frame_paths=["/tmp/keyframe_000001_000001000.jpg"],
time_range="00:00:03,000-00:00:06,000",
)
self.assertEqual("failed", batch.status)
self.assertEqual("", batch.raw_response)
self.assertTrue(batch.fallback_summary)
def test_failed_batch_result_uses_prompt_contract_field_names(self):
service = DocumentaryFrameAnalysisService()
batch = service._build_failed_batch_result(
batch_index=4,
raw_response="not-json",
error_message="JSON decode failed",
frame_paths=["/tmp/keyframe_000002_000002000.jpg"],
time_range="00:00:06,000-00:00:09,000",
)
self.assertEqual([], batch.frame_observations)
self.assertEqual("", batch.overall_activity_summary)
self.assertFalse(hasattr(batch, "observations"))
self.assertFalse(hasattr(batch, "summary"))
def test_parse_batch_returns_failed_result_when_json_is_invalid(self):
service = DocumentaryFrameAnalysisService()
batch = service._parse_batch_response(
batch_index=0,
raw_response="plain text",
frame_paths=["/tmp/keyframe_000000_000000000.jpg"],
time_range="00:00:00,000-00:00:03,000",
)
self.assertEqual("failed", batch.status)
self.assertEqual("plain text", batch.raw_response)
self.assertEqual(["/tmp/keyframe_000000_000000000.jpg"], batch.frame_paths)
self.assertEqual([], batch.frame_observations)
self.assertEqual("", batch.overall_activity_summary)
def test_parse_batch_returns_failed_result_for_empty_json_object(self):
service = DocumentaryFrameAnalysisService()
batch = service._parse_batch_response(
batch_index=0,
raw_response="{}",
frame_paths=["/tmp/keyframe_000000_000000000.jpg"],
time_range="00:00:00,000-00:00:03,000",
)
self.assertEqual("failed", batch.status)
self.assertEqual("{}", batch.raw_response)
self.assertIn("frame_observations", batch.error_message)
def test_parse_batch_returns_failed_result_when_observations_are_too_short(self):
service = DocumentaryFrameAnalysisService()
raw_response = """
{
"frame_observations": [
{"observation": "第一帧画面"}
],
"overall_activity_summary": "只有一条帧观察"
}
""".strip()
batch = service._parse_batch_response(
batch_index=1,
raw_response=raw_response,
frame_paths=[
"/tmp/keyframe_000000_000000000.jpg",
"/tmp/keyframe_000075_000003000.jpg",
],
time_range="00:00:00,000-00:00:06,000",
)
self.assertEqual("failed", batch.status)
self.assertEqual(raw_response, batch.raw_response)
self.assertIn("frame_observations", batch.error_message)
def test_parse_batch_parses_code_fenced_json_into_structured_result(self):
service = DocumentaryFrameAnalysisService()
raw_response = """```json
{
"frame_observations": [
{"observation": "第一帧画面"},
{"observation": "第二帧画面"}
],
"overall_activity_summary": "人物从房间走到街道"
}
```"""
batch = service._parse_batch_response(
batch_index=1,
raw_response=raw_response,
frame_paths=[
"/tmp/keyframe_000000_000000000.jpg",
"/tmp/keyframe_000075_000003000.jpg",
],
time_range="00:00:00,000-00:00:06,000",
)
self.assertEqual("success", batch.status)
self.assertEqual(
[
{
"frame_path": "/tmp/keyframe_000000_000000000.jpg",
"timestamp": "",
"observation": "第一帧画面",
},
{
"frame_path": "/tmp/keyframe_000075_000003000.jpg",
"timestamp": "",
"observation": "第二帧画面",
},
],
batch.frame_observations,
)
self.assertEqual("人物从房间走到街道", batch.overall_activity_summary)
self.assertEqual("", batch.fallback_summary)
def test_parse_batch_preserves_frames_when_summary_is_missing(self):
service = DocumentaryFrameAnalysisService()
raw_response = """
{
"frame_observations": [
{"observation": "第一帧画面"},
{"observation": "第二帧画面"}
]
}
""".strip()
batch = service._parse_batch_response(
batch_index=2,
raw_response=raw_response,
frame_paths=[
"/tmp/keyframe_000000_000000000.jpg",
"/tmp/keyframe_000075_000003000.jpg",
],
time_range="00:00:00,000-00:00:06,000",
)
self.assertEqual("success", batch.status)
self.assertEqual(2, len(batch.frame_observations))
self.assertEqual("", batch.overall_activity_summary)
def test_cache_key_changes_when_interval_changes(self):
service = DocumentaryFrameAnalysisService()
with patch("app.services.documentary.frame_analysis_service.os.path.getmtime", return_value=100.0):
key_a = service._build_cache_key("video.mp4", 3.0, "prompt-v1", "model-a", 10, 2)
key_b = service._build_cache_key("video.mp4", 5.0, "prompt-v1", "model-a", 10, 2)
self.assertNotEqual(key_a, key_b)
def test_cache_key_changes_when_model_changes(self):
service = DocumentaryFrameAnalysisService()
with patch("app.services.documentary.frame_analysis_service.os.path.getmtime", return_value=100.0):
key_a = service._build_cache_key("video.mp4", 3.0, "prompt-v1", "model-a", 10, 2)
key_b = service._build_cache_key("video.mp4", 3.0, "prompt-v1", "model-b", 10, 2)
self.assertNotEqual(key_a, key_b)
def test_cache_key_starts_with_legacy_video_hash_prefix(self):
service = DocumentaryFrameAnalysisService()
with patch("app.services.documentary.frame_analysis_service.os.path.getmtime", return_value=123.0):
key = service._build_cache_key("video.mp4", 3.0, "prompt-v1", "model-a", 10, 2)
expected_prefix = utils.md5("video.mp4" + "123.0")
self.assertTrue(key.startswith(expected_prefix))
def test_clear_keyframes_cache_respects_scope_and_prefix_match(self):
with TemporaryDirectory() as temp_root:
service = DocumentaryFrameAnalysisService()
analysis_dir = os.path.join(temp_root, "analysis")
os.makedirs(analysis_dir, exist_ok=True)
with patch("app.services.documentary.frame_analysis_service.os.path.getmtime", return_value=123.0):
target_key_a = service._build_cache_key("video.mp4", 3.0, "prompt-v1", "model-a", 10, 2)
target_key_b = service._build_cache_key("video.mp4", 5.0, "prompt-v1", "model-a", 10, 2)
keep_key = service._build_cache_key("other.mp4", 3.0, "prompt-v1", "model-a", 10, 2)
target_dir_a = os.path.join(analysis_dir, target_key_a)
target_dir_b = os.path.join(analysis_dir, target_key_b)
keep_dir = os.path.join(analysis_dir, keep_key)
os.makedirs(target_dir_a, exist_ok=True)
os.makedirs(target_dir_b, exist_ok=True)
os.makedirs(keep_dir, exist_ok=True)
with patch("app.utils.utils.temp_dir", return_value=temp_root), patch(
"app.utils.utils.os.path.getmtime", return_value=123.0
):
utils.clear_keyframes_cache(video_path="video.mp4", cache_scope="analysis")
self.assertFalse(os.path.exists(target_dir_a))
self.assertFalse(os.path.exists(target_dir_b))
self.assertTrue(os.path.exists(keep_dir))
class DocumentaryAnalysisConfigTests(unittest.TestCase):
def test_config_rejects_non_positive_frame_interval(self):
with self.assertRaises(ValueError):
DocumentaryAnalysisConfig(
video_path="/tmp/demo.mp4",
frame_interval_seconds=0,
vision_batch_size=5,
vision_llm_provider="openai",
vision_model_name="gpt-4o-mini",
)
def test_config_rejects_non_positive_batch_size(self):
with self.assertRaises(ValueError):
DocumentaryAnalysisConfig(
video_path="/tmp/demo.mp4",
frame_interval_seconds=5,
vision_batch_size=0,
vision_llm_provider="openai",
vision_model_name="gpt-4o-mini",
)
def test_config_rejects_non_positive_max_concurrency(self):
with self.assertRaises(ValueError):
DocumentaryAnalysisConfig(
video_path="/tmp/demo.mp4",
frame_interval_seconds=5,
vision_batch_size=5,
vision_llm_provider="openai",
vision_model_name="gpt-4o-mini",
max_concurrency=0,
)

View File

@ -0,0 +1,58 @@
import json
import unittest
from pathlib import Path
from tempfile import TemporaryDirectory
from app.services.generate_narration_script import parse_frame_analysis_to_markdown
class GenerateNarrationMarkdownTests(unittest.TestCase):
def test_markdown_keeps_batches_without_summary_and_sorts_by_time(self):
artifact = {
"batches": [
{
"batch_index": 1,
"time_range": "00:00:03,000-00:00:06,000",
"overall_activity_summary": "人物转身跑向远处",
"fallback_summary": "",
"frame_observations": [
{
"timestamp": "00:00:03,000",
"observation": "人物突然回头",
}
],
},
{
"batch_index": 0,
"time_range": "00:00:00,000-00:00:03,000",
"overall_activity_summary": "",
"fallback_summary": "原始响应回退摘要",
"frame_observations": [
{
"timestamp": "00:00:00,000",
"observation": "镜头里有一只猫",
}
],
},
]
}
with TemporaryDirectory() as temp_dir:
analysis_path = Path(temp_dir) / "frame-analysis.json"
analysis_path.write_text(json.dumps(artifact, ensure_ascii=False), encoding="utf-8")
markdown = parse_frame_analysis_to_markdown(str(analysis_path))
first_range_index = markdown.find("00:00:00,000-00:00:03,000")
second_range_index = markdown.find("00:00:03,000-00:00:06,000")
self.assertIn("原始响应回退摘要", markdown)
self.assertIn("镜头里有一只猫", markdown)
self.assertIn("人物转身跑向远处", markdown)
self.assertIn("人物突然回头", markdown)
self.assertNotEqual(-1, first_range_index)
self.assertNotEqual(-1, second_range_index)
self.assertLess(first_range_index, second_range_index)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,19 @@
import unittest
from webui.tools.generate_script_docu import _normalize_progress_value
class GenerateScriptDocuProgressTests(unittest.TestCase):
def test_normalize_progress_rounds_percentage_float_to_valid_streamlit_int(self):
self.assertEqual(43, _normalize_progress_value(43.125))
def test_normalize_progress_converts_ratio_float_to_percentage_int(self):
self.assertEqual(43, _normalize_progress_value(0.43125))
def test_normalize_progress_clamps_out_of_range_values(self):
self.assertEqual(0, _normalize_progress_value(-5))
self.assertEqual(100, _normalize_progress_value(101))
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,316 @@
import json
import unittest
from pathlib import Path
from tempfile import TemporaryDirectory
from unittest.mock import AsyncMock, patch
from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService
from app.services.script_service import ScriptGenerator
class ScriptGeneratorDocumentaryTests(unittest.IsolatedAsyncioTestCase):
async def test_generate_script_forwards_explicit_values_to_shared_service(self):
expected_script = [
{
"timestamp": "00:00:00,000-00:00:03,000",
"picture": "批次描述",
"narration": "这里是解说词",
"OST": 2,
}
]
callback = lambda _percent, _message: None
with patch("app.services.script_service.DocumentaryFrameAnalysisService") as service_cls:
service = service_cls.return_value
service.generate_documentary_script = AsyncMock(return_value=expected_script)
generator = ScriptGenerator()
result = await generator.generate_script(
video_path="demo.mp4",
video_theme="荒野生存",
custom_prompt="请聚焦生存动作",
frame_interval_input=3,
vision_batch_size=6,
vision_llm_provider="openai",
progress_callback=callback,
)
self.assertEqual(expected_script, result)
self.assertTrue(result[0]["narration"])
service.generate_documentary_script.assert_awaited_once()
called_kwargs = service.generate_documentary_script.await_args.kwargs
self.assertEqual("demo.mp4", called_kwargs["video_path"])
self.assertEqual(3, called_kwargs["frame_interval_input"])
self.assertEqual(6, called_kwargs["vision_batch_size"])
self.assertEqual("openai", called_kwargs["vision_llm_provider"])
self.assertEqual("荒野生存", called_kwargs["video_theme"])
self.assertEqual("请聚焦生存动作", called_kwargs["custom_prompt"])
self.assertIs(called_kwargs["progress_callback"], callback)
async def test_generate_script_forwards_unset_values_as_none(self):
expected_script = [
{
"timestamp": "00:00:00,000-00:00:03,000",
"picture": "批次描述",
"narration": "这里是解说词",
"OST": 2,
}
]
with patch("app.services.script_service.DocumentaryFrameAnalysisService") as service_cls:
service = service_cls.return_value
service.generate_documentary_script = AsyncMock(return_value=expected_script)
generator = ScriptGenerator()
await generator.generate_script(video_path="demo.mp4")
called_kwargs = service.generate_documentary_script.await_args.kwargs
self.assertIsNone(called_kwargs["frame_interval_input"])
self.assertIsNone(called_kwargs["vision_batch_size"])
self.assertIsNone(called_kwargs["vision_llm_provider"])
async def test_generate_script_warns_when_skip_seconds_or_threshold_are_non_default(self):
expected_script = [
{
"timestamp": "00:00:00,000-00:00:03,000",
"picture": "批次描述",
"narration": "这里是解说词",
"OST": 2,
}
]
with patch("app.services.script_service.DocumentaryFrameAnalysisService") as service_cls, patch(
"app.services.script_service.logger.warning"
) as warning:
service = service_cls.return_value
service.generate_documentary_script = AsyncMock(return_value=expected_script)
generator = ScriptGenerator()
await generator.generate_script(
video_path="demo.mp4",
skip_seconds=2,
threshold=20,
)
warning.assert_called_once()
warning_message = warning.call_args.args[0]
self.assertIn("skip_seconds", warning_message)
self.assertIn("threshold", warning_message)
self.assertIn("does not currently apply", warning_message)
class DocumentaryFrameAnalysisServiceScriptGenerationTests(unittest.IsolatedAsyncioTestCase):
async def test_generate_documentary_script_returns_final_narrated_items(self):
service = DocumentaryFrameAnalysisService()
analysis_payload = {
"batches": [
{
"batch_index": 0,
"time_range": "00:00:00,000-00:00:03,000",
"overall_activity_summary": "",
"fallback_summary": "回退摘要",
"frame_observations": [
{"timestamp": "00:00:00,000", "observation": "镜头里有一只猫"},
],
}
]
}
with TemporaryDirectory() as temp_dir:
analysis_path = Path(temp_dir) / "frame_analysis_test.json"
analysis_path.write_text(json.dumps(analysis_payload, ensure_ascii=False), encoding="utf-8")
with patch.object(
DocumentaryFrameAnalysisService,
"analyze_video",
AsyncMock(return_value={"analysis_json_path": str(analysis_path)}),
), patch.dict(
"app.services.documentary.frame_analysis_service.config.app",
{
"text_llm_provider": "openai",
"text_openai_api_key": "test-key",
"text_openai_model_name": "test-model",
"text_openai_base_url": "https://example.com/v1",
},
), patch(
"app.services.documentary.frame_analysis_service.generate_narration",
return_value='{"items":[{"timestamp":"00:00:00,000-00:00:03,000","picture":"镜头里有一只猫","narration":"一只猫警觉地望向镜头。"}]}',
):
result = await service.generate_documentary_script(video_path="demo.mp4")
self.assertEqual(1, len(result))
self.assertEqual("00:00:00,000-00:00:03,000", result[0]["timestamp"])
self.assertEqual("镜头里有一只猫", result[0]["picture"])
self.assertEqual("一只猫警觉地望向镜头。", result[0]["narration"])
self.assertEqual(2, result[0]["OST"])
async def test_generate_documentary_script_raises_when_narration_json_is_malformed(self):
service = DocumentaryFrameAnalysisService()
analysis_payload = {
"batches": [
{
"batch_index": 0,
"time_range": "00:00:00,000-00:00:03,000",
"overall_activity_summary": "测试摘要",
"fallback_summary": "",
"frame_observations": [
{"timestamp": "00:00:00,000", "observation": "镜头里有一只猫"},
],
}
]
}
with TemporaryDirectory() as temp_dir:
analysis_path = Path(temp_dir) / "frame_analysis_test.json"
analysis_path.write_text(json.dumps(analysis_payload, ensure_ascii=False), encoding="utf-8")
with patch.object(
DocumentaryFrameAnalysisService,
"analyze_video",
AsyncMock(return_value={"analysis_json_path": str(analysis_path)}),
), patch.dict(
"app.services.documentary.frame_analysis_service.config.app",
{
"text_llm_provider": "openai",
"text_openai_api_key": "test-key",
"text_openai_model_name": "test-model",
"text_openai_base_url": "https://example.com/v1",
},
), patch(
"app.services.documentary.frame_analysis_service.generate_narration",
return_value="malformed narration payload",
):
with self.assertRaises(Exception) as ctx:
await service.generate_documentary_script(video_path="demo.mp4")
self.assertIn("解说文案格式错误", str(ctx.exception))
self.assertIn("items", str(ctx.exception))
def test_parse_narration_items_recovers_from_common_json_damage(self):
service = DocumentaryFrameAnalysisService()
damaged_payload = """
解释文字
```json
{{
"items": [
{{
"timestamp": "00:00:00,000-00:00:03,000",
"picture": "镜头里有一只猫",
"narration": "一只猫警觉地望向镜头。",
}},
],
}}
```
补充文字
""".strip()
parsed_items = service._parse_narration_items(damaged_payload)
self.assertEqual(1, len(parsed_items))
self.assertEqual("00:00:00,000-00:00:03,000", parsed_items[0]["timestamp"])
self.assertEqual("镜头里有一只猫", parsed_items[0]["picture"])
self.assertEqual("一只猫警觉地望向镜头。", parsed_items[0]["narration"])
def test_parse_narration_items_raises_for_unrecoverable_payload(self):
service = DocumentaryFrameAnalysisService()
with self.assertRaises(ValueError) as ctx:
service._parse_narration_items("not-json-at-all ::: ???")
self.assertIn("解说文案格式错误", str(ctx.exception))
self.assertIn("items", str(ctx.exception))
async def test_generate_documentary_script_includes_theme_and_custom_prompt_for_narration(self):
service = DocumentaryFrameAnalysisService()
analysis_payload = {
"batches": [
{
"batch_index": 0,
"time_range": "00:00:00,000-00:00:03,000",
"overall_activity_summary": "测试摘要",
"fallback_summary": "",
"frame_observations": [
{"timestamp": "00:00:00,000", "observation": "镜头里有一只猫"},
],
}
]
}
with TemporaryDirectory() as temp_dir:
analysis_path = Path(temp_dir) / "frame_analysis_test.json"
analysis_path.write_text(json.dumps(analysis_payload, ensure_ascii=False), encoding="utf-8")
with patch.object(
DocumentaryFrameAnalysisService,
"analyze_video",
AsyncMock(return_value={"analysis_json_path": str(analysis_path)}),
), patch.dict(
"app.services.documentary.frame_analysis_service.config.app",
{
"text_llm_provider": "openai",
"text_openai_api_key": "test-key",
"text_openai_model_name": "test-model",
"text_openai_base_url": "https://example.com/v1",
},
), patch(
"app.services.documentary.frame_analysis_service.generate_narration",
return_value='{"items":[{"timestamp":"00:00:00,000-00:00:03,000","picture":"镜头里有一只猫","narration":"一只猫警觉地望向镜头。"}]}',
) as mocked_generate:
await service.generate_documentary_script(
video_path="demo.mp4",
video_theme="野生动物纪录片",
custom_prompt="重点描述危险信号",
)
narration_input = mocked_generate.call_args.args[0]
self.assertIn("## 创作上下文", narration_input)
self.assertIn("视频主题:野生动物纪录片", narration_input)
self.assertIn("补充创作要求:重点描述危险信号", narration_input)
async def test_analyze_video_forwards_explicit_empty_base_url_without_config_fallback(self):
service = DocumentaryFrameAnalysisService()
with patch.dict(
"app.services.documentary.frame_analysis_service.config.app",
{
"vision_llm_provider": "openai",
"vision_openai_api_key": "config-key",
"vision_openai_model_name": "config-model",
"vision_openai_base_url": "https://config.example/v1",
},
), patch(
"app.services.documentary.frame_analysis_service.os.path.exists",
return_value=True,
), patch.object(
service,
"_load_or_extract_keyframes",
return_value=["/tmp/keyframe_000001_000000100.jpg"],
), patch.object(
service,
"_analyze_batches",
AsyncMock(return_value=[]),
), patch.object(
service,
"_save_analysis_artifact",
return_value="/tmp/frame_analysis_test.json",
), patch.object(
service,
"_build_video_clip_json",
return_value=[],
), patch(
"app.services.documentary.frame_analysis_service.create_vision_analyzer",
return_value=object(),
) as mocked_create_analyzer:
await service.analyze_video(
video_path="/tmp/demo.mp4",
vision_api_key="explicit-key",
vision_model_name="explicit-model",
vision_base_url="",
)
called_kwargs = mocked_create_analyzer.call_args.kwargs
self.assertEqual("openai", called_kwargs["provider"])
self.assertEqual("explicit-key", called_kwargs["api_key"])
self.assertEqual("explicit-model", called_kwargs["model"])
self.assertEqual("", called_kwargs["base_url"])
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,91 @@
import os
import unittest
from tempfile import TemporaryDirectory
from unittest.mock import patch
from app.utils.video_processor import VideoProcessor
class VideoProcessorDocumentaryTests(unittest.TestCase):
@patch.object(VideoProcessor, "_extract_frames_fast_path", return_value=["a.jpg"])
def test_extract_frames_by_interval_prefers_fast_path(self, fast_path):
processor = VideoProcessor.__new__(VideoProcessor)
processor.video_path = "demo.mp4"
processor.duration = 6.0
processor.fps = 25.0
result = processor.extract_frames_by_interval_with_fallback("/tmp/out", interval_seconds=3.0)
self.assertEqual(["a.jpg"], result)
fast_path.assert_called_once_with("/tmp/out", interval_seconds=3.0)
def test_extract_frames_by_interval_falls_back_to_ultra_compatible(self):
processor = VideoProcessor.__new__(VideoProcessor)
processor.video_path = "demo.mp4"
processor.duration = 6.0
processor.fps = 25.0
with TemporaryDirectory() as output_dir:
expected_frame_path = os.path.join(output_dir, "keyframe_000000_000000000.jpg")
def ultra_compatible_fallback(self, output_dir_arg, interval_seconds=5.0):
with open(expected_frame_path, "wb") as frame_file:
frame_file.write(b"frame")
return [0]
with patch.object(VideoProcessor, "_extract_frames_fast_path", side_effect=RuntimeError("fast path failed")) as fast_path, patch.object(
VideoProcessor,
"extract_frames_by_interval_ultra_compatible",
side_effect=ultra_compatible_fallback,
autospec=True,
) as fallback:
result = processor.extract_frames_by_interval_with_fallback(output_dir, interval_seconds=3.0)
self.assertEqual([expected_frame_path], result)
fast_path.assert_called_once_with(output_dir, interval_seconds=3.0)
fallback.assert_called_once_with(processor, output_dir, interval_seconds=3.0)
def test_extract_frames_by_interval_rejects_non_positive_interval(self):
processor = VideoProcessor.__new__(VideoProcessor)
processor.video_path = "demo.mp4"
processor.duration = 6.0
processor.fps = 25.0
with patch.object(VideoProcessor, "extract_frames_by_interval_ultra_compatible", autospec=True) as fallback:
with self.assertRaises(ValueError):
processor.extract_frames_by_interval_with_fallback("/tmp/out", interval_seconds=0)
fallback.assert_not_called()
def test_extract_frames_by_interval_fallback_cleans_partial_fast_path_artifacts(self):
processor = VideoProcessor.__new__(VideoProcessor)
processor.video_path = "demo.mp4"
processor.duration = 6.0
processor.fps = 25.0
with TemporaryDirectory() as output_dir:
stale_fastframe = os.path.join(output_dir, "fastframe_000000.jpg")
expected_keyframe = os.path.join(output_dir, "keyframe_000000_000000000.jpg")
def fast_path_with_partial_output(_output_dir, interval_seconds=5.0):
with open(stale_fastframe, "wb") as frame_file:
frame_file.write(b"stale")
raise RuntimeError("simulated fast-path failure")
def ultra_compatible_fallback(self, output_dir_arg, interval_seconds=5.0):
with open(expected_keyframe, "wb") as frame_file:
frame_file.write(b"frame")
return [0]
with patch.object(VideoProcessor, "_extract_frames_fast_path", side_effect=fast_path_with_partial_output) as fast_path, patch.object(
VideoProcessor,
"extract_frames_by_interval_ultra_compatible",
side_effect=ultra_compatible_fallback,
autospec=True,
) as fallback:
result = processor.extract_frames_by_interval_with_fallback(output_dir, interval_seconds=3.0)
self.assertEqual([expected_keyframe], result)
self.assertFalse(os.path.exists(stale_fastframe))
fast_path.assert_called_once_with(output_dir, interval_seconds=3.0)
fallback.assert_called_once_with(processor, output_dir, interval_seconds=3.0)

View File

@ -1,21 +1,32 @@
# 纪录片脚本生成
import os
import asyncio
import json
import time
import asyncio
import traceback
import streamlit as st
from loguru import logger
from datetime import datetime
from app.config import config
from app.utils import utils, video_processor
from webui.tools.base import create_vision_analyzer, get_batch_files, get_batch_timestamps
from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService
def _normalize_progress_value(progress: float | int) -> int:
"""Normalize mixed progress inputs to Streamlit's 0-100 integer range."""
try:
value = float(progress)
except (TypeError, ValueError):
return 0
if 0.0 <= value <= 1.0:
value *= 100
return max(0, min(100, int(round(value))))
def generate_script_docu(params):
"""
生成 纪录片 视频脚本
生成纪录片视频脚本
要求: 原视频无字幕无配音
适合场景: 纪录片动物搞笑解说荒野建造等
"""
@ -23,419 +34,72 @@ def generate_script_docu(params):
status_text = st.empty()
def update_progress(progress: float, message: str = ""):
progress_bar.progress(progress)
normalized_progress = _normalize_progress_value(progress)
progress_bar.progress(normalized_progress)
if message:
status_text.text(f"🎬 {message}")
else:
status_text.text(f"📊 进度: {progress}%")
status_text.text(f"📊 进度: {normalized_progress}%")
try:
with st.spinner("正在生成脚本..."):
if not params.video_origin_path:
st.error("请先选择视频文件")
return
"""
1. 提取键帧
"""
update_progress(10, "正在提取关键帧...")
# 创建临时目录用于存储关键帧
keyframes_dir = os.path.join(utils.temp_dir(), "keyframes")
video_hash = utils.md5(params.video_origin_path + str(os.path.getmtime(params.video_origin_path)))
video_keyframes_dir = os.path.join(keyframes_dir, video_hash)
# 检查是否已经提取过关键帧
keyframe_files = []
if os.path.exists(video_keyframes_dir):
# 取已有的关键帧文件
for filename in sorted(os.listdir(video_keyframes_dir)):
if filename.endswith('.jpg'):
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
if keyframe_files:
logger.info(f"使用已缓存的关键帧: {video_keyframes_dir}")
st.info(f"✅ 使用已缓存关键帧,共 {len(keyframe_files)}")
update_progress(20, f"使用已缓存关键帧,共 {len(keyframe_files)}")
# 如果没有缓存的关键帧,则进行提取
if not keyframe_files:
try:
# 确保目录存在
os.makedirs(video_keyframes_dir, exist_ok=True)
# 初始化视频处理器
processor = video_processor.VideoProcessor(params.video_origin_path)
# 显示视频信息
st.info(f"📹 视频信息: {processor.width}x{processor.height}, {processor.fps:.1f}fps, {processor.duration:.1f}")
# 处理视频并提取关键帧 - 直接使用超级兼容性方案
update_progress(15, "正在提取关键帧(使用超级兼容性方案)...")
try:
# 使用优化的关键帧提取方法
processor.extract_frames_by_interval_ultra_compatible(
output_dir=video_keyframes_dir,
interval_seconds=st.session_state.get('frame_interval_input'),
)
except Exception as extract_error:
logger.error(f"关键帧提取失败: {extract_error}")
# 提供详细的错误信息和解决建议
error_msg = str(extract_error)
if "权限" in error_msg or "permission" in error_msg.lower():
suggestion = "建议:检查输出目录权限,或更换输出位置"
elif "空间" in error_msg or "space" in error_msg.lower():
suggestion = "建议:检查磁盘空间是否足够"
else:
suggestion = "建议:检查视频文件是否损坏,或尝试转换为标准格式"
raise Exception(f"关键帧提取失败: {error_msg}\n{suggestion}")
# 获取所有关键文件路径
for filename in sorted(os.listdir(video_keyframes_dir)):
if filename.endswith('.jpg'):
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
if not keyframe_files:
# 检查目录中是否有其他文件
all_files = os.listdir(video_keyframes_dir)
logger.error(f"关键帧目录内容: {all_files}")
raise Exception("未提取到任何关键帧文件,请检查视频文件格式")
update_progress(20, f"关键帧提取完成,共 {len(keyframe_files)}")
st.success(f"✅ 成功提取 {len(keyframe_files)} 个关键帧")
except Exception as e:
# 如果提取失败,清理创建的目录
try:
if os.path.exists(video_keyframes_dir):
import shutil
shutil.rmtree(video_keyframes_dir)
except Exception as cleanup_err:
logger.error(f"清理失败的关键帧目录时出错: {cleanup_err}")
raise Exception(f"关键帧提取失败: {str(e)}")
"""
2. 视觉分析(批量分析每一帧)
"""
# 最佳实践:使用 get() 的默认值参数 + 从 config 获取备用值
vision_llm_provider = (
st.session_state.get('vision_llm_provider') or
config.app.get('vision_llm_provider', 'openai')
st.session_state.get("vision_llm_provider") or config.app.get("vision_llm_provider", "openai")
).lower()
logger.info(f"使用 {vision_llm_provider.upper()} 进行视觉分析")
try:
# ===================初始化视觉分析器===================
update_progress(30, "正在初始化视觉分析器...")
# 使用统一的配置键格式获取配置(支持所有 provider
vision_api_key = (
st.session_state.get(f'vision_{vision_llm_provider}_api_key') or
config.app.get(f'vision_{vision_llm_provider}_api_key')
)
vision_model = (
st.session_state.get(f'vision_{vision_llm_provider}_model_name') or
config.app.get(f'vision_{vision_llm_provider}_model_name')
)
vision_base_url = (
st.session_state.get(f'vision_{vision_llm_provider}_base_url') or
config.app.get(f'vision_{vision_llm_provider}_base_url', '')
vision_api_key = (
st.session_state.get(f"vision_{vision_llm_provider}_api_key")
or config.app.get(f"vision_{vision_llm_provider}_api_key")
)
vision_model = (
st.session_state.get(f"vision_{vision_llm_provider}_model_name")
or config.app.get(f"vision_{vision_llm_provider}_model_name")
)
vision_base_url = (
st.session_state.get(f"vision_{vision_llm_provider}_base_url")
or config.app.get(f"vision_{vision_llm_provider}_base_url", "")
)
if not vision_api_key or not vision_model:
raise ValueError(
f"未配置 {vision_llm_provider} 的 API Key 或模型名称。"
f"请在设置页面配置 vision_{vision_llm_provider}_api_key 和 vision_{vision_llm_provider}_model_name"
)
# 验证必需配置
if not vision_api_key or not vision_model:
raise ValueError(
f"未配置 {vision_llm_provider} 的 API Key 或模型名称。"
f"请在设置页面配置 vision_{vision_llm_provider}_api_key 和 vision_{vision_llm_provider}_model_name"
)
frame_interval_input = st.session_state.get("frame_interval_input") or config.frames.get(
"frame_interval_input", 3
)
vision_batch_size = st.session_state.get("vision_batch_size") or config.frames.get("vision_batch_size", 10)
vision_max_concurrency = st.session_state.get("vision_max_concurrency") or config.frames.get(
"vision_max_concurrency", 2
)
# 创建视觉分析器实例(使用统一接口)
llm_params = {
"vision_provider": vision_llm_provider,
"vision_api_key": vision_api_key,
"vision_model_name": vision_model,
"vision_base_url": vision_base_url,
}
logger.debug(f"视觉分析器配置: provider={vision_llm_provider}, model={vision_model}")
analyzer = create_vision_analyzer(
provider=vision_llm_provider,
api_key=vision_api_key,
model=vision_model,
base_url=vision_base_url
update_progress(10, "正在提取关键帧...")
service = DocumentaryFrameAnalysisService()
script_items = asyncio.run(
service.generate_documentary_script(
video_path=params.video_origin_path,
video_theme=st.session_state.get("video_theme", ""),
custom_prompt=st.session_state.get("custom_prompt", ""),
frame_interval_input=frame_interval_input,
vision_batch_size=vision_batch_size,
vision_llm_provider=vision_llm_provider,
progress_callback=update_progress,
vision_api_key=vision_api_key,
vision_model_name=vision_model,
vision_base_url=vision_base_url,
max_concurrency=vision_max_concurrency,
)
)
update_progress(40, "正在分析关键帧...")
# ===================创建异步事件循环===================
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 执行异步分析
vision_batch_size = st.session_state.get('vision_batch_size') or config.frames.get("vision_batch_size")
vision_analysis_prompt = """
我提供了 %s 张视频帧它们按时间顺序排列代表一个连续的视频片段请仔细分析每一帧的内容并关注帧与帧之间的变化以理解整个片段的活动
首先请详细描述每一帧的关键视觉信息包含主要内容人物动作和场景
然后基于所有帧的分析请用**简洁的语言**总结整个视频片段中发生的主要活动或事件流程
请务必使用 JSON 格式输出你的结果JSON 结构应如下
{
"frame_observations": [
{
"frame_number": 1, // 或其他标识帧的方式
"observation": "描述每张视频帧中的主要内容、人物、动作和场景。"
},
// ... 更多帧的观察 ...
],
"overall_activity_summary": "在这里填写你总结的整个片段的主要活动,保持简洁。"
}
请务必不要遗漏视频帧我提供了 %s 张视频帧frame_observations 必须包含 %s 个元素
请只返回 JSON 字符串不要包含任何其他解释性文字
"""
results = loop.run_until_complete(
analyzer.analyze_images(
images=keyframe_files,
prompt=vision_analysis_prompt,
batch_size=vision_batch_size
)
)
loop.close()
"""
3. 处理分析结果格式化为 json 数据
"""
# ===================处理分析结果===================
update_progress(60, "正在整理分析结果...")
# 合并所有批次的分析结果
frame_analysis = ""
merged_frame_observations = [] # 合并所有批次的帧观察
overall_activity_summaries = [] # 合并所有批次的整体总结
prev_batch_files = None
frame_counter = 1 # 初始化帧计数器,用于给所有帧分配连续的序号
# 确保分析目录存在
analysis_dir = os.path.join(utils.storage_dir(), "temp", "analysis")
os.makedirs(analysis_dir, exist_ok=True)
origin_res = os.path.join(analysis_dir, "frame_analysis.json")
with open(origin_res, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
# 开始处理
for result in results:
if 'error' in result:
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
continue
# 获取当前批次的文件列表
batch_files = get_batch_files(keyframe_files, result, vision_batch_size)
# 获取批次的时间戳范围
first_timestamp, last_timestamp, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files)
# 解析响应中的JSON数据
response_text = result['response']
try:
# 处理可能包含```json```格式的响应
if "```json" in response_text:
json_content = response_text.split("```json")[1].split("```")[0].strip()
elif "```" in response_text:
json_content = response_text.split("```")[1].split("```")[0].strip()
else:
json_content = response_text.strip()
response_data = json.loads(json_content)
# 提取frame_observations和overall_activity_summary
if "frame_observations" in response_data:
frame_obs = response_data["frame_observations"]
overall_summary = response_data.get("overall_activity_summary", "")
# 添加时间戳信息到每个帧观察
for i, obs in enumerate(frame_obs):
if i < len(batch_files):
# 从文件名中提取时间戳
file_path = batch_files[i]
file_name = os.path.basename(file_path)
# 提取时间戳字符串 (格式如: keyframe_000675_000027000.jpg)
# 格式解析: keyframe_帧序号_毫秒时间戳.jpg
timestamp_parts = file_name.split('_')
if len(timestamp_parts) >= 3:
timestamp_str = timestamp_parts[-1].split('.')[0]
try:
# 修正时间戳解析逻辑
# 格式为000100000表示00:01:00,000即1分钟
# 需要按照对应位数进行解析:
# 前两位是小时,中间两位是分钟,后面是秒和毫秒
if len(timestamp_str) >= 9: # 确保格式正确
hours = int(timestamp_str[0:2])
minutes = int(timestamp_str[2:4])
seconds = int(timestamp_str[4:6])
milliseconds = int(timestamp_str[6:9])
# 计算总秒数
timestamp_seconds = hours * 3600 + minutes * 60 + seconds + milliseconds / 1000
formatted_time = utils.format_time(timestamp_seconds) # 格式化时间戳
else:
# 兼容旧的解析方式
timestamp_seconds = int(timestamp_str) / 1000 # 转换为秒
formatted_time = utils.format_time(timestamp_seconds) # 格式化时间戳
except ValueError:
logger.warning(f"无法解析时间戳: {timestamp_str}")
timestamp_seconds = 0
formatted_time = "00:00:00,000"
else:
logger.warning(f"文件名格式不符合预期: {file_name}")
timestamp_seconds = 0
formatted_time = "00:00:00,000"
# 添加额外信息到帧观察
obs["frame_path"] = file_path
obs["timestamp"] = formatted_time
obs["timestamp_seconds"] = timestamp_seconds
obs["batch_index"] = result['batch_index']
# 使用全局递增的帧计数器替换原始的frame_number
if "frame_number" in obs:
obs["original_frame_number"] = obs["frame_number"] # 保留原始编号作为参考
obs["frame_number"] = frame_counter # 赋值连续的帧编号
frame_counter += 1 # 增加帧计数器
# 添加到合并列表
merged_frame_observations.append(obs)
# 添加批次整体总结信息
if overall_summary:
# 从文件名中提取时间戳数值
first_time_str = first_timestamp.split('_')[-1].split('.')[0]
last_time_str = last_timestamp.split('_')[-1].split('.')[0]
# 转换为毫秒并计算持续时间(秒)
try:
# 修正解析逻辑,与上面相同的方式解析时间戳
if len(first_time_str) >= 9 and len(last_time_str) >= 9:
# 解析第一个时间戳
first_hours = int(first_time_str[0:2])
first_minutes = int(first_time_str[2:4])
first_seconds = int(first_time_str[4:6])
first_ms = int(first_time_str[6:9])
first_time_seconds = first_hours * 3600 + first_minutes * 60 + first_seconds + first_ms / 1000
# 解析第二个时间戳
last_hours = int(last_time_str[0:2])
last_minutes = int(last_time_str[2:4])
last_seconds = int(last_time_str[4:6])
last_ms = int(last_time_str[6:9])
last_time_seconds = last_hours * 3600 + last_minutes * 60 + last_seconds + last_ms / 1000
batch_duration = last_time_seconds - first_time_seconds
else:
# 兼容旧的解析方式
first_time_ms = int(first_time_str)
last_time_ms = int(last_time_str)
batch_duration = (last_time_ms - first_time_ms) / 1000
except ValueError:
# 使用 utils.time_to_seconds 函数处理格式化的时间戳
first_time_seconds = utils.time_to_seconds(first_time_str.replace('_', ':').replace('-', ','))
last_time_seconds = utils.time_to_seconds(last_time_str.replace('_', ':').replace('-', ','))
batch_duration = last_time_seconds - first_time_seconds
overall_activity_summaries.append({
"batch_index": result['batch_index'],
"time_range": f"{first_timestamp}-{last_timestamp}",
"duration_seconds": batch_duration,
"summary": overall_summary
})
except Exception as e:
logger.error(f"解析批次 {result['batch_index']} 的响应数据失败: {str(e)}")
# 添加原始响应作为回退
frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n"
frame_analysis += response_text
frame_analysis += "\n"
# 更新上一个批次的文件
prev_batch_files = batch_files
# 将合并后的结果转为JSON字符串
merged_results = {
"frame_observations": merged_frame_observations,
"overall_activity_summaries": overall_activity_summaries
}
# 使用当前时间创建文件名
now = datetime.now()
timestamp_str = now.strftime("%Y%m%d_%H%M")
# 保存完整的分析结果为JSON
analysis_filename = f"frame_analysis_{timestamp_str}.json"
analysis_json_path = os.path.join(analysis_dir, analysis_filename)
with open(analysis_json_path, 'w', encoding='utf-8') as f:
json.dump(merged_results, f, ensure_ascii=False, indent=2)
logger.info(f"分析结果已保存到: {analysis_json_path}")
"""
4. 生成文案
"""
logger.info("开始生成解说文案")
update_progress(80, "正在生成解说文案...")
from app.services.generate_narration_script import parse_frame_analysis_to_markdown, generate_narration
# 从配置中获取文本生成相关配置
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
text_api_key = config.app.get(f'text_{text_provider}_api_key')
text_model = config.app.get(f'text_{text_provider}_model_name')
text_base_url = config.app.get(f'text_{text_provider}_base_url')
llm_params.update({
"text_provider": text_provider,
"text_api_key": text_api_key,
"text_model_name": text_model,
"text_base_url": text_base_url
})
# 整理帧分析数据
markdown_output = parse_frame_analysis_to_markdown(analysis_json_path)
# 生成解说文案
narration = generate_narration(
markdown_output,
text_api_key,
base_url=text_base_url,
model=text_model
)
# 使用增强的JSON解析器
from webui.tools.generate_short_summary import parse_and_fix_json
narration_data = parse_and_fix_json(narration)
if not narration_data or 'items' not in narration_data:
logger.error(f"解说文案JSON解析失败原始内容: {narration[:200]}...")
raise Exception("解说文案格式错误无法解析JSON或缺少items字段")
narration_dict = narration_data['items']
# 为 narration_dict 中每个 item 新增一个 OST: 2 的字段, 代表保留原声和配音
narration_dict = [{**item, "OST": 2} for item in narration_dict]
logger.info(f"解说文案生成完成,共 {len(narration_dict)} 个片段")
# 结果转换为JSON字符串
script = json.dumps(narration_dict, ensure_ascii=False, indent=2)
except Exception as e:
logger.exception(f"大模型处理过程中发生错误\n{traceback.format_exc()}")
raise Exception(f"分析失败: {str(e)}")
if script is None:
st.error("生成脚本失败,请检查日志")
st.stop()
logger.info(f"纪录片解说脚本生成完成")
logger.info(f"纪录片解说脚本生成完成,共 {len(script_items)} 个片段")
script = json.dumps(script_items, ensure_ascii=False, indent=2)
if isinstance(script, list):
st.session_state['video_clip_json'] = script
st.session_state["video_clip_json"] = script
elif isinstance(script, str):
st.session_state['video_clip_json'] = json.loads(script)
st.session_state["video_clip_json"] = json.loads(script)
update_progress(100, "脚本生成完成")
time.sleep(0.1)