mirror of
https://github.com/linyqh/NarratoAI.git
synced 2026-05-01 14:18:19 +00:00
refactor(documentary): route adapters through shared analysis service
This commit is contained in:
parent
df034d104b
commit
ac63fea953
2
.gitignore
vendored
2
.gitignore
vendored
@ -48,3 +48,5 @@ CLAUDE.md
|
||||
tests/*
|
||||
!tests/test_documentary_frame_analysis_service.py
|
||||
!tests/test_video_processor_documentary_unittest.py
|
||||
!tests/test_script_service_documentary_unittest.py
|
||||
!tests/test_generate_narration_script_documentary_unittest.py
|
||||
|
||||
@ -1,9 +1,16 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable
|
||||
|
||||
from app.utils import utils
|
||||
from loguru import logger
|
||||
|
||||
from app.config import config
|
||||
from app.services.documentary.frame_analysis_models import FrameBatchResult
|
||||
from app.services.llm.migration_adapter import create_vision_analyzer
|
||||
from app.utils import utils, video_processor
|
||||
|
||||
|
||||
class DocumentaryFrameAnalysisService:
|
||||
@ -23,8 +30,488 @@ JSON 必须包含以下键:
|
||||
"overall_activity_summary": "本批次主要活动总结"
|
||||
}}
|
||||
请务必不要遗漏视频帧,我提供了 {frame_count} 张视频帧,frame_observations 必须包含 {frame_count} 个元素
|
||||
请只返回 JSON 字符串,不要附加解释文字。
|
||||
""".strip()
|
||||
|
||||
async def generate_documentary_script(
|
||||
self,
|
||||
*,
|
||||
video_path: str,
|
||||
video_theme: str = "",
|
||||
custom_prompt: str = "",
|
||||
frame_interval_input: int | float = 3,
|
||||
vision_batch_size: int = 10,
|
||||
vision_llm_provider: str = "openai",
|
||||
progress_callback: Callable[[float, str], None] | None = None,
|
||||
vision_api_key: str | None = None,
|
||||
vision_model_name: str | None = None,
|
||||
vision_base_url: str | None = None,
|
||||
max_concurrency: int | None = None,
|
||||
) -> list[dict]:
|
||||
analysis_result = await self.analyze_video(
|
||||
video_path=video_path,
|
||||
video_theme=video_theme,
|
||||
custom_prompt=custom_prompt,
|
||||
frame_interval_input=frame_interval_input,
|
||||
vision_batch_size=vision_batch_size,
|
||||
vision_llm_provider=vision_llm_provider,
|
||||
progress_callback=progress_callback,
|
||||
vision_api_key=vision_api_key,
|
||||
vision_model_name=vision_model_name,
|
||||
vision_base_url=vision_base_url,
|
||||
max_concurrency=max_concurrency,
|
||||
)
|
||||
return analysis_result["video_clip_json"]
|
||||
|
||||
async def analyze_video(
|
||||
self,
|
||||
*,
|
||||
video_path: str,
|
||||
video_theme: str = "",
|
||||
custom_prompt: str = "",
|
||||
frame_interval_input: int | float = 3,
|
||||
vision_batch_size: int = 10,
|
||||
vision_llm_provider: str = "openai",
|
||||
progress_callback: Callable[[float, str], None] | None = None,
|
||||
vision_api_key: str | None = None,
|
||||
vision_model_name: str | None = None,
|
||||
vision_base_url: str | None = None,
|
||||
max_concurrency: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
progress = progress_callback or (lambda _p, _m: None)
|
||||
|
||||
if not video_path or not os.path.exists(video_path):
|
||||
raise FileNotFoundError(f"视频文件不存在: {video_path}")
|
||||
|
||||
frame_interval_seconds = self._resolve_frame_interval(frame_interval_input)
|
||||
batch_size = self._resolve_batch_size(vision_batch_size)
|
||||
concurrency = self._resolve_max_concurrency(max_concurrency)
|
||||
provider = (vision_llm_provider or config.app.get("vision_llm_provider", "openai")).lower()
|
||||
|
||||
api_key = vision_api_key or config.app.get(f"vision_{provider}_api_key")
|
||||
model_name = vision_model_name or config.app.get(f"vision_{provider}_model_name")
|
||||
base_url = vision_base_url or config.app.get(f"vision_{provider}_base_url", "")
|
||||
if not api_key or not model_name:
|
||||
raise ValueError(
|
||||
f"未配置 {provider} 的 API Key 或模型名称。"
|
||||
f"请在设置中配置 vision_{provider}_api_key 和 vision_{provider}_model_name"
|
||||
)
|
||||
|
||||
progress(10, "正在提取关键帧...")
|
||||
keyframe_files = self._load_or_extract_keyframes(video_path, frame_interval_seconds)
|
||||
progress(25, f"关键帧准备完成,共 {len(keyframe_files)} 帧")
|
||||
|
||||
progress(30, "正在初始化视觉分析器...")
|
||||
analyzer = create_vision_analyzer(
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
model=model_name,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
batches = self._chunk_keyframes(keyframe_files, batch_size=batch_size)
|
||||
if not batches:
|
||||
raise RuntimeError("未能构建任何关键帧批次")
|
||||
|
||||
progress(40, f"正在分析关键帧,共 {len(batches)} 个批次...")
|
||||
batch_results = await self._analyze_batches(
|
||||
analyzer=analyzer,
|
||||
batches=batches,
|
||||
custom_prompt=custom_prompt,
|
||||
video_theme=video_theme,
|
||||
max_concurrency=concurrency,
|
||||
progress_callback=progress,
|
||||
)
|
||||
|
||||
progress(65, "正在整理分析结果...")
|
||||
sorted_batches = self._sort_batch_results(batch_results)
|
||||
artifact = self._build_analysis_artifact(
|
||||
sorted_batches,
|
||||
video_path=video_path,
|
||||
frame_interval_seconds=frame_interval_seconds,
|
||||
vision_batch_size=batch_size,
|
||||
vision_llm_provider=provider,
|
||||
vision_model_name=model_name,
|
||||
max_concurrency=concurrency,
|
||||
)
|
||||
analysis_json_path = self._save_analysis_artifact(artifact)
|
||||
video_clip_json = self._build_video_clip_json(sorted_batches)
|
||||
|
||||
progress(75, "逐帧分析完成")
|
||||
return {
|
||||
"analysis_json_path": analysis_json_path,
|
||||
"analysis_artifact": artifact,
|
||||
"video_clip_json": video_clip_json,
|
||||
"keyframe_files": keyframe_files,
|
||||
}
|
||||
|
||||
def _resolve_frame_interval(self, frame_interval_input: int | float | None) -> float:
|
||||
interval = frame_interval_input
|
||||
if interval in (None, ""):
|
||||
interval = config.frames.get("frame_interval_input", 3)
|
||||
try:
|
||||
value = float(interval)
|
||||
except (TypeError, ValueError):
|
||||
value = 3.0
|
||||
if value <= 0:
|
||||
raise ValueError("frame_interval_input must be > 0")
|
||||
return value
|
||||
|
||||
def _resolve_batch_size(self, vision_batch_size: int | None) -> int:
|
||||
size = vision_batch_size or config.frames.get("vision_batch_size", 10)
|
||||
try:
|
||||
value = int(size)
|
||||
except (TypeError, ValueError):
|
||||
value = 10
|
||||
if value <= 0:
|
||||
raise ValueError("vision_batch_size must be > 0")
|
||||
return value
|
||||
|
||||
def _resolve_max_concurrency(self, max_concurrency: int | None) -> int:
|
||||
value = max_concurrency if max_concurrency is not None else config.frames.get("vision_max_concurrency", 2)
|
||||
try:
|
||||
parsed = int(value)
|
||||
except (TypeError, ValueError):
|
||||
parsed = 1
|
||||
return max(1, parsed)
|
||||
|
||||
def _load_or_extract_keyframes(self, video_path: str, frame_interval_seconds: float) -> list[str]:
|
||||
keyframes_root = os.path.join(utils.temp_dir(), "keyframes")
|
||||
os.makedirs(keyframes_root, exist_ok=True)
|
||||
cache_key = self._build_keyframe_cache_key(video_path, frame_interval_seconds)
|
||||
cache_dir = os.path.join(keyframes_root, cache_key)
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
cached_files = self._collect_keyframe_paths(cache_dir)
|
||||
if cached_files:
|
||||
logger.info(f"使用已缓存关键帧: {cache_dir}, 共 {len(cached_files)} 帧")
|
||||
return cached_files
|
||||
|
||||
processor = video_processor.VideoProcessor(video_path)
|
||||
extracted = processor.extract_frames_by_interval_with_fallback(
|
||||
output_dir=cache_dir,
|
||||
interval_seconds=frame_interval_seconds,
|
||||
)
|
||||
keyframe_files = sorted(str(path) for path in extracted if str(path).endswith(".jpg"))
|
||||
if not keyframe_files:
|
||||
keyframe_files = self._collect_keyframe_paths(cache_dir)
|
||||
if not keyframe_files:
|
||||
raise RuntimeError("未提取到任何关键帧")
|
||||
|
||||
logger.info(f"关键帧提取完成: {cache_dir}, 共 {len(keyframe_files)} 帧")
|
||||
return keyframe_files
|
||||
|
||||
def _build_keyframe_cache_key(self, video_path: str, frame_interval_seconds: float) -> str:
|
||||
try:
|
||||
video_mtime = os.path.getmtime(video_path)
|
||||
except OSError:
|
||||
video_mtime = 0
|
||||
|
||||
legacy_prefix = utils.md5(f"{video_path}{video_mtime}")
|
||||
payload = "|".join(
|
||||
[
|
||||
str(video_path),
|
||||
str(video_mtime),
|
||||
str(frame_interval_seconds),
|
||||
"documentary-keyframes-v2",
|
||||
]
|
||||
)
|
||||
return f"{legacy_prefix}_{utils.md5(payload)}"
|
||||
|
||||
@staticmethod
|
||||
def _collect_keyframe_paths(cache_dir: str) -> list[str]:
|
||||
if not os.path.exists(cache_dir):
|
||||
return []
|
||||
return sorted(
|
||||
os.path.join(cache_dir, name)
|
||||
for name in os.listdir(cache_dir)
|
||||
if re.fullmatch(r"keyframe_\d{6}_\d{9}\.jpg", name)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _chunk_keyframes(keyframe_files: list[str], batch_size: int) -> list[list[str]]:
|
||||
return [keyframe_files[index : index + batch_size] for index in range(0, len(keyframe_files), batch_size)]
|
||||
|
||||
async def _analyze_batches(
|
||||
self,
|
||||
*,
|
||||
analyzer: Any,
|
||||
batches: list[list[str]],
|
||||
custom_prompt: str,
|
||||
video_theme: str,
|
||||
max_concurrency: int,
|
||||
progress_callback: Callable[[float, str], None],
|
||||
) -> list[FrameBatchResult]:
|
||||
semaphore = asyncio.Semaphore(max(1, max_concurrency))
|
||||
total = len(batches)
|
||||
done = 0
|
||||
done_lock = asyncio.Lock()
|
||||
|
||||
batch_time_ranges: list[str] = []
|
||||
previous_batch_files: list[str] | None = None
|
||||
for batch_files in batches:
|
||||
_, _, time_range = self._get_batch_timestamps(batch_files, previous_batch_files)
|
||||
batch_time_ranges.append(time_range)
|
||||
previous_batch_files = batch_files
|
||||
|
||||
async def run_single(batch_index: int, frame_paths: list[str], time_range: str) -> FrameBatchResult:
|
||||
nonlocal done
|
||||
prompt = self._build_batch_prompt(
|
||||
frame_count=len(frame_paths),
|
||||
video_theme=video_theme,
|
||||
custom_prompt=custom_prompt,
|
||||
)
|
||||
try:
|
||||
async with semaphore:
|
||||
raw_results = await analyzer.analyze_images(
|
||||
images=frame_paths,
|
||||
prompt=prompt,
|
||||
batch_size=max(1, len(frame_paths)),
|
||||
max_concurrency=1,
|
||||
)
|
||||
raw_response, error_message = self._extract_batch_response(raw_results)
|
||||
if error_message:
|
||||
return self._build_failed_batch_result(
|
||||
batch_index=batch_index,
|
||||
raw_response=raw_response,
|
||||
error_message=error_message,
|
||||
frame_paths=frame_paths,
|
||||
time_range=time_range,
|
||||
)
|
||||
return self._parse_batch_response(
|
||||
batch_index=batch_index,
|
||||
raw_response=raw_response,
|
||||
frame_paths=frame_paths,
|
||||
time_range=time_range,
|
||||
)
|
||||
except Exception as exc:
|
||||
return self._build_failed_batch_result(
|
||||
batch_index=batch_index,
|
||||
raw_response="",
|
||||
error_message=str(exc),
|
||||
frame_paths=frame_paths,
|
||||
time_range=time_range,
|
||||
)
|
||||
finally:
|
||||
async with done_lock:
|
||||
done += 1
|
||||
progress = 40 + (done / max(1, total)) * 25
|
||||
progress_callback(progress, f"正在分析关键帧批次 ({done}/{total})...")
|
||||
|
||||
tasks = [
|
||||
run_single(batch_index=index, frame_paths=batch_files, time_range=batch_time_ranges[index])
|
||||
for index, batch_files in enumerate(batches)
|
||||
]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
def _build_batch_prompt(self, *, frame_count: int, video_theme: str, custom_prompt: str) -> str:
|
||||
prompt = self._build_analysis_prompt(frame_count=frame_count)
|
||||
extra_lines: list[str] = []
|
||||
if (video_theme or "").strip():
|
||||
extra_lines.append(f"视频主题:{video_theme.strip()}")
|
||||
if (custom_prompt or "").strip():
|
||||
extra_lines.append(custom_prompt.strip())
|
||||
if not extra_lines:
|
||||
return prompt
|
||||
|
||||
extras = "\n".join(f"- {line}" for line in extra_lines)
|
||||
return f"{prompt}\n\n补充分析要求:\n{extras}"
|
||||
|
||||
def _extract_batch_response(self, raw_results: Any) -> tuple[str, str]:
|
||||
if not raw_results:
|
||||
return "", "Batch response is empty"
|
||||
|
||||
first_result = raw_results[0] if isinstance(raw_results, list) else raw_results
|
||||
if isinstance(first_result, dict):
|
||||
raw_response = str(first_result.get("response", "") or "")
|
||||
error_message = str(first_result.get("error", "") or "")
|
||||
if error_message:
|
||||
if not raw_response:
|
||||
raw_response = error_message
|
||||
return raw_response, error_message
|
||||
if not raw_response.strip():
|
||||
return raw_response, "Batch response is empty"
|
||||
return raw_response, ""
|
||||
|
||||
raw_response = str(first_result or "")
|
||||
if not raw_response.strip():
|
||||
return raw_response, "Batch response is empty"
|
||||
return raw_response, ""
|
||||
|
||||
def _sort_batch_results(self, batch_results: list[FrameBatchResult]) -> list[FrameBatchResult]:
|
||||
return sorted(batch_results, key=lambda item: (self._time_range_sort_key(item.time_range), item.batch_index))
|
||||
|
||||
def _build_analysis_artifact(
|
||||
self,
|
||||
batch_results: list[FrameBatchResult],
|
||||
*,
|
||||
video_path: str,
|
||||
frame_interval_seconds: float,
|
||||
vision_batch_size: int,
|
||||
vision_llm_provider: str,
|
||||
vision_model_name: str,
|
||||
max_concurrency: int,
|
||||
) -> dict[str, Any]:
|
||||
sorted_batches = self._sort_batch_results(batch_results)
|
||||
|
||||
batch_dicts: list[dict[str, Any]] = []
|
||||
frame_observations: list[dict[str, Any]] = []
|
||||
overall_activity_summaries: list[dict[str, Any]] = []
|
||||
|
||||
for batch in sorted_batches:
|
||||
batch_payload = {
|
||||
"batch_index": batch.batch_index,
|
||||
"status": batch.status,
|
||||
"time_range": batch.time_range,
|
||||
"raw_response": batch.raw_response,
|
||||
"frame_paths": list(batch.frame_paths),
|
||||
"frame_observations": list(batch.frame_observations),
|
||||
"overall_activity_summary": batch.overall_activity_summary,
|
||||
"fallback_summary": batch.fallback_summary,
|
||||
"error_message": batch.error_message,
|
||||
}
|
||||
batch_dicts.append(batch_payload)
|
||||
|
||||
for observation in batch.frame_observations:
|
||||
observation_payload = dict(observation)
|
||||
observation_payload["batch_index"] = batch.batch_index
|
||||
observation_payload["time_range"] = batch.time_range
|
||||
frame_observations.append(observation_payload)
|
||||
|
||||
summary_text = (batch.overall_activity_summary or batch.fallback_summary or "").strip()
|
||||
if summary_text:
|
||||
overall_activity_summaries.append(
|
||||
{
|
||||
"batch_index": batch.batch_index,
|
||||
"time_range": batch.time_range,
|
||||
"summary": summary_text,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"artifact_version": "documentary-frame-analysis-v2",
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"video_path": video_path,
|
||||
"frame_interval_seconds": frame_interval_seconds,
|
||||
"vision_batch_size": vision_batch_size,
|
||||
"vision_llm_provider": vision_llm_provider,
|
||||
"vision_model_name": vision_model_name,
|
||||
"vision_max_concurrency": max_concurrency,
|
||||
"batches": batch_dicts,
|
||||
# 向后兼容旧解析器结构
|
||||
"frame_observations": frame_observations,
|
||||
"overall_activity_summaries": overall_activity_summaries,
|
||||
}
|
||||
|
||||
def _save_analysis_artifact(self, artifact: dict[str, Any]) -> str:
|
||||
analysis_dir = os.path.join(utils.storage_dir(), "temp", "analysis")
|
||||
os.makedirs(analysis_dir, exist_ok=True)
|
||||
|
||||
filename = f"frame_analysis_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||||
file_path = os.path.join(analysis_dir, filename)
|
||||
suffix = 1
|
||||
while os.path.exists(file_path):
|
||||
filename = f"frame_analysis_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{suffix:02d}.json"
|
||||
file_path = os.path.join(analysis_dir, filename)
|
||||
suffix += 1
|
||||
|
||||
with open(file_path, "w", encoding="utf-8") as fp:
|
||||
json.dump(artifact, fp, ensure_ascii=False, indent=2)
|
||||
logger.info(f"分析结果已保存到: {file_path}")
|
||||
return file_path
|
||||
|
||||
def _build_video_clip_json(self, batch_results: list[FrameBatchResult]) -> list[dict]:
|
||||
clips: list[dict] = []
|
||||
for batch in self._sort_batch_results(batch_results):
|
||||
picture = self._build_batch_picture(batch)
|
||||
clips.append(
|
||||
{
|
||||
"timestamp": batch.time_range,
|
||||
"picture": picture,
|
||||
"narration": "",
|
||||
"OST": 2,
|
||||
}
|
||||
)
|
||||
return clips
|
||||
|
||||
def _build_batch_picture(self, batch: FrameBatchResult) -> str:
|
||||
summary = (batch.overall_activity_summary or "").strip()
|
||||
if summary:
|
||||
return summary
|
||||
|
||||
fallback = (batch.fallback_summary or "").strip()
|
||||
if fallback:
|
||||
return fallback
|
||||
|
||||
observation_lines = []
|
||||
for frame in batch.frame_observations:
|
||||
timestamp = str(frame.get("timestamp", "") or "").strip()
|
||||
observation = str(frame.get("observation", "") or "").strip()
|
||||
if timestamp and observation:
|
||||
observation_lines.append(f"{timestamp}: {observation}")
|
||||
elif observation:
|
||||
observation_lines.append(observation)
|
||||
if observation_lines:
|
||||
return " ".join(observation_lines)
|
||||
|
||||
raw_response = (batch.raw_response or "").strip()
|
||||
if raw_response:
|
||||
return raw_response[:200]
|
||||
return "该批次分析失败,未返回可用描述。"
|
||||
|
||||
def _time_range_sort_key(self, time_range: str) -> tuple[int, str]:
|
||||
start = (time_range or "").split("-", 1)[0].strip()
|
||||
return self._timestamp_to_milliseconds(start), time_range
|
||||
|
||||
@staticmethod
|
||||
def _timestamp_to_milliseconds(timestamp: str) -> int:
|
||||
text = (timestamp or "").strip()
|
||||
try:
|
||||
if "," in text:
|
||||
time_part, ms_part = text.split(",", 1)
|
||||
milliseconds = int(ms_part)
|
||||
else:
|
||||
time_part = text
|
||||
milliseconds = 0
|
||||
|
||||
parts = [int(part) for part in time_part.split(":") if part]
|
||||
while len(parts) < 3:
|
||||
parts.insert(0, 0)
|
||||
hours, minutes, seconds = parts[-3], parts[-2], parts[-1]
|
||||
return ((hours * 3600 + minutes * 60 + seconds) * 1000) + milliseconds
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def _get_batch_timestamps(
|
||||
self,
|
||||
batch_files: list[str],
|
||||
prev_batch_files: list[str] | None = None,
|
||||
) -> tuple[str, str, str]:
|
||||
if not batch_files:
|
||||
return "00:00:00,000", "00:00:00,000", "00:00:00,000-00:00:00,000"
|
||||
|
||||
if len(batch_files) == 1 and prev_batch_files:
|
||||
first_frame = os.path.basename(prev_batch_files[-1])
|
||||
last_frame = os.path.basename(batch_files[0])
|
||||
else:
|
||||
first_frame = os.path.basename(batch_files[0])
|
||||
last_frame = os.path.basename(batch_files[-1])
|
||||
|
||||
first_timestamp = self._timestamp_from_keyframe_name(first_frame)
|
||||
last_timestamp = self._timestamp_from_keyframe_name(last_frame)
|
||||
return first_timestamp, last_timestamp, f"{first_timestamp}-{last_timestamp}"
|
||||
|
||||
def _timestamp_from_keyframe_name(self, filename: str) -> str:
|
||||
match = re.search(r"keyframe_\d{6}_(\d{9})\.jpg$", filename)
|
||||
if not match:
|
||||
return "00:00:00,000"
|
||||
token = match.group(1)
|
||||
hours = int(token[0:2])
|
||||
minutes = int(token[2:4])
|
||||
seconds = int(token[4:6])
|
||||
milliseconds = int(token[6:9])
|
||||
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
|
||||
|
||||
def _build_analysis_prompt(self, frame_count: int) -> str:
|
||||
return self.PROMPT_TEMPLATE.format(frame_count=frame_count)
|
||||
|
||||
|
||||
@ -38,46 +38,90 @@ def parse_frame_analysis_to_markdown(json_file_path):
|
||||
with open(json_file_path, 'r', encoding='utf-8') as file:
|
||||
data = json.load(file)
|
||||
|
||||
# 初始化Markdown字符串
|
||||
def time_to_milliseconds(time_text):
|
||||
time_text = (time_text or "").strip()
|
||||
if not time_text:
|
||||
return 0
|
||||
try:
|
||||
if "," in time_text:
|
||||
hhmmss, ms = time_text.split(",", 1)
|
||||
milliseconds = int(ms)
|
||||
else:
|
||||
hhmmss = time_text
|
||||
milliseconds = 0
|
||||
|
||||
parts = [int(part) for part in hhmmss.split(":") if part]
|
||||
while len(parts) < 3:
|
||||
parts.insert(0, 0)
|
||||
hours, minutes, seconds = parts[-3], parts[-2], parts[-1]
|
||||
return ((hours * 3600 + minutes * 60 + seconds) * 1000) + milliseconds
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def batch_sort_key(batch):
|
||||
time_range = batch.get("time_range", "")
|
||||
start = time_range.split("-", 1)[0].strip()
|
||||
return time_to_milliseconds(start), batch.get("batch_index", 0)
|
||||
|
||||
markdown = ""
|
||||
|
||||
# 获取总结和帧观察数据
|
||||
|
||||
# 新结构:按批次保存完整分析产物
|
||||
if isinstance(data.get("batches"), list):
|
||||
ordered_batches = sorted(data.get("batches", []), key=batch_sort_key)
|
||||
|
||||
for i, batch in enumerate(ordered_batches, 1):
|
||||
time_range = batch.get("time_range", "")
|
||||
summary = (
|
||||
batch.get("overall_activity_summary")
|
||||
or batch.get("summary")
|
||||
or batch.get("fallback_summary")
|
||||
or ""
|
||||
)
|
||||
observations = batch.get("frame_observations") or batch.get("observations") or []
|
||||
|
||||
markdown += f"## 片段 {i}\n"
|
||||
markdown += f"- 时间范围:{time_range}\n"
|
||||
markdown += f"- 片段描述:{summary}\n" if summary else "- 片段描述:\n"
|
||||
markdown += "- 详细描述:\n"
|
||||
|
||||
for frame in observations:
|
||||
timestamp = frame.get("timestamp", "")
|
||||
observation = frame.get("observation", "")
|
||||
markdown += f" - {timestamp}: {observation}\n" if observation else f" - {timestamp}: \n"
|
||||
|
||||
markdown += "\n"
|
||||
|
||||
return markdown
|
||||
|
||||
# 兼容旧结构
|
||||
summaries = data.get('overall_activity_summaries', [])
|
||||
frame_observations = data.get('frame_observations', [])
|
||||
|
||||
# 按批次组织数据
|
||||
|
||||
batch_frames = {}
|
||||
for frame in frame_observations:
|
||||
batch_index = frame.get('batch_index')
|
||||
if batch_index not in batch_frames:
|
||||
batch_frames[batch_index] = []
|
||||
batch_frames[batch_index].append(frame)
|
||||
|
||||
# 生成Markdown内容
|
||||
|
||||
for i, summary in enumerate(summaries, 1):
|
||||
batch_index = summary.get('batch_index')
|
||||
time_range = summary.get('time_range', '')
|
||||
batch_summary = summary.get('summary', '')
|
||||
|
||||
|
||||
markdown += f"## 片段 {i}\n"
|
||||
markdown += f"- 时间范围:{time_range}\n"
|
||||
|
||||
# 添加片段描述
|
||||
markdown += f"- 片段描述:{batch_summary}\n" if batch_summary else f"- 片段描述:\n"
|
||||
|
||||
markdown += "- 详细描述:\n"
|
||||
|
||||
# 添加该批次的帧观察详情
|
||||
|
||||
frames = batch_frames.get(batch_index, [])
|
||||
for frame in frames:
|
||||
timestamp = frame.get('timestamp', '')
|
||||
observation = frame.get('observation', '')
|
||||
|
||||
# 直接使用原始文本,不进行分割
|
||||
markdown += f" - {timestamp}: {observation}\n" if observation else f" - {timestamp}: \n"
|
||||
|
||||
|
||||
markdown += "\n"
|
||||
|
||||
|
||||
return markdown
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@ -1,22 +1,12 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import asyncio
|
||||
import requests
|
||||
from app.utils import video_processor
|
||||
from loguru import logger
|
||||
from typing import List, Dict, Any, Callable
|
||||
from typing import Any, Callable
|
||||
|
||||
from app.utils import utils, gemini_analyzer, video_processor
|
||||
from app.utils.script_generator import ScriptProcessor
|
||||
from app.config import config
|
||||
from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService
|
||||
|
||||
|
||||
class ScriptGenerator:
|
||||
def __init__(self):
|
||||
self.temp_dir = utils.temp_dir()
|
||||
self.keyframes_dir = os.path.join(self.temp_dir, "keyframes")
|
||||
|
||||
def __init__(self, documentary_service: DocumentaryFrameAnalysisService | None = None):
|
||||
self.documentary_service = documentary_service or DocumentaryFrameAnalysisService()
|
||||
|
||||
async def generate_script(
|
||||
self,
|
||||
video_path: str,
|
||||
@ -27,298 +17,18 @@ class ScriptGenerator:
|
||||
threshold: int = 30,
|
||||
vision_batch_size: int = 5,
|
||||
vision_llm_provider: str = "gemini",
|
||||
progress_callback: Callable[[float, str], None] = None
|
||||
) -> List[Dict[Any, Any]]:
|
||||
"""
|
||||
生成视频脚本的核心逻辑
|
||||
|
||||
Args:
|
||||
video_path: 视频文件路径
|
||||
video_theme: 视频主题
|
||||
custom_prompt: 自定义提示词
|
||||
skip_seconds: 跳过开始的秒数
|
||||
threshold: 差异<EFBFBD><EFBFBD><EFBFBD>值
|
||||
vision_batch_size: 视觉处理批次大小
|
||||
vision_llm_provider: 视觉模型提供商
|
||||
progress_callback: 进度回调函数
|
||||
|
||||
Returns:
|
||||
List[Dict]: 生成的视频脚本
|
||||
"""
|
||||
if progress_callback is None:
|
||||
progress_callback = lambda p, m: None
|
||||
|
||||
try:
|
||||
# 提取关键帧
|
||||
progress_callback(10, "正在提取关键帧...")
|
||||
keyframe_files = await self._extract_keyframes(
|
||||
video_path,
|
||||
skip_seconds,
|
||||
threshold
|
||||
)
|
||||
|
||||
# 使用统一的 LLM 接口(支持所有 provider)
|
||||
script = await self._process_with_llm(
|
||||
keyframe_files,
|
||||
video_theme,
|
||||
custom_prompt,
|
||||
vision_batch_size,
|
||||
vision_llm_provider,
|
||||
progress_callback
|
||||
)
|
||||
|
||||
return json.loads(script) if isinstance(script, str) else script
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Generate script failed")
|
||||
raise
|
||||
|
||||
async def _extract_keyframes(
|
||||
self,
|
||||
video_path: str,
|
||||
skip_seconds: int,
|
||||
threshold: int
|
||||
) -> List[str]:
|
||||
"""提取视频关键帧"""
|
||||
video_hash = utils.md5(video_path + str(os.path.getmtime(video_path)))
|
||||
video_keyframes_dir = os.path.join(self.keyframes_dir, video_hash)
|
||||
|
||||
# 检查缓存
|
||||
keyframe_files = []
|
||||
if os.path.exists(video_keyframes_dir):
|
||||
for filename in sorted(os.listdir(video_keyframes_dir)):
|
||||
if filename.endswith('.jpg'):
|
||||
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
|
||||
|
||||
if keyframe_files:
|
||||
logger.info(f"Using cached keyframes: {video_keyframes_dir}")
|
||||
return keyframe_files
|
||||
|
||||
# 提取新的关键帧
|
||||
os.makedirs(video_keyframes_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
processor = video_processor.VideoProcessor(video_path)
|
||||
processor.process_video_pipeline(
|
||||
output_dir=video_keyframes_dir,
|
||||
skip_seconds=skip_seconds,
|
||||
threshold=threshold
|
||||
)
|
||||
|
||||
for filename in sorted(os.listdir(video_keyframes_dir)):
|
||||
if filename.endswith('.jpg'):
|
||||
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
|
||||
|
||||
return keyframe_files
|
||||
|
||||
except Exception as e:
|
||||
if os.path.exists(video_keyframes_dir):
|
||||
import shutil
|
||||
shutil.rmtree(video_keyframes_dir)
|
||||
raise
|
||||
|
||||
async def _process_with_llm(
|
||||
self,
|
||||
keyframe_files: List[str],
|
||||
video_theme: str,
|
||||
custom_prompt: str,
|
||||
vision_batch_size: int,
|
||||
vision_llm_provider: str,
|
||||
progress_callback: Callable[[float, str], None]
|
||||
) -> str:
|
||||
"""使用统一 LLM 接口处理视频帧"""
|
||||
progress_callback(30, "正在初始化视觉分析器...")
|
||||
|
||||
# 使用新的 LLM 迁移适配器(支持所有 provider)
|
||||
from app.services.llm.migration_adapter import create_vision_analyzer
|
||||
|
||||
# 获取配置
|
||||
text_provider = config.app.get('text_llm_provider', 'openai').lower()
|
||||
vision_api_key = config.app.get(f'vision_{vision_llm_provider}_api_key')
|
||||
vision_model = config.app.get(f'vision_{vision_llm_provider}_model_name')
|
||||
vision_base_url = config.app.get(f'vision_{vision_llm_provider}_base_url')
|
||||
|
||||
if not vision_api_key or not vision_model:
|
||||
raise ValueError(f"未配置 {vision_llm_provider} API Key 或者模型")
|
||||
|
||||
# 创建统一的视觉分析器
|
||||
analyzer = create_vision_analyzer(
|
||||
provider=vision_llm_provider,
|
||||
api_key=vision_api_key,
|
||||
model=vision_model,
|
||||
base_url=vision_base_url
|
||||
progress_callback: Callable[[float, str], None] | None = None,
|
||||
) -> list[dict[Any, Any]]:
|
||||
callback = progress_callback or (lambda _p, _m: None)
|
||||
return await self.documentary_service.generate_documentary_script(
|
||||
video_path=video_path,
|
||||
video_theme=video_theme,
|
||||
custom_prompt=custom_prompt,
|
||||
frame_interval_input=frame_interval_input,
|
||||
vision_batch_size=vision_batch_size,
|
||||
vision_llm_provider=vision_llm_provider,
|
||||
progress_callback=callback,
|
||||
# 历史参数保留在签名中以兼容调用方;共享逐帧分析当前不使用这两个参数。
|
||||
# skip_seconds=skip_seconds,
|
||||
# threshold=threshold,
|
||||
)
|
||||
|
||||
progress_callback(40, "正在分析关键帧...")
|
||||
|
||||
# 执行异步分析
|
||||
results = await analyzer.analyze_images(
|
||||
images=keyframe_files,
|
||||
prompt=config.app.get('vision_analysis_prompt'),
|
||||
batch_size=vision_batch_size
|
||||
)
|
||||
|
||||
progress_callback(60, "正在整理分析结果...")
|
||||
|
||||
# 合并所有批次的分析结果
|
||||
frame_analysis = ""
|
||||
prev_batch_files = None
|
||||
|
||||
for result in results:
|
||||
if 'error' in result:
|
||||
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
|
||||
continue
|
||||
|
||||
batch_files = self._get_batch_files(keyframe_files, result, vision_batch_size)
|
||||
first_timestamp, last_timestamp, _ = self._get_batch_timestamps(batch_files, prev_batch_files)
|
||||
|
||||
# 添加带时间戳的分<E79A84><E58886>结果
|
||||
frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n"
|
||||
frame_analysis += result['response']
|
||||
frame_analysis += "\n"
|
||||
|
||||
prev_batch_files = batch_files
|
||||
|
||||
if not frame_analysis.strip():
|
||||
raise Exception("未能生成有效的帧分析结果")
|
||||
|
||||
progress_callback(70, "正在生成脚本...")
|
||||
|
||||
# 构建帧内容列表
|
||||
frame_content_list = []
|
||||
prev_batch_files = None
|
||||
|
||||
for result in results:
|
||||
if 'error' in result:
|
||||
continue
|
||||
|
||||
batch_files = self._get_batch_files(keyframe_files, result, vision_batch_size)
|
||||
_, _, timestamp_range = self._get_batch_timestamps(batch_files, prev_batch_files)
|
||||
|
||||
frame_content = {
|
||||
"timestamp": timestamp_range,
|
||||
"picture": result['response'],
|
||||
"narration": "",
|
||||
"OST": 2
|
||||
}
|
||||
frame_content_list.append(frame_content)
|
||||
prev_batch_files = batch_files
|
||||
|
||||
if not frame_content_list:
|
||||
raise Exception("没有有效的帧内容可以处理")
|
||||
|
||||
progress_callback(90, "正在生成文案...")
|
||||
|
||||
# 获取文本生<E69CAC><E7949F>配置
|
||||
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
|
||||
text_api_key = config.app.get(f'text_{text_provider}_api_key')
|
||||
text_model = config.app.get(f'text_{text_provider}_model_name')
|
||||
text_base_url = config.app.get(f'text_{text_provider}_base_url')
|
||||
|
||||
# 根据提供商类型选择合适的处理器
|
||||
if text_provider == 'gemini(openai)':
|
||||
# 使用OpenAI兼容的Gemini代理
|
||||
from app.utils.script_generator import GeminiOpenAIGenerator
|
||||
generator = GeminiOpenAIGenerator(
|
||||
model_name=text_model,
|
||||
api_key=text_api_key,
|
||||
prompt=custom_prompt,
|
||||
base_url=text_base_url
|
||||
)
|
||||
processor = ScriptProcessor(
|
||||
model_name=text_model,
|
||||
api_key=text_api_key,
|
||||
base_url=text_base_url,
|
||||
prompt=custom_prompt,
|
||||
video_theme=video_theme
|
||||
)
|
||||
processor.generator = generator
|
||||
else:
|
||||
# 使用标准处理器(包括原生Gemini)
|
||||
processor = ScriptProcessor(
|
||||
model_name=text_model,
|
||||
api_key=text_api_key,
|
||||
base_url=text_base_url,
|
||||
prompt=custom_prompt,
|
||||
video_theme=video_theme
|
||||
)
|
||||
|
||||
return processor.process_frames(frame_content_list)
|
||||
|
||||
def _get_batch_files(
|
||||
self,
|
||||
keyframe_files: List[str],
|
||||
result: Dict[str, Any],
|
||||
batch_size: int
|
||||
) -> List[str]:
|
||||
"""获取当前批次的图片文件"""
|
||||
batch_start = result['batch_index'] * batch_size
|
||||
batch_end = min(batch_start + batch_size, len(keyframe_files))
|
||||
return keyframe_files[batch_start:batch_end]
|
||||
|
||||
def _get_batch_timestamps(
|
||||
self,
|
||||
batch_files: List[str],
|
||||
prev_batch_files: List[str] = None
|
||||
) -> tuple[str, str, str]:
|
||||
"""获取一批文件的时间戳范围,支持毫秒级精度"""
|
||||
if not batch_files:
|
||||
logger.warning("Empty batch files")
|
||||
return "00:00:00,000", "00:00:00,000", "00:00:00,000-00:00:00,000"
|
||||
|
||||
if len(batch_files) == 1 and prev_batch_files and len(prev_batch_files) > 0:
|
||||
first_frame = os.path.basename(prev_batch_files[-1])
|
||||
last_frame = os.path.basename(batch_files[0])
|
||||
else:
|
||||
first_frame = os.path.basename(batch_files[0])
|
||||
last_frame = os.path.basename(batch_files[-1])
|
||||
|
||||
first_time = first_frame.split('_')[2].replace('.jpg', '')
|
||||
last_time = last_frame.split('_')[2].replace('.jpg', '')
|
||||
|
||||
def format_timestamp(time_str: str) -> str:
|
||||
"""将时间字符串转换为 HH:MM:SS,mmm 格式"""
|
||||
try:
|
||||
if len(time_str) < 4:
|
||||
logger.warning(f"Invalid timestamp format: {time_str}")
|
||||
return "00:00:00,000"
|
||||
|
||||
# 处理毫秒部分
|
||||
if ',' in time_str:
|
||||
time_part, ms_part = time_str.split(',')
|
||||
ms = int(ms_part)
|
||||
else:
|
||||
time_part = time_str
|
||||
ms = 0
|
||||
|
||||
# 处理时分秒
|
||||
parts = time_part.split(':')
|
||||
if len(parts) == 3: # HH:MM:SS
|
||||
h, m, s = map(int, parts)
|
||||
elif len(parts) == 2: # MM:SS
|
||||
h = 0
|
||||
m, s = map(int, parts)
|
||||
else: # SS
|
||||
h = 0
|
||||
m = 0
|
||||
s = int(parts[0])
|
||||
|
||||
# 处理进位
|
||||
if s >= 60:
|
||||
m += s // 60
|
||||
s = s % 60
|
||||
if m >= 60:
|
||||
h += m // 60
|
||||
m = m % 60
|
||||
|
||||
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"时间戳格式转换错误 {time_str}: {str(e)}")
|
||||
return "00:00:00,000"
|
||||
|
||||
first_timestamp = format_timestamp(first_time)
|
||||
last_timestamp = format_timestamp(last_time)
|
||||
timestamp_range = f"{first_timestamp}-{last_timestamp}"
|
||||
|
||||
return first_timestamp, last_timestamp, timestamp_range
|
||||
|
||||
58
tests/test_generate_narration_script_documentary_unittest.py
Normal file
58
tests/test_generate_narration_script_documentary_unittest.py
Normal file
@ -0,0 +1,58 @@
|
||||
import json
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from app.services.generate_narration_script import parse_frame_analysis_to_markdown
|
||||
|
||||
|
||||
class GenerateNarrationMarkdownTests(unittest.TestCase):
|
||||
def test_markdown_keeps_batches_without_summary_and_sorts_by_time(self):
|
||||
artifact = {
|
||||
"batches": [
|
||||
{
|
||||
"batch_index": 1,
|
||||
"time_range": "00:00:03,000-00:00:06,000",
|
||||
"overall_activity_summary": "人物转身跑向远处",
|
||||
"fallback_summary": "",
|
||||
"frame_observations": [
|
||||
{
|
||||
"timestamp": "00:00:03,000",
|
||||
"observation": "人物突然回头",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"batch_index": 0,
|
||||
"time_range": "00:00:00,000-00:00:03,000",
|
||||
"overall_activity_summary": "",
|
||||
"fallback_summary": "原始响应回退摘要",
|
||||
"frame_observations": [
|
||||
{
|
||||
"timestamp": "00:00:00,000",
|
||||
"observation": "镜头里有一只猫",
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
analysis_path = Path(temp_dir) / "frame-analysis.json"
|
||||
analysis_path.write_text(json.dumps(artifact, ensure_ascii=False), encoding="utf-8")
|
||||
markdown = parse_frame_analysis_to_markdown(str(analysis_path))
|
||||
|
||||
first_range_index = markdown.find("00:00:00,000-00:00:03,000")
|
||||
second_range_index = markdown.find("00:00:03,000-00:00:06,000")
|
||||
|
||||
self.assertIn("原始响应回退摘要", markdown)
|
||||
self.assertIn("镜头里有一只猫", markdown)
|
||||
self.assertIn("人物转身跑向远处", markdown)
|
||||
self.assertIn("人物突然回头", markdown)
|
||||
self.assertNotEqual(-1, first_range_index)
|
||||
self.assertNotEqual(-1, second_range_index)
|
||||
self.assertLess(first_range_index, second_range_index)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
51
tests/test_script_service_documentary_unittest.py
Normal file
51
tests/test_script_service_documentary_unittest.py
Normal file
@ -0,0 +1,51 @@
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from app.services.script_service import ScriptGenerator
|
||||
|
||||
|
||||
class ScriptGeneratorDocumentaryTests(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_generate_script_passes_frame_interval_to_shared_service(self):
|
||||
expected_script = [
|
||||
{
|
||||
"timestamp": "00:00:00,000-00:00:03,000",
|
||||
"picture": "批次描述",
|
||||
"narration": "",
|
||||
"OST": 2,
|
||||
}
|
||||
]
|
||||
progress = []
|
||||
|
||||
def progress_callback(percent, message):
|
||||
progress.append((percent, message))
|
||||
|
||||
with patch("app.services.script_service.DocumentaryFrameAnalysisService") as service_cls:
|
||||
service = service_cls.return_value
|
||||
service.generate_documentary_script = AsyncMock(return_value=expected_script)
|
||||
generator = ScriptGenerator()
|
||||
|
||||
result = await generator.generate_script(
|
||||
video_path="demo.mp4",
|
||||
video_theme="荒野生存",
|
||||
custom_prompt="请聚焦生存动作",
|
||||
frame_interval_input=3,
|
||||
vision_batch_size=6,
|
||||
vision_llm_provider="openai",
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
self.assertEqual(expected_script, result)
|
||||
service.generate_documentary_script.assert_awaited_once()
|
||||
called_kwargs = service.generate_documentary_script.await_args.kwargs
|
||||
self.assertEqual("demo.mp4", called_kwargs["video_path"])
|
||||
self.assertEqual(3, called_kwargs["frame_interval_input"])
|
||||
self.assertEqual(6, called_kwargs["vision_batch_size"])
|
||||
self.assertEqual("openai", called_kwargs["vision_llm_provider"])
|
||||
self.assertEqual("荒野生存", called_kwargs["video_theme"])
|
||||
self.assertEqual("请聚焦生存动作", called_kwargs["custom_prompt"])
|
||||
self.assertIs(called_kwargs["progress_callback"], progress_callback)
|
||||
self.assertEqual([], progress)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -1,21 +1,21 @@
|
||||
# 纪录片脚本生成
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
import streamlit as st
|
||||
from loguru import logger
|
||||
from datetime import datetime
|
||||
|
||||
from app.config import config
|
||||
from app.utils import utils, video_processor
|
||||
from webui.tools.base import create_vision_analyzer, get_batch_files, get_batch_timestamps
|
||||
from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService
|
||||
from app.services.generate_narration_script import generate_narration, parse_frame_analysis_to_markdown
|
||||
from webui.tools.generate_short_summary import parse_and_fix_json
|
||||
|
||||
|
||||
def generate_script_docu(params):
|
||||
"""
|
||||
生成 纪录片 视频脚本
|
||||
生成纪录片视频脚本。
|
||||
要求: 原视频无字幕无配音
|
||||
适合场景: 纪录片、动物搞笑解说、荒野建造等
|
||||
"""
|
||||
@ -34,408 +34,83 @@ def generate_script_docu(params):
|
||||
if not params.video_origin_path:
|
||||
st.error("请先选择视频文件")
|
||||
return
|
||||
"""
|
||||
1. 提取键帧
|
||||
"""
|
||||
update_progress(10, "正在提取关键帧...")
|
||||
|
||||
# 创建临时目录用于存储关键帧
|
||||
keyframes_dir = os.path.join(utils.temp_dir(), "keyframes")
|
||||
video_hash = utils.md5(params.video_origin_path + str(os.path.getmtime(params.video_origin_path)))
|
||||
video_keyframes_dir = os.path.join(keyframes_dir, video_hash)
|
||||
|
||||
# 检查是否已经提取过关键帧
|
||||
keyframe_files = []
|
||||
if os.path.exists(video_keyframes_dir):
|
||||
# 取已有的关键帧文件
|
||||
for filename in sorted(os.listdir(video_keyframes_dir)):
|
||||
if filename.endswith('.jpg'):
|
||||
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
|
||||
|
||||
if keyframe_files:
|
||||
logger.info(f"使用已缓存的关键帧: {video_keyframes_dir}")
|
||||
st.info(f"✅ 使用已缓存关键帧,共 {len(keyframe_files)} 帧")
|
||||
update_progress(20, f"使用已缓存关键帧,共 {len(keyframe_files)} 帧")
|
||||
|
||||
# 如果没有缓存的关键帧,则进行提取
|
||||
if not keyframe_files:
|
||||
try:
|
||||
# 确保目录存在
|
||||
os.makedirs(video_keyframes_dir, exist_ok=True)
|
||||
|
||||
# 初始化视频处理器
|
||||
processor = video_processor.VideoProcessor(params.video_origin_path)
|
||||
|
||||
# 显示视频信息
|
||||
st.info(f"📹 视频信息: {processor.width}x{processor.height}, {processor.fps:.1f}fps, {processor.duration:.1f}秒")
|
||||
|
||||
# 处理视频并提取关键帧 - 直接使用超级兼容性方案
|
||||
update_progress(15, "正在提取关键帧(使用超级兼容性方案)...")
|
||||
|
||||
try:
|
||||
# 使用优化的关键帧提取方法
|
||||
processor.extract_frames_by_interval_ultra_compatible(
|
||||
output_dir=video_keyframes_dir,
|
||||
interval_seconds=st.session_state.get('frame_interval_input'),
|
||||
)
|
||||
except Exception as extract_error:
|
||||
logger.error(f"关键帧提取失败: {extract_error}")
|
||||
|
||||
# 提供详细的错误信息和解决建议
|
||||
error_msg = str(extract_error)
|
||||
if "权限" in error_msg or "permission" in error_msg.lower():
|
||||
suggestion = "建议:检查输出目录权限,或更换输出位置"
|
||||
elif "空间" in error_msg or "space" in error_msg.lower():
|
||||
suggestion = "建议:检查磁盘空间是否足够"
|
||||
else:
|
||||
suggestion = "建议:检查视频文件是否损坏,或尝试转换为标准格式"
|
||||
|
||||
raise Exception(f"关键帧提取失败: {error_msg}\n{suggestion}")
|
||||
|
||||
# 获取所有关键文件路径
|
||||
for filename in sorted(os.listdir(video_keyframes_dir)):
|
||||
if filename.endswith('.jpg'):
|
||||
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
|
||||
|
||||
if not keyframe_files:
|
||||
# 检查目录中是否有其他文件
|
||||
all_files = os.listdir(video_keyframes_dir)
|
||||
logger.error(f"关键帧目录内容: {all_files}")
|
||||
raise Exception("未提取到任何关键帧文件,请检查视频文件格式")
|
||||
|
||||
update_progress(20, f"关键帧提取完成,共 {len(keyframe_files)} 帧")
|
||||
st.success(f"✅ 成功提取 {len(keyframe_files)} 个关键帧")
|
||||
|
||||
except Exception as e:
|
||||
# 如果提取失败,清理创建的目录
|
||||
try:
|
||||
if os.path.exists(video_keyframes_dir):
|
||||
import shutil
|
||||
shutil.rmtree(video_keyframes_dir)
|
||||
except Exception as cleanup_err:
|
||||
logger.error(f"清理失败的关键帧目录时出错: {cleanup_err}")
|
||||
|
||||
raise Exception(f"关键帧提取失败: {str(e)}")
|
||||
|
||||
"""
|
||||
2. 视觉分析(批量分析每一帧)
|
||||
"""
|
||||
# 最佳实践:使用 get() 的默认值参数 + 从 config 获取备用值
|
||||
vision_llm_provider = (
|
||||
st.session_state.get('vision_llm_provider') or
|
||||
config.app.get('vision_llm_provider', 'openai')
|
||||
st.session_state.get("vision_llm_provider") or config.app.get("vision_llm_provider", "openai")
|
||||
).lower()
|
||||
|
||||
logger.info(f"使用 {vision_llm_provider.upper()} 进行视觉分析")
|
||||
|
||||
try:
|
||||
# ===================初始化视觉分析器===================
|
||||
update_progress(30, "正在初始化视觉分析器...")
|
||||
|
||||
# 使用统一的配置键格式获取配置(支持所有 provider)
|
||||
vision_api_key = (
|
||||
st.session_state.get(f'vision_{vision_llm_provider}_api_key') or
|
||||
config.app.get(f'vision_{vision_llm_provider}_api_key')
|
||||
)
|
||||
vision_model = (
|
||||
st.session_state.get(f'vision_{vision_llm_provider}_model_name') or
|
||||
config.app.get(f'vision_{vision_llm_provider}_model_name')
|
||||
)
|
||||
vision_base_url = (
|
||||
st.session_state.get(f'vision_{vision_llm_provider}_base_url') or
|
||||
config.app.get(f'vision_{vision_llm_provider}_base_url', '')
|
||||
vision_api_key = (
|
||||
st.session_state.get(f"vision_{vision_llm_provider}_api_key")
|
||||
or config.app.get(f"vision_{vision_llm_provider}_api_key")
|
||||
)
|
||||
vision_model = (
|
||||
st.session_state.get(f"vision_{vision_llm_provider}_model_name")
|
||||
or config.app.get(f"vision_{vision_llm_provider}_model_name")
|
||||
)
|
||||
vision_base_url = (
|
||||
st.session_state.get(f"vision_{vision_llm_provider}_base_url")
|
||||
or config.app.get(f"vision_{vision_llm_provider}_base_url", "")
|
||||
)
|
||||
if not vision_api_key or not vision_model:
|
||||
raise ValueError(
|
||||
f"未配置 {vision_llm_provider} 的 API Key 或模型名称。"
|
||||
f"请在设置页面配置 vision_{vision_llm_provider}_api_key 和 vision_{vision_llm_provider}_model_name"
|
||||
)
|
||||
|
||||
# 验证必需配置
|
||||
if not vision_api_key or not vision_model:
|
||||
raise ValueError(
|
||||
f"未配置 {vision_llm_provider} 的 API Key 或模型名称。"
|
||||
f"请在设置页面配置 vision_{vision_llm_provider}_api_key 和 vision_{vision_llm_provider}_model_name"
|
||||
)
|
||||
frame_interval_input = st.session_state.get("frame_interval_input") or config.frames.get(
|
||||
"frame_interval_input", 3
|
||||
)
|
||||
vision_batch_size = st.session_state.get("vision_batch_size") or config.frames.get("vision_batch_size", 10)
|
||||
vision_max_concurrency = st.session_state.get("vision_max_concurrency") or config.frames.get(
|
||||
"vision_max_concurrency", 2
|
||||
)
|
||||
|
||||
# 创建视觉分析器实例(使用统一接口)
|
||||
llm_params = {
|
||||
"vision_provider": vision_llm_provider,
|
||||
"vision_api_key": vision_api_key,
|
||||
"vision_model_name": vision_model,
|
||||
"vision_base_url": vision_base_url,
|
||||
}
|
||||
|
||||
logger.debug(f"视觉分析器配置: provider={vision_llm_provider}, model={vision_model}")
|
||||
|
||||
analyzer = create_vision_analyzer(
|
||||
provider=vision_llm_provider,
|
||||
api_key=vision_api_key,
|
||||
model=vision_model,
|
||||
base_url=vision_base_url
|
||||
update_progress(10, "正在提取关键帧...")
|
||||
service = DocumentaryFrameAnalysisService()
|
||||
analysis_result = asyncio.run(
|
||||
service.analyze_video(
|
||||
video_path=params.video_origin_path,
|
||||
video_theme=st.session_state.get("video_theme", ""),
|
||||
custom_prompt=st.session_state.get("custom_prompt", ""),
|
||||
frame_interval_input=frame_interval_input,
|
||||
vision_batch_size=vision_batch_size,
|
||||
vision_llm_provider=vision_llm_provider,
|
||||
progress_callback=update_progress,
|
||||
vision_api_key=vision_api_key,
|
||||
vision_model_name=vision_model,
|
||||
vision_base_url=vision_base_url,
|
||||
max_concurrency=vision_max_concurrency,
|
||||
)
|
||||
)
|
||||
|
||||
update_progress(40, "正在分析关键帧...")
|
||||
analysis_json_path = analysis_result["analysis_json_path"]
|
||||
update_progress(80, "正在生成解说文案...")
|
||||
|
||||
# ===================创建异步事件循环===================
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
text_provider = config.app.get("text_llm_provider", "gemini").lower()
|
||||
text_api_key = config.app.get(f"text_{text_provider}_api_key")
|
||||
text_model = config.app.get(f"text_{text_provider}_model_name")
|
||||
text_base_url = config.app.get(f"text_{text_provider}_base_url")
|
||||
|
||||
# 执行异步分析
|
||||
vision_batch_size = st.session_state.get('vision_batch_size') or config.frames.get("vision_batch_size")
|
||||
vision_analysis_prompt = """
|
||||
我提供了 %s 张视频帧,它们按时间顺序排列,代表一个连续的视频片段。请仔细分析每一帧的内容,并关注帧与帧之间的变化,以理解整个片段的活动。
|
||||
markdown_output = parse_frame_analysis_to_markdown(analysis_json_path)
|
||||
narration = generate_narration(
|
||||
markdown_output,
|
||||
text_api_key,
|
||||
base_url=text_base_url,
|
||||
model=text_model,
|
||||
)
|
||||
narration_data = parse_and_fix_json(narration)
|
||||
|
||||
首先,请详细描述每一帧的关键视觉信息(包含:主要内容、人物、动作和场景)。
|
||||
然后,基于所有帧的分析,请用**简洁的语言**总结整个视频片段中发生的主要活动或事件流程。
|
||||
if not narration_data or "items" not in narration_data:
|
||||
logger.error(f"解说文案JSON解析失败,原始内容: {narration[:200]}...")
|
||||
raise Exception("解说文案格式错误,无法解析JSON或缺少items字段")
|
||||
|
||||
请务必使用 JSON 格式输出你的结果。JSON 结构应如下:
|
||||
{
|
||||
"frame_observations": [
|
||||
{
|
||||
"frame_number": 1, // 或其他标识帧的方式
|
||||
"observation": "描述每张视频帧中的主要内容、人物、动作和场景。"
|
||||
},
|
||||
// ... 更多帧的观察 ...
|
||||
],
|
||||
"overall_activity_summary": "在这里填写你总结的整个片段的主要活动,保持简洁。"
|
||||
}
|
||||
narration_dict = [{**item, "OST": 2} for item in narration_data["items"]]
|
||||
script = json.dumps(narration_dict, ensure_ascii=False, indent=2)
|
||||
|
||||
请务必不要遗漏视频帧,我提供了 %s 张视频帧,frame_observations 必须包含 %s 个元素
|
||||
|
||||
请只返回 JSON 字符串,不要包含任何其他解释性文字。
|
||||
"""
|
||||
results = loop.run_until_complete(
|
||||
analyzer.analyze_images(
|
||||
images=keyframe_files,
|
||||
prompt=vision_analysis_prompt,
|
||||
batch_size=vision_batch_size
|
||||
)
|
||||
)
|
||||
loop.close()
|
||||
|
||||
"""
|
||||
3. 处理分析结果(格式化为 json 数据)
|
||||
"""
|
||||
# ===================处理分析结果===================
|
||||
update_progress(60, "正在整理分析结果...")
|
||||
|
||||
# 合并所有批次的分析结果
|
||||
frame_analysis = ""
|
||||
merged_frame_observations = [] # 合并所有批次的帧观察
|
||||
overall_activity_summaries = [] # 合并所有批次的整体总结
|
||||
prev_batch_files = None
|
||||
frame_counter = 1 # 初始化帧计数器,用于给所有帧分配连续的序号
|
||||
|
||||
# 确保分析目录存在
|
||||
analysis_dir = os.path.join(utils.storage_dir(), "temp", "analysis")
|
||||
os.makedirs(analysis_dir, exist_ok=True)
|
||||
origin_res = os.path.join(analysis_dir, "frame_analysis.json")
|
||||
with open(origin_res, 'w', encoding='utf-8') as f:
|
||||
json.dump(results, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 开始处理
|
||||
for result in results:
|
||||
if 'error' in result:
|
||||
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
|
||||
continue
|
||||
|
||||
# 获取当前批次的文件列表
|
||||
batch_files = get_batch_files(keyframe_files, result, vision_batch_size)
|
||||
|
||||
# 获取批次的时间戳范围
|
||||
first_timestamp, last_timestamp, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files)
|
||||
|
||||
# 解析响应中的JSON数据
|
||||
response_text = result['response']
|
||||
try:
|
||||
# 处理可能包含```json```格式的响应
|
||||
if "```json" in response_text:
|
||||
json_content = response_text.split("```json")[1].split("```")[0].strip()
|
||||
elif "```" in response_text:
|
||||
json_content = response_text.split("```")[1].split("```")[0].strip()
|
||||
else:
|
||||
json_content = response_text.strip()
|
||||
|
||||
response_data = json.loads(json_content)
|
||||
|
||||
# 提取frame_observations和overall_activity_summary
|
||||
if "frame_observations" in response_data:
|
||||
frame_obs = response_data["frame_observations"]
|
||||
overall_summary = response_data.get("overall_activity_summary", "")
|
||||
|
||||
# 添加时间戳信息到每个帧观察
|
||||
for i, obs in enumerate(frame_obs):
|
||||
if i < len(batch_files):
|
||||
# 从文件名中提取时间戳
|
||||
file_path = batch_files[i]
|
||||
file_name = os.path.basename(file_path)
|
||||
# 提取时间戳字符串 (格式如: keyframe_000675_000027000.jpg)
|
||||
# 格式解析: keyframe_帧序号_毫秒时间戳.jpg
|
||||
timestamp_parts = file_name.split('_')
|
||||
if len(timestamp_parts) >= 3:
|
||||
timestamp_str = timestamp_parts[-1].split('.')[0]
|
||||
try:
|
||||
# 修正时间戳解析逻辑
|
||||
# 格式为000100000,表示00:01:00,000,即1分钟
|
||||
# 需要按照对应位数进行解析:
|
||||
# 前两位是小时,中间两位是分钟,后面是秒和毫秒
|
||||
if len(timestamp_str) >= 9: # 确保格式正确
|
||||
hours = int(timestamp_str[0:2])
|
||||
minutes = int(timestamp_str[2:4])
|
||||
seconds = int(timestamp_str[4:6])
|
||||
milliseconds = int(timestamp_str[6:9])
|
||||
|
||||
# 计算总秒数
|
||||
timestamp_seconds = hours * 3600 + minutes * 60 + seconds + milliseconds / 1000
|
||||
formatted_time = utils.format_time(timestamp_seconds) # 格式化时间戳
|
||||
else:
|
||||
# 兼容旧的解析方式
|
||||
timestamp_seconds = int(timestamp_str) / 1000 # 转换为秒
|
||||
formatted_time = utils.format_time(timestamp_seconds) # 格式化时间戳
|
||||
except ValueError:
|
||||
logger.warning(f"无法解析时间戳: {timestamp_str}")
|
||||
timestamp_seconds = 0
|
||||
formatted_time = "00:00:00,000"
|
||||
else:
|
||||
logger.warning(f"文件名格式不符合预期: {file_name}")
|
||||
timestamp_seconds = 0
|
||||
formatted_time = "00:00:00,000"
|
||||
|
||||
# 添加额外信息到帧观察
|
||||
obs["frame_path"] = file_path
|
||||
obs["timestamp"] = formatted_time
|
||||
obs["timestamp_seconds"] = timestamp_seconds
|
||||
obs["batch_index"] = result['batch_index']
|
||||
|
||||
# 使用全局递增的帧计数器替换原始的frame_number
|
||||
if "frame_number" in obs:
|
||||
obs["original_frame_number"] = obs["frame_number"] # 保留原始编号作为参考
|
||||
obs["frame_number"] = frame_counter # 赋值连续的帧编号
|
||||
frame_counter += 1 # 增加帧计数器
|
||||
|
||||
# 添加到合并列表
|
||||
merged_frame_observations.append(obs)
|
||||
|
||||
# 添加批次整体总结信息
|
||||
if overall_summary:
|
||||
# 从文件名中提取时间戳数值
|
||||
first_time_str = first_timestamp.split('_')[-1].split('.')[0]
|
||||
last_time_str = last_timestamp.split('_')[-1].split('.')[0]
|
||||
|
||||
# 转换为毫秒并计算持续时间(秒)
|
||||
try:
|
||||
# 修正解析逻辑,与上面相同的方式解析时间戳
|
||||
if len(first_time_str) >= 9 and len(last_time_str) >= 9:
|
||||
# 解析第一个时间戳
|
||||
first_hours = int(first_time_str[0:2])
|
||||
first_minutes = int(first_time_str[2:4])
|
||||
first_seconds = int(first_time_str[4:6])
|
||||
first_ms = int(first_time_str[6:9])
|
||||
first_time_seconds = first_hours * 3600 + first_minutes * 60 + first_seconds + first_ms / 1000
|
||||
|
||||
# 解析第二个时间戳
|
||||
last_hours = int(last_time_str[0:2])
|
||||
last_minutes = int(last_time_str[2:4])
|
||||
last_seconds = int(last_time_str[4:6])
|
||||
last_ms = int(last_time_str[6:9])
|
||||
last_time_seconds = last_hours * 3600 + last_minutes * 60 + last_seconds + last_ms / 1000
|
||||
|
||||
batch_duration = last_time_seconds - first_time_seconds
|
||||
else:
|
||||
# 兼容旧的解析方式
|
||||
first_time_ms = int(first_time_str)
|
||||
last_time_ms = int(last_time_str)
|
||||
batch_duration = (last_time_ms - first_time_ms) / 1000
|
||||
except ValueError:
|
||||
# 使用 utils.time_to_seconds 函数处理格式化的时间戳
|
||||
first_time_seconds = utils.time_to_seconds(first_time_str.replace('_', ':').replace('-', ','))
|
||||
last_time_seconds = utils.time_to_seconds(last_time_str.replace('_', ':').replace('-', ','))
|
||||
batch_duration = last_time_seconds - first_time_seconds
|
||||
|
||||
overall_activity_summaries.append({
|
||||
"batch_index": result['batch_index'],
|
||||
"time_range": f"{first_timestamp}-{last_timestamp}",
|
||||
"duration_seconds": batch_duration,
|
||||
"summary": overall_summary
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"解析批次 {result['batch_index']} 的响应数据失败: {str(e)}")
|
||||
# 添加原始响应作为回退
|
||||
frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n"
|
||||
frame_analysis += response_text
|
||||
frame_analysis += "\n"
|
||||
|
||||
# 更新上一个批次的文件
|
||||
prev_batch_files = batch_files
|
||||
|
||||
# 将合并后的结果转为JSON字符串
|
||||
merged_results = {
|
||||
"frame_observations": merged_frame_observations,
|
||||
"overall_activity_summaries": overall_activity_summaries
|
||||
}
|
||||
|
||||
# 使用当前时间创建文件名
|
||||
now = datetime.now()
|
||||
timestamp_str = now.strftime("%Y%m%d_%H%M")
|
||||
|
||||
# 保存完整的分析结果为JSON
|
||||
analysis_filename = f"frame_analysis_{timestamp_str}.json"
|
||||
analysis_json_path = os.path.join(analysis_dir, analysis_filename)
|
||||
with open(analysis_json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(merged_results, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"分析结果已保存到: {analysis_json_path}")
|
||||
|
||||
"""
|
||||
4. 生成文案
|
||||
"""
|
||||
logger.info("开始生成解说文案")
|
||||
update_progress(80, "正在生成解说文案...")
|
||||
from app.services.generate_narration_script import parse_frame_analysis_to_markdown, generate_narration
|
||||
# 从配置中获取文本生成相关配置
|
||||
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
|
||||
text_api_key = config.app.get(f'text_{text_provider}_api_key')
|
||||
text_model = config.app.get(f'text_{text_provider}_model_name')
|
||||
text_base_url = config.app.get(f'text_{text_provider}_base_url')
|
||||
llm_params.update({
|
||||
"text_provider": text_provider,
|
||||
"text_api_key": text_api_key,
|
||||
"text_model_name": text_model,
|
||||
"text_base_url": text_base_url
|
||||
})
|
||||
# 整理帧分析数据
|
||||
markdown_output = parse_frame_analysis_to_markdown(analysis_json_path)
|
||||
|
||||
# 生成解说文案
|
||||
narration = generate_narration(
|
||||
markdown_output,
|
||||
text_api_key,
|
||||
base_url=text_base_url,
|
||||
model=text_model
|
||||
)
|
||||
|
||||
# 使用增强的JSON解析器
|
||||
from webui.tools.generate_short_summary import parse_and_fix_json
|
||||
narration_data = parse_and_fix_json(narration)
|
||||
|
||||
if not narration_data or 'items' not in narration_data:
|
||||
logger.error(f"解说文案JSON解析失败,原始内容: {narration[:200]}...")
|
||||
raise Exception("解说文案格式错误,无法解析JSON或缺少items字段")
|
||||
|
||||
narration_dict = narration_data['items']
|
||||
# 为 narration_dict 中每个 item 新增一个 OST: 2 的字段, 代表保留原声和配音
|
||||
narration_dict = [{**item, "OST": 2} for item in narration_dict]
|
||||
logger.info(f"解说文案生成完成,共 {len(narration_dict)} 个片段")
|
||||
# 结果转换为JSON字符串
|
||||
script = json.dumps(narration_dict, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"大模型处理过程中发生错误\n{traceback.format_exc()}")
|
||||
raise Exception(f"分析失败: {str(e)}")
|
||||
|
||||
if script is None:
|
||||
st.error("生成脚本失败,请检查日志")
|
||||
st.stop()
|
||||
logger.info(f"纪录片解说脚本生成完成")
|
||||
logger.info(f"纪录片解说脚本生成完成,共 {len(narration_dict)} 个片段")
|
||||
if isinstance(script, list):
|
||||
st.session_state['video_clip_json'] = script
|
||||
st.session_state["video_clip_json"] = script
|
||||
elif isinstance(script, str):
|
||||
st.session_state['video_clip_json'] = json.loads(script)
|
||||
st.session_state["video_clip_json"] = json.loads(script)
|
||||
update_progress(100, "脚本生成完成")
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user