mirror of
https://github.com/linyqh/NarratoAI.git
synced 2025-12-10 18:02:51 +00:00
324 lines
12 KiB
Python
324 lines
12 KiB
Python
import os
|
||
import json
|
||
import time
|
||
import asyncio
|
||
import requests
|
||
from app.utils import video_processor
|
||
from loguru import logger
|
||
from typing import List, Dict, Any, Callable
|
||
|
||
from app.utils import utils, gemini_analyzer, video_processor
|
||
from app.utils.script_generator import ScriptProcessor
|
||
from app.config import config
|
||
|
||
|
||
class ScriptGenerator:
|
||
def __init__(self):
|
||
self.temp_dir = utils.temp_dir()
|
||
self.keyframes_dir = os.path.join(self.temp_dir, "keyframes")
|
||
|
||
async def generate_script(
|
||
self,
|
||
video_path: str,
|
||
video_theme: str = "",
|
||
custom_prompt: str = "",
|
||
frame_interval_input: int = 5,
|
||
skip_seconds: int = 0,
|
||
threshold: int = 30,
|
||
vision_batch_size: int = 5,
|
||
vision_llm_provider: str = "gemini",
|
||
progress_callback: Callable[[float, str], None] = None
|
||
) -> List[Dict[Any, Any]]:
|
||
"""
|
||
生成视频脚本的核心逻辑
|
||
|
||
Args:
|
||
video_path: 视频文件路径
|
||
video_theme: 视频主题
|
||
custom_prompt: 自定义提示词
|
||
skip_seconds: 跳过开始的秒数
|
||
threshold: 差异<E5B7AE><E5BC82><EFBFBD>值
|
||
vision_batch_size: 视觉处理批次大小
|
||
vision_llm_provider: 视觉模型提供商
|
||
progress_callback: 进度回调函数
|
||
|
||
Returns:
|
||
List[Dict]: 生成的视频脚本
|
||
"""
|
||
if progress_callback is None:
|
||
progress_callback = lambda p, m: None
|
||
|
||
try:
|
||
# 提取关键帧
|
||
progress_callback(10, "正在提取关键帧...")
|
||
keyframe_files = await self._extract_keyframes(
|
||
video_path,
|
||
skip_seconds,
|
||
threshold
|
||
)
|
||
|
||
# 使用统一的 LLM 接口(支持所有 provider)
|
||
script = await self._process_with_llm(
|
||
keyframe_files,
|
||
video_theme,
|
||
custom_prompt,
|
||
vision_batch_size,
|
||
vision_llm_provider,
|
||
progress_callback
|
||
)
|
||
|
||
return json.loads(script) if isinstance(script, str) else script
|
||
|
||
except Exception as e:
|
||
logger.exception("Generate script failed")
|
||
raise
|
||
|
||
async def _extract_keyframes(
|
||
self,
|
||
video_path: str,
|
||
skip_seconds: int,
|
||
threshold: int
|
||
) -> List[str]:
|
||
"""提取视频关键帧"""
|
||
video_hash = utils.md5(video_path + str(os.path.getmtime(video_path)))
|
||
video_keyframes_dir = os.path.join(self.keyframes_dir, video_hash)
|
||
|
||
# 检查缓存
|
||
keyframe_files = []
|
||
if os.path.exists(video_keyframes_dir):
|
||
for filename in sorted(os.listdir(video_keyframes_dir)):
|
||
if filename.endswith('.jpg'):
|
||
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
|
||
|
||
if keyframe_files:
|
||
logger.info(f"Using cached keyframes: {video_keyframes_dir}")
|
||
return keyframe_files
|
||
|
||
# 提取新的关键帧
|
||
os.makedirs(video_keyframes_dir, exist_ok=True)
|
||
|
||
try:
|
||
processor = video_processor.VideoProcessor(video_path)
|
||
processor.process_video_pipeline(
|
||
output_dir=video_keyframes_dir,
|
||
skip_seconds=skip_seconds,
|
||
threshold=threshold
|
||
)
|
||
|
||
for filename in sorted(os.listdir(video_keyframes_dir)):
|
||
if filename.endswith('.jpg'):
|
||
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
|
||
|
||
return keyframe_files
|
||
|
||
except Exception as e:
|
||
if os.path.exists(video_keyframes_dir):
|
||
import shutil
|
||
shutil.rmtree(video_keyframes_dir)
|
||
raise
|
||
|
||
async def _process_with_llm(
|
||
self,
|
||
keyframe_files: List[str],
|
||
video_theme: str,
|
||
custom_prompt: str,
|
||
vision_batch_size: int,
|
||
vision_llm_provider: str,
|
||
progress_callback: Callable[[float, str], None]
|
||
) -> str:
|
||
"""使用统一 LLM 接口处理视频帧"""
|
||
progress_callback(30, "正在初始化视觉分析器...")
|
||
|
||
# 使用新的 LLM 迁移适配器(支持所有 provider)
|
||
from app.services.llm.migration_adapter import create_vision_analyzer
|
||
|
||
# 获取配置
|
||
text_provider = config.app.get('text_llm_provider', 'litellm').lower()
|
||
vision_api_key = config.app.get(f'vision_{vision_llm_provider}_api_key')
|
||
vision_model = config.app.get(f'vision_{vision_llm_provider}_model_name')
|
||
vision_base_url = config.app.get(f'vision_{vision_llm_provider}_base_url')
|
||
|
||
if not vision_api_key or not vision_model:
|
||
raise ValueError(f"未配置 {vision_llm_provider} API Key 或者模型")
|
||
|
||
# 创建统一的视觉分析器
|
||
analyzer = create_vision_analyzer(
|
||
provider=vision_llm_provider,
|
||
api_key=vision_api_key,
|
||
model=vision_model,
|
||
base_url=vision_base_url
|
||
)
|
||
|
||
progress_callback(40, "正在分析关键帧...")
|
||
|
||
# 执行异步分析
|
||
results = await analyzer.analyze_images(
|
||
images=keyframe_files,
|
||
prompt=config.app.get('vision_analysis_prompt'),
|
||
batch_size=vision_batch_size
|
||
)
|
||
|
||
progress_callback(60, "正在整理分析结果...")
|
||
|
||
# 合并所有批次的分析结果
|
||
frame_analysis = ""
|
||
prev_batch_files = None
|
||
|
||
for result in results:
|
||
if 'error' in result:
|
||
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
|
||
continue
|
||
|
||
batch_files = self._get_batch_files(keyframe_files, result, vision_batch_size)
|
||
first_timestamp, last_timestamp, _ = self._get_batch_timestamps(batch_files, prev_batch_files)
|
||
|
||
# 添加带时间戳的分<E79A84><E58886>结果
|
||
frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n"
|
||
frame_analysis += result['response']
|
||
frame_analysis += "\n"
|
||
|
||
prev_batch_files = batch_files
|
||
|
||
if not frame_analysis.strip():
|
||
raise Exception("未能生成有效的帧分析结果")
|
||
|
||
progress_callback(70, "正在生成脚本...")
|
||
|
||
# 构建帧内容列表
|
||
frame_content_list = []
|
||
prev_batch_files = None
|
||
|
||
for result in results:
|
||
if 'error' in result:
|
||
continue
|
||
|
||
batch_files = self._get_batch_files(keyframe_files, result, vision_batch_size)
|
||
_, _, timestamp_range = self._get_batch_timestamps(batch_files, prev_batch_files)
|
||
|
||
frame_content = {
|
||
"timestamp": timestamp_range,
|
||
"picture": result['response'],
|
||
"narration": "",
|
||
"OST": 2
|
||
}
|
||
frame_content_list.append(frame_content)
|
||
prev_batch_files = batch_files
|
||
|
||
if not frame_content_list:
|
||
raise Exception("没有有效的帧内容可以处理")
|
||
|
||
progress_callback(90, "正在生成文案...")
|
||
|
||
# 获取文本生<E69CAC><E7949F>配置
|
||
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
|
||
text_api_key = config.app.get(f'text_{text_provider}_api_key')
|
||
text_model = config.app.get(f'text_{text_provider}_model_name')
|
||
text_base_url = config.app.get(f'text_{text_provider}_base_url')
|
||
|
||
# 根据提供商类型选择合适的处理器
|
||
if text_provider == 'gemini(openai)':
|
||
# 使用OpenAI兼容的Gemini代理
|
||
from app.utils.script_generator import GeminiOpenAIGenerator
|
||
generator = GeminiOpenAIGenerator(
|
||
model_name=text_model,
|
||
api_key=text_api_key,
|
||
prompt=custom_prompt,
|
||
base_url=text_base_url
|
||
)
|
||
processor = ScriptProcessor(
|
||
model_name=text_model,
|
||
api_key=text_api_key,
|
||
base_url=text_base_url,
|
||
prompt=custom_prompt,
|
||
video_theme=video_theme
|
||
)
|
||
processor.generator = generator
|
||
else:
|
||
# 使用标准处理器(包括原生Gemini)
|
||
processor = ScriptProcessor(
|
||
model_name=text_model,
|
||
api_key=text_api_key,
|
||
base_url=text_base_url,
|
||
prompt=custom_prompt,
|
||
video_theme=video_theme
|
||
)
|
||
|
||
return processor.process_frames(frame_content_list)
|
||
|
||
def _get_batch_files(
|
||
self,
|
||
keyframe_files: List[str],
|
||
result: Dict[str, Any],
|
||
batch_size: int
|
||
) -> List[str]:
|
||
"""获取当前批次的图片文件"""
|
||
batch_start = result['batch_index'] * batch_size
|
||
batch_end = min(batch_start + batch_size, len(keyframe_files))
|
||
return keyframe_files[batch_start:batch_end]
|
||
|
||
def _get_batch_timestamps(
|
||
self,
|
||
batch_files: List[str],
|
||
prev_batch_files: List[str] = None
|
||
) -> tuple[str, str, str]:
|
||
"""获取一批文件的时间戳范围,支持毫秒级精度"""
|
||
if not batch_files:
|
||
logger.warning("Empty batch files")
|
||
return "00:00:00,000", "00:00:00,000", "00:00:00,000-00:00:00,000"
|
||
|
||
if len(batch_files) == 1 and prev_batch_files and len(prev_batch_files) > 0:
|
||
first_frame = os.path.basename(prev_batch_files[-1])
|
||
last_frame = os.path.basename(batch_files[0])
|
||
else:
|
||
first_frame = os.path.basename(batch_files[0])
|
||
last_frame = os.path.basename(batch_files[-1])
|
||
|
||
first_time = first_frame.split('_')[2].replace('.jpg', '')
|
||
last_time = last_frame.split('_')[2].replace('.jpg', '')
|
||
|
||
def format_timestamp(time_str: str) -> str:
|
||
"""将时间字符串转换为 HH:MM:SS,mmm 格式"""
|
||
try:
|
||
if len(time_str) < 4:
|
||
logger.warning(f"Invalid timestamp format: {time_str}")
|
||
return "00:00:00,000"
|
||
|
||
# 处理毫秒部分
|
||
if ',' in time_str:
|
||
time_part, ms_part = time_str.split(',')
|
||
ms = int(ms_part)
|
||
else:
|
||
time_part = time_str
|
||
ms = 0
|
||
|
||
# 处理时分秒
|
||
parts = time_part.split(':')
|
||
if len(parts) == 3: # HH:MM:SS
|
||
h, m, s = map(int, parts)
|
||
elif len(parts) == 2: # MM:SS
|
||
h = 0
|
||
m, s = map(int, parts)
|
||
else: # SS
|
||
h = 0
|
||
m = 0
|
||
s = int(parts[0])
|
||
|
||
# 处理进位
|
||
if s >= 60:
|
||
m += s // 60
|
||
s = s % 60
|
||
if m >= 60:
|
||
h += m // 60
|
||
m = m % 60
|
||
|
||
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
|
||
|
||
except Exception as e:
|
||
logger.error(f"时间戳格式转换错误 {time_str}: {str(e)}")
|
||
return "00:00:00,000"
|
||
|
||
first_timestamp = format_timestamp(first_time)
|
||
last_timestamp = format_timestamp(last_time)
|
||
timestamp_range = f"{first_timestamp}-{last_timestamp}"
|
||
|
||
return first_timestamp, last_timestamp, timestamp_range |