refactor(documentary): route adapters through shared analysis service

This commit is contained in:
linyq 2026-04-03 02:24:30 +08:00
parent df034d104b
commit ac63fea953
7 changed files with 747 additions and 720 deletions

2
.gitignore vendored
View File

@ -48,3 +48,5 @@ CLAUDE.md
tests/*
!tests/test_documentary_frame_analysis_service.py
!tests/test_video_processor_documentary_unittest.py
!tests/test_script_service_documentary_unittest.py
!tests/test_generate_narration_script_documentary_unittest.py

View File

@ -1,9 +1,16 @@
import asyncio
import json
import os
import re
from datetime import datetime
from typing import Any, Callable
from app.utils import utils
from loguru import logger
from app.config import config
from app.services.documentary.frame_analysis_models import FrameBatchResult
from app.services.llm.migration_adapter import create_vision_analyzer
from app.utils import utils, video_processor
class DocumentaryFrameAnalysisService:
@ -23,8 +30,488 @@ JSON 必须包含以下键:
"overall_activity_summary": "本批次主要活动总结"
}}
请务必不要遗漏视频帧我提供了 {frame_count} 张视频帧frame_observations 必须包含 {frame_count} 个元素
请只返回 JSON 字符串不要附加解释文字
""".strip()
async def generate_documentary_script(
self,
*,
video_path: str,
video_theme: str = "",
custom_prompt: str = "",
frame_interval_input: int | float = 3,
vision_batch_size: int = 10,
vision_llm_provider: str = "openai",
progress_callback: Callable[[float, str], None] | None = None,
vision_api_key: str | None = None,
vision_model_name: str | None = None,
vision_base_url: str | None = None,
max_concurrency: int | None = None,
) -> list[dict]:
analysis_result = await self.analyze_video(
video_path=video_path,
video_theme=video_theme,
custom_prompt=custom_prompt,
frame_interval_input=frame_interval_input,
vision_batch_size=vision_batch_size,
vision_llm_provider=vision_llm_provider,
progress_callback=progress_callback,
vision_api_key=vision_api_key,
vision_model_name=vision_model_name,
vision_base_url=vision_base_url,
max_concurrency=max_concurrency,
)
return analysis_result["video_clip_json"]
async def analyze_video(
self,
*,
video_path: str,
video_theme: str = "",
custom_prompt: str = "",
frame_interval_input: int | float = 3,
vision_batch_size: int = 10,
vision_llm_provider: str = "openai",
progress_callback: Callable[[float, str], None] | None = None,
vision_api_key: str | None = None,
vision_model_name: str | None = None,
vision_base_url: str | None = None,
max_concurrency: int | None = None,
) -> dict[str, Any]:
progress = progress_callback or (lambda _p, _m: None)
if not video_path or not os.path.exists(video_path):
raise FileNotFoundError(f"视频文件不存在: {video_path}")
frame_interval_seconds = self._resolve_frame_interval(frame_interval_input)
batch_size = self._resolve_batch_size(vision_batch_size)
concurrency = self._resolve_max_concurrency(max_concurrency)
provider = (vision_llm_provider or config.app.get("vision_llm_provider", "openai")).lower()
api_key = vision_api_key or config.app.get(f"vision_{provider}_api_key")
model_name = vision_model_name or config.app.get(f"vision_{provider}_model_name")
base_url = vision_base_url or config.app.get(f"vision_{provider}_base_url", "")
if not api_key or not model_name:
raise ValueError(
f"未配置 {provider} 的 API Key 或模型名称。"
f"请在设置中配置 vision_{provider}_api_key 和 vision_{provider}_model_name"
)
progress(10, "正在提取关键帧...")
keyframe_files = self._load_or_extract_keyframes(video_path, frame_interval_seconds)
progress(25, f"关键帧准备完成,共 {len(keyframe_files)}")
progress(30, "正在初始化视觉分析器...")
analyzer = create_vision_analyzer(
provider=provider,
api_key=api_key,
model=model_name,
base_url=base_url,
)
batches = self._chunk_keyframes(keyframe_files, batch_size=batch_size)
if not batches:
raise RuntimeError("未能构建任何关键帧批次")
progress(40, f"正在分析关键帧,共 {len(batches)} 个批次...")
batch_results = await self._analyze_batches(
analyzer=analyzer,
batches=batches,
custom_prompt=custom_prompt,
video_theme=video_theme,
max_concurrency=concurrency,
progress_callback=progress,
)
progress(65, "正在整理分析结果...")
sorted_batches = self._sort_batch_results(batch_results)
artifact = self._build_analysis_artifact(
sorted_batches,
video_path=video_path,
frame_interval_seconds=frame_interval_seconds,
vision_batch_size=batch_size,
vision_llm_provider=provider,
vision_model_name=model_name,
max_concurrency=concurrency,
)
analysis_json_path = self._save_analysis_artifact(artifact)
video_clip_json = self._build_video_clip_json(sorted_batches)
progress(75, "逐帧分析完成")
return {
"analysis_json_path": analysis_json_path,
"analysis_artifact": artifact,
"video_clip_json": video_clip_json,
"keyframe_files": keyframe_files,
}
def _resolve_frame_interval(self, frame_interval_input: int | float | None) -> float:
interval = frame_interval_input
if interval in (None, ""):
interval = config.frames.get("frame_interval_input", 3)
try:
value = float(interval)
except (TypeError, ValueError):
value = 3.0
if value <= 0:
raise ValueError("frame_interval_input must be > 0")
return value
def _resolve_batch_size(self, vision_batch_size: int | None) -> int:
size = vision_batch_size or config.frames.get("vision_batch_size", 10)
try:
value = int(size)
except (TypeError, ValueError):
value = 10
if value <= 0:
raise ValueError("vision_batch_size must be > 0")
return value
def _resolve_max_concurrency(self, max_concurrency: int | None) -> int:
value = max_concurrency if max_concurrency is not None else config.frames.get("vision_max_concurrency", 2)
try:
parsed = int(value)
except (TypeError, ValueError):
parsed = 1
return max(1, parsed)
def _load_or_extract_keyframes(self, video_path: str, frame_interval_seconds: float) -> list[str]:
keyframes_root = os.path.join(utils.temp_dir(), "keyframes")
os.makedirs(keyframes_root, exist_ok=True)
cache_key = self._build_keyframe_cache_key(video_path, frame_interval_seconds)
cache_dir = os.path.join(keyframes_root, cache_key)
os.makedirs(cache_dir, exist_ok=True)
cached_files = self._collect_keyframe_paths(cache_dir)
if cached_files:
logger.info(f"使用已缓存关键帧: {cache_dir}, 共 {len(cached_files)}")
return cached_files
processor = video_processor.VideoProcessor(video_path)
extracted = processor.extract_frames_by_interval_with_fallback(
output_dir=cache_dir,
interval_seconds=frame_interval_seconds,
)
keyframe_files = sorted(str(path) for path in extracted if str(path).endswith(".jpg"))
if not keyframe_files:
keyframe_files = self._collect_keyframe_paths(cache_dir)
if not keyframe_files:
raise RuntimeError("未提取到任何关键帧")
logger.info(f"关键帧提取完成: {cache_dir}, 共 {len(keyframe_files)}")
return keyframe_files
def _build_keyframe_cache_key(self, video_path: str, frame_interval_seconds: float) -> str:
try:
video_mtime = os.path.getmtime(video_path)
except OSError:
video_mtime = 0
legacy_prefix = utils.md5(f"{video_path}{video_mtime}")
payload = "|".join(
[
str(video_path),
str(video_mtime),
str(frame_interval_seconds),
"documentary-keyframes-v2",
]
)
return f"{legacy_prefix}_{utils.md5(payload)}"
@staticmethod
def _collect_keyframe_paths(cache_dir: str) -> list[str]:
if not os.path.exists(cache_dir):
return []
return sorted(
os.path.join(cache_dir, name)
for name in os.listdir(cache_dir)
if re.fullmatch(r"keyframe_\d{6}_\d{9}\.jpg", name)
)
@staticmethod
def _chunk_keyframes(keyframe_files: list[str], batch_size: int) -> list[list[str]]:
return [keyframe_files[index : index + batch_size] for index in range(0, len(keyframe_files), batch_size)]
async def _analyze_batches(
self,
*,
analyzer: Any,
batches: list[list[str]],
custom_prompt: str,
video_theme: str,
max_concurrency: int,
progress_callback: Callable[[float, str], None],
) -> list[FrameBatchResult]:
semaphore = asyncio.Semaphore(max(1, max_concurrency))
total = len(batches)
done = 0
done_lock = asyncio.Lock()
batch_time_ranges: list[str] = []
previous_batch_files: list[str] | None = None
for batch_files in batches:
_, _, time_range = self._get_batch_timestamps(batch_files, previous_batch_files)
batch_time_ranges.append(time_range)
previous_batch_files = batch_files
async def run_single(batch_index: int, frame_paths: list[str], time_range: str) -> FrameBatchResult:
nonlocal done
prompt = self._build_batch_prompt(
frame_count=len(frame_paths),
video_theme=video_theme,
custom_prompt=custom_prompt,
)
try:
async with semaphore:
raw_results = await analyzer.analyze_images(
images=frame_paths,
prompt=prompt,
batch_size=max(1, len(frame_paths)),
max_concurrency=1,
)
raw_response, error_message = self._extract_batch_response(raw_results)
if error_message:
return self._build_failed_batch_result(
batch_index=batch_index,
raw_response=raw_response,
error_message=error_message,
frame_paths=frame_paths,
time_range=time_range,
)
return self._parse_batch_response(
batch_index=batch_index,
raw_response=raw_response,
frame_paths=frame_paths,
time_range=time_range,
)
except Exception as exc:
return self._build_failed_batch_result(
batch_index=batch_index,
raw_response="",
error_message=str(exc),
frame_paths=frame_paths,
time_range=time_range,
)
finally:
async with done_lock:
done += 1
progress = 40 + (done / max(1, total)) * 25
progress_callback(progress, f"正在分析关键帧批次 ({done}/{total})...")
tasks = [
run_single(batch_index=index, frame_paths=batch_files, time_range=batch_time_ranges[index])
for index, batch_files in enumerate(batches)
]
return await asyncio.gather(*tasks)
def _build_batch_prompt(self, *, frame_count: int, video_theme: str, custom_prompt: str) -> str:
prompt = self._build_analysis_prompt(frame_count=frame_count)
extra_lines: list[str] = []
if (video_theme or "").strip():
extra_lines.append(f"视频主题:{video_theme.strip()}")
if (custom_prompt or "").strip():
extra_lines.append(custom_prompt.strip())
if not extra_lines:
return prompt
extras = "\n".join(f"- {line}" for line in extra_lines)
return f"{prompt}\n\n补充分析要求:\n{extras}"
def _extract_batch_response(self, raw_results: Any) -> tuple[str, str]:
if not raw_results:
return "", "Batch response is empty"
first_result = raw_results[0] if isinstance(raw_results, list) else raw_results
if isinstance(first_result, dict):
raw_response = str(first_result.get("response", "") or "")
error_message = str(first_result.get("error", "") or "")
if error_message:
if not raw_response:
raw_response = error_message
return raw_response, error_message
if not raw_response.strip():
return raw_response, "Batch response is empty"
return raw_response, ""
raw_response = str(first_result or "")
if not raw_response.strip():
return raw_response, "Batch response is empty"
return raw_response, ""
def _sort_batch_results(self, batch_results: list[FrameBatchResult]) -> list[FrameBatchResult]:
return sorted(batch_results, key=lambda item: (self._time_range_sort_key(item.time_range), item.batch_index))
def _build_analysis_artifact(
self,
batch_results: list[FrameBatchResult],
*,
video_path: str,
frame_interval_seconds: float,
vision_batch_size: int,
vision_llm_provider: str,
vision_model_name: str,
max_concurrency: int,
) -> dict[str, Any]:
sorted_batches = self._sort_batch_results(batch_results)
batch_dicts: list[dict[str, Any]] = []
frame_observations: list[dict[str, Any]] = []
overall_activity_summaries: list[dict[str, Any]] = []
for batch in sorted_batches:
batch_payload = {
"batch_index": batch.batch_index,
"status": batch.status,
"time_range": batch.time_range,
"raw_response": batch.raw_response,
"frame_paths": list(batch.frame_paths),
"frame_observations": list(batch.frame_observations),
"overall_activity_summary": batch.overall_activity_summary,
"fallback_summary": batch.fallback_summary,
"error_message": batch.error_message,
}
batch_dicts.append(batch_payload)
for observation in batch.frame_observations:
observation_payload = dict(observation)
observation_payload["batch_index"] = batch.batch_index
observation_payload["time_range"] = batch.time_range
frame_observations.append(observation_payload)
summary_text = (batch.overall_activity_summary or batch.fallback_summary or "").strip()
if summary_text:
overall_activity_summaries.append(
{
"batch_index": batch.batch_index,
"time_range": batch.time_range,
"summary": summary_text,
}
)
return {
"artifact_version": "documentary-frame-analysis-v2",
"generated_at": datetime.now().isoformat(),
"video_path": video_path,
"frame_interval_seconds": frame_interval_seconds,
"vision_batch_size": vision_batch_size,
"vision_llm_provider": vision_llm_provider,
"vision_model_name": vision_model_name,
"vision_max_concurrency": max_concurrency,
"batches": batch_dicts,
# 向后兼容旧解析器结构
"frame_observations": frame_observations,
"overall_activity_summaries": overall_activity_summaries,
}
def _save_analysis_artifact(self, artifact: dict[str, Any]) -> str:
analysis_dir = os.path.join(utils.storage_dir(), "temp", "analysis")
os.makedirs(analysis_dir, exist_ok=True)
filename = f"frame_analysis_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
file_path = os.path.join(analysis_dir, filename)
suffix = 1
while os.path.exists(file_path):
filename = f"frame_analysis_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{suffix:02d}.json"
file_path = os.path.join(analysis_dir, filename)
suffix += 1
with open(file_path, "w", encoding="utf-8") as fp:
json.dump(artifact, fp, ensure_ascii=False, indent=2)
logger.info(f"分析结果已保存到: {file_path}")
return file_path
def _build_video_clip_json(self, batch_results: list[FrameBatchResult]) -> list[dict]:
clips: list[dict] = []
for batch in self._sort_batch_results(batch_results):
picture = self._build_batch_picture(batch)
clips.append(
{
"timestamp": batch.time_range,
"picture": picture,
"narration": "",
"OST": 2,
}
)
return clips
def _build_batch_picture(self, batch: FrameBatchResult) -> str:
summary = (batch.overall_activity_summary or "").strip()
if summary:
return summary
fallback = (batch.fallback_summary or "").strip()
if fallback:
return fallback
observation_lines = []
for frame in batch.frame_observations:
timestamp = str(frame.get("timestamp", "") or "").strip()
observation = str(frame.get("observation", "") or "").strip()
if timestamp and observation:
observation_lines.append(f"{timestamp}: {observation}")
elif observation:
observation_lines.append(observation)
if observation_lines:
return " ".join(observation_lines)
raw_response = (batch.raw_response or "").strip()
if raw_response:
return raw_response[:200]
return "该批次分析失败,未返回可用描述。"
def _time_range_sort_key(self, time_range: str) -> tuple[int, str]:
start = (time_range or "").split("-", 1)[0].strip()
return self._timestamp_to_milliseconds(start), time_range
@staticmethod
def _timestamp_to_milliseconds(timestamp: str) -> int:
text = (timestamp or "").strip()
try:
if "," in text:
time_part, ms_part = text.split(",", 1)
milliseconds = int(ms_part)
else:
time_part = text
milliseconds = 0
parts = [int(part) for part in time_part.split(":") if part]
while len(parts) < 3:
parts.insert(0, 0)
hours, minutes, seconds = parts[-3], parts[-2], parts[-1]
return ((hours * 3600 + minutes * 60 + seconds) * 1000) + milliseconds
except Exception:
return 0
def _get_batch_timestamps(
self,
batch_files: list[str],
prev_batch_files: list[str] | None = None,
) -> tuple[str, str, str]:
if not batch_files:
return "00:00:00,000", "00:00:00,000", "00:00:00,000-00:00:00,000"
if len(batch_files) == 1 and prev_batch_files:
first_frame = os.path.basename(prev_batch_files[-1])
last_frame = os.path.basename(batch_files[0])
else:
first_frame = os.path.basename(batch_files[0])
last_frame = os.path.basename(batch_files[-1])
first_timestamp = self._timestamp_from_keyframe_name(first_frame)
last_timestamp = self._timestamp_from_keyframe_name(last_frame)
return first_timestamp, last_timestamp, f"{first_timestamp}-{last_timestamp}"
def _timestamp_from_keyframe_name(self, filename: str) -> str:
match = re.search(r"keyframe_\d{6}_(\d{9})\.jpg$", filename)
if not match:
return "00:00:00,000"
token = match.group(1)
hours = int(token[0:2])
minutes = int(token[2:4])
seconds = int(token[4:6])
milliseconds = int(token[6:9])
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
def _build_analysis_prompt(self, frame_count: int) -> str:
return self.PROMPT_TEMPLATE.format(frame_count=frame_count)

View File

@ -38,46 +38,90 @@ def parse_frame_analysis_to_markdown(json_file_path):
with open(json_file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
# 初始化Markdown字符串
def time_to_milliseconds(time_text):
time_text = (time_text or "").strip()
if not time_text:
return 0
try:
if "," in time_text:
hhmmss, ms = time_text.split(",", 1)
milliseconds = int(ms)
else:
hhmmss = time_text
milliseconds = 0
parts = [int(part) for part in hhmmss.split(":") if part]
while len(parts) < 3:
parts.insert(0, 0)
hours, minutes, seconds = parts[-3], parts[-2], parts[-1]
return ((hours * 3600 + minutes * 60 + seconds) * 1000) + milliseconds
except Exception:
return 0
def batch_sort_key(batch):
time_range = batch.get("time_range", "")
start = time_range.split("-", 1)[0].strip()
return time_to_milliseconds(start), batch.get("batch_index", 0)
markdown = ""
# 获取总结和帧观察数据
# 新结构:按批次保存完整分析产物
if isinstance(data.get("batches"), list):
ordered_batches = sorted(data.get("batches", []), key=batch_sort_key)
for i, batch in enumerate(ordered_batches, 1):
time_range = batch.get("time_range", "")
summary = (
batch.get("overall_activity_summary")
or batch.get("summary")
or batch.get("fallback_summary")
or ""
)
observations = batch.get("frame_observations") or batch.get("observations") or []
markdown += f"## 片段 {i}\n"
markdown += f"- 时间范围:{time_range}\n"
markdown += f"- 片段描述:{summary}\n" if summary else "- 片段描述:\n"
markdown += "- 详细描述:\n"
for frame in observations:
timestamp = frame.get("timestamp", "")
observation = frame.get("observation", "")
markdown += f" - {timestamp}: {observation}\n" if observation else f" - {timestamp}: \n"
markdown += "\n"
return markdown
# 兼容旧结构
summaries = data.get('overall_activity_summaries', [])
frame_observations = data.get('frame_observations', [])
# 按批次组织数据
batch_frames = {}
for frame in frame_observations:
batch_index = frame.get('batch_index')
if batch_index not in batch_frames:
batch_frames[batch_index] = []
batch_frames[batch_index].append(frame)
# 生成Markdown内容
for i, summary in enumerate(summaries, 1):
batch_index = summary.get('batch_index')
time_range = summary.get('time_range', '')
batch_summary = summary.get('summary', '')
markdown += f"## 片段 {i}\n"
markdown += f"- 时间范围:{time_range}\n"
# 添加片段描述
markdown += f"- 片段描述:{batch_summary}\n" if batch_summary else f"- 片段描述:\n"
markdown += "- 详细描述:\n"
# 添加该批次的帧观察详情
frames = batch_frames.get(batch_index, [])
for frame in frames:
timestamp = frame.get('timestamp', '')
observation = frame.get('observation', '')
# 直接使用原始文本,不进行分割
markdown += f" - {timestamp}: {observation}\n" if observation else f" - {timestamp}: \n"
markdown += "\n"
return markdown
except Exception as e:

View File

@ -1,22 +1,12 @@
import os
import json
import time
import asyncio
import requests
from app.utils import video_processor
from loguru import logger
from typing import List, Dict, Any, Callable
from typing import Any, Callable
from app.utils import utils, gemini_analyzer, video_processor
from app.utils.script_generator import ScriptProcessor
from app.config import config
from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService
class ScriptGenerator:
def __init__(self):
self.temp_dir = utils.temp_dir()
self.keyframes_dir = os.path.join(self.temp_dir, "keyframes")
def __init__(self, documentary_service: DocumentaryFrameAnalysisService | None = None):
self.documentary_service = documentary_service or DocumentaryFrameAnalysisService()
async def generate_script(
self,
video_path: str,
@ -27,298 +17,18 @@ class ScriptGenerator:
threshold: int = 30,
vision_batch_size: int = 5,
vision_llm_provider: str = "gemini",
progress_callback: Callable[[float, str], None] = None
) -> List[Dict[Any, Any]]:
"""
生成视频脚本的核心逻辑
Args:
video_path: 视频文件路径
video_theme: 视频主题
custom_prompt: 自定义提示词
skip_seconds: 跳过开始的秒数
threshold: 差异<EFBFBD><EFBFBD><EFBFBD>
vision_batch_size: 视觉处理批次大小
vision_llm_provider: 视觉模型提供商
progress_callback: 进度回调函数
Returns:
List[Dict]: 生成的视频脚本
"""
if progress_callback is None:
progress_callback = lambda p, m: None
try:
# 提取关键帧
progress_callback(10, "正在提取关键帧...")
keyframe_files = await self._extract_keyframes(
video_path,
skip_seconds,
threshold
)
# 使用统一的 LLM 接口(支持所有 provider
script = await self._process_with_llm(
keyframe_files,
video_theme,
custom_prompt,
vision_batch_size,
vision_llm_provider,
progress_callback
)
return json.loads(script) if isinstance(script, str) else script
except Exception as e:
logger.exception("Generate script failed")
raise
async def _extract_keyframes(
self,
video_path: str,
skip_seconds: int,
threshold: int
) -> List[str]:
"""提取视频关键帧"""
video_hash = utils.md5(video_path + str(os.path.getmtime(video_path)))
video_keyframes_dir = os.path.join(self.keyframes_dir, video_hash)
# 检查缓存
keyframe_files = []
if os.path.exists(video_keyframes_dir):
for filename in sorted(os.listdir(video_keyframes_dir)):
if filename.endswith('.jpg'):
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
if keyframe_files:
logger.info(f"Using cached keyframes: {video_keyframes_dir}")
return keyframe_files
# 提取新的关键帧
os.makedirs(video_keyframes_dir, exist_ok=True)
try:
processor = video_processor.VideoProcessor(video_path)
processor.process_video_pipeline(
output_dir=video_keyframes_dir,
skip_seconds=skip_seconds,
threshold=threshold
)
for filename in sorted(os.listdir(video_keyframes_dir)):
if filename.endswith('.jpg'):
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
return keyframe_files
except Exception as e:
if os.path.exists(video_keyframes_dir):
import shutil
shutil.rmtree(video_keyframes_dir)
raise
async def _process_with_llm(
self,
keyframe_files: List[str],
video_theme: str,
custom_prompt: str,
vision_batch_size: int,
vision_llm_provider: str,
progress_callback: Callable[[float, str], None]
) -> str:
"""使用统一 LLM 接口处理视频帧"""
progress_callback(30, "正在初始化视觉分析器...")
# 使用新的 LLM 迁移适配器(支持所有 provider
from app.services.llm.migration_adapter import create_vision_analyzer
# 获取配置
text_provider = config.app.get('text_llm_provider', 'openai').lower()
vision_api_key = config.app.get(f'vision_{vision_llm_provider}_api_key')
vision_model = config.app.get(f'vision_{vision_llm_provider}_model_name')
vision_base_url = config.app.get(f'vision_{vision_llm_provider}_base_url')
if not vision_api_key or not vision_model:
raise ValueError(f"未配置 {vision_llm_provider} API Key 或者模型")
# 创建统一的视觉分析器
analyzer = create_vision_analyzer(
provider=vision_llm_provider,
api_key=vision_api_key,
model=vision_model,
base_url=vision_base_url
progress_callback: Callable[[float, str], None] | None = None,
) -> list[dict[Any, Any]]:
callback = progress_callback or (lambda _p, _m: None)
return await self.documentary_service.generate_documentary_script(
video_path=video_path,
video_theme=video_theme,
custom_prompt=custom_prompt,
frame_interval_input=frame_interval_input,
vision_batch_size=vision_batch_size,
vision_llm_provider=vision_llm_provider,
progress_callback=callback,
# 历史参数保留在签名中以兼容调用方;共享逐帧分析当前不使用这两个参数。
# skip_seconds=skip_seconds,
# threshold=threshold,
)
progress_callback(40, "正在分析关键帧...")
# 执行异步分析
results = await analyzer.analyze_images(
images=keyframe_files,
prompt=config.app.get('vision_analysis_prompt'),
batch_size=vision_batch_size
)
progress_callback(60, "正在整理分析结果...")
# 合并所有批次的分析结果
frame_analysis = ""
prev_batch_files = None
for result in results:
if 'error' in result:
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
continue
batch_files = self._get_batch_files(keyframe_files, result, vision_batch_size)
first_timestamp, last_timestamp, _ = self._get_batch_timestamps(batch_files, prev_batch_files)
# 添加带时间戳的分<E79A84><E58886>结果
frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n"
frame_analysis += result['response']
frame_analysis += "\n"
prev_batch_files = batch_files
if not frame_analysis.strip():
raise Exception("未能生成有效的帧分析结果")
progress_callback(70, "正在生成脚本...")
# 构建帧内容列表
frame_content_list = []
prev_batch_files = None
for result in results:
if 'error' in result:
continue
batch_files = self._get_batch_files(keyframe_files, result, vision_batch_size)
_, _, timestamp_range = self._get_batch_timestamps(batch_files, prev_batch_files)
frame_content = {
"timestamp": timestamp_range,
"picture": result['response'],
"narration": "",
"OST": 2
}
frame_content_list.append(frame_content)
prev_batch_files = batch_files
if not frame_content_list:
raise Exception("没有有效的帧内容可以处理")
progress_callback(90, "正在生成文案...")
# 获取文本生<E69CAC><E7949F>配置
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
text_api_key = config.app.get(f'text_{text_provider}_api_key')
text_model = config.app.get(f'text_{text_provider}_model_name')
text_base_url = config.app.get(f'text_{text_provider}_base_url')
# 根据提供商类型选择合适的处理器
if text_provider == 'gemini(openai)':
# 使用OpenAI兼容的Gemini代理
from app.utils.script_generator import GeminiOpenAIGenerator
generator = GeminiOpenAIGenerator(
model_name=text_model,
api_key=text_api_key,
prompt=custom_prompt,
base_url=text_base_url
)
processor = ScriptProcessor(
model_name=text_model,
api_key=text_api_key,
base_url=text_base_url,
prompt=custom_prompt,
video_theme=video_theme
)
processor.generator = generator
else:
# 使用标准处理器包括原生Gemini
processor = ScriptProcessor(
model_name=text_model,
api_key=text_api_key,
base_url=text_base_url,
prompt=custom_prompt,
video_theme=video_theme
)
return processor.process_frames(frame_content_list)
def _get_batch_files(
self,
keyframe_files: List[str],
result: Dict[str, Any],
batch_size: int
) -> List[str]:
"""获取当前批次的图片文件"""
batch_start = result['batch_index'] * batch_size
batch_end = min(batch_start + batch_size, len(keyframe_files))
return keyframe_files[batch_start:batch_end]
def _get_batch_timestamps(
self,
batch_files: List[str],
prev_batch_files: List[str] = None
) -> tuple[str, str, str]:
"""获取一批文件的时间戳范围,支持毫秒级精度"""
if not batch_files:
logger.warning("Empty batch files")
return "00:00:00,000", "00:00:00,000", "00:00:00,000-00:00:00,000"
if len(batch_files) == 1 and prev_batch_files and len(prev_batch_files) > 0:
first_frame = os.path.basename(prev_batch_files[-1])
last_frame = os.path.basename(batch_files[0])
else:
first_frame = os.path.basename(batch_files[0])
last_frame = os.path.basename(batch_files[-1])
first_time = first_frame.split('_')[2].replace('.jpg', '')
last_time = last_frame.split('_')[2].replace('.jpg', '')
def format_timestamp(time_str: str) -> str:
"""将时间字符串转换为 HH:MM:SS,mmm 格式"""
try:
if len(time_str) < 4:
logger.warning(f"Invalid timestamp format: {time_str}")
return "00:00:00,000"
# 处理毫秒部分
if ',' in time_str:
time_part, ms_part = time_str.split(',')
ms = int(ms_part)
else:
time_part = time_str
ms = 0
# 处理时分秒
parts = time_part.split(':')
if len(parts) == 3: # HH:MM:SS
h, m, s = map(int, parts)
elif len(parts) == 2: # MM:SS
h = 0
m, s = map(int, parts)
else: # SS
h = 0
m = 0
s = int(parts[0])
# 处理进位
if s >= 60:
m += s // 60
s = s % 60
if m >= 60:
h += m // 60
m = m % 60
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
except Exception as e:
logger.error(f"时间戳格式转换错误 {time_str}: {str(e)}")
return "00:00:00,000"
first_timestamp = format_timestamp(first_time)
last_timestamp = format_timestamp(last_time)
timestamp_range = f"{first_timestamp}-{last_timestamp}"
return first_timestamp, last_timestamp, timestamp_range

View File

@ -0,0 +1,58 @@
import json
import unittest
from pathlib import Path
from tempfile import TemporaryDirectory
from app.services.generate_narration_script import parse_frame_analysis_to_markdown
class GenerateNarrationMarkdownTests(unittest.TestCase):
def test_markdown_keeps_batches_without_summary_and_sorts_by_time(self):
artifact = {
"batches": [
{
"batch_index": 1,
"time_range": "00:00:03,000-00:00:06,000",
"overall_activity_summary": "人物转身跑向远处",
"fallback_summary": "",
"frame_observations": [
{
"timestamp": "00:00:03,000",
"observation": "人物突然回头",
}
],
},
{
"batch_index": 0,
"time_range": "00:00:00,000-00:00:03,000",
"overall_activity_summary": "",
"fallback_summary": "原始响应回退摘要",
"frame_observations": [
{
"timestamp": "00:00:00,000",
"observation": "镜头里有一只猫",
}
],
},
]
}
with TemporaryDirectory() as temp_dir:
analysis_path = Path(temp_dir) / "frame-analysis.json"
analysis_path.write_text(json.dumps(artifact, ensure_ascii=False), encoding="utf-8")
markdown = parse_frame_analysis_to_markdown(str(analysis_path))
first_range_index = markdown.find("00:00:00,000-00:00:03,000")
second_range_index = markdown.find("00:00:03,000-00:00:06,000")
self.assertIn("原始响应回退摘要", markdown)
self.assertIn("镜头里有一只猫", markdown)
self.assertIn("人物转身跑向远处", markdown)
self.assertIn("人物突然回头", markdown)
self.assertNotEqual(-1, first_range_index)
self.assertNotEqual(-1, second_range_index)
self.assertLess(first_range_index, second_range_index)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,51 @@
import unittest
from unittest.mock import AsyncMock, patch
from app.services.script_service import ScriptGenerator
class ScriptGeneratorDocumentaryTests(unittest.IsolatedAsyncioTestCase):
async def test_generate_script_passes_frame_interval_to_shared_service(self):
expected_script = [
{
"timestamp": "00:00:00,000-00:00:03,000",
"picture": "批次描述",
"narration": "",
"OST": 2,
}
]
progress = []
def progress_callback(percent, message):
progress.append((percent, message))
with patch("app.services.script_service.DocumentaryFrameAnalysisService") as service_cls:
service = service_cls.return_value
service.generate_documentary_script = AsyncMock(return_value=expected_script)
generator = ScriptGenerator()
result = await generator.generate_script(
video_path="demo.mp4",
video_theme="荒野生存",
custom_prompt="请聚焦生存动作",
frame_interval_input=3,
vision_batch_size=6,
vision_llm_provider="openai",
progress_callback=progress_callback,
)
self.assertEqual(expected_script, result)
service.generate_documentary_script.assert_awaited_once()
called_kwargs = service.generate_documentary_script.await_args.kwargs
self.assertEqual("demo.mp4", called_kwargs["video_path"])
self.assertEqual(3, called_kwargs["frame_interval_input"])
self.assertEqual(6, called_kwargs["vision_batch_size"])
self.assertEqual("openai", called_kwargs["vision_llm_provider"])
self.assertEqual("荒野生存", called_kwargs["video_theme"])
self.assertEqual("请聚焦生存动作", called_kwargs["custom_prompt"])
self.assertIs(called_kwargs["progress_callback"], progress_callback)
self.assertEqual([], progress)
if __name__ == "__main__":
unittest.main()

View File

@ -1,21 +1,21 @@
# 纪录片脚本生成
import os
import asyncio
import json
import time
import asyncio
import traceback
import streamlit as st
from loguru import logger
from datetime import datetime
from app.config import config
from app.utils import utils, video_processor
from webui.tools.base import create_vision_analyzer, get_batch_files, get_batch_timestamps
from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService
from app.services.generate_narration_script import generate_narration, parse_frame_analysis_to_markdown
from webui.tools.generate_short_summary import parse_and_fix_json
def generate_script_docu(params):
"""
生成 纪录片 视频脚本
生成纪录片视频脚本
要求: 原视频无字幕无配音
适合场景: 纪录片动物搞笑解说荒野建造等
"""
@ -34,408 +34,83 @@ def generate_script_docu(params):
if not params.video_origin_path:
st.error("请先选择视频文件")
return
"""
1. 提取键帧
"""
update_progress(10, "正在提取关键帧...")
# 创建临时目录用于存储关键帧
keyframes_dir = os.path.join(utils.temp_dir(), "keyframes")
video_hash = utils.md5(params.video_origin_path + str(os.path.getmtime(params.video_origin_path)))
video_keyframes_dir = os.path.join(keyframes_dir, video_hash)
# 检查是否已经提取过关键帧
keyframe_files = []
if os.path.exists(video_keyframes_dir):
# 取已有的关键帧文件
for filename in sorted(os.listdir(video_keyframes_dir)):
if filename.endswith('.jpg'):
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
if keyframe_files:
logger.info(f"使用已缓存的关键帧: {video_keyframes_dir}")
st.info(f"✅ 使用已缓存关键帧,共 {len(keyframe_files)}")
update_progress(20, f"使用已缓存关键帧,共 {len(keyframe_files)}")
# 如果没有缓存的关键帧,则进行提取
if not keyframe_files:
try:
# 确保目录存在
os.makedirs(video_keyframes_dir, exist_ok=True)
# 初始化视频处理器
processor = video_processor.VideoProcessor(params.video_origin_path)
# 显示视频信息
st.info(f"📹 视频信息: {processor.width}x{processor.height}, {processor.fps:.1f}fps, {processor.duration:.1f}")
# 处理视频并提取关键帧 - 直接使用超级兼容性方案
update_progress(15, "正在提取关键帧(使用超级兼容性方案)...")
try:
# 使用优化的关键帧提取方法
processor.extract_frames_by_interval_ultra_compatible(
output_dir=video_keyframes_dir,
interval_seconds=st.session_state.get('frame_interval_input'),
)
except Exception as extract_error:
logger.error(f"关键帧提取失败: {extract_error}")
# 提供详细的错误信息和解决建议
error_msg = str(extract_error)
if "权限" in error_msg or "permission" in error_msg.lower():
suggestion = "建议:检查输出目录权限,或更换输出位置"
elif "空间" in error_msg or "space" in error_msg.lower():
suggestion = "建议:检查磁盘空间是否足够"
else:
suggestion = "建议:检查视频文件是否损坏,或尝试转换为标准格式"
raise Exception(f"关键帧提取失败: {error_msg}\n{suggestion}")
# 获取所有关键文件路径
for filename in sorted(os.listdir(video_keyframes_dir)):
if filename.endswith('.jpg'):
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
if not keyframe_files:
# 检查目录中是否有其他文件
all_files = os.listdir(video_keyframes_dir)
logger.error(f"关键帧目录内容: {all_files}")
raise Exception("未提取到任何关键帧文件,请检查视频文件格式")
update_progress(20, f"关键帧提取完成,共 {len(keyframe_files)}")
st.success(f"✅ 成功提取 {len(keyframe_files)} 个关键帧")
except Exception as e:
# 如果提取失败,清理创建的目录
try:
if os.path.exists(video_keyframes_dir):
import shutil
shutil.rmtree(video_keyframes_dir)
except Exception as cleanup_err:
logger.error(f"清理失败的关键帧目录时出错: {cleanup_err}")
raise Exception(f"关键帧提取失败: {str(e)}")
"""
2. 视觉分析(批量分析每一帧)
"""
# 最佳实践:使用 get() 的默认值参数 + 从 config 获取备用值
vision_llm_provider = (
st.session_state.get('vision_llm_provider') or
config.app.get('vision_llm_provider', 'openai')
st.session_state.get("vision_llm_provider") or config.app.get("vision_llm_provider", "openai")
).lower()
logger.info(f"使用 {vision_llm_provider.upper()} 进行视觉分析")
try:
# ===================初始化视觉分析器===================
update_progress(30, "正在初始化视觉分析器...")
# 使用统一的配置键格式获取配置(支持所有 provider
vision_api_key = (
st.session_state.get(f'vision_{vision_llm_provider}_api_key') or
config.app.get(f'vision_{vision_llm_provider}_api_key')
)
vision_model = (
st.session_state.get(f'vision_{vision_llm_provider}_model_name') or
config.app.get(f'vision_{vision_llm_provider}_model_name')
)
vision_base_url = (
st.session_state.get(f'vision_{vision_llm_provider}_base_url') or
config.app.get(f'vision_{vision_llm_provider}_base_url', '')
vision_api_key = (
st.session_state.get(f"vision_{vision_llm_provider}_api_key")
or config.app.get(f"vision_{vision_llm_provider}_api_key")
)
vision_model = (
st.session_state.get(f"vision_{vision_llm_provider}_model_name")
or config.app.get(f"vision_{vision_llm_provider}_model_name")
)
vision_base_url = (
st.session_state.get(f"vision_{vision_llm_provider}_base_url")
or config.app.get(f"vision_{vision_llm_provider}_base_url", "")
)
if not vision_api_key or not vision_model:
raise ValueError(
f"未配置 {vision_llm_provider} 的 API Key 或模型名称。"
f"请在设置页面配置 vision_{vision_llm_provider}_api_key 和 vision_{vision_llm_provider}_model_name"
)
# 验证必需配置
if not vision_api_key or not vision_model:
raise ValueError(
f"未配置 {vision_llm_provider} 的 API Key 或模型名称。"
f"请在设置页面配置 vision_{vision_llm_provider}_api_key 和 vision_{vision_llm_provider}_model_name"
)
frame_interval_input = st.session_state.get("frame_interval_input") or config.frames.get(
"frame_interval_input", 3
)
vision_batch_size = st.session_state.get("vision_batch_size") or config.frames.get("vision_batch_size", 10)
vision_max_concurrency = st.session_state.get("vision_max_concurrency") or config.frames.get(
"vision_max_concurrency", 2
)
# 创建视觉分析器实例(使用统一接口)
llm_params = {
"vision_provider": vision_llm_provider,
"vision_api_key": vision_api_key,
"vision_model_name": vision_model,
"vision_base_url": vision_base_url,
}
logger.debug(f"视觉分析器配置: provider={vision_llm_provider}, model={vision_model}")
analyzer = create_vision_analyzer(
provider=vision_llm_provider,
api_key=vision_api_key,
model=vision_model,
base_url=vision_base_url
update_progress(10, "正在提取关键帧...")
service = DocumentaryFrameAnalysisService()
analysis_result = asyncio.run(
service.analyze_video(
video_path=params.video_origin_path,
video_theme=st.session_state.get("video_theme", ""),
custom_prompt=st.session_state.get("custom_prompt", ""),
frame_interval_input=frame_interval_input,
vision_batch_size=vision_batch_size,
vision_llm_provider=vision_llm_provider,
progress_callback=update_progress,
vision_api_key=vision_api_key,
vision_model_name=vision_model,
vision_base_url=vision_base_url,
max_concurrency=vision_max_concurrency,
)
)
update_progress(40, "正在分析关键帧...")
analysis_json_path = analysis_result["analysis_json_path"]
update_progress(80, "正在生成解说文案...")
# ===================创建异步事件循环===================
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
text_provider = config.app.get("text_llm_provider", "gemini").lower()
text_api_key = config.app.get(f"text_{text_provider}_api_key")
text_model = config.app.get(f"text_{text_provider}_model_name")
text_base_url = config.app.get(f"text_{text_provider}_base_url")
# 执行异步分析
vision_batch_size = st.session_state.get('vision_batch_size') or config.frames.get("vision_batch_size")
vision_analysis_prompt = """
我提供了 %s 张视频帧它们按时间顺序排列代表一个连续的视频片段请仔细分析每一帧的内容并关注帧与帧之间的变化以理解整个片段的活动
markdown_output = parse_frame_analysis_to_markdown(analysis_json_path)
narration = generate_narration(
markdown_output,
text_api_key,
base_url=text_base_url,
model=text_model,
)
narration_data = parse_and_fix_json(narration)
首先请详细描述每一帧的关键视觉信息包含主要内容人物动作和场景
然后基于所有帧的分析请用**简洁的语言**总结整个视频片段中发生的主要活动或事件流程
if not narration_data or "items" not in narration_data:
logger.error(f"解说文案JSON解析失败原始内容: {narration[:200]}...")
raise Exception("解说文案格式错误无法解析JSON或缺少items字段")
请务必使用 JSON 格式输出你的结果JSON 结构应如下
{
"frame_observations": [
{
"frame_number": 1, // 或其他标识帧的方式
"observation": "描述每张视频帧中的主要内容、人物、动作和场景。"
},
// ... 更多帧的观察 ...
],
"overall_activity_summary": "在这里填写你总结的整个片段的主要活动,保持简洁。"
}
narration_dict = [{**item, "OST": 2} for item in narration_data["items"]]
script = json.dumps(narration_dict, ensure_ascii=False, indent=2)
请务必不要遗漏视频帧我提供了 %s 张视频帧frame_observations 必须包含 %s 个元素
请只返回 JSON 字符串不要包含任何其他解释性文字
"""
results = loop.run_until_complete(
analyzer.analyze_images(
images=keyframe_files,
prompt=vision_analysis_prompt,
batch_size=vision_batch_size
)
)
loop.close()
"""
3. 处理分析结果格式化为 json 数据
"""
# ===================处理分析结果===================
update_progress(60, "正在整理分析结果...")
# 合并所有批次的分析结果
frame_analysis = ""
merged_frame_observations = [] # 合并所有批次的帧观察
overall_activity_summaries = [] # 合并所有批次的整体总结
prev_batch_files = None
frame_counter = 1 # 初始化帧计数器,用于给所有帧分配连续的序号
# 确保分析目录存在
analysis_dir = os.path.join(utils.storage_dir(), "temp", "analysis")
os.makedirs(analysis_dir, exist_ok=True)
origin_res = os.path.join(analysis_dir, "frame_analysis.json")
with open(origin_res, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
# 开始处理
for result in results:
if 'error' in result:
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
continue
# 获取当前批次的文件列表
batch_files = get_batch_files(keyframe_files, result, vision_batch_size)
# 获取批次的时间戳范围
first_timestamp, last_timestamp, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files)
# 解析响应中的JSON数据
response_text = result['response']
try:
# 处理可能包含```json```格式的响应
if "```json" in response_text:
json_content = response_text.split("```json")[1].split("```")[0].strip()
elif "```" in response_text:
json_content = response_text.split("```")[1].split("```")[0].strip()
else:
json_content = response_text.strip()
response_data = json.loads(json_content)
# 提取frame_observations和overall_activity_summary
if "frame_observations" in response_data:
frame_obs = response_data["frame_observations"]
overall_summary = response_data.get("overall_activity_summary", "")
# 添加时间戳信息到每个帧观察
for i, obs in enumerate(frame_obs):
if i < len(batch_files):
# 从文件名中提取时间戳
file_path = batch_files[i]
file_name = os.path.basename(file_path)
# 提取时间戳字符串 (格式如: keyframe_000675_000027000.jpg)
# 格式解析: keyframe_帧序号_毫秒时间戳.jpg
timestamp_parts = file_name.split('_')
if len(timestamp_parts) >= 3:
timestamp_str = timestamp_parts[-1].split('.')[0]
try:
# 修正时间戳解析逻辑
# 格式为000100000表示00:01:00,000即1分钟
# 需要按照对应位数进行解析:
# 前两位是小时,中间两位是分钟,后面是秒和毫秒
if len(timestamp_str) >= 9: # 确保格式正确
hours = int(timestamp_str[0:2])
minutes = int(timestamp_str[2:4])
seconds = int(timestamp_str[4:6])
milliseconds = int(timestamp_str[6:9])
# 计算总秒数
timestamp_seconds = hours * 3600 + minutes * 60 + seconds + milliseconds / 1000
formatted_time = utils.format_time(timestamp_seconds) # 格式化时间戳
else:
# 兼容旧的解析方式
timestamp_seconds = int(timestamp_str) / 1000 # 转换为秒
formatted_time = utils.format_time(timestamp_seconds) # 格式化时间戳
except ValueError:
logger.warning(f"无法解析时间戳: {timestamp_str}")
timestamp_seconds = 0
formatted_time = "00:00:00,000"
else:
logger.warning(f"文件名格式不符合预期: {file_name}")
timestamp_seconds = 0
formatted_time = "00:00:00,000"
# 添加额外信息到帧观察
obs["frame_path"] = file_path
obs["timestamp"] = formatted_time
obs["timestamp_seconds"] = timestamp_seconds
obs["batch_index"] = result['batch_index']
# 使用全局递增的帧计数器替换原始的frame_number
if "frame_number" in obs:
obs["original_frame_number"] = obs["frame_number"] # 保留原始编号作为参考
obs["frame_number"] = frame_counter # 赋值连续的帧编号
frame_counter += 1 # 增加帧计数器
# 添加到合并列表
merged_frame_observations.append(obs)
# 添加批次整体总结信息
if overall_summary:
# 从文件名中提取时间戳数值
first_time_str = first_timestamp.split('_')[-1].split('.')[0]
last_time_str = last_timestamp.split('_')[-1].split('.')[0]
# 转换为毫秒并计算持续时间(秒)
try:
# 修正解析逻辑,与上面相同的方式解析时间戳
if len(first_time_str) >= 9 and len(last_time_str) >= 9:
# 解析第一个时间戳
first_hours = int(first_time_str[0:2])
first_minutes = int(first_time_str[2:4])
first_seconds = int(first_time_str[4:6])
first_ms = int(first_time_str[6:9])
first_time_seconds = first_hours * 3600 + first_minutes * 60 + first_seconds + first_ms / 1000
# 解析第二个时间戳
last_hours = int(last_time_str[0:2])
last_minutes = int(last_time_str[2:4])
last_seconds = int(last_time_str[4:6])
last_ms = int(last_time_str[6:9])
last_time_seconds = last_hours * 3600 + last_minutes * 60 + last_seconds + last_ms / 1000
batch_duration = last_time_seconds - first_time_seconds
else:
# 兼容旧的解析方式
first_time_ms = int(first_time_str)
last_time_ms = int(last_time_str)
batch_duration = (last_time_ms - first_time_ms) / 1000
except ValueError:
# 使用 utils.time_to_seconds 函数处理格式化的时间戳
first_time_seconds = utils.time_to_seconds(first_time_str.replace('_', ':').replace('-', ','))
last_time_seconds = utils.time_to_seconds(last_time_str.replace('_', ':').replace('-', ','))
batch_duration = last_time_seconds - first_time_seconds
overall_activity_summaries.append({
"batch_index": result['batch_index'],
"time_range": f"{first_timestamp}-{last_timestamp}",
"duration_seconds": batch_duration,
"summary": overall_summary
})
except Exception as e:
logger.error(f"解析批次 {result['batch_index']} 的响应数据失败: {str(e)}")
# 添加原始响应作为回退
frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n"
frame_analysis += response_text
frame_analysis += "\n"
# 更新上一个批次的文件
prev_batch_files = batch_files
# 将合并后的结果转为JSON字符串
merged_results = {
"frame_observations": merged_frame_observations,
"overall_activity_summaries": overall_activity_summaries
}
# 使用当前时间创建文件名
now = datetime.now()
timestamp_str = now.strftime("%Y%m%d_%H%M")
# 保存完整的分析结果为JSON
analysis_filename = f"frame_analysis_{timestamp_str}.json"
analysis_json_path = os.path.join(analysis_dir, analysis_filename)
with open(analysis_json_path, 'w', encoding='utf-8') as f:
json.dump(merged_results, f, ensure_ascii=False, indent=2)
logger.info(f"分析结果已保存到: {analysis_json_path}")
"""
4. 生成文案
"""
logger.info("开始生成解说文案")
update_progress(80, "正在生成解说文案...")
from app.services.generate_narration_script import parse_frame_analysis_to_markdown, generate_narration
# 从配置中获取文本生成相关配置
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
text_api_key = config.app.get(f'text_{text_provider}_api_key')
text_model = config.app.get(f'text_{text_provider}_model_name')
text_base_url = config.app.get(f'text_{text_provider}_base_url')
llm_params.update({
"text_provider": text_provider,
"text_api_key": text_api_key,
"text_model_name": text_model,
"text_base_url": text_base_url
})
# 整理帧分析数据
markdown_output = parse_frame_analysis_to_markdown(analysis_json_path)
# 生成解说文案
narration = generate_narration(
markdown_output,
text_api_key,
base_url=text_base_url,
model=text_model
)
# 使用增强的JSON解析器
from webui.tools.generate_short_summary import parse_and_fix_json
narration_data = parse_and_fix_json(narration)
if not narration_data or 'items' not in narration_data:
logger.error(f"解说文案JSON解析失败原始内容: {narration[:200]}...")
raise Exception("解说文案格式错误无法解析JSON或缺少items字段")
narration_dict = narration_data['items']
# 为 narration_dict 中每个 item 新增一个 OST: 2 的字段, 代表保留原声和配音
narration_dict = [{**item, "OST": 2} for item in narration_dict]
logger.info(f"解说文案生成完成,共 {len(narration_dict)} 个片段")
# 结果转换为JSON字符串
script = json.dumps(narration_dict, ensure_ascii=False, indent=2)
except Exception as e:
logger.exception(f"大模型处理过程中发生错误\n{traceback.format_exc()}")
raise Exception(f"分析失败: {str(e)}")
if script is None:
st.error("生成脚本失败,请检查日志")
st.stop()
logger.info(f"纪录片解说脚本生成完成")
logger.info(f"纪录片解说脚本生成完成,共 {len(narration_dict)} 个片段")
if isinstance(script, list):
st.session_state['video_clip_json'] = script
st.session_state["video_clip_json"] = script
elif isinstance(script, str):
st.session_state['video_clip_json'] = json.loads(script)
st.session_state["video_clip_json"] = json.loads(script)
update_progress(100, "脚本生成完成")
time.sleep(0.1)