完成了gemini 生成视频脚本的逻辑

This commit is contained in:
linyqh 2024-11-09 18:18:57 +08:00
parent ec282adb1b
commit 8267a0b3eb
15 changed files with 1698 additions and 285 deletions

View File

@ -4,6 +4,7 @@ import glob
import random
from typing import List
from typing import Union
import traceback
from loguru import logger
from moviepy.editor import *
@ -145,7 +146,7 @@ def combine_videos(
return combined_video_path
def wrap_text(text, max_width, font="Arial", fontsize=60):
def wrap_text(text, max_width, font, fontsize=60):
# 创建字体对象
font = ImageFont.truetype(font, fontsize)
@ -158,7 +159,7 @@ def wrap_text(text, max_width, font="Arial", fontsize=60):
if width <= max_width:
return text, height
# logger.warning(f"wrapping text, max_width: {max_width}, text_width: {width}, text: {text}")
logger.debug(f"换行文本, 最大宽度: {max_width}, 文本宽度: {width}, 文本: {text}")
processed = True
@ -199,7 +200,7 @@ def wrap_text(text, max_width, font="Arial", fontsize=60):
_wrapped_lines_.append(_txt_)
result = "\n".join(_wrapped_lines_).strip()
height = len(_wrapped_lines_) * height
# logger.warning(f"wrapped text: {result}")
logger.debug(f"换行文本: {result}")
return result, height
@ -233,7 +234,7 @@ def generate_video_v2(
Returns:
"""
total_steps = 4 # 总步<E680BB><E6ADA5><EFBFBD>
total_steps = 4
current_step = 0
def update_progress(step_name):
@ -506,7 +507,7 @@ def combine_clip_videos(combined_video_path: str,
temp_audiofile=os.path.join(output_dir, "temp-audio.m4a")
)
finally:
# 确保资源被正确
# 确保资源被正确<EFBFBD><EFBFBD><EFBFBD>
video_clip.close()
for clip in clips:
clip.close()

View File

@ -0,0 +1,399 @@
import os
import json
import traceback
from loguru import logger
import tiktoken
from typing import List, Dict
from datetime import datetime
from openai import OpenAI
import google.generativeai as genai
class BaseGenerator:
def __init__(self, model_name: str, api_key: str, prompt: str):
self.model_name = model_name
self.api_key = api_key
self.base_prompt = prompt
self.conversation_history = []
self.chunk_overlap = 50
self.last_chunk_ending = ""
def generate_script(self, scene_description: str, word_count: int) -> str:
raise NotImplementedError("Subclasses must implement generate_script method")
class OpenAIGenerator(BaseGenerator):
def __init__(self, model_name: str, api_key: str, prompt: str):
super().__init__(model_name, api_key, prompt)
self.client = OpenAI(api_key=api_key)
self.max_tokens = 7000
try:
self.encoding = tiktoken.encoding_for_model(self.model_name)
except KeyError:
logger.info(f"警告:未找到模型 {self.model_name} 的专用编码器,使用认编码器")
self.encoding = tiktoken.get_encoding("cl100k_base")
def _count_tokens(self, messages: list) -> int:
num_tokens = 0
for message in messages:
num_tokens += 3
for key, value in message.items():
num_tokens += len(self.encoding.encode(str(value)))
if key == "role":
num_tokens += 1
num_tokens += 3
return num_tokens
def _trim_conversation_history(self, system_prompt: str, new_user_prompt: str) -> None:
base_messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": new_user_prompt}
]
base_tokens = self._count_tokens(base_messages)
temp_history = []
current_tokens = base_tokens
for message in reversed(self.conversation_history):
message_tokens = self._count_tokens([message])
if current_tokens + message_tokens > self.max_tokens:
break
temp_history.insert(0, message)
current_tokens += message_tokens
self.conversation_history = temp_history
def generate_script(self, scene_description: str, word_count: int) -> str:
max_attempts = 3
tolerance = 5
for attempt in range(max_attempts):
system_prompt, user_prompt = self._create_prompt(scene_description, word_count)
self._trim_conversation_history(system_prompt, user_prompt)
messages = [
{"role": "system", "content": system_prompt},
*self.conversation_history,
{"role": "user", "content": user_prompt}
]
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
temperature=0.7,
max_tokens=500,
top_p=0.9,
frequency_penalty=0.3,
presence_penalty=0.5
)
generated_script = response.choices[0].message.content.strip().strip('"').strip("'").replace('\"',
'').replace(
'\n', '')
current_length = len(generated_script)
if abs(current_length - word_count) <= tolerance:
self.conversation_history.append({"role": "user", "content": user_prompt})
self.conversation_history.append({"role": "assistant", "content": generated_script})
self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len(
generated_script) > self.chunk_overlap else generated_script
return generated_script
return generated_script
def _create_prompt(self, scene_description: str, word_count: int) -> tuple:
system_prompt = self.base_prompt.format(word_count=word_count)
user_prompt = f"""上一段文案的结尾:{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"}
当前画面描述{scene_description}
请确保新生成的文案与上文自然衔接保持叙事的连贯性和趣味性
严格字数要求{word_count}允许误差±5"""
return system_prompt, user_prompt
class GeminiGenerator(BaseGenerator):
def __init__(self, model_name: str, api_key: str, prompt: str):
super().__init__(model_name, api_key, prompt)
genai.configure(api_key=api_key)
self.model = genai.GenerativeModel(model_name)
def generate_script(self, scene_description: str, word_count: int) -> str:
max_attempts = 3
tolerance = 5
for attempt in range(max_attempts):
prompt = f"""{self.base_prompt}
上一段文案的结尾{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"}
当前画面描述{scene_description}
请确保新生成的文案与上文自然衔接保持叙事的连贯性和趣味性
严格字数要求{word_count}允许误差±5"""
response = self.model.generate_content(prompt)
generated_script = response.text.strip().strip('"').strip("'").replace('\"', '').replace('\n', '')
current_length = len(generated_script)
if abs(current_length - word_count) <= tolerance:
self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len(
generated_script) > self.chunk_overlap else generated_script
return generated_script
return generated_script
class QwenGenerator(BaseGenerator):
def __init__(self, model_name: str, api_key: str, prompt: str):
super().__init__(model_name, api_key, prompt)
self.client = OpenAI(
api_key=api_key,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
)
def generate_script(self, scene_description: str, word_count: int) -> str:
max_attempts = 3
tolerance = 5
for attempt in range(max_attempts):
prompt = f"""{self.base_prompt}
上一段文案的结尾{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"}
当前画面描述{scene_description}
请确保新生成的文案与上文自然衔接保持叙事的连贯性和趣味性
严格字数要求{word_count}允许误差±5"""
messages = [
{"role": "system", "content": self.base_prompt},
{"role": "user", "content": prompt}
]
response = self.client.chat.completions.create(
model=self.model_name, # 如 "qwen-plus"
messages=messages,
temperature=0.7,
max_tokens=500,
top_p=0.9,
frequency_penalty=0.3,
presence_penalty=0.5
)
generated_script = response.choices[0].message.content.strip().strip('"').strip("'").replace('\"',
'').replace(
'\n', '')
current_length = len(generated_script)
if abs(current_length - word_count) <= tolerance:
self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len(
generated_script) > self.chunk_overlap else generated_script
return generated_script
return generated_script
class MoonshotGenerator(BaseGenerator):
def __init__(self, model_name: str, api_key: str, prompt: str):
super().__init__(model_name, api_key, prompt)
self.client = OpenAI(
api_key=api_key,
base_url="https://api.moonshot.cn/v1"
)
def generate_script(self, scene_description: str, word_count: int) -> str:
max_attempts = 3
tolerance = 5
for attempt in range(max_attempts):
prompt = f"""{self.base_prompt}
上一段文案的结尾{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"}
当前画面描述{scene_description}
请确保新生成的文案与上文自然衔接保持叙事的连贯性和趣味性
严格字数要求{word_count}允许误差±5"""
messages = [
{"role": "system", "content": self.base_prompt},
{"role": "user", "content": prompt}
]
response = self.client.chat.completions.create(
model=self.model_name, # 如 "moonshot-v1-8k"
messages=messages,
temperature=0.7,
max_tokens=500,
top_p=0.9,
frequency_penalty=0.3,
presence_penalty=0.5
)
generated_script = response.choices[0].message.content.strip().strip('"').strip("'").replace('\"',
'').replace(
'\n', '')
current_length = len(generated_script)
if abs(current_length - word_count) <= tolerance:
self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len(
generated_script) > self.chunk_overlap else generated_script
return generated_script
return generated_script
class ScriptProcessor:
def __init__(self, model_name, api_key=None, prompt=None):
self.model_name = model_name
# 根据不同模型选择对应的环境变量
default_api_key = {
'gemini': 'GOOGLE_API_KEY',
'gpt': 'OPENAI_API_KEY',
'qwen': 'DASHSCOPE_API_KEY',
'moonshot': 'MOONSHOT_API_KEY'
}
api_key_env = next((v for k, v in default_api_key.items() if k in model_name.lower()), 'OPENAI_API_KEY')
self.api_key = api_key or os.getenv(api_key_env)
self.prompt = prompt or self._get_default_prompt()
# 根据模型名称选择对应的生成器
if 'gemini' in model_name.lower():
self.generator = GeminiGenerator(model_name, self.api_key, self.prompt)
elif 'qwen' in model_name.lower():
self.generator = QwenGenerator(model_name, self.api_key, self.prompt)
elif 'moonshot' in model_name.lower():
self.generator = MoonshotGenerator(model_name, self.api_key, self.prompt)
else:
self.generator = OpenAIGenerator(model_name, self.api_key, self.prompt)
def _get_default_prompt(self) -> str:
return """你是一位极具幽默感的短视频脚本创作大师,擅长用"温和的违反"制造笑点,让野外建造视频既有趣又富有传播力。你的任务是将视频画面描述转化为能在社交平台疯狂传播的爆款口播文案。
目标受众热爱野外生活追求独特体验的18-35岁年轻人
文案风格基于HKRR理论 + 段子手精神
主题野外建造
创作核心理念
1. 敢于用"温和的违反"制造笑点但不能过于冒犯
2. 巧妙运用中国式幽默让观众会心一笑
3. 保持轻松愉快的叙事基调
爆款内容四要素
快乐元素 Happy
1. 用调侃的语气描述建造过程中的"笨手笨脚"
2. 巧妙植入网络流行梗增加内容的传播性
3. 适时自嘲展现真实且有趣的一面
知识价值 Knowledge
1. 用段子手的方式解释专业知识比如"这根木头不是一般的木头,它比我前任还难搞..."
2. 把复杂的建造技巧转化为生动有趣的比喻
3. 在幽默中传递实用的野外生存技能
情感共鸣 Resonance
1. 描述"真实但夸张"的建造困境
2. 把对自然的感悟融入俏皮话中
3. 用接地气的表达方式拉近与观众距离
节奏控制 Rhythm
1. 严格控制文案字数在{word_count}字左右允许误差不超过5字
2. 像讲段子一样注意铺垫和包袱的节奏
3. 确保每段都有笑点但不强求
4. 段落结尾干净利落不拖泥带水
连贯性要求
1. 新生成的内容必须自然衔接上一段文案的结尾
2. 使用恰当的连接词和过渡语确保叙事流畅
3. 保持人物视角和语气的一致性
4. 避免重复上一段已经提到的信息
5. 确保情节和建造过程的逻辑连续性
字数控制要求
1. 严格控制文案字数在{word_count}字左右允许误差不超过5字
2. 如果内容过长优先精简修饰性词语
3. 如果内容过短可以适当增加细节描写
4. 保持文案结构完整不因字数限制而牺牲内容质量
5. 确保每个笑点和包袱都得到完整表达
我会按顺序提供多段视频画面描述请创作既搞笑又能火爆全网的口播文案
记住要敢于用"温和的违反"制造笑点但要把握好尺度让观众在轻松愉快中感受野外建造的乐趣"""
def calculate_duration_and_word_count(self, time_range: str) -> int:
try:
start_str, end_str = time_range.split('-')
def time_to_seconds(time_str):
minutes, seconds = map(int, time_str.split(':'))
return minutes * 60 + seconds
start_seconds = time_to_seconds(start_str)
end_seconds = time_to_seconds(end_str)
duration = end_seconds - start_seconds
word_count = int(duration / 0.2)
return word_count
except Exception as e:
logger.info(f"时间格式转换错误: {traceback.format_exc()}")
return 100
def process_frames(self, frame_content_list: List[Dict]) -> List[Dict]:
for frame_content in frame_content_list:
word_count = self.calculate_duration_and_word_count(frame_content["timestamp"])
script = self.generator.generate_script(frame_content["picture"], word_count)
frame_content["narration"] = script
frame_content["OST"] = 2
logger.info(f"时间范围: {frame_content['timestamp']}, 建议字数: {word_count}")
logger.info(script)
self._save_results(frame_content_list)
return frame_content_list
def _save_results(self, frame_content_list: List[Dict]):
"""保存处理结果,并添加新的时间戳"""
try:
# 计算新的时间戳
current_time = 0 # 当前时间点(秒)
for frame in frame_content_list:
# 获取原始时间戳的持续时间
start_str, end_str = frame['timestamp'].split('-')
def time_to_seconds(time_str):
minutes, seconds = map(int, time_str.split(':'))
return minutes * 60 + seconds
# 计算当前片段的持续时间
start_seconds = time_to_seconds(start_str)
end_seconds = time_to_seconds(end_str)
duration = end_seconds - start_seconds
# 转换秒数为 MM:SS 格式
def seconds_to_time(seconds):
minutes = seconds // 60
remaining_seconds = seconds % 60
return f"{minutes:02d}:{remaining_seconds:02d}"
# 设置新的时间戳
new_start = seconds_to_time(current_time)
new_end = seconds_to_time(current_time + duration)
frame['new_timestamp'] = f"{new_start}-{new_end}"
# 更新当前时间点
current_time += duration
# 保存结果
file_name = f"storage/json/step2_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, 'w', encoding='utf-8') as file:
json.dump(frame_content_list, file, ensure_ascii=False, indent=4)
logger.info(f"保存脚本成功,总时长: {seconds_to_time(current_time)}")
except Exception as e:
logger.error(f"保存结果时发生错误: {str(e)}\n{traceback.format_exc()}")
raise

View File

@ -56,7 +56,7 @@ def to_json(obj):
# 使用serialize函数处理输入对象
serialized_obj = serialize(obj)
# 序列化处理后的对象为JSON符串
# 序列化处理后的对象为JSON<EFBFBD><EFBFBD><EFBFBD>符串
return json.dumps(serialized_obj, ensure_ascii=False, indent=4)
except Exception as e:
return None
@ -100,7 +100,7 @@ def task_dir(sub_dir: str = ""):
def font_dir(sub_dir: str = ""):
d = resource_dir(f"fonts")
d = resource_dir("fonts")
if sub_dir:
d = os.path.join(d, sub_dir)
if not os.path.exists(d):
@ -109,7 +109,7 @@ def font_dir(sub_dir: str = ""):
def song_dir(sub_dir: str = ""):
d = resource_dir(f"songs")
d = resource_dir("songs")
if sub_dir:
d = os.path.join(d, sub_dir)
if not os.path.exists(d):
@ -425,3 +425,102 @@ def cut_video(params, progress_callback=None):
except Exception as e:
logger.error(f"视频裁剪过程中发生错误: \n{traceback.format_exc()}")
raise
def temp_dir(sub_dir: str = ""):
"""
获取临时文件目录
Args:
sub_dir: 子目录名
Returns:
str: 临时文件目录路径
"""
d = os.path.join(storage_dir(), "temp")
if sub_dir:
d = os.path.join(d, sub_dir)
if not os.path.exists(d):
os.makedirs(d)
return d
def clear_keyframes_cache(video_path: str = None):
"""
清理关键帧缓存
Args:
video_path: 视频文件路径如果指定则只清理该视频的缓存
"""
try:
keyframes_dir = os.path.join(temp_dir(), "keyframes")
if not os.path.exists(keyframes_dir):
return
if video_path:
# <20><><EFBFBD>理指定视频的缓存
video_hash = md5(video_path + str(os.path.getmtime(video_path)))
video_keyframes_dir = os.path.join(keyframes_dir, video_hash)
if os.path.exists(video_keyframes_dir):
import shutil
shutil.rmtree(video_keyframes_dir)
logger.info(f"已清理视频关键帧缓存: {video_path}")
else:
# 清理所有缓存
import shutil
shutil.rmtree(keyframes_dir)
logger.info("已清理所有关键帧缓存")
except Exception as e:
logger.error(f"清理关键帧缓存失败: {e}")
def init_resources():
"""初始化资源文件"""
try:
# 创建字体目录
font_dir = os.path.join(root_dir(), "resource", "fonts")
os.makedirs(font_dir, exist_ok=True)
# 检查字体文件
font_files = [
("SourceHanSansCN-Regular.otf", "https://github.com/adobe-fonts/source-han-sans/raw/release/OTF/SimplifiedChinese/SourceHanSansSC-Regular.otf"),
("simhei.ttf", "C:/Windows/Fonts/simhei.ttf"), # Windows 黑体
("simkai.ttf", "C:/Windows/Fonts/simkai.ttf"), # Windows 楷体
("simsun.ttc", "C:/Windows/Fonts/simsun.ttc"), # Windows 宋体
]
# 优先使用系统字体
system_font_found = False
for font_name, source in font_files:
if not source.startswith("http") and os.path.exists(source):
target_path = os.path.join(font_dir, font_name)
if not os.path.exists(target_path):
import shutil
shutil.copy2(source, target_path)
logger.info(f"已复制系统字体: {font_name}")
system_font_found = True
break
# 如果没有找到系统字体,则下载思源黑体
if not system_font_found:
source_han_path = os.path.join(font_dir, "SourceHanSansCN-Regular.otf")
if not os.path.exists(source_han_path):
download_font(font_files[0][1], source_han_path)
except Exception as e:
logger.error(f"初始化资源文件失败: {e}")
def download_font(url: str, font_path: str):
"""下载字体文件"""
try:
logger.info(f"正在下载字体文件: {url}")
import requests
response = requests.get(url)
response.raise_for_status()
with open(font_path, 'wb') as f:
f.write(response.content)
logger.info(f"字体文件下载成功: {font_path}")
except Exception as e:
logger.error(f"下载字体文件失败: {e}")
raise

View File

@ -0,0 +1,209 @@
import cv2
import numpy as np
from sklearn.cluster import KMeans
import os
import re
from typing import List, Tuple, Generator
class VideoProcessor:
def __init__(self, video_path: str):
"""
初始化视频处理器
Args:
video_path: 视频文件路径
"""
if not os.path.exists(video_path):
raise FileNotFoundError(f"视频文件不存在: {video_path}")
self.video_path = video_path
self.cap = cv2.VideoCapture(video_path)
if not self.cap.isOpened():
raise RuntimeError(f"无法打开视频文件: {video_path}")
self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
self.fps = int(self.cap.get(cv2.CAP_PROP_FPS))
def __del__(self):
"""析构函数,确保视频资源被释放"""
if hasattr(self, 'cap'):
self.cap.release()
def preprocess_video(self) -> Generator[np.ndarray, None, None]:
"""
使用生成器方式读取视频帧
Yields:
np.ndarray: 视频帧
"""
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # 重置到视频开始
while self.cap.isOpened():
ret, frame = self.cap.read()
if not ret:
break
yield frame
def detect_shot_boundaries(self, frames: List[np.ndarray], threshold: int = 30) -> List[int]:
"""
使用帧差法检测镜头边界
Args:
frames: 视频帧列表
threshold: 差异阈值
Returns:
List[int]: 镜头边界帧的索引列表
"""
shot_boundaries = []
for i in range(1, len(frames)):
prev_frame = cv2.cvtColor(frames[i - 1], cv2.COLOR_BGR2GRAY)
curr_frame = cv2.cvtColor(frames[i], cv2.COLOR_BGR2GRAY)
diff = np.mean(np.abs(curr_frame.astype(int) - prev_frame.astype(int)))
if diff > threshold:
shot_boundaries.append(i)
return shot_boundaries
def extract_keyframes(self, frames: List[np.ndarray], shot_boundaries: List[int]) -> Tuple[List[np.ndarray], List[int]]:
"""
从每个镜头中提取关键帧
Args:
frames: 视频帧列表
shot_boundaries: 镜头边界列表
Returns:
Tuple[List[np.ndarray], List[int]]: <EFBFBD><EFBFBD><EFBFBD>帧列表和对应的帧索引
"""
keyframes = []
keyframe_indices = []
for i in range(len(shot_boundaries)):
start = shot_boundaries[i - 1] if i > 0 else 0
end = shot_boundaries[i]
shot_frames = frames[start:end]
frame_features = np.array([cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).flatten()
for frame in shot_frames])
kmeans = KMeans(n_clusters=1, random_state=0).fit(frame_features)
center_idx = np.argmin(np.sum((frame_features - kmeans.cluster_centers_[0]) ** 2, axis=1))
keyframes.append(shot_frames[center_idx])
keyframe_indices.append(start + center_idx)
return keyframes, keyframe_indices
def save_keyframes(self, keyframes: List[np.ndarray], keyframe_indices: List[int],
output_dir: str) -> None:
"""
保存关键帧到指定目录文件名格式为keyframe_帧序号_时间戳.jpg
Args:
keyframes: 关键帧列表
keyframe_indices: 关键帧索引列表
output_dir: 输出目录
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
for keyframe, frame_idx in zip(keyframes, keyframe_indices):
# 计算时间戳(秒)
timestamp = frame_idx / self.fps
# 将时间戳转换为 HH:MM:SS 格式
hours = int(timestamp // 3600)
minutes = int((timestamp % 3600) // 60)
seconds = int(timestamp % 60)
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}"
# 构建新的文件名格式keyframe_帧序号_时间戳.jpg
output_path = os.path.join(output_dir,
f'keyframe_{frame_idx:06d}_{time_str}.jpg')
cv2.imwrite(output_path, keyframe)
print(f"已保存 {len(keyframes)} 个关键帧到 {output_dir}")
def extract_frames_by_numbers(self, frame_numbers: List[int], output_folder: str) -> None:
"""
根据指定的帧号提取帧
Args:
frame_numbers: 要提取的帧号列表
output_folder: 输出文件夹路径
"""
if not frame_numbers:
raise ValueError("未提供帧号列表")
if any(fn >= self.total_frames or fn < 0 for fn in frame_numbers):
raise ValueError("存在无效的帧号")
if not os.path.exists(output_folder):
os.makedirs(output_folder)
for frame_number in frame_numbers:
self.cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
ret, frame = self.cap.read()
if ret:
# 计算时间戳
timestamp = frame_number / self.fps
hours = int(timestamp // 3600)
minutes = int((timestamp % 3600) // 60)
seconds = int(timestamp % 60)
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}"
# 使用与关键帧相同的命名格式
output_path = os.path.join(output_folder,
f"extracted_frame_{frame_number:06d}_{time_str}.jpg")
cv2.imwrite(output_path, frame)
print(f"已提取并保存帧 {frame_number}")
else:
print(f"无法读取帧 {frame_number}")
@staticmethod
def extract_numbers_from_folder(folder_path: str) -> List[int]:
"""
从文件夹中提取帧号
Args:
folder_path: 关键帧文件夹路径
Returns:
List[int]: 排序后的帧号列表
"""
files = [f for f in os.listdir(folder_path) if f.endswith('.jpg')]
# 更新正则表达式以匹配新的文件名格式keyframe_000123_010534.jpg
pattern = re.compile(r'keyframe_(\d+)_\d+\.jpg$')
numbers = []
for f in files:
match = pattern.search(f)
if match:
numbers.append(int(match.group(1)))
return sorted(numbers)
def process_video(self, output_dir: str, skip_seconds: float = 0) -> None:
"""
处理视频并提取关键帧
Args:
output_dir: 输出目录
skip_seconds: 跳过视频开头的秒数
"""
# 计算要跳过的帧数
skip_frames = int(skip_seconds * self.fps)
# 获取所有帧
frames = list(self.preprocess_video())
# 跳过指定秒数的帧
frames = frames[skip_frames:]
if not frames:
raise ValueError(f"跳过 {skip_seconds} 秒后没有剩余帧可以处理")
shot_boundaries = self.detect_shot_boundaries(frames)
keyframes, keyframe_indices = self.extract_keyframes(frames, shot_boundaries)
# 调整关键帧索引,加上跳过的帧数
adjusted_indices = [idx + skip_frames for idx in keyframe_indices]
self.save_keyframes(keyframes, adjusted_indices, output_dir)

View File

@ -0,0 +1,183 @@
import json
from typing import List, Union, Dict
import os
from pathlib import Path
from loguru import logger
from tqdm import tqdm
import asyncio
from tenacity import retry, stop_after_attempt, RetryError, retry_if_exception_type, wait_exponential
from google.api_core import exceptions
import google.generativeai as genai
import PIL.Image
import traceback
class VisionAnalyzer:
"""视觉分析器类"""
def __init__(self, model_name: str = "gemini-1.5-flash", api_key: str = None):
"""初始化视觉分析器"""
if not api_key:
raise ValueError("必须提供API密钥")
self.model_name = model_name
self.api_key = api_key
# 初始化配置
self._configure_client()
def _configure_client(self):
"""配置API客户端"""
genai.configure(api_key=self.api_key)
self.model = genai.GenerativeModel(self.model_name)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(exceptions.ResourceExhausted)
)
async def _generate_content_with_retry(self, prompt, batch):
"""使用重试机制的内部方法来调用 generate_content_async"""
try:
return await self.model.generate_content_async([prompt, *batch])
except exceptions.ResourceExhausted as e:
print(f"API配额限制: {str(e)}")
raise RetryError("API调用失败")
async def analyze_images(self,
images: Union[List[str], List[PIL.Image.Image]],
prompt: str,
batch_size: int = 5) -> List[Dict]:
"""批量分析多张图片"""
try:
# 加载图片
if isinstance(images[0], str):
logger.info("正在加载图片...")
images = self.load_images(images)
# 验证图片列表
if not images:
raise ValueError("图片列表为空")
# 验证每个图片对象
valid_images = []
for i, img in enumerate(images):
if not isinstance(img, PIL.Image.Image):
logger.error(f"无效的图片对象,索引 {i}: {type(img)}")
continue
valid_images.append(img)
if not valid_images:
raise ValueError("没有有效的图片对象")
images = valid_images
results = []
total_batches = (len(images) + batch_size - 1) // batch_size
with tqdm(total=total_batches, desc="分析进度") as pbar:
for i in range(0, len(images), batch_size):
batch = images[i:i + batch_size]
retry_count = 0
while retry_count < 3:
try:
# 在每个批次处理前添加小延迟
if i > 0:
await asyncio.sleep(2)
# 确保每个批次的图片都是有效的
valid_batch = [img for img in batch if isinstance(img, PIL.Image.Image)]
if not valid_batch:
raise ValueError(f"批次 {i // batch_size} 中没有有效的图片")
response = await self._generate_content_with_retry(prompt, valid_batch)
results.append({
'batch_index': i // batch_size,
'images_processed': len(valid_batch),
'response': response.text,
'model_used': self.model_name
})
break
except Exception as e:
retry_count += 1
error_msg = f"批次 {i // batch_size} 处理出错: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
if retry_count >= 3:
results.append({
'batch_index': i // batch_size,
'images_processed': len(batch),
'error': error_msg,
'model_used': self.model_name
})
else:
logger.info(f"批次 {i // batch_size} 处理失败等待60秒后重试...")
await asyncio.sleep(60)
pbar.update(1)
return results
except Exception as e:
error_msg = f"图片分析过程中发生错误: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
raise Exception(error_msg)
def save_results_to_txt(self, results: List[Dict], output_dir: str):
"""将分析结果保存到txt文件"""
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
for result in results:
if not result.get('image_paths'):
continue
response_text = result['response']
image_paths = result['image_paths']
img_name_start = Path(image_paths[0]).stem.split('_')[-1]
img_name_end = Path(image_paths[-1]).stem.split('_')[-1]
txt_path = os.path.join(output_dir, f"frame_{img_name_start}_{img_name_end}.txt")
# 保存结果到txt文件
with open(txt_path, 'w', encoding='utf-8') as f:
f.write(response_text.strip())
print(f"已保存分析结果到: {txt_path}")
def load_images(self, image_paths: List[str]) -> List[PIL.Image.Image]:
"""
加载多张图片
Args:
image_paths: 图片路径列表
Returns:
加载后的PIL Image对象列表
"""
images = []
failed_images = []
for img_path in image_paths:
try:
if not os.path.exists(img_path):
logger.error(f"图片文件不存在: {img_path}")
failed_images.append(img_path)
continue
img = PIL.Image.open(img_path)
# 确保图片被完全加载
img.load()
# 转换为RGB模式
if img.mode != 'RGB':
img = img.convert('RGB')
images.append(img)
except Exception as e:
logger.error(f"无法加载图片 {img_path}: {str(e)}")
failed_images.append(img_path)
if failed_images:
logger.warning(f"以下图片加载失败:\n{json.dumps(failed_images, indent=2, ensure_ascii=False)}")
if not images:
raise ValueError("没有成功加载任何图片")
return images

View File

@ -1,5 +1,5 @@
[app]
project_version="0.2.2"
project_version="0.3.0"
# 支持视频理解的大模型提供商
# gemini
# qwen2-vl (待增加)

View File

@ -1,16 +1,14 @@
requests~=2.31.0
moviepy~=2.0.0.dev2
openai~=1.13.3
faster-whisper~=1.0.1
edge_tts~=6.1.15
uvicorn~=0.27.1
fastapi~=0.115.4
tomli~=2.0.1
streamlit~=1.39.0
streamlit~=1.40.0
loguru~=0.7.2
aiohttp~=3.10.10
urllib3~=2.2.1
pillow~=10.4.0
pydantic~=2.6.3
g4f~=0.3.0.4
dashscope~=1.15.0
@ -25,3 +23,12 @@ git-changelog~=2.5.2
watchdog==5.0.2
pydub==0.25.1
psutil>=5.9.0
opencv-python~=4.10.0.84
scikit-learn~=1.5.2
google-generativeai~=0.8.3
Pillow>=11.0.0
python-dotenv~=1.0.1
openai~=1.53.0
tqdm>=4.66.6
tenacity>=9.0.0
tiktoken==0.8.0

158
webui.py
View File

@ -4,9 +4,10 @@ import sys
from uuid import uuid4
from app.config import config
from webui.components import basic_settings, video_settings, audio_settings, subtitle_settings, script_settings, review_settings
from webui.utils import cache, file_utils, performance
from webui.utils import cache, file_utils
from app.utils import utils
from app.models.schema import VideoClipParams, VideoAspect
from webui.utils.performance import PerformanceMonitor
# 初始化配置 - 必须是第一个 Streamlit 命令
st.set_page_config(
@ -34,6 +35,17 @@ def init_log():
_lvl = "DEBUG"
def format_record(record):
# 增加更多需要过滤的警告消息
ignore_messages = [
"Examining the path of torch.classes raised",
"torch.cuda.is_available()",
"CUDA initialization"
]
for msg in ignore_messages:
if msg in record["message"]:
return ""
file_path = record["file"].path
relative_path = os.path.relpath(file_path, config.root_dir)
record["file"].path = f"./{relative_path}"
@ -45,11 +57,21 @@ def init_log():
'- <level>{message}</>' + "\n"
return _format
# 优化日志过滤器
def log_filter(record):
ignore_messages = [
"Examining the path of torch.classes raised",
"torch.cuda.is_available()",
"CUDA initialization"
]
return not any(msg in record["message"] for msg in ignore_messages)
logger.add(
sys.stdout,
level=_lvl,
format=format_record,
colorize=True,
filter=log_filter
)
def init_global_state():
@ -73,77 +95,83 @@ def tr(key):
def render_generate_button():
"""渲染生成按钮和处理逻辑"""
if st.button(tr("Generate Video"), use_container_width=True, type="primary"):
from app.services import task as tm
# 重置日志容器和记录
log_container = st.empty()
log_records = []
def log_received(msg):
with log_container:
log_records.append(msg)
st.code("\n".join(log_records))
from loguru import logger
logger.add(log_received)
config.save_config()
task_id = st.session_state.get('task_id')
if not task_id:
st.error(tr("请先裁剪视频"))
return
if not st.session_state.get('video_clip_json_path'):
st.error(tr("脚本文件不能为空"))
return
if not st.session_state.get('video_origin_path'):
st.error(tr("视频文件不能为空"))
return
st.toast(tr("生成视频"))
logger.info(tr("开始生成视频"))
# 获取所有参数
script_params = script_settings.get_script_params()
video_params = video_settings.get_video_params()
audio_params = audio_settings.get_audio_params()
subtitle_params = subtitle_settings.get_subtitle_params()
# 合并所有参数
all_params = {
**script_params,
**video_params,
**audio_params,
**subtitle_params
}
# 创建参数对象
params = VideoClipParams(**all_params)
result = tm.start_subclip(
task_id=task_id,
params=params,
subclip_path_videos=st.session_state['subclip_videos']
)
video_files = result.get("videos", [])
st.success(tr("视频生成完成"))
try:
if video_files:
player_cols = st.columns(len(video_files) * 2 + 1)
for i, url in enumerate(video_files):
player_cols[i * 2 + 1].video(url)
except Exception as e:
logger.error(f"播放视频失败: {e}")
from app.services import task as tm
import torch
# 重置日志容器和记录
log_container = st.empty()
log_records = []
file_utils.open_task_folder(config.root_dir, task_id)
logger.info(tr("视频生成完成"))
def log_received(msg):
with log_container:
log_records.append(msg)
st.code("\n".join(log_records))
from loguru import logger
logger.add(log_received)
config.save_config()
task_id = st.session_state.get('task_id')
if not task_id:
st.error(tr("请先裁剪视频"))
return
if not st.session_state.get('video_clip_json_path'):
st.error(tr("脚本文件不能为空"))
return
if not st.session_state.get('video_origin_path'):
st.error(tr("视频文件不能为空"))
return
st.toast(tr("生成视频"))
logger.info(tr("开始生成视频"))
# 获取所有参数
script_params = script_settings.get_script_params()
video_params = video_settings.get_video_params()
audio_params = audio_settings.get_audio_params()
subtitle_params = subtitle_settings.get_subtitle_params()
# 合并所有参数
all_params = {
**script_params,
**video_params,
**audio_params,
**subtitle_params
}
# 创建参数对象
params = VideoClipParams(**all_params)
result = tm.start_subclip(
task_id=task_id,
params=params,
subclip_path_videos=st.session_state['subclip_videos']
)
video_files = result.get("videos", [])
st.success(tr("视生成完成"))
try:
if video_files:
player_cols = st.columns(len(video_files) * 2 + 1)
for i, url in enumerate(video_files):
player_cols[i * 2 + 1].video(url)
except Exception as e:
logger.error(f"播放视频失败: {e}")
file_utils.open_task_folder(config.root_dir, task_id)
logger.info(tr("视频生成完成"))
finally:
PerformanceMonitor.cleanup_resources()
def main():
"""主函数"""
init_log()
init_global_state()
utils.init_resources()
st.title(f"NarratoAI :sunglasses:📽️")
st.write(tr("Get Help"))

View File

@ -3,6 +3,7 @@ import os
from app.config import config
from app.utils import utils
def render_basic_settings(tr):
"""渲染基础设置面板"""
with st.expander(tr("Basic Settings"), expanded=False):
@ -10,23 +11,24 @@ def render_basic_settings(tr):
left_config_panel = config_panels[0]
middle_config_panel = config_panels[1]
right_config_panel = config_panels[2]
with left_config_panel:
render_language_settings(tr)
render_proxy_settings(tr)
with middle_config_panel:
render_video_llm_settings(tr)
render_vision_llm_settings(tr) # 视频分析模型设置
with right_config_panel:
render_llm_settings(tr)
render_text_llm_settings(tr) # 文案生成模型设置
def render_language_settings(tr):
"""渲染语言设置"""
system_locale = utils.get_system_locale()
i18n_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "i18n")
locales = utils.load_locales(i18n_dir)
display_languages = []
selected_index = 0
for i, code in enumerate(locales.keys()):
@ -35,24 +37,25 @@ def render_language_settings(tr):
selected_index = i
selected_language = st.selectbox(
tr("Language"),
tr("Language"),
options=display_languages,
index=selected_index
)
if selected_language:
code = selected_language.split(" - ")[0].strip()
st.session_state['ui_language'] = code
config.ui['language'] = code
def render_proxy_settings(tr):
"""渲染代理设置"""
proxy_url_http = config.proxy.get("http", "") or os.getenv("VPN_PROXY_URL", "")
proxy_url_https = config.proxy.get("https", "") or os.getenv("VPN_PROXY_URL", "")
HTTP_PROXY = st.text_input(tr("HTTP_PROXY"), value=proxy_url_http)
HTTPS_PROXY = st.text_input(tr("HTTPs_PROXY"), value=proxy_url_https)
if HTTP_PROXY:
config.proxy["http"] = HTTP_PROXY
os.environ["HTTP_PROXY"] = HTTP_PROXY
@ -60,83 +63,172 @@ def render_proxy_settings(tr):
config.proxy["https"] = HTTPS_PROXY
os.environ["HTTPS_PROXY"] = HTTPS_PROXY
def render_video_llm_settings(tr):
"""渲染视频LLM设置"""
video_llm_providers = ['Gemini', 'NarratoAPI']
saved_llm_provider = config.app.get("video_llm_provider", "OpenAI").lower()
saved_llm_provider_index = 0
for i, provider in enumerate(video_llm_providers):
if provider.lower() == saved_llm_provider:
saved_llm_provider_index = i
def render_vision_llm_settings(tr):
"""渲染视频分析模型设置"""
st.subheader(tr("Vision Model Settings"))
# 视频分析模型提供商选择
vision_providers = ['Gemini', 'NarratoAPI']
saved_vision_provider = config.app.get("vision_llm_provider", "Gemini").lower()
saved_provider_index = 0
for i, provider in enumerate(vision_providers):
if provider.lower() == saved_vision_provider:
saved_provider_index = i
break
video_llm_provider = st.selectbox(
tr("Video LLM Provider"),
options=video_llm_providers,
index=saved_llm_provider_index
vision_provider = st.selectbox(
tr("Vision Model Provider"),
options=vision_providers,
index=saved_provider_index
)
video_llm_provider = video_llm_provider.lower()
config.app["video_llm_provider"] = video_llm_provider
vision_provider = vision_provider.lower()
config.app["vision_llm_provider"] = vision_provider
st.session_state['vision_llm_providers'] = vision_provider
# 获取已保存的配置
video_llm_api_key = config.app.get(f"{video_llm_provider}_api_key", "")
video_llm_base_url = config.app.get(f"{video_llm_provider}_base_url", "")
video_llm_model_name = config.app.get(f"{video_llm_provider}_model_name", "")
# 渲染输入框
st_llm_api_key = st.text_input(tr("Video API Key"), value=video_llm_api_key, type="password")
st_llm_base_url = st.text_input(tr("Video Base Url"), value=video_llm_base_url)
st_llm_model_name = st.text_input(tr("Video Model Name"), value=video_llm_model_name)
# 保存配置
if st_llm_api_key:
config.app[f"{video_llm_provider}_api_key"] = st_llm_api_key
if st_llm_base_url:
config.app[f"{video_llm_provider}_base_url"] = st_llm_base_url
if st_llm_model_name:
config.app[f"{video_llm_provider}_model_name"] = st_llm_model_name
# 获取已保存的视觉模型配置
vision_api_key = config.app.get(f"vision_{vision_provider}_api_key", "")
vision_base_url = config.app.get(f"vision_{vision_provider}_base_url", "")
vision_model_name = config.app.get(f"vision_{vision_provider}_model_name", "")
def render_llm_settings(tr):
"""渲染LLM设置"""
llm_providers = ['Gemini', 'OpenAI', 'Moonshot', 'Azure', 'Qwen', 'Ollama', 'G4f', 'OneAPI', "Cloudflare"]
saved_llm_provider = config.app.get("llm_provider", "OpenAI").lower()
saved_llm_provider_index = 0
for i, provider in enumerate(llm_providers):
if provider.lower() == saved_llm_provider:
saved_llm_provider_index = i
# 渲染视觉模型配置输入框
st_vision_api_key = st.text_input(tr("Vision API Key"), value=vision_api_key, type="password")
st_vision_base_url = st.text_input(tr("Vision Base URL"), value=vision_base_url)
st_vision_model_name = st.text_input(tr("Vision Model Name"), value=vision_model_name)
# 保存视觉模型配置
if st_vision_api_key:
config.app[f"vision_{vision_provider}_api_key"] = st_vision_api_key
st.session_state[f"vision_{vision_provider}_api_key"] = st_vision_api_key # 用于script_settings.py
if st_vision_base_url:
config.app[f"vision_{vision_provider}_base_url"] = st_vision_base_url
st.session_state[f"vision_{vision_provider}_base_url"] = st_vision_base_url
if st_vision_model_name:
config.app[f"vision_{vision_provider}_model_name"] = st_vision_model_name
st.session_state[f"vision_{vision_provider}_model_name"] = st_vision_model_name
# NarratoAPI 特殊配置
if vision_provider == 'narratoapi':
st.subheader(tr("Narrato Additional Settings"))
# Narrato API 基础配置
narrato_api_key = st.text_input(
tr("Narrato API Key"),
value=config.app.get("narrato_api_key", ""),
type="password",
help="用于访问 Narrato API 的密钥"
)
if narrato_api_key:
config.app["narrato_api_key"] = narrato_api_key
st.session_state['narrato_api_key'] = narrato_api_key
narrato_api_url = st.text_input(
tr("Narrato API URL"),
value=config.app.get("narrato_api_url", "http://127.0.0.1:8000/api/v1/video/analyze")
)
if narrato_api_url:
config.app["narrato_api_url"] = narrato_api_url
st.session_state['narrato_api_url'] = narrato_api_url
# 视频分析模型配置
st.markdown("##### " + tr("Vision Model Settings"))
narrato_vision_model = st.text_input(
tr("Vision Model Name"),
value=config.app.get("narrato_vision_model", "gemini-1.5-flash")
)
narrato_vision_key = st.text_input(
tr("Vision Model API Key"),
value=config.app.get("narrato_vision_key", ""),
type="password",
help="用于视频分析的模型 API Key"
)
if narrato_vision_model:
config.app["narrato_vision_model"] = narrato_vision_model
st.session_state['narrato_vision_model'] = narrato_vision_model
if narrato_vision_key:
config.app["narrato_vision_key"] = narrato_vision_key
st.session_state['narrato_vision_key'] = narrato_vision_key
# 文案生成模型配置
st.markdown("##### " + tr("Text Generation Model Settings"))
narrato_llm_model = st.text_input(
tr("LLM Model Name"),
value=config.app.get("narrato_llm_model", "qwen-plus")
)
narrato_llm_key = st.text_input(
tr("LLM Model API Key"),
value=config.app.get("narrato_llm_key", ""),
type="password",
help="用于文案生成的模型 API Key"
)
if narrato_llm_model:
config.app["narrato_llm_model"] = narrato_llm_model
st.session_state['narrato_llm_model'] = narrato_llm_model
if narrato_llm_key:
config.app["narrato_llm_key"] = narrato_llm_key
st.session_state['narrato_llm_key'] = narrato_llm_key
# 批处理配置
narrato_batch_size = st.number_input(
tr("Batch Size"),
min_value=1,
max_value=50,
value=config.app.get("narrato_batch_size", 10),
help="每批处理的图片数量"
)
if narrato_batch_size:
config.app["narrato_batch_size"] = narrato_batch_size
st.session_state['narrato_batch_size'] = narrato_batch_size
def render_text_llm_settings(tr):
"""渲染文案生成模型设置"""
st.subheader(tr("Text Generation Model Settings"))
# 文案生成模型提供商选择
text_providers = ['OpenAI', 'Gemini', 'Moonshot', 'Azure', 'Qwen', 'Ollama', 'G4f', 'OneAPI', 'Cloudflare']
saved_text_provider = config.app.get("text_llm_provider", "OpenAI").lower()
saved_provider_index = 0
for i, provider in enumerate(text_providers):
if provider.lower() == saved_text_provider:
saved_provider_index = i
break
llm_provider = st.selectbox(
tr("LLM Provider"),
options=llm_providers,
index=saved_llm_provider_index
text_provider = st.selectbox(
tr("Text Model Provider"),
options=text_providers,
index=saved_provider_index
)
llm_provider = llm_provider.lower()
config.app["llm_provider"] = llm_provider
text_provider = text_provider.lower()
config.app["text_llm_provider"] = text_provider
# 获取已保存的配置
llm_api_key = config.app.get(f"{llm_provider}_api_key", "")
llm_base_url = config.app.get(f"{llm_provider}_base_url", "")
llm_model_name = config.app.get(f"{llm_provider}_model_name", "")
llm_account_id = config.app.get(f"{llm_provider}_account_id", "")
# 渲染输入框
st_llm_api_key = st.text_input(tr("API Key"), value=llm_api_key, type="password")
st_llm_base_url = st.text_input(tr("Base Url"), value=llm_base_url)
st_llm_model_name = st.text_input(tr("Model Name"), value=llm_model_name)
# 保存配置
if st_llm_api_key:
config.app[f"{llm_provider}_api_key"] = st_llm_api_key
if st_llm_base_url:
config.app[f"{llm_provider}_base_url"] = st_llm_base_url
if st_llm_model_name:
config.app[f"{llm_provider}_model_name"] = st_llm_model_name
# 获取已保存的文本模型配置
text_api_key = config.app.get(f"text_{text_provider}_api_key", "")
text_base_url = config.app.get(f"text_{text_provider}_base_url", "")
text_model_name = config.app.get(f"text_{text_provider}_model_name", "")
# Cloudflare 特殊处理
if llm_provider == 'cloudflare':
st_llm_account_id = st.text_input(tr("Account ID"), value=llm_account_id)
if st_llm_account_id:
config.app[f"{llm_provider}_account_id"] = st_llm_account_id
# 渲染文本模型配置输入框
st_text_api_key = st.text_input(tr("Text API Key"), value=text_api_key, type="password")
st_text_base_url = st.text_input(tr("Text Base URL"), value=text_base_url)
st_text_model_name = st.text_input(tr("Text Model Name"), value=text_model_name)
# 保存文本模型配置
if st_text_api_key:
config.app[f"text_{text_provider}_api_key"] = st_text_api_key
if st_text_base_url:
config.app[f"text_{text_provider}_base_url"] = st_text_base_url
if st_text_model_name:
config.app[f"text_{text_provider}_model_name"] = st_text_model_name
# Cloudflare 特殊配置
if text_provider == 'cloudflare':
st_account_id = st.text_input(
tr("Account ID"),
value=config.app.get(f"text_{text_provider}_account_id", "")
)
if st_account_id:
config.app[f"text_{text_provider}_account_id"] = st_account_id

View File

@ -7,8 +7,10 @@ def render_review_panel(tr):
with st.expander(tr("Video Check"), expanded=False):
try:
video_list = st.session_state.get('video_clip_json', [])
subclip_videos = st.session_state.get('subclip_videos', {})
except KeyError:
video_list = []
subclip_videos = {}
# 计算列数和行数
num_videos = len(video_list)
@ -22,44 +24,62 @@ def render_review_panel(tr):
index = row * cols_per_row + col
if index < num_videos:
with cols[col]:
render_video_item(tr, video_list, index)
render_video_item(tr, video_list, subclip_videos, index)
def render_video_item(tr, video_list, index):
def render_video_item(tr, video_list, subclip_videos, index):
"""渲染单个视频项"""
video_info = video_list[index]
video_path = video_info.get('path')
if video_path is not None and os.path.exists(video_path):
initial_narration = video_info.get('narration', '')
initial_picture = video_info.get('picture', '')
initial_timestamp = video_info.get('timestamp', '')
# 显示视频
with open(video_path, 'rb') as video_file:
video_bytes = video_file.read()
st.video(video_bytes)
# 显示信息(只读)
text_panels = st.columns(2)
with text_panels[0]:
st.text_area(
tr("timestamp"),
value=initial_timestamp,
height=20,
key=f"timestamp_{index}",
disabled=True
)
with text_panels[1]:
st.text_area(
tr("Picture description"),
value=initial_picture,
height=20,
key=f"picture_{index}",
disabled=True
)
st.text_area(
tr("Narration"),
value=initial_narration,
height=100,
key=f"narration_{index}",
disabled=True
)
video_script = video_list[index]
# 显示时间戳
timestamp = video_script.get('timestamp', '')
st.text_area(
tr("Timestamp"),
value=timestamp,
height=70,
disabled=True,
key=f"timestamp_{index}"
)
# 显示视频播放器
video_path = subclip_videos.get(timestamp)
if video_path and os.path.exists(video_path):
try:
st.video(video_path)
except Exception as e:
logger.error(f"加载视频失败 {video_path}: {e}")
st.error(f"无法加载视频: {os.path.basename(video_path)}")
else:
st.warning(tr("视频文件未找到"))
# 显示画面描述
st.text_area(
tr("Picture Description"),
value=video_script.get('picture', ''),
height=150,
disabled=True,
key=f"picture_{index}"
)
# 显示旁白文本
narration = st.text_area(
tr("Narration"),
value=video_script.get('narration', ''),
height=150,
key=f"narration_{index}"
)
# 保存修改后的旁白文本
if narration != video_script.get('narration', ''):
video_script['narration'] = narration
st.session_state['video_clip_json'] = video_list
# 显示剪辑模式
ost = st.selectbox(
tr("Clip Mode"),
options=range(1, 10),
index=video_script.get('OST', 1) - 1,
key=f"ost_{index}"
)
# 保存修改后的剪辑模式
if ost != video_script.get('OST', 1):
video_script['OST'] = ost
st.session_state['video_clip_json'] = video_list

View File

@ -1,43 +1,50 @@
import streamlit as st
import os
import glob
import json
import time
import asyncio
import traceback
import requests
import streamlit as st
from loguru import logger
from app.config import config
from app.models.schema import VideoClipParams
from app.services import llm
from app.utils import utils, check_script
from loguru import logger
from app.utils import utils, check_script, vision_analyzer, video_processor
from webui.utils import file_utils
def render_script_panel(tr):
"""渲染脚本配置面板"""
with st.container(border=True):
st.write(tr("Video Script Configuration"))
params = VideoClipParams()
# 渲染脚本文件选择
render_script_file(tr, params)
# 渲染视频文件选择
render_video_file(tr, params)
# 渲染视频主题和提示词
render_video_details(tr)
# 渲染脚本操作按钮
render_script_buttons(tr, params)
def render_script_file(tr, params):
"""渲染脚本文件选择"""
script_list = [(tr("None"), ""), (tr("Auto Generate"), "auto")]
# 获取已有脚本文件
suffix = "*.json"
script_dir = utils.script_dir()
files = glob.glob(os.path.join(script_dir, suffix))
file_list = []
for file in files:
file_list.append({
"name": os.path.basename(file),
@ -64,15 +71,16 @@ def render_script_file(tr, params):
options=range(len(script_list)),
format_func=lambda x: script_list[x][0]
)
script_path = script_list[selected_script_index][1]
st.session_state['video_clip_json_path'] = script_path
params.video_clip_json_path = script_path
def render_video_file(tr, params):
"""渲染视频文件选择"""
video_list = [(tr("None"), ""), (tr("Upload Local Files"), "local")]
# 获取已有视频文件
for suffix in ["*.mp4", "*.mov", "*.avi", "*.mkv"]:
video_files = glob.glob(os.path.join(utils.video_dir(), suffix))
@ -86,7 +94,7 @@ def render_video_file(tr, params):
options=range(len(video_list)),
format_func=lambda x: video_list[x][0]
)
video_path = video_list[selected_video_index][1]
st.session_state['video_origin_path'] = video_path
params.video_origin_path = video_path
@ -97,16 +105,16 @@ def render_video_file(tr, params):
type=["mp4", "mov", "avi", "flv", "mkv"],
accept_multiple_files=False,
)
if uploaded_file is not None:
video_file_path = os.path.join(utils.video_dir(), uploaded_file.name)
file_name, file_extension = os.path.splitext(uploaded_file.name)
if os.path.exists(video_file_path):
timestamp = time.strftime("%Y%m%d%H%M%S")
file_name_with_timestamp = f"{file_name}_{timestamp}"
video_file_path = os.path.join(utils.video_dir(), file_name_with_timestamp + file_extension)
with open(video_file_path, "wb") as f:
f.write(uploaded_file.read())
st.success(tr("File Uploaded Successfully"))
@ -115,6 +123,7 @@ def render_video_file(tr, params):
time.sleep(1)
st.rerun()
def render_video_details(tr):
"""渲染视频主题和提示词"""
video_theme = st.text_input(tr("Video Theme"))
@ -128,6 +137,7 @@ def render_video_details(tr):
st.session_state['video_plot'] = prompt
return video_theme, prompt
def render_script_buttons(tr, params):
"""渲染脚本操作按钮"""
# 生成/加载按钮
@ -157,16 +167,17 @@ def render_script_buttons(tr, params):
with button_cols[0]:
if st.button(tr("Check Format"), key="check_format", use_container_width=True):
check_script_format(tr, video_clip_json_details)
with button_cols[1]:
if st.button(tr("Save Script"), key="save_script", use_container_width=True):
save_script(tr, video_clip_json_details)
with button_cols[2]:
script_valid = st.session_state.get('script_format_valid', False)
if st.button(tr("Crop Video"), key="crop_video", disabled=not script_valid, use_container_width=True):
crop_video(tr, params)
def check_script_format(tr, script_content):
"""检查脚本格式"""
try:
@ -181,6 +192,7 @@ def check_script_format(tr, script_content):
st.error(f"{tr('Script format check error')}: {str(e)}")
st.session_state['script_format_valid'] = False
def load_script(tr, script_path):
"""加载脚本文件"""
try:
@ -193,6 +205,7 @@ def load_script(tr, script_path):
except Exception as e:
st.error(f"{tr('Failed to load script')}: {str(e)}")
def generate_script(tr, params):
"""生成视频脚本"""
progress_bar = st.progress(0)
@ -207,48 +220,354 @@ def generate_script(tr, params):
try:
with st.spinner("正在生成脚本..."):
if not st.session_state.get('video_plot'):
st.warning("视频剧情为空; 会极大影响生成效果!")
if not params.video_origin_path:
st.error("请先选择视频文件")
st.stop()
return
# ===================提取键帧===================
update_progress(10, "正在提取关键帧...")
# 创建临时目录用于存储关键帧
keyframes_dir = os.path.join(utils.temp_dir(), "keyframes")
video_hash = utils.md5(params.video_origin_path + str(os.path.getmtime(params.video_origin_path)))
video_keyframes_dir = os.path.join(keyframes_dir, video_hash)
# 检查是否已经提取过关键帧
keyframe_files = []
if os.path.exists(video_keyframes_dir):
# 获取已有的关键帧文件
for filename in sorted(os.listdir(video_keyframes_dir)):
if filename.endswith('.jpg'):
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
if params.video_clip_json_path == "" and params.video_origin_path != "":
update_progress(10, "压缩视频中...")
script = llm.generate_script(
video_path=params.video_origin_path,
video_plot=st.session_state.get('video_plot', ''),
video_name=st.session_state.get('video_name', ''),
language=params.video_language,
progress_callback=update_progress
)
if script is None:
st.error("生成脚本失败,请检查日志")
st.stop()
else:
update_progress(90)
if keyframe_files:
logger.info(f"使用已缓存的关键帧: {video_keyframes_dir}")
st.info(f"使用已缓存的关键帧,如需重新提取请删除目录: {video_keyframes_dir}")
update_progress(20, f"使用已缓存关键帧,共 {len(keyframe_files)}")
# 如果没有缓存的关键帧,则进行提取
if not keyframe_files:
try:
# 确保目录存在
os.makedirs(video_keyframes_dir, exist_ok=True)
# 初始化视频处理器
processor = video_processor.VideoProcessor(params.video_origin_path)
# 处理视频并提取关键帧
processor.process_video(
output_dir=video_keyframes_dir,
skip_seconds=0
)
# 获取所有关键帧文件路径
for filename in sorted(os.listdir(video_keyframes_dir)):
if filename.endswith('.jpg'):
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
if not keyframe_files:
raise Exception("未提取到任何关键帧")
update_progress(20, f"关键帧提取完成,共 {len(keyframe_files)}")
except Exception as e:
# 如果提取失败,清理创建的目录
try:
if os.path.exists(video_keyframes_dir):
import shutil
shutil.rmtree(video_keyframes_dir)
except Exception as cleanup_err:
logger.error(f"清理失败的关键帧目录时出错: {cleanup_err}")
raise Exception(f"关键帧提取失败: {str(e)}")
script = utils.clean_model_output(script)
st.session_state['video_clip_json'] = json.loads(script)
else:
# 从本地加载
with open(params.video_clip_json_path, 'r', encoding='utf-8') as f:
update_progress(50)
status_text.text("从本地加载中...")
script = f.read()
script = utils.clean_model_output(script)
st.session_state['video_clip_json'] = json.loads(script)
update_progress(100)
status_text.text("从本地加载成功")
# 根据不同的 LLM 提供商处理
video_llm_provider = st.session_state.get('video_llm_providers', 'Gemini').lower()
if video_llm_provider == 'gemini':
try:
# ===================初始化视觉分析器===================
update_progress(30, "正在初始化视觉分析器...")
# 从配置中获取 Gemini 相关配置
vision_api_key = st.session_state.get('vision_gemini_api_key')
vision_model = st.session_state.get('vision_gemini_model_name')
vision_base_url = st.session_state.get('vision_gemini_base_url')
if not vision_api_key or not vision_model:
raise ValueError("未配置 Gemini API Key 或者 模型,请在基础设置中配置")
analyzer = vision_analyzer.VisionAnalyzer(
model_name=vision_model,
api_key=vision_api_key
)
update_progress(40, "正在分析关键帧...")
# ===================创建异步事件循环===================
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 执行异步分析
results = loop.run_until_complete(
analyzer.analyze_images(
images=keyframe_files,
prompt=config.app.get('vision_analysis_prompt'),
batch_size=config.app.get("vision_batch_size", 5)
)
)
loop.close()
# ===================处理分析结果===================
update_progress(60, "正在整理分析结果...")
# 合并所有批次的析结果
frame_analysis = ""
for result in results:
if 'error' in result:
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
continue
# 获取当前批次的图片文件
batch_start = result['batch_index'] * config.app.get("vision_batch_size", 5)
batch_end = min(batch_start + config.app.get("vision_batch_size", 5), len(keyframe_files))
batch_files = keyframe_files[batch_start:batch_end]
# 提取首帧和尾帧的时间戳
first_frame = os.path.basename(batch_files[0])
last_frame = os.path.basename(batch_files[-1])
# 从文件名中提取时间信息
first_time = first_frame.split('_')[2].replace('.jpg', '') # 000002
last_time = last_frame.split('_')[2].replace('.jpg', '') # 000005
# 转换为分:秒格式
def format_timestamp(time_str):
seconds = int(time_str)
minutes = seconds // 60
seconds = seconds % 60
return f"{minutes:02d}:{seconds:02d}"
first_timestamp = format_timestamp(first_time)
last_timestamp = format_timestamp(last_time)
# 添加带时间戳的分析结果
frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n"
frame_analysis += result['response']
frame_analysis += "\n"
if not frame_analysis.strip():
raise Exception("未能生成有效的帧分析结果")
# 保存分析结果
analysis_path = os.path.join(utils.temp_dir(), "frame_analysis.txt")
with open(analysis_path, 'w', encoding='utf-8') as f:
f.write(frame_analysis)
update_progress(70, "正在生成脚本...")
# 构建完整的上下文
context = {
'video_name': st.session_state.get('video_name', ''),
'video_plot': st.session_state.get('video_plot', ''),
'frame_analysis': frame_analysis,
'total_frames': len(keyframe_files)
}
# 从配置中获取文本生成相关配置
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
text_api_key = config.app.get(f'text_{text_provider}_api_key')
text_model = config.app.get(f'text_{text_provider}_model_name', 'gemini-1.5-pro')
text_base_url = config.app.get(f'text_{text_provider}_base_url')
# 构建帧内容列表
frame_content_list = []
for i, result in enumerate(results):
if 'error' in result:
continue
# 获取当前批次的图片文件
batch_start = result['batch_index'] * config.app.get("vision_batch_size", 5)
batch_end = min(batch_start + config.app.get("vision_batch_size", 5), len(keyframe_files))
batch_files = keyframe_files[batch_start:batch_end]
# 提取首帧和尾帧的时间戳
first_frame = os.path.basename(batch_files[0])
last_frame = os.path.basename(batch_files[-1])
# 从文件名中提取时间信息
first_time = first_frame.split('_')[2].replace('.jpg', '') # 000002
last_time = last_frame.split('_')[2].replace('.jpg', '') # 000005
# 转换为分:秒格式
def format_timestamp(time_str):
seconds = int(time_str)
minutes = seconds // 60
seconds = seconds % 60
return f"{minutes:02d}:{seconds:02d}"
first_timestamp = format_timestamp(first_time)
last_timestamp = format_timestamp(last_time)
# 构建时间戳范围字符串 (MM:SS-MM:SS 格式)
timestamp_range = f"{first_timestamp}-{last_timestamp}"
frame_content = {
"timestamp": timestamp_range, # 使用时间范围格式 "MM:SS-MM:SS"
"picture": result['response'], # 图片分析结果
"narration": "", # 将由 ScriptProcessor 生成
"OST": 2 # 默认值
}
frame_content_list.append(frame_content)
logger.debug(f"添加帧内容: 时间范围={timestamp_range}, 分析结果长度={len(result['response'])}")
if not frame_content_list:
raise Exception("没有有效的帧内容可以处理")
# 使用 ScriptProcessor 生成脚本
from app.utils.script_generator import ScriptProcessor
processor = ScriptProcessor(
model_name=text_model,
api_key=text_api_key,
prompt="" # 使用默认提示词
)
# 处理帧内容生成脚本
script_result = processor.process_frames(frame_content_list)
# 将结果转换为JSON字符串
script = json.dumps(script_result, ensure_ascii=False, indent=2)
except Exception as e:
logger.exception(f"Gemini 处理过程中发生错误\n{traceback.format_exc()}")
raise Exception(f"视觉分析失败: {str(e)}")
else: # NarratoAPI
try:
# 创建临时目录
temp_dir = utils.temp_dir("narrato")
# 打包关键帧
update_progress(30, "正在打包关键帧...")
zip_path = os.path.join(temp_dir, f"keyframes_{int(time.time())}.zip")
if not file_utils.create_zip(keyframe_files, zip_path):
raise Exception("打包关键帧失败")
# 获取API配置
api_url = st.session_state.get('narrato_api_url', 'http://127.0.0.1:8000/api/v1/video/analyze')
api_key = st.session_state.get('narrato_api_key')
if not api_key:
raise ValueError("未配置 Narrato API Key请在基础设置中配置")
# 准备API请求
headers = {
'X-API-Key': api_key,
'accept': 'application/json'
}
api_params = {
'batch_size': st.session_state.get('narrato_batch_size', 10),
'use_ai': False,
'start_offset': 0,
'vision_model': st.session_state.get('narrato_vision_model', 'gemini-1.5-flash'),
'vision_api_key': st.session_state.get('narrato_vision_key'),
'llm_model': st.session_state.get('narrato_llm_model', 'qwen-plus'),
'llm_api_key': st.session_state.get('narrato_llm_key'),
'custom_prompt': st.session_state.get('video_plot', '')
}
# 发送API请求
update_progress(40, "正在上传文件...")
with open(zip_path, 'rb') as f:
files = {'file': (os.path.basename(zip_path), f, 'application/x-zip-compressed')}
try:
response = requests.post(
api_url,
headers=headers,
params=api_params,
files=files,
timeout=30 # 设置超时时间
)
response.raise_for_status()
except requests.RequestException as e:
raise Exception(f"API请求失败: {str(e)}")
task_data = response.json()
task_id = task_data.get('task_id')
if not task_id:
raise Exception(f"无效的API响应: {response.text}")
# 轮询任务状态
update_progress(50, "正在等待分析结果...")
retry_count = 0
max_retries = 60 # 最多等待2分钟
while retry_count < max_retries:
try:
status_response = requests.get(
f"{api_url}/tasks/{task_id}",
headers=headers,
timeout=10
)
status_response.raise_for_status()
task_status = status_response.json()['data']
if task_status['status'] == 'SUCCESS':
script = task_status['result']
break
elif task_status['status'] in ['FAILURE', 'RETRY']:
raise Exception(f"任务失败: {task_status.get('error')}")
retry_count += 1
progress = min(70, 50 + (retry_count * 20 / max_retries))
update_progress(progress, "正在分析中...")
time.sleep(2)
except requests.RequestException as e:
logger.warning(f"获取任务状态失败,重试中: {str(e)}")
retry_count += 1
time.sleep(2)
continue
if retry_count >= max_retries:
raise Exception("任务执行超时")
except Exception as e:
logger.exception("NarratoAPI 处理过程中发生错误")
raise Exception(f"NarratoAPI 处理失败: {str(e)}")
finally:
# 清理临时文件
try:
if os.path.exists(zip_path):
os.remove(zip_path)
except Exception as e:
logger.warning(f"清理临时文件失败: {str(e)}")
if script is None:
st.error("生成脚本失败,请检查日志")
st.stop()
script = utils.clean_model_output(script)
st.session_state['video_clip_json'] = json.loads(script)
update_progress(90, "脚本生成完成")
time.sleep(0.5)
progress_bar.progress(100)
status_text.text("脚本生成完成!")
st.success("视频脚本生成成功!")
except Exception as err:
st.error(f"生成过程中发生错误: {str(err)}")
logger.exception("生成脚本时发生错误")
finally:
time.sleep(2)
progress_bar.empty()
status_text.empty()
def save_script(tr, video_clip_json_details):
"""保存视频脚本"""
if not video_clip_json_details:
@ -266,21 +585,22 @@ def save_script(tr, video_clip_json_details):
json.dump(data, file, ensure_ascii=False, indent=4)
st.session_state['video_clip_json'] = data
st.session_state['video_clip_json_path'] = save_path
# 更新配置
config.app["video_clip_json_path"] = save_path
# 显示成功消息
st.success(tr("Script saved successfully"))
# 强制重新加载页面更新选择框
# 强制重新加载页面<EFBFBD><EFBFBD>更新选择框
time.sleep(0.5) # 给一点时间让用户看到成功消息
st.rerun()
except Exception as err:
st.error(f"{tr('Failed to save script')}: {str(err)}")
st.stop()
def crop_video(tr, params):
"""裁剪视频"""
progress_bar = st.progress(0)
@ -303,6 +623,7 @@ def crop_video(tr, params):
progress_bar.empty()
status_text.empty()
def get_script_params():
"""获取脚本参数"""
return {
@ -311,4 +632,4 @@ def get_script_params():
'video_origin_path': st.session_state.get('video_origin_path', ''),
'video_name': st.session_state.get('video_name', ''),
'video_plot': st.session_state.get('video_plot', '')
}
}

View File

@ -19,6 +19,18 @@ class WebUIConfig:
project_version: str = "0.1.0"
# 项目根目录
root_dir: str = None
# Gemini API Key
gemini_api_key: str = ""
# 每批处理的图片数量
vision_batch_size: int = 5
# 提示词
vision_prompt: str = """..."""
# Narrato API 配置
narrato_api_url: str = "http://127.0.0.1:8000/api/v1/video/analyze"
narrato_api_key: str = ""
narrato_batch_size: int = 10
narrato_vision_model: str = "gemini-1.5-flash"
narrato_llm_model: str = "qwen-plus"
def __post_init__(self):
"""初始化默认值"""

View File

@ -1,20 +1,8 @@
from .cache import get_fonts_cache, get_video_files_cache, get_songs_cache
from .file_utils import (
open_task_folder, cleanup_temp_files, get_file_list,
save_uploaded_file, create_temp_file, get_file_size, ensure_directory
)
from .performance import monitor_performance
from .performance import monitor_performance, PerformanceMonitor
from .cache import *
from .file_utils import *
__all__ = [
'get_fonts_cache',
'get_video_files_cache',
'get_songs_cache',
'open_task_folder',
'cleanup_temp_files',
'get_file_list',
'save_uploaded_file',
'create_temp_file',
'get_file_size',
'ensure_directory',
'monitor_performance'
'monitor_performance',
'PerformanceMonitor'
]

View File

@ -186,4 +186,45 @@ def ensure_directory(directory):
return True
except Exception as e:
logger.error(f"创建目录失败: {directory}, 错误: {e}")
return False
def create_zip(files: list, zip_path: str, base_dir: str = None) -> bool:
"""
创建zip文件
Args:
files: 要打包的文件列表
zip_path: zip文件保存路径
base_dir: 基础目录用于保持目录结构
Returns:
bool: 是否成功
"""
try:
import zipfile
# 确保目标目录存在
os.makedirs(os.path.dirname(zip_path), exist_ok=True)
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
for file in files:
if not os.path.exists(file):
logger.warning(f"文件不存在,跳过: {file}")
continue
# 计算文件在zip中的路径
if base_dir:
arcname = os.path.relpath(file, base_dir)
else:
arcname = os.path.basename(file)
try:
zipf.write(file, arcname)
except Exception as e:
logger.error(f"添加文件到zip失败: {file}, 错误: {e}")
continue
logger.info(f"创建zip文件成功: {zip_path}")
return True
except Exception as e:
logger.error(f"创建zip文件失败: {e}")
return False

View File

@ -1,24 +1,37 @@
import time
import psutil
import os
from loguru import logger
import torch
try:
import psutil
ENABLE_PERFORMANCE_MONITORING = True
except ImportError:
ENABLE_PERFORMANCE_MONITORING = False
logger.warning("psutil not installed. Performance monitoring is disabled.")
class PerformanceMonitor:
@staticmethod
def monitor_memory():
process = psutil.Process(os.getpid())
memory_info = process.memory_info()
logger.debug(f"Memory usage: {memory_info.rss / 1024 / 1024:.2f} MB")
if torch.cuda.is_available():
gpu_memory = torch.cuda.memory_allocated() / 1024 / 1024
logger.debug(f"GPU Memory usage: {gpu_memory:.2f} MB")
@staticmethod
def cleanup_resources():
if torch.cuda.is_available():
torch.cuda.empty_cache()
import gc
gc.collect()
PerformanceMonitor.monitor_memory()
def monitor_performance():
if not ENABLE_PERFORMANCE_MONITORING:
return {'execution_time': 0, 'memory_usage': 0}
start_time = time.time()
try:
memory_usage = psutil.Process().memory_info().rss / 1024 / 1024 # MB
except:
memory_usage = 0
return {
'execution_time': time.time() - start_time,
'memory_usage': memory_usage
}
def monitor_performance(func):
"""性能监控装饰器"""
def wrapper(*args, **kwargs):
try:
PerformanceMonitor.monitor_memory()
result = func(*args, **kwargs)
return result
finally:
PerformanceMonitor.cleanup_resources()
return wrapper