Merge pull request #47 from linyqh/dev

0.3.2 发版
This commit is contained in:
linyq 2024-11-10 01:41:56 +08:00 committed by GitHub
commit 960db5e622
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 3548 additions and 1319 deletions

View File

@ -130,7 +130,7 @@ docker-compose up
## 开发 💻
1. 安装依赖
```shell
conda create -n narratoai python=3.10
conda create -n narratoai python=3.11
conda activate narratoai
cd narratoai
pip install -r requirements.txt

View File

@ -1,18 +1,24 @@
from fastapi import Request
from fastapi import Request, File, UploadFile
import os
from app.controllers.v1.base import new_router
from app.models.schema import (
VideoScriptResponse,
VideoScriptRequest,
VideoTermsResponse,
VideoTermsRequest,
VideoTranscriptionRequest,
VideoTranscriptionResponse,
)
from app.services import llm
from app.utils import utils
from app.config import config
# 认证依赖项
# router = new_router(dependencies=[Depends(base.verify_token)])
router = new_router()
# 定义上传目录
UPLOAD_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "uploads")
@router.post(
"/scripts",
@ -42,3 +48,46 @@ def generate_video_terms(request: Request, body: VideoTermsRequest):
)
response = {"video_terms": video_terms}
return utils.get_response(200, response)
@router.post(
"/transcription",
response_model=VideoTranscriptionResponse,
summary="Transcribe video content using Gemini"
)
async def transcribe_video(
request: Request,
video_name: str,
language: str = "zh-CN",
video_file: UploadFile = File(...)
):
"""
使用 Gemini 转录视频内容,包括时间戳画面描述和语音内容
Args:
video_name: 视频名称
language: 语言代码,默认zh-CN
video_file: 上传的视频文件
"""
# 创建临时目录用于存储上传的视频
os.makedirs(UPLOAD_DIR, exist_ok=True)
# 保存上传的视频文件
video_path = os.path.join(UPLOAD_DIR, video_file.filename)
with open(video_path, "wb") as buffer:
content = await video_file.read()
buffer.write(content)
try:
transcription = llm.gemini_video_transcription(
video_name=video_name,
video_path=video_path,
language=language,
llm_provider_video=config.app.get("video_llm_provider", "gemini")
)
response = {"transcription": transcription}
return utils.get_response(200, response)
finally:
# 处理完成后删除临时文件
if os.path.exists(video_path):
os.remove(video_path)

View File

@ -347,6 +347,7 @@ class VideoClipParams(BaseModel):
voice_name: Optional[str] = Field(default="zh-CN-YunjianNeural", description="语音名称")
voice_volume: Optional[float] = Field(default=1.0, description="语音音量")
voice_rate: Optional[float] = Field(default=1.0, description="语速")
voice_pitch: Optional[float] = Field(default=1.0, description="语调")
bgm_name: Optional[str] = Field(default="random", description="背景音乐名称")
bgm_type: Optional[str] = Field(default="random", description="背景音乐类型")
@ -365,3 +366,13 @@ class VideoClipParams(BaseModel):
custom_position: float = Field(default=70.0, description="自定义位置")
n_threads: Optional[int] = 8 # 线程数,有助于提升视频处理速度
class VideoTranscriptionRequest(BaseModel):
video_name: str
language: str = "zh-CN"
class Config:
arbitrary_types_allowed = True
class VideoTranscriptionResponse(BaseModel):
transcription: str

View File

@ -73,25 +73,40 @@ def merge_audio_files(task_id: str, audio_file_paths: List[str], total_duration:
def parse_timestamp(timestamp: str):
"""解析时间戳字符串为秒数"""
# start, end = timestamp.split('-')
# 确保使用冒号作为分隔符
timestamp = timestamp.replace('_', ':')
return time_to_seconds(timestamp)
def extract_timestamp(filename):
"""从文件名中提取开始和结束时间戳"""
time_part = filename.split('_')[1].split('.')[0]
times = time_part.split('-')
# 从 "audio_00_06-00_24.mp3" 这样的格式中提取时间
time_part = filename.split('_', 1)[1].split('.')[0] # 获取 "00_06-00_24" 部分
start_time, end_time = time_part.split('-') # 分割成 "00_06" 和 "00_24"
# 将下划线格式转换回冒号格式
start_time = start_time.replace('_', ':')
end_time = end_time.replace('_', ':')
# 将时间戳转换为秒
start_seconds = time_to_seconds(times[0])
end_seconds = time_to_seconds(times[1])
start_seconds = time_to_seconds(start_time)
end_seconds = time_to_seconds(end_time)
return start_seconds, end_seconds
def time_to_seconds(times):
"""将 “00:06” 转换为总秒数 """
times = times.split(':')
return int(times[0]) * 60 + int(times[1])
def time_to_seconds(time_str):
""""00:06""00_06" 格式转换为总秒数"""
# 确保使用冒号作为分隔符
time_str = time_str.replace('_', ':')
try:
parts = time_str.split(':')
if len(parts) != 2:
logger.error(f"Invalid time format: {time_str}")
return 0
return int(parts[0]) * 60 + int(parts[1])
except (ValueError, IndexError) as e:
logger.error(f"Error parsing time {time_str}: {str(e)}")
return 0
if __name__ == "__main__":

View File

@ -14,6 +14,7 @@ from googleapiclient.errors import ResumableUploadError
from google.api_core.exceptions import *
from google.generativeai.types import *
import subprocess
from typing import Union, TextIO
from app.config import config
from app.utils.utils import clean_model_output
@ -353,7 +354,7 @@ def _generate_response(prompt: str, llm_provider: str = None) -> str:
return content.replace("\n", "")
def _generate_response_video(prompt: str, llm_provider_video: str, video_file: str | File) -> str:
def _generate_response_video(prompt: str, llm_provider_video: str, video_file: Union[str, TextIO]) -> str:
"""
多模态能力大模型
"""
@ -780,22 +781,28 @@ def screen_matching(huamian: str, wenan: str, llm_provider: str):
if __name__ == "__main__":
# 1. 视频转录
# video_subject = "第二十条之无罪释放"
# video_path = "../../resource/videos/test01.mp4"
# language = "zh-CN"
# gemini_video_transcription(video_subject, video_path, language)
video_subject = "第二十条之无罪释放"
video_path = "/Users/apple/Desktop/home/pipedream_project/downloads/jianzao.mp4"
language = "zh-CN"
gemini_video_transcription(
video_name=video_subject,
video_path=video_path,
language=language,
progress_callback=print,
llm_provider_video="gemini"
)
# 2. 解说文案
video_path = "/Users/apple/Desktop/home/NarratoAI/resource/videos/1.mp4"
# video_path = "E:\\projects\\NarratoAI\\resource\\videos\\1.mp4"
video_plot = """
李自忠拿着儿子李牧名下的存折去银行取钱给儿子救命却被要求证明"你儿子是你儿子"
走投无路时碰到银行被抢劫劫匪给了他两沓钱救命李自忠却因此被银行以抢劫罪起诉并顶格判处20年有期徒刑
苏醒后的李牧坚决为父亲做无罪辩护面对银行的顶级律师团队他一个法学院大一学生能否力挽狂澜创作奇迹挥法律之利剑 持正义之天平
"""
res = generate_script(video_path, video_plot, video_name="第二十条之无罪释放")
# res = generate_script(video_path, video_plot, video_name="海岸")
print("脚本生成成功:\n", res)
res = clean_model_output(res)
aaa = json.loads(res)
print(json.dumps(aaa, indent=2, ensure_ascii=False))
# # 2. 解说文案
# video_path = "/Users/apple/Desktop/home/NarratoAI/resource/videos/1.mp4"
# # video_path = "E:\\projects\\NarratoAI\\resource\\videos\\1.mp4"
# video_plot = """
# 李自忠拿着儿子李牧名下的存折,去银行取钱给儿子救命,却被要求证明"你儿子是你儿子"。
# 走投无路时碰到银行被抢劫劫匪给了他两沓钱救命李自忠却因此被银行以抢劫罪起诉并顶格判处20年有期徒刑
# 苏醒后的李牧坚决为父亲做无罪辩护,面对银行的顶级律师团队,他一个法学院大一学生,能否力挽狂澜,创作奇迹?挥法律之利剑 ,持正义之天平
# """
# res = generate_script(video_path, video_plot, video_name="第二十条之无罪释放")
# # res = generate_script(video_path, video_plot, video_name="海岸")
# print("脚本生成成功:\n", res)
# res = clean_model_output(res)
# aaa = json.loads(res)
# print(json.dumps(aaa, indent=2, ensure_ascii=False))

View File

@ -1,6 +1,7 @@
import os
import subprocess
import random
import traceback
from urllib.parse import urlencode
import requests
@ -274,25 +275,37 @@ def save_clip_video(timestamp: str, origin_video: str, save_dir: str = "") -> di
logger.info(f"video already exists: {video_path}")
return {timestamp: video_path}
# 剪辑视频
start, end = utils.split_timestamp(timestamp)
video = VideoFileClip(origin_video).subclip(start, end)
video.write_videofile(video_path, logger=None) # 禁用 MoviePy 的内置日志
try:
# 剪辑视频
start, end = utils.split_timestamp(timestamp)
video = VideoFileClip(origin_video).subclip(start, end)
# 检查视频是否有音频轨道
if video.audio is not None:
video.write_videofile(video_path, logger=None) # 有音频时正常处理
else:
# 没有音频时使用不同的写入方式
video.write_videofile(video_path, audio=False, logger=None)
video.close() # 确保关闭视频文件
if os.path.getsize(video_path) > 0 and os.path.exists(video_path):
try:
clip = VideoFileClip(video_path)
duration = clip.duration
fps = clip.fps
clip.close()
if duration > 0 and fps > 0:
return {timestamp: video_path}
except Exception as e:
if os.path.getsize(video_path) > 0 and os.path.exists(video_path):
try:
os.remove(video_path)
clip = VideoFileClip(video_path)
duration = clip.duration
fps = clip.fps
clip.close()
if duration > 0 and fps > 0:
return {timestamp: video_path}
except Exception as e:
logger.warning(str(e))
logger.warning(f"无效的视频文件: {video_path}")
logger.warning(f"视频文件验证失败: {video_path} => {str(e)}")
if os.path.exists(video_path):
os.remove(video_path)
except Exception as e:
logger.warning(f"视频剪辑失败: {str(e)}")
if os.path.exists(video_path):
os.remove(video_path)
return {}
@ -327,7 +340,7 @@ def clip_videos(task_id: str, timestamp_terms: List[str], origin_video: str, pro
if progress_callback:
progress_callback(index + 1, total_items)
except Exception as e:
logger.error(f"视频裁剪失败: {utils.to_json(item)} => {str(e)}")
logger.error(f"视频裁剪失败: {utils.to_json(item)} =>\n{str(traceback.format_exc())}")
return {}
logger.success(f"裁剪 {len(video_paths)} videos")
return video_paths

View File

@ -29,7 +29,7 @@ def create(audio_file, subtitle_file: str = ""):
返回:
无返回值但会在指定路径生成字幕文件
"""
global model
global model, device, compute_type
if not model:
model_path = f"{utils.root_dir()}/app/models/faster-whisper-large-v2"
model_bin_file = f"{model_path}/model.bin"
@ -43,27 +43,45 @@ def create(audio_file, subtitle_file: str = ""):
)
return None
logger.info(
f"加载模型: {model_path}, 设备: {device}, 计算类型: {compute_type}"
)
# 尝试使用 CUDA如果失败则回退到 CPU
try:
import torch
if torch.cuda.is_available():
try:
logger.info(f"尝试使用 CUDA 加载模型: {model_path}")
model = WhisperModel(
model_size_or_path=model_path,
device="cuda",
compute_type="float16",
local_files_only=True
)
device = "cuda"
compute_type = "float16"
logger.info("成功使用 CUDA 加载模型")
except Exception as e:
logger.warning(f"CUDA 加载失败,错误信息: {str(e)}")
logger.warning("回退到 CPU 模式")
device = "cpu"
compute_type = "int8"
else:
logger.info("未检测到 CUDA使用 CPU 模式")
device = "cpu"
compute_type = "int8"
except ImportError:
logger.warning("未安装 torch使用 CPU 模式")
device = "cpu"
compute_type = "int8"
if device == "cpu":
logger.info(f"使用 CPU 加载模型: {model_path}")
model = WhisperModel(
model_size_or_path=model_path,
device=device,
compute_type=compute_type,
local_files_only=True
)
except Exception as e:
logger.error(
f"加载模型失败: {e} \n\n"
f"********************************************\n"
f"这可能是由网络问题引起的. \n"
f"请手动下载模型并将其放入 'app/models' 文件夹中。 \n"
f"see [README.md FAQ](https://github.com/linyqh/NarratoAI) for more details.\n"
f"********************************************\n\n"
f"{traceback.format_exc()}"
)
return None
logger.info(f"模型加载完成,使用设备: {device}, 计算类型: {compute_type}")
logger.info(f"start, output file: {subtitle_file}")
if not subtitle_file:

View File

@ -372,12 +372,13 @@ def start_subclip(task_id: str, params: VideoClipParams, subclip_path_videos: li
list_script=list_script,
voice_name=voice_name,
voice_rate=params.voice_rate,
voice_pitch=params.voice_pitch,
force_regenerate=True
)
if audio_files is None:
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
logger.error(
"音频文件为空可能是网络不可用。如果您在中国请使用VPN。或者手动选择 zh-CN-Yunjian-男性 音频")
"TTS转换音频失败, 可能是网络不可用! 如果您在中国, 请使用VPN.")
return
logger.info(f"合并音频:\n\n {audio_files}")
audio_file = audio_merger.merge_audio_files(task_id, audio_files, total_duration, list_script)

View File

@ -4,11 +4,13 @@ import glob
import random
from typing import List
from typing import Union
import traceback
from loguru import logger
from moviepy.editor import *
from moviepy.video.tools.subtitles import SubtitlesClip
from PIL import ImageFont
from contextlib import contextmanager
from app.models import const
from app.models.schema import MaterialInfo, VideoAspect, VideoConcatMode, VideoParams, VideoClipParams
@ -144,7 +146,7 @@ def combine_videos(
return combined_video_path
def wrap_text(text, max_width, font="Arial", fontsize=60):
def wrap_text(text, max_width, font, fontsize=60):
# 创建字体对象
font = ImageFont.truetype(font, fontsize)
@ -157,7 +159,7 @@ def wrap_text(text, max_width, font="Arial", fontsize=60):
if width <= max_width:
return text, height
# logger.warning(f"wrapping text, max_width: {max_width}, text_width: {width}, text: {text}")
logger.debug(f"换行文本, 最大宽度: {max_width}, 文本宽度: {width}, 文本: {text}")
processed = True
@ -198,110 +200,17 @@ def wrap_text(text, max_width, font="Arial", fontsize=60):
_wrapped_lines_.append(_txt_)
result = "\n".join(_wrapped_lines_).strip()
height = len(_wrapped_lines_) * height
# logger.warning(f"wrapped text: {result}")
logger.debug(f"换行文本: {result}")
return result, height
def generate_video(
video_path: str,
audio_path: str,
subtitle_path: str,
output_file: str,
params: Union[VideoParams, VideoClipParams],
):
aspect = VideoAspect(params.video_aspect)
video_width, video_height = aspect.to_resolution()
logger.info(f"start, video size: {video_width} x {video_height}")
logger.info(f" ① video: {video_path}")
logger.info(f" ② audio: {audio_path}")
logger.info(f" ③ subtitle: {subtitle_path}")
logger.info(f" ④ output: {output_file}")
# 写入与输出文件相同的目录
output_dir = os.path.dirname(output_file)
font_path = ""
if params.subtitle_enabled:
if not params.font_name:
params.font_name = "STHeitiMedium.ttc"
font_path = os.path.join(utils.font_dir(), params.font_name)
if os.name == "nt":
font_path = font_path.replace("\\", "/")
logger.info(f"using font: {font_path}")
def create_text_clip(subtitle_item):
phrase = subtitle_item[1]
max_width = video_width * 0.9
wrapped_txt, txt_height = wrap_text(
phrase, max_width=max_width, font=font_path, fontsize=params.font_size
)
_clip = TextClip(
wrapped_txt,
font=font_path,
fontsize=params.font_size,
color=params.text_fore_color,
bg_color=params.text_background_color,
stroke_color=params.stroke_color,
stroke_width=params.stroke_width,
print_cmd=False,
)
duration = subtitle_item[0][1] - subtitle_item[0][0]
_clip = _clip.set_start(subtitle_item[0][0])
_clip = _clip.set_end(subtitle_item[0][1])
_clip = _clip.set_duration(duration)
if params.subtitle_position == "bottom":
_clip = _clip.set_position(("center", video_height * 0.95 - _clip.h))
elif params.subtitle_position == "top":
_clip = _clip.set_position(("center", video_height * 0.05))
elif params.subtitle_position == "custom":
# 确保字幕完全在屏幕内
margin = 10 # 额外的边距,单位为像素
max_y = video_height - _clip.h - margin
min_y = margin
custom_y = (video_height - _clip.h) * (params.custom_position / 100)
custom_y = max(min_y, min(custom_y, max_y)) # 限制 y 值在有效范围内
_clip = _clip.set_position(("center", custom_y))
else: # center
_clip = _clip.set_position(("center", "center"))
return _clip
video_clip = VideoFileClip(video_path)
audio_clip = AudioFileClip(audio_path).volumex(params.voice_volume)
if subtitle_path and os.path.exists(subtitle_path):
sub = SubtitlesClip(subtitles=subtitle_path, encoding="utf-8")
text_clips = []
for item in sub.subtitles:
clip = create_text_clip(subtitle_item=item)
text_clips.append(clip)
video_clip = CompositeVideoClip([video_clip, *text_clips])
bgm_file = get_bgm_file(bgm_type=params.bgm_type, bgm_file=params.bgm_file)
if bgm_file:
try:
bgm_clip = (
AudioFileClip(bgm_file).volumex(params.bgm_volume).audio_fadeout(3)
)
bgm_clip = afx.audio_loop(bgm_clip, duration=video_clip.duration)
audio_clip = CompositeAudioClip([audio_clip, bgm_clip])
except Exception as e:
logger.error(f"failed to add bgm: {str(e)}")
video_clip = video_clip.set_audio(audio_clip)
video_clip.write_videofile(
output_file,
audio_codec="aac",
temp_audiofile_path=output_dir,
threads=params.n_threads,
logger=None,
fps=30,
)
video_clip.close()
del video_clip
logger.success(""
"completed")
@contextmanager
def manage_clip(clip):
try:
yield clip
finally:
clip.close()
del clip
def generate_video_v2(
@ -310,6 +219,7 @@ def generate_video_v2(
subtitle_path: str,
output_file: str,
params: Union[VideoParams, VideoClipParams],
progress_callback=None,
):
"""
合并所有素材
@ -319,135 +229,163 @@ def generate_video_v2(
subtitle_path: 字幕文件路径
output_file: 输出文件路径
params: 视频参数
progress_callback: 进度回调函数接收 0-100 的进度值
Returns:
"""
aspect = VideoAspect(params.video_aspect)
video_width, video_height = aspect.to_resolution()
total_steps = 4
current_step = 0
def update_progress(step_name):
nonlocal current_step
current_step += 1
if progress_callback:
progress_callback(int(current_step * 100 / total_steps))
logger.info(f"完成步骤: {step_name}")
logger.info(f"开始,视频尺寸: {video_width} x {video_height}")
logger.info(f" ① 视频: {video_path}")
logger.info(f" ② 音频: {audio_path}")
logger.info(f" ③ 字幕: {subtitle_path}")
logger.info(f" ④ 输出: {output_file}")
# 写入与输出文件相同的目录
output_dir = os.path.dirname(output_file)
# 字体设置部分保持不变
font_path = ""
if params.subtitle_enabled:
if not params.font_name:
params.font_name = "STHeitiMedium.ttc"
font_path = os.path.join(utils.font_dir(), params.font_name)
if os.name == "nt":
font_path = font_path.replace("\\", "/")
logger.info(f"使用字体: {font_path}")
# create_text_clip 函数保持不变
def create_text_clip(subtitle_item):
phrase = subtitle_item[1]
max_width = video_width * 0.9
wrapped_txt, txt_height = wrap_text(
phrase, max_width=max_width, font=font_path, fontsize=params.font_size
)
_clip = TextClip(
wrapped_txt,
font=font_path,
fontsize=params.font_size,
color=params.text_fore_color,
bg_color=params.text_background_color,
stroke_color=params.stroke_color,
stroke_width=params.stroke_width,
print_cmd=False,
)
duration = subtitle_item[0][1] - subtitle_item[0][0]
_clip = _clip.set_start(subtitle_item[0][0])
_clip = _clip.set_end(subtitle_item[0][1])
_clip = _clip.set_duration(duration)
if params.subtitle_position == "bottom":
_clip = _clip.set_position(("center", video_height * 0.95 - _clip.h))
elif params.subtitle_position == "top":
_clip = _clip.set_position(("center", video_height * 0.05))
elif params.subtitle_position == "custom":
# 确保字幕完全在屏幕内
margin = 10 # 额外的边距,单位为像素
max_y = video_height - _clip.h - margin
min_y = margin
custom_y = (video_height - _clip.h) * (params.custom_position / 100)
custom_y = max(min_y, min(custom_y, max_y)) # 限制 y 值在有效范围内
_clip = _clip.set_position(("center", custom_y))
else: # center
_clip = _clip.set_position(("center", "center"))
return _clip
video_clip = VideoFileClip(video_path)
original_audio = video_clip.audio # 保存原始视频的音轨
video_duration = video_clip.duration
# 处理新的音频文件
new_audio = AudioFileClip(audio_path).volumex(params.voice_volume)
# 字幕处理部分
if subtitle_path and os.path.exists(subtitle_path):
sub = SubtitlesClip(subtitles=subtitle_path, encoding="utf-8")
text_clips = []
try:
validate_params(video_path, audio_path, output_file, params)
for item in sub.subtitles:
clip = create_text_clip(subtitle_item=item)
# 确保字幕的开始时间不早于视频开始
start_time = max(clip.start, 0)
# 如果字幕的开始时间晚于视频结束时间,则跳过此字幕
if start_time >= video_duration:
continue
# 调整字幕的结束时间,但不要超过视频长度
end_time = min(clip.end, video_duration)
# 调整字幕的时间范围
clip = clip.set_start(start_time).set_end(end_time)
text_clips.append(clip)
logger.info(f"处理了 {len(text_clips)} 段字幕")
# 创建一个新的视频剪辑,包含所有字幕
video_clip = CompositeVideoClip([video_clip, *text_clips])
with manage_clip(VideoFileClip(video_path)) as video_clip:
aspect = VideoAspect(params.video_aspect)
video_width, video_height = aspect.to_resolution()
# 背景音乐处理部分
logger.info(f"开始,视频尺寸: {video_width} x {video_height}")
logger.info(f" ① 视频: {video_path}")
logger.info(f" ② 音频: {audio_path}")
logger.info(f" ③ 字幕: {subtitle_path}")
logger.info(f" ④ 输出: {output_file}")
output_dir = os.path.dirname(output_file)
update_progress("初始化完成")
# 字体设置
font_path = ""
if params.subtitle_enabled:
if not params.font_name:
params.font_name = "STHeitiMedium.ttc"
font_path = os.path.join(utils.font_dir(), params.font_name)
if os.name == "nt":
font_path = font_path.replace("\\", "/")
logger.info(f"使用字体: {font_path}")
def create_text_clip(subtitle_item):
phrase = subtitle_item[1]
max_width = video_width * 0.9
wrapped_txt, txt_height = wrap_text(
phrase, max_width=max_width, font=font_path, fontsize=params.font_size
)
_clip = TextClip(
wrapped_txt,
font=font_path,
fontsize=params.font_size,
color=params.text_fore_color,
bg_color=params.text_background_color,
stroke_color=params.stroke_color,
stroke_width=params.stroke_width,
print_cmd=False,
)
duration = subtitle_item[0][1] - subtitle_item[0][0]
_clip = _clip.set_start(subtitle_item[0][0])
_clip = _clip.set_end(subtitle_item[0][1])
_clip = _clip.set_duration(duration)
if params.subtitle_position == "bottom":
_clip = _clip.set_position(("center", video_height * 0.95 - _clip.h))
elif params.subtitle_position == "top":
_clip = _clip.set_position(("center", video_height * 0.05))
elif params.subtitle_position == "custom":
margin = 10
max_y = video_height - _clip.h - margin
min_y = margin
custom_y = (video_height - _clip.h) * (params.custom_position / 100)
custom_y = max(min_y, min(custom_y, max_y))
_clip = _clip.set_position(("center", custom_y))
else: # center
_clip = _clip.set_position(("center", "center"))
return _clip
update_progress("字体设置完成")
# 处理音频
original_audio = video_clip.audio
video_duration = video_clip.duration
new_audio = AudioFileClip(audio_path)
final_audio = process_audio_tracks(original_audio, new_audio, params, video_duration)
update_progress("音频处理完成")
# 处理字幕
if subtitle_path and os.path.exists(subtitle_path):
video_clip = process_subtitles(subtitle_path, video_clip, video_duration, create_text_clip)
update_progress("字幕处理完成")
# 合并音频和导出
video_clip = video_clip.set_audio(final_audio)
video_clip.write_videofile(
output_file,
audio_codec="aac",
temp_audiofile=os.path.join(output_dir, "temp-audio.m4a"),
threads=params.n_threads,
logger=None,
fps=30,
)
except FileNotFoundError as e:
logger.error(f"文件不存在: {str(e)}")
raise
except Exception as e:
logger.error(f"视频生成失败: {str(e)}")
raise
finally:
logger.success("完成")
def process_audio_tracks(original_audio, new_audio, params, video_duration):
"""处理所有音轨"""
audio_tracks = []
if original_audio is not None:
audio_tracks.append(original_audio)
new_audio = new_audio.volumex(params.voice_volume)
audio_tracks.append(new_audio)
# 处理背景音乐
bgm_file = get_bgm_file(bgm_type=params.bgm_type, bgm_file=params.bgm_file)
# 合并音频轨道
audio_tracks = [original_audio, new_audio]
if bgm_file:
try:
bgm_clip = (
AudioFileClip(bgm_file).volumex(params.bgm_volume).audio_fadeout(3)
)
bgm_clip = AudioFileClip(bgm_file).volumex(params.bgm_volume).audio_fadeout(3)
bgm_clip = afx.audio_loop(bgm_clip, duration=video_duration)
audio_tracks.append(bgm_clip)
except Exception as e:
logger.error(f"添加背景音乐失败: {str(e)}")
return CompositeAudioClip(audio_tracks) if audio_tracks else new_audio
# 合并所有音频轨道
final_audio = CompositeAudioClip(audio_tracks)
video_clip = video_clip.set_audio(final_audio)
video_clip.write_videofile(
output_file,
audio_codec="aac",
temp_audiofile_path=output_dir,
threads=params.n_threads,
logger=None,
fps=30,
)
video_clip.close()
del video_clip
logger.success("完成")
def process_subtitles(subtitle_path, video_clip, video_duration, create_text_clip):
"""处理字幕"""
if not (subtitle_path and os.path.exists(subtitle_path)):
return video_clip
sub = SubtitlesClip(subtitles=subtitle_path, encoding="utf-8")
text_clips = []
for item in sub.subtitles:
clip = create_text_clip(subtitle_item=item)
# 时间范围调整
start_time = max(clip.start, 0)
if start_time >= video_duration:
continue
end_time = min(clip.end, video_duration)
clip = clip.set_start(start_time).set_end(end_time)
text_clips.append(clip)
logger.info(f"处理了 {len(text_clips)} 段字幕")
return CompositeVideoClip([video_clip, *text_clips])
def preprocess_video(materials: List[MaterialInfo], clip_duration=4):
@ -499,7 +437,7 @@ def preprocess_video(materials: List[MaterialInfo], clip_duration=4):
def combine_clip_videos(combined_video_path: str,
video_paths: List[str],
video_ost_list: List[bool],
video_ost_list: List[int],
list_script: list,
video_aspect: VideoAspect = VideoAspect.portrait,
threads: int = 2,
@ -509,92 +447,119 @@ def combine_clip_videos(combined_video_path: str,
Args:
combined_video_path: 合并后的存储路径
video_paths: 子视频路径列表
video_ost_list: 原声播放列表
video_ost_list: 原声播放列表 (0: 不保留原声, 1: 只保留原声, 2: 保留原声并保留解说)
list_script: 剪辑脚本
video_aspect: 屏幕比例
threads: 线程数
Returns:
str: 合并后的视频路径
"""
from app.utils.utils import calculate_total_duration
audio_duration = calculate_total_duration(list_script)
logger.info(f"音频的最大持续时间: {audio_duration} s")
# 每个剪辑所需的持续时间
req_dur = audio_duration / len(video_paths)
# req_dur = max_clip_duration
# logger.info(f"每个剪辑的最大长度为 {req_dur} s")
output_dir = os.path.dirname(combined_video_path)
aspect = VideoAspect(video_aspect)
video_width, video_height = aspect.to_resolution()
clips = []
video_duration = 0
# 一遍又一遍地添加下载的剪辑,直到达到音频的持续时间 max_duration
# while video_duration < audio_duration:
for video_path, video_ost in zip(video_paths, video_ost_list):
cache_video_path = utils.root_dir()
clip = VideoFileClip(os.path.join(cache_video_path, video_path))
# 通过 ost 字段判断是否播放原声
if not video_ost:
clip = clip.without_audio()
# # 检查剪辑是否比剩余音频长
# if (audio_duration - video_duration) < clip.duration:
# clip = clip.subclip(0, (audio_duration - video_duration))
# # 仅当计算出的剪辑长度 req_dur 短于实际剪辑时,才缩短剪辑以防止静止图像
# elif req_dur < clip.duration:
# clip = clip.subclip(0, req_dur)
clip = clip.set_fps(30)
try:
clip = VideoFileClip(video_path)
if video_ost == 0: # 不保留原声
clip = clip.without_audio()
# video_ost 为 1 或 2 时都保留原声,不需要特殊处理
clip = clip.set_fps(30)
# 并非所有视频的大小都相同,因此我们需要调整它们的大小
clip_w, clip_h = clip.size
if clip_w != video_width or clip_h != video_height:
clip_ratio = clip.w / clip.h
video_ratio = video_width / video_height
# 处理视频尺寸
clip_w, clip_h = clip.size
if clip_w != video_width or clip_h != video_height:
clip = resize_video_with_padding(
clip,
target_width=video_width,
target_height=video_height
)
logger.info(f"视频 {video_path} 已调整尺寸为 {video_width} x {video_height}")
if clip_ratio == video_ratio:
# 等比例缩放
clip = clip.resize((video_width, video_height))
else:
# 等比缩放视频
if clip_ratio > video_ratio:
# 按照目标宽度等比缩放
scale_factor = video_width / clip_w
else:
# 按照目标高度等比缩放
scale_factor = video_height / clip_h
clips.append(clip)
except Exception as e:
logger.error(f"处理视频 {video_path} 时出错: {str(e)}")
continue
new_width = int(clip_w * scale_factor)
new_height = int(clip_h * scale_factor)
clip_resized = clip.resize(newsize=(new_width, new_height))
if not clips:
raise ValueError("没有有效的视频片段可以合并")
background = ColorClip(size=(video_width, video_height), color=(0, 0, 0))
clip = CompositeVideoClip([
background.set_duration(clip.duration),
clip_resized.set_position("center")
])
try:
video_clip = concatenate_videoclips(clips)
video_clip = video_clip.set_fps(30)
logger.info("开始合并视频...")
video_clip.write_videofile(
filename=combined_video_path,
threads=threads,
logger=None,
audio_codec="aac",
fps=30,
temp_audiofile=os.path.join(output_dir, "temp-audio.m4a")
)
finally:
# 确保资源被正确<E6ADA3><E7A1AE><EFBFBD>
video_clip.close()
for clip in clips:
clip.close()
logger.info(f"将视频 {video_path} 大小调整为 {video_width} x {video_height}, 剪辑尺寸: {clip_w} x {clip_h}")
clips.append(clip)
video_duration += clip.duration
video_clip = concatenate_videoclips(clips)
video_clip = video_clip.set_fps(30)
logger.info(f"合并视频中...")
video_clip.write_videofile(filename=combined_video_path,
threads=threads,
logger=None,
temp_audiofile_path=output_dir,
audio_codec="aac",
fps=30,
)
video_clip.close()
logger.success(f"completed")
logger.success("视频合并完成")
return combined_video_path
def resize_video_with_padding(clip, target_width: int, target_height: int):
"""辅助函数:调整视频尺寸并添加黑边"""
clip_ratio = clip.w / clip.h
target_ratio = target_width / target_height
if clip_ratio == target_ratio:
return clip.resize((target_width, target_height))
if clip_ratio > target_ratio:
scale_factor = target_width / clip.w
else:
scale_factor = target_height / clip.h
new_width = int(clip.w * scale_factor)
new_height = int(clip.h * scale_factor)
clip_resized = clip.resize(newsize=(new_width, new_height))
background = ColorClip(
size=(target_width, target_height),
color=(0, 0, 0)
).set_duration(clip.duration)
return CompositeVideoClip([
background,
clip_resized.set_position("center")
])
def validate_params(video_path, audio_path, output_file, params):
"""验证输入参数"""
if not os.path.exists(video_path):
raise FileNotFoundError(f"视频文件不存在: {video_path}")
if not os.path.exists(audio_path):
raise FileNotFoundError(f"音频文件不存在: {audio_path}")
output_dir = os.path.dirname(output_file)
if not os.path.exists(output_dir):
raise FileNotFoundError(f"输出目录不存在: {output_dir}")
if not hasattr(params, 'video_aspect'):
raise ValueError("params 缺少必要参数 video_aspect")
if __name__ == "__main__":
# combined_video_path = "../../storage/tasks/12312312/com123.mp4"
#
@ -635,23 +600,23 @@ if __name__ == "__main__":
# ]
# combine_clip_videos(combined_video_path=combined_video_path, video_paths=video_paths, video_ost_list=video_ost_list, list_script=list_script)
cfg = VideoClipParams()
cfg.video_aspect = VideoAspect.portrait
cfg.font_name = "STHeitiMedium.ttc"
cfg.font_size = 60
cfg.stroke_color = "#000000"
cfg.stroke_width = 1.5
cfg.text_fore_color = "#FFFFFF"
cfg.text_background_color = "transparent"
cfg.bgm_type = "random"
cfg.bgm_file = ""
cfg.bgm_volume = 1.0
cfg.subtitle_enabled = True
cfg.subtitle_position = "bottom"
cfg.n_threads = 2
cfg.paragraph_number = 1
cfg.voice_volume = 1.0
# cfg = VideoClipParams()
# cfg.video_aspect = VideoAspect.portrait
# cfg.font_name = "STHeitiMedium.ttc"
# cfg.font_size = 60
# cfg.stroke_color = "#000000"
# cfg.stroke_width = 1.5
# cfg.text_fore_color = "#FFFFFF"
# cfg.text_background_color = "transparent"
# cfg.bgm_type = "random"
# cfg.bgm_file = ""
# cfg.bgm_volume = 1.0
# cfg.subtitle_enabled = True
# cfg.subtitle_position = "bottom"
# cfg.n_threads = 2
# cfg.paragraph_number = 1
#
# cfg.voice_volume = 1.0
# generate_video(video_path=video_file,
# audio_path=audio_file,
@ -659,18 +624,27 @@ if __name__ == "__main__":
# output_file=output_file,
# params=cfg
# )
#
# video_path = "../../storage/tasks/7f5ae494-abce-43cf-8f4f-4be43320eafa/combined-1.mp4"
#
# audio_path = "../../storage/tasks/7f5ae494-abce-43cf-8f4f-4be43320eafa/audio_00-00-00-07.mp3"
#
# subtitle_path = "../../storage/tasks/7f5ae494-abce-43cf-8f4f-4be43320eafa\subtitle.srt"
#
# output_file = "../../storage/tasks/7f5ae494-abce-43cf-8f4f-4be43320eafa/final-123.mp4"
#
# generate_video_v2(video_path=video_path,
# audio_path=audio_path,
# subtitle_path=subtitle_path,
# output_file=output_file,
# params=cfg
# )
video_path = "../../storage/tasks/7f5ae494-abce-43cf-8f4f-4be43320eafa/combined-1.mp4"
# 合并视频
video_list = [
'./storage/cache_videos/vid-01_03-01_50.mp4',
'./storage/cache_videos/vid-01_55-02_29.mp4',
'./storage/cache_videos/vid-03_24-04_04.mp4',
'./storage/cache_videos/vid-04_50-05_28.mp4'
]
audio_path = "../../storage/tasks/7f5ae494-abce-43cf-8f4f-4be43320eafa/audio_00-00-00-07.mp3"
subtitle_path = "../../storage/tasks/7f5ae494-abce-43cf-8f4f-4be43320eafa\subtitle.srt"
output_file = "../../storage/tasks/7f5ae494-abce-43cf-8f4f-4be43320eafa/final-123.mp4"
generate_video_v2(video_path=video_path,
audio_path=audio_path,
subtitle_path=subtitle_path,
output_file=output_file,
params=cfg
)

View File

@ -1032,11 +1032,11 @@ def is_azure_v2_voice(voice_name: str):
def tts(
text: str, voice_name: str, voice_rate: float, voice_file: str
text: str, voice_name: str, voice_rate: float, voice_pitch: float, voice_file: str
) -> [SubMaker, None]:
# if is_azure_v2_voice(voice_name):
# return azure_tts_v2(text, voice_name, voice_file)
return azure_tts_v1(text, voice_name, voice_rate, voice_file)
return azure_tts_v1(text, voice_name, voice_rate, voice_pitch, voice_file)
def convert_rate_to_percent(rate: float) -> str:
@ -1049,18 +1049,29 @@ def convert_rate_to_percent(rate: float) -> str:
return f"{percent}%"
def convert_pitch_to_percent(rate: float) -> str:
if rate == 1.0:
return "+0Hz"
percent = round((rate - 1.0) * 100)
if percent > 0:
return f"+{percent}Hz"
else:
return f"{percent}Hz"
def azure_tts_v1(
text: str, voice_name: str, voice_rate: float, voice_file: str
text: str, voice_name: str, voice_rate: float, voice_pitch: float, voice_file: str
) -> [SubMaker, None]:
voice_name = parse_voice_name(voice_name)
text = text.strip()
rate_str = convert_rate_to_percent(voice_rate)
pitch_str = convert_pitch_to_percent(voice_pitch)
for i in range(3):
try:
logger.info(f"start, voice name: {voice_name}, try: {i + 1}")
async def _do() -> SubMaker:
communicate = edge_tts.Communicate(text, voice_name, rate=rate_str)
communicate = edge_tts.Communicate(text, voice_name, rate=rate_str, pitch=pitch_str, proxy=config.proxy.get("http"))
sub_maker = edge_tts.SubMaker()
with open(voice_file, "wb") as file:
async for chunk in communicate.stream():
@ -1392,7 +1403,7 @@ def get_audio_duration(sub_maker: submaker.SubMaker):
return sub_maker.offset[-1][1] / 10000000
def tts_multiple(task_id: str, list_script: list, voice_name: str, voice_rate: float, force_regenerate: bool = True):
def tts_multiple(task_id: str, list_script: list, voice_name: str, voice_rate: float, voice_pitch: float, force_regenerate: bool = True):
"""
根据JSON文件中的多段文本进行TTS转换
@ -1409,13 +1420,13 @@ def tts_multiple(task_id: str, list_script: list, voice_name: str, voice_rate: f
sub_maker_list = []
for item in list_script:
if not item['OST']:
# timestamp = item['new_timestamp'].replace(':', '@')
timestamp = item['new_timestamp']
if item['OST'] != 1:
# 将时间戳中的冒号替换为下划线
timestamp = item['new_timestamp'].replace(':', '_')
audio_file = os.path.join(output_dir, f"audio_{timestamp}.mp3")
# 检查文件是否已存在,如存在且不强制重新生成,则跳过
if os.path.exists(audio_file):
if os.path.exists(audio_file) and not force_regenerate:
logger.info(f"音频文件已存在,跳过生成: {audio_file}")
audio_files.append(audio_file)
continue
@ -1426,7 +1437,8 @@ def tts_multiple(task_id: str, list_script: list, voice_name: str, voice_rate: f
text=text,
voice_name=voice_name,
voice_rate=voice_rate,
voice_file=audio_file
voice_pitch=voice_pitch,
voice_file=audio_file,
)
if sub_maker is None:

View File

@ -1,115 +1,81 @@
import json
from loguru import logger
import os
from datetime import timedelta
from typing import Dict, Any
def time_to_seconds(time_str):
parts = list(map(int, time_str.split(':')))
if len(parts) == 2:
return timedelta(minutes=parts[0], seconds=parts[1]).total_seconds()
elif len(parts) == 3:
return timedelta(hours=parts[0], minutes=parts[1], seconds=parts[2]).total_seconds()
raise ValueError(f"无法解析时间字符串: {time_str}")
def check_format(script_content: str) -> Dict[str, Any]:
"""检查脚本格式
Args:
script_content: 脚本内容
Returns:
Dict: {'success': bool, 'message': str}
"""
try:
# 检查是否为有效的JSON
data = json.loads(script_content)
# 检查是否为列表
if not isinstance(data, list):
return {
'success': False,
'message': '脚本必须是JSON数组格式'
}
# 检查每个片段
for i, clip in enumerate(data):
# 检查必需字段
required_fields = ['narration', 'picture', 'timestamp']
for field in required_fields:
if field not in clip:
return {
'success': False,
'message': f'{i+1}个片段缺少必需字段: {field}'
}
# 检查字段类型
if not isinstance(clip['narration'], str):
return {
'success': False,
'message': f'{i+1}个片段的narration必须是字符串'
}
if not isinstance(clip['picture'], str):
return {
'success': False,
'message': f'{i+1}个片段的picture必须是字符串'
}
if not isinstance(clip['timestamp'], str):
return {
'success': False,
'message': f'{i+1}个片段的timestamp必须是字符串'
}
# 检查字段内容不能为空
if not clip['narration'].strip():
return {
'success': False,
'message': f'{i+1}个片段的narration不能为空'
}
if not clip['picture'].strip():
return {
'success': False,
'message': f'{i+1}个片段的picture不能为空'
}
if not clip['timestamp'].strip():
return {
'success': False,
'message': f'{i+1}个片段的timestamp不能为空'
}
def seconds_to_time_str(seconds):
hours, remainder = divmod(int(seconds), 3600)
minutes, seconds = divmod(remainder, 60)
if hours > 0:
return f"{hours:02d}:{minutes:02d}:{seconds:02d}"
else:
return f"{minutes:02d}:{seconds:02d}"
return {
'success': True,
'message': '脚本格式检查通过'
}
def adjust_timestamp(start_time, duration):
start_seconds = time_to_seconds(start_time)
end_seconds = start_seconds + duration
return f"{start_time}-{seconds_to_time_str(end_seconds)}"
def estimate_audio_duration(text):
# 假设平均每个字符需要 0.2 秒
return len(text) * 0.2
def check_script(data, total_duration):
errors = []
time_ranges = []
logger.info("开始检查脚本")
logger.info(f"视频总时长: {total_duration:.2f}")
logger.info("=" * 50)
for i, item in enumerate(data, 1):
logger.info(f"\n检查第 {i} 项:")
# 检查所有必需字段
required_fields = ['picture', 'timestamp', 'narration', 'OST']
for field in required_fields:
if field not in item:
errors.append(f"{i} 项缺少 {field} 字段")
logger.info(f" - 错误: 缺少 {field} 字段")
else:
logger.info(f" - {field}: {item[field]}")
# 检查 OST 相关规则
if item.get('OST') == False:
if not item.get('narration'):
errors.append(f"{i} 项 OST 为 false但 narration 为空")
logger.info(" - 错误: OST 为 false但 narration 为空")
elif len(item['narration']) > 60:
errors.append(f"{i} 项 OST 为 false但 narration 超过 60 字")
logger.info(f" - 错误: OST 为 false但 narration 超过 60 字 (当前: {len(item['narration'])} 字)")
else:
logger.info(" - OST 为 falsenarration 检查通过")
elif item.get('OST') == True:
if "原声播放_" not in item.get('narration'):
errors.append(f"{i} 项 OST 为 true但 narration 不为空")
logger.info(" - 错误: OST 为 true但 narration 不为空")
else:
logger.info(" - OST 为 truenarration 检查通过")
# 检查 timestamp
if 'timestamp' in item:
start, end = map(time_to_seconds, item['timestamp'].split('-'))
if any((start < existing_end and end > existing_start) for existing_start, existing_end in time_ranges):
errors.append(f"{i} 项 timestamp '{item['timestamp']}' 与其他时间段重叠")
logger.info(f" - 错误: timestamp '{item['timestamp']}' 与其他时间段重叠")
else:
logger.info(f" - timestamp '{item['timestamp']}' 检查通过")
time_ranges.append((start, end))
# if end > total_duration:
# errors.append(f"第 {i} 项 timestamp '{item['timestamp']}' 超过总时长 {total_duration:.2f} 秒")
# logger.info(f" - 错误: timestamp '{item['timestamp']}' 超过总时长 {total_duration:.2f} 秒")
# else:
# logger.info(f" - timestamp 在总时长范围内")
# 处理 narration 字段
if item.get('OST') == False and item.get('narration'):
estimated_duration = estimate_audio_duration(item['narration'])
start_time = item['timestamp'].split('-')[0]
item['timestamp'] = adjust_timestamp(start_time, estimated_duration)
logger.info(f" - 已调整 timestamp 为 {item['timestamp']} (估算音频时长: {estimated_duration:.2f} 秒)")
if errors:
logger.info("检查结果:不通过")
logger.info("发现以下错误:")
for error in errors:
logger.info(f"- {error}")
else:
logger.info("检查结果:通过")
logger.info("所有项目均符合规则要求。")
return errors, data
if __name__ == "__main__":
file_path = "/Users/apple/Desktop/home/NarratoAI/resource/scripts/test004.json"
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
total_duration = 280
# check_script(data, total_duration)
from app.utils.utils import add_new_timestamps
res = add_new_timestamps(data)
print(json.dumps(res, indent=4, ensure_ascii=False))
except json.JSONDecodeError as e:
return {
'success': False,
'message': f'JSON格式错误: {str(e)}'
}
except Exception as e:
return {
'success': False,
'message': f'检查过程中发生错误: {str(e)}'
}

View File

@ -0,0 +1,392 @@
import os
import json
import traceback
from loguru import logger
import tiktoken
from typing import List, Dict
from datetime import datetime
from openai import OpenAI
import google.generativeai as genai
class BaseGenerator:
def __init__(self, model_name: str, api_key: str, prompt: str):
self.model_name = model_name
self.api_key = api_key
self.base_prompt = prompt
self.conversation_history = []
self.chunk_overlap = 50
self.last_chunk_ending = ""
def generate_script(self, scene_description: str, word_count: int) -> str:
raise NotImplementedError("Subclasses must implement generate_script method")
class OpenAIGenerator(BaseGenerator):
def __init__(self, model_name: str, api_key: str, prompt: str):
super().__init__(model_name, api_key, prompt)
self.client = OpenAI(api_key=api_key)
self.max_tokens = 7000
try:
self.encoding = tiktoken.encoding_for_model(self.model_name)
except KeyError:
logger.info(f"警告:未找到模型 {self.model_name} 的专用编码器,使用认编码器")
self.encoding = tiktoken.get_encoding("cl100k_base")
def _count_tokens(self, messages: list) -> int:
num_tokens = 0
for message in messages:
num_tokens += 3
for key, value in message.items():
num_tokens += len(self.encoding.encode(str(value)))
if key == "role":
num_tokens += 1
num_tokens += 3
return num_tokens
def _trim_conversation_history(self, system_prompt: str, new_user_prompt: str) -> None:
base_messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": new_user_prompt}
]
base_tokens = self._count_tokens(base_messages)
temp_history = []
current_tokens = base_tokens
for message in reversed(self.conversation_history):
message_tokens = self._count_tokens([message])
if current_tokens + message_tokens > self.max_tokens:
break
temp_history.insert(0, message)
current_tokens += message_tokens
self.conversation_history = temp_history
def generate_script(self, scene_description: str, word_count: int) -> str:
max_attempts = 3
tolerance = 5
for attempt in range(max_attempts):
system_prompt, user_prompt = self._create_prompt(scene_description, word_count)
self._trim_conversation_history(system_prompt, user_prompt)
messages = [
{"role": "system", "content": system_prompt},
*self.conversation_history,
{"role": "user", "content": user_prompt}
]
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
temperature=0.7,
max_tokens=500,
top_p=0.9,
frequency_penalty=0.3,
presence_penalty=0.5
)
generated_script = response.choices[0].message.content.strip().strip('"').strip("'").replace('\"',
'').replace(
'\n', '')
current_length = len(generated_script)
if abs(current_length - word_count) <= tolerance:
self.conversation_history.append({"role": "user", "content": user_prompt})
self.conversation_history.append({"role": "assistant", "content": generated_script})
self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len(
generated_script) > self.chunk_overlap else generated_script
return generated_script
return generated_script
def _create_prompt(self, scene_description: str, word_count: int) -> tuple:
system_prompt = self.base_prompt.format(word_count=word_count)
user_prompt = f"""上一段文案的结尾:{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"}
当前画面描述{scene_description}
请确保新生成的文案与上文自然衔接保持叙事的连贯性和趣味性
严格字数要求{word_count}允许误差±5"""
return system_prompt, user_prompt
class GeminiGenerator(BaseGenerator):
def __init__(self, model_name: str, api_key: str, prompt: str):
super().__init__(model_name, api_key, prompt)
genai.configure(api_key=api_key)
self.model = genai.GenerativeModel(model_name)
def generate_script(self, scene_description: str, word_count: int) -> str:
max_attempts = 3
tolerance = 5
for attempt in range(max_attempts):
prompt = f"""{self.base_prompt}
上一段文案的结尾{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"}
当前画面描述{scene_description}
请确保新生成的文案与上文自然衔接保持叙事的连贯性和趣味性
严格字数要求{word_count}允许误差±5"""
response = self.model.generate_content(prompt)
generated_script = response.text.strip().strip('"').strip("'").replace('\"', '').replace('\n', '')
current_length = len(generated_script)
if abs(current_length - word_count) <= tolerance:
self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len(
generated_script) > self.chunk_overlap else generated_script
return generated_script
return generated_script
class QwenGenerator(BaseGenerator):
def __init__(self, model_name: str, api_key: str, prompt: str):
super().__init__(model_name, api_key, prompt)
self.client = OpenAI(
api_key=api_key,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
)
def generate_script(self, scene_description: str, word_count: int) -> str:
max_attempts = 3
tolerance = 5
for attempt in range(max_attempts):
prompt = f"""{self.base_prompt}
上一段文案的结尾{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"}
当前画面描述{scene_description}
请确保新生成的文案与上文自然衔接保持叙事的连贯性和趣味性
严格字数要求{word_count}允许误差±5"""
messages = [
{"role": "system", "content": self.base_prompt},
{"role": "user", "content": prompt}
]
response = self.client.chat.completions.create(
model=self.model_name, # 如 "qwen-plus"
messages=messages,
temperature=0.7,
max_tokens=500,
top_p=0.9,
frequency_penalty=0.3,
presence_penalty=0.5
)
generated_script = response.choices[0].message.content.strip().strip('"').strip("'").replace('\"',
'').replace(
'\n', '')
current_length = len(generated_script)
if abs(current_length - word_count) <= tolerance:
self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len(
generated_script) > self.chunk_overlap else generated_script
return generated_script
return generated_script
class MoonshotGenerator(BaseGenerator):
def __init__(self, model_name: str, api_key: str, prompt: str):
super().__init__(model_name, api_key, prompt)
self.client = OpenAI(
api_key=api_key,
base_url="https://api.moonshot.cn/v1"
)
def generate_script(self, scene_description: str, word_count: int) -> str:
max_attempts = 3
tolerance = 5
for attempt in range(max_attempts):
prompt = f"""{self.base_prompt}
上一段文案的结尾{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"}
当前画面描述{scene_description}
请确保新生成的文案与上文自然衔接保持叙事的连贯性和趣味性
严格字数要求{word_count}允许误差±5"""
messages = [
{"role": "system", "content": self.base_prompt},
{"role": "user", "content": prompt}
]
response = self.client.chat.completions.create(
model=self.model_name, # 如 "moonshot-v1-8k"
messages=messages,
temperature=0.7,
max_tokens=500,
top_p=0.9,
frequency_penalty=0.3,
presence_penalty=0.5
)
generated_script = response.choices[0].message.content.strip().strip('"').strip("'").replace('\"',
'').replace(
'\n', '')
current_length = len(generated_script)
if abs(current_length - word_count) <= tolerance:
self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len(
generated_script) > self.chunk_overlap else generated_script
return generated_script
return generated_script
class ScriptProcessor:
def __init__(self, model_name: str, api_key: str = None, prompt: str = None, video_theme: str = ""):
self.model_name = model_name
self.api_key = api_key
self.video_theme = video_theme
self.prompt = prompt or self._get_default_prompt()
# 根据模型名称选择对应的生成器
if 'gemini' in model_name.lower():
self.generator = GeminiGenerator(model_name, self.api_key, self.prompt)
elif 'qwen' in model_name.lower():
self.generator = QwenGenerator(model_name, self.api_key, self.prompt)
elif 'moonshot' in model_name.lower():
self.generator = MoonshotGenerator(model_name, self.api_key, self.prompt)
else:
self.generator = OpenAIGenerator(model_name, self.api_key, self.prompt)
def _get_default_prompt(self, word_count=None) -> str:
return f"""你是一位极具幽默感的短视频脚本创作大师,擅长用"温和的违反"制造笑点,让{self.video_theme}视频既有趣又富有传播力。你的任务是将视频画面描述转化为能在社交平台疯狂传播的爆款口播文案。
目标受众热爱生活追求独特体验的18-35岁年轻人
文案风格基于HKRR理论 + 段子手精神
主题{self.video_theme}
创作核心理念
1. 敢于用"温和的违反"制造笑点但不能过于冒犯
2. 巧妙运用中国式幽默让观众会心一笑
3. 保持轻松愉快的叙事基调
爆款内容四要素
快乐元素 Happy
1. 用调侃的语气描述建造过程中的"笨手笨脚"
2. 巧妙植入网络流行梗增加内容的传播性
3. 适时自嘲展现真实且有趣的一面
知识价值 Knowledge
1. 用段子手的方式解释专业知识比如"这根木头不是一般的木头,它比我前任还难搞..."
2. 把复杂的建造技巧转化为生动有趣的比喻
3. 在幽默中传递实用的野外生存技能
情感共鸣 Resonance
1. 描述"真实但夸张"的建造困境
2. 把对自然的感悟融入俏皮话中
3. 用接地气的表达方式拉近与观众距离
节奏控制 Rhythm
1. 严格控制文案字数在{word_count}字左右允许误差不超过5字
2. 像讲段子一样注意铺垫和包袱的节奏
3. 确保每段都有笑点但不强求
4. 段落结尾干净利落不拖泥带水
连贯性要求
1. 新生成的内容必须自然衔接上一段文案的结尾
2. 使用恰当的连接词和过渡语确保叙事流畅
3. 保持人物视角和语气的一致性
4. 避免重复上一段已经提到的信息
5. 确保情节和建造过程的逻辑连续性
字数控制要求
1. 严格控制文案字数在{word_count}字左右允许误差不超过5字
2. 如果内容过长优先精简修饰性词语
3. 如果内容过短可以适当增加细节描写
4. 保持文案结构完整不因字数限制而牺牲内容质量
5. 确保每个笑点和包袱都得到完整表达
我会按顺序提供多段视频画面描述请创作既搞笑又能火爆全网的口播文案
记住要敢于用"温和的违反"制造笑点但要把握好尺度让观众在轻松愉快中感受野外建造的乐趣"""
def calculate_duration_and_word_count(self, time_range: str) -> int:
try:
start_str, end_str = time_range.split('-')
def time_to_seconds(time_str):
minutes, seconds = map(int, time_str.split(':'))
return minutes * 60 + seconds
start_seconds = time_to_seconds(start_str)
end_seconds = time_to_seconds(end_str)
duration = end_seconds - start_seconds
word_count = int(duration / 0.2)
return word_count
except Exception as e:
logger.info(f"时间格式转换错误: {traceback.format_exc()}")
return 100
def process_frames(self, frame_content_list: List[Dict]) -> List[Dict]:
for frame_content in frame_content_list:
word_count = self.calculate_duration_and_word_count(frame_content["timestamp"])
script = self.generator.generate_script(frame_content["picture"], word_count)
frame_content["narration"] = script
frame_content["OST"] = 2
logger.info(f"时间范围: {frame_content['timestamp']}, 建议字数: {word_count}")
logger.info(script)
self._save_results(frame_content_list)
return frame_content_list
def _save_results(self, frame_content_list: List[Dict]):
"""保存处理结果,并添加新的时间戳"""
try:
# 计算新的时间戳
current_time = 0 # 当前时间点(秒)
for frame in frame_content_list:
# 获取原始时间戳的持续时间
start_str, end_str = frame['timestamp'].split('-')
def time_to_seconds(time_str):
minutes, seconds = map(int, time_str.split(':'))
return minutes * 60 + seconds
# 计算当前片段的持续时间
start_seconds = time_to_seconds(start_str)
end_seconds = time_to_seconds(end_str)
duration = end_seconds - start_seconds
# 转换秒数为 MM:SS 格式
def seconds_to_time(seconds):
minutes = seconds // 60
remaining_seconds = seconds % 60
return f"{minutes:02d}:{remaining_seconds:02d}"
# 设置新的时间戳
new_start = seconds_to_time(current_time)
new_end = seconds_to_time(current_time + duration)
frame['new_timestamp'] = f"{new_start}-{new_end}"
# 更新当前时间点
current_time += duration
# 保存结果
file_name = f"storage/json/step2_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, 'w', encoding='utf-8') as file:
json.dump(frame_content_list, file, ensure_ascii=False, indent=4)
logger.info(f"保存脚本成功,总时长: {seconds_to_time(current_time)}")
except Exception as e:
logger.error(f"保存结果时发生错误: {str(e)}\n{traceback.format_exc()}")
raise

View File

@ -56,7 +56,7 @@ def to_json(obj):
# 使用serialize函数处理输入对象
serialized_obj = serialize(obj)
# 序列化处理后的对象为JSON符串
# 序列化处理后的对象为JSON<EFBFBD><EFBFBD><EFBFBD>符串
return json.dumps(serialized_obj, ensure_ascii=False, indent=4)
except Exception as e:
return None
@ -100,7 +100,7 @@ def task_dir(sub_dir: str = ""):
def font_dir(sub_dir: str = ""):
d = resource_dir(f"fonts")
d = resource_dir("fonts")
if sub_dir:
d = os.path.join(d, sub_dir)
if not os.path.exists(d):
@ -109,7 +109,7 @@ def font_dir(sub_dir: str = ""):
def song_dir(sub_dir: str = ""):
d = resource_dir(f"songs")
d = resource_dir("songs")
if sub_dir:
d = os.path.join(d, sub_dir)
if not os.path.exists(d):
@ -423,5 +423,104 @@ def cut_video(params, progress_callback=None):
return task_id, subclip_videos
except Exception as e:
logger.error(f"视频裁剪过程中发生错误: {traceback.format_exc()}")
logger.error(f"视频裁剪过程中发生错误: \n{traceback.format_exc()}")
raise
def temp_dir(sub_dir: str = ""):
"""
获取临时文件目录
Args:
sub_dir: 子目录名
Returns:
str: 临时文件目录路径
"""
d = os.path.join(storage_dir(), "temp")
if sub_dir:
d = os.path.join(d, sub_dir)
if not os.path.exists(d):
os.makedirs(d)
return d
def clear_keyframes_cache(video_path: str = None):
"""
清理关键帧缓存
Args:
video_path: 视频文件路径如果指定则只清理该视频的缓存
"""
try:
keyframes_dir = os.path.join(temp_dir(), "keyframes")
if not os.path.exists(keyframes_dir):
return
if video_path:
# <20><><EFBFBD>理指定视频的缓存
video_hash = md5(video_path + str(os.path.getmtime(video_path)))
video_keyframes_dir = os.path.join(keyframes_dir, video_hash)
if os.path.exists(video_keyframes_dir):
import shutil
shutil.rmtree(video_keyframes_dir)
logger.info(f"已清理视频关键帧缓存: {video_path}")
else:
# 清理所有缓存
import shutil
shutil.rmtree(keyframes_dir)
logger.info("已清理所有关键帧缓存")
except Exception as e:
logger.error(f"清理关键帧缓存失败: {e}")
def init_resources():
"""初始化资源文件"""
try:
# 创建字体目录
font_dir = os.path.join(root_dir(), "resource", "fonts")
os.makedirs(font_dir, exist_ok=True)
# 检查字体文件
font_files = [
("SourceHanSansCN-Regular.otf", "https://github.com/adobe-fonts/source-han-sans/raw/release/OTF/SimplifiedChinese/SourceHanSansSC-Regular.otf"),
("simhei.ttf", "C:/Windows/Fonts/simhei.ttf"), # Windows 黑体
("simkai.ttf", "C:/Windows/Fonts/simkai.ttf"), # Windows 楷体
("simsun.ttc", "C:/Windows/Fonts/simsun.ttc"), # Windows 宋体
]
# 优先使用系统字体
system_font_found = False
for font_name, source in font_files:
if not source.startswith("http") and os.path.exists(source):
target_path = os.path.join(font_dir, font_name)
if not os.path.exists(target_path):
import shutil
shutil.copy2(source, target_path)
logger.info(f"已复制系统字体: {font_name}")
system_font_found = True
break
# 如果没有找到系统字体,则下载思源黑体
if not system_font_found:
source_han_path = os.path.join(font_dir, "SourceHanSansCN-Regular.otf")
if not os.path.exists(source_han_path):
download_font(font_files[0][1], source_han_path)
except Exception as e:
logger.error(f"初始化资源文件失败: {e}")
def download_font(url: str, font_path: str):
"""下载字体文件"""
try:
logger.info(f"正在下载字体文件: {url}")
import requests
response = requests.get(url)
response.raise_for_status()
with open(font_path, 'wb') as f:
f.write(response.content)
logger.info(f"字体文件下载成功: {font_path}")
except Exception as e:
logger.error(f"下载字体文件失败: {e}")
raise

View File

@ -0,0 +1,209 @@
import cv2
import numpy as np
from sklearn.cluster import KMeans
import os
import re
from typing import List, Tuple, Generator
class VideoProcessor:
def __init__(self, video_path: str):
"""
初始化视频处理器
Args:
video_path: 视频文件路径
"""
if not os.path.exists(video_path):
raise FileNotFoundError(f"视频文件不存在: {video_path}")
self.video_path = video_path
self.cap = cv2.VideoCapture(video_path)
if not self.cap.isOpened():
raise RuntimeError(f"无法打开视频文件: {video_path}")
self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
self.fps = int(self.cap.get(cv2.CAP_PROP_FPS))
def __del__(self):
"""析构函数,确保视频资源被释放"""
if hasattr(self, 'cap'):
self.cap.release()
def preprocess_video(self) -> Generator[np.ndarray, None, None]:
"""
使用生成器方式读取视频帧
Yields:
np.ndarray: 视频帧
"""
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # 重置到视频开始
while self.cap.isOpened():
ret, frame = self.cap.read()
if not ret:
break
yield frame
def detect_shot_boundaries(self, frames: List[np.ndarray], threshold: int = 30) -> List[int]:
"""
使用帧差法检测镜头边界
Args:
frames: 视频帧列表
threshold: 差异阈值
Returns:
List[int]: 镜头边界帧的索引列表
"""
shot_boundaries = []
for i in range(1, len(frames)):
prev_frame = cv2.cvtColor(frames[i - 1], cv2.COLOR_BGR2GRAY)
curr_frame = cv2.cvtColor(frames[i], cv2.COLOR_BGR2GRAY)
diff = np.mean(np.abs(curr_frame.astype(int) - prev_frame.astype(int)))
if diff > threshold:
shot_boundaries.append(i)
return shot_boundaries
def extract_keyframes(self, frames: List[np.ndarray], shot_boundaries: List[int]) -> Tuple[List[np.ndarray], List[int]]:
"""
从每个镜头中提取关键帧
Args:
frames: 视频帧列表
shot_boundaries: 镜头边界列表
Returns:
Tuple[List[np.ndarray], List[int]]: <EFBFBD><EFBFBD><EFBFBD>帧列表和对应的帧索引
"""
keyframes = []
keyframe_indices = []
for i in range(len(shot_boundaries)):
start = shot_boundaries[i - 1] if i > 0 else 0
end = shot_boundaries[i]
shot_frames = frames[start:end]
frame_features = np.array([cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).flatten()
for frame in shot_frames])
kmeans = KMeans(n_clusters=1, random_state=0).fit(frame_features)
center_idx = np.argmin(np.sum((frame_features - kmeans.cluster_centers_[0]) ** 2, axis=1))
keyframes.append(shot_frames[center_idx])
keyframe_indices.append(start + center_idx)
return keyframes, keyframe_indices
def save_keyframes(self, keyframes: List[np.ndarray], keyframe_indices: List[int],
output_dir: str) -> None:
"""
保存关键帧到指定目录文件名格式为keyframe_帧序号_时间戳.jpg
Args:
keyframes: 关键帧列表
keyframe_indices: 关键帧索引列表
output_dir: 输出目录
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
for keyframe, frame_idx in zip(keyframes, keyframe_indices):
# 计算时间戳(秒)
timestamp = frame_idx / self.fps
# 将时间戳转换为 HH:MM:SS 格式
hours = int(timestamp // 3600)
minutes = int((timestamp % 3600) // 60)
seconds = int(timestamp % 60)
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}"
# 构建新的文件名格式keyframe_帧序号_时间戳.jpg
output_path = os.path.join(output_dir,
f'keyframe_{frame_idx:06d}_{time_str}.jpg')
cv2.imwrite(output_path, keyframe)
print(f"已保存 {len(keyframes)} 个关键帧到 {output_dir}")
def extract_frames_by_numbers(self, frame_numbers: List[int], output_folder: str) -> None:
"""
根据指定的帧号提取帧
Args:
frame_numbers: 要提取的帧号列表
output_folder: 输出文件夹路径
"""
if not frame_numbers:
raise ValueError("未提供帧号列表")
if any(fn >= self.total_frames or fn < 0 for fn in frame_numbers):
raise ValueError("存在无效的帧号")
if not os.path.exists(output_folder):
os.makedirs(output_folder)
for frame_number in frame_numbers:
self.cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
ret, frame = self.cap.read()
if ret:
# 计算时间戳
timestamp = frame_number / self.fps
hours = int(timestamp // 3600)
minutes = int((timestamp % 3600) // 60)
seconds = int(timestamp % 60)
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}"
# 使用与关键帧相同的命名格式
output_path = os.path.join(output_folder,
f"extracted_frame_{frame_number:06d}_{time_str}.jpg")
cv2.imwrite(output_path, frame)
print(f"已提取并保存帧 {frame_number}")
else:
print(f"无法读取帧 {frame_number}")
@staticmethod
def extract_numbers_from_folder(folder_path: str) -> List[int]:
"""
从文件夹中提取帧号
Args:
folder_path: 关键帧文件夹路径
Returns:
List[int]: 排序后的帧号列表
"""
files = [f for f in os.listdir(folder_path) if f.endswith('.jpg')]
# 更新正则表达式以匹配新的文件名格式keyframe_000123_010534.jpg
pattern = re.compile(r'keyframe_(\d+)_\d+\.jpg$')
numbers = []
for f in files:
match = pattern.search(f)
if match:
numbers.append(int(match.group(1)))
return sorted(numbers)
def process_video(self, output_dir: str, skip_seconds: float = 0) -> None:
"""
处理视频并提取关键帧
Args:
output_dir: 输出目录
skip_seconds: 跳过视频开头的秒数
"""
# 计算要跳过的帧数
skip_frames = int(skip_seconds * self.fps)
# 获取所有帧
frames = list(self.preprocess_video())
# 跳过指定秒数的帧
frames = frames[skip_frames:]
if not frames:
raise ValueError(f"跳过 {skip_seconds} 秒后没有剩余帧可以处理")
shot_boundaries = self.detect_shot_boundaries(frames)
keyframes, keyframe_indices = self.extract_keyframes(frames, shot_boundaries)
# 调整关键帧索引,加上跳过的帧数
adjusted_indices = [idx + skip_frames for idx in keyframe_indices]
self.save_keyframes(keyframes, adjusted_indices, output_dir)

View File

@ -0,0 +1,183 @@
import json
from typing import List, Union, Dict
import os
from pathlib import Path
from loguru import logger
from tqdm import tqdm
import asyncio
from tenacity import retry, stop_after_attempt, RetryError, retry_if_exception_type, wait_exponential
from google.api_core import exceptions
import google.generativeai as genai
import PIL.Image
import traceback
class VisionAnalyzer:
"""视觉分析器类"""
def __init__(self, model_name: str = "gemini-1.5-flash", api_key: str = None):
"""初始化视觉分析器"""
if not api_key:
raise ValueError("必须提供API密钥")
self.model_name = model_name
self.api_key = api_key
# 初始化配置
self._configure_client()
def _configure_client(self):
"""配置API客户端"""
genai.configure(api_key=self.api_key)
self.model = genai.GenerativeModel(self.model_name)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(exceptions.ResourceExhausted)
)
async def _generate_content_with_retry(self, prompt, batch):
"""使用重试机制的内部方法来调用 generate_content_async"""
try:
return await self.model.generate_content_async([prompt, *batch])
except exceptions.ResourceExhausted as e:
print(f"API配额限制: {str(e)}")
raise RetryError("API调用失败")
async def analyze_images(self,
images: Union[List[str], List[PIL.Image.Image]],
prompt: str,
batch_size: int = 5) -> List[Dict]:
"""批量分析多张图片"""
try:
# 加载图片
if isinstance(images[0], str):
logger.info("正在加载图片...")
images = self.load_images(images)
# 验证图片列表
if not images:
raise ValueError("图片列表为空")
# 验证每个图片对象
valid_images = []
for i, img in enumerate(images):
if not isinstance(img, PIL.Image.Image):
logger.error(f"无效的图片对象,索引 {i}: {type(img)}")
continue
valid_images.append(img)
if not valid_images:
raise ValueError("没有有效的图片对象")
images = valid_images
results = []
total_batches = (len(images) + batch_size - 1) // batch_size
with tqdm(total=total_batches, desc="分析进度") as pbar:
for i in range(0, len(images), batch_size):
batch = images[i:i + batch_size]
retry_count = 0
while retry_count < 3:
try:
# 在每个批次处理前添加小延迟
if i > 0:
await asyncio.sleep(2)
# 确保每个批次的图片都是有效的
valid_batch = [img for img in batch if isinstance(img, PIL.Image.Image)]
if not valid_batch:
raise ValueError(f"批次 {i // batch_size} 中没有有效的图片")
response = await self._generate_content_with_retry(prompt, valid_batch)
results.append({
'batch_index': i // batch_size,
'images_processed': len(valid_batch),
'response': response.text,
'model_used': self.model_name
})
break
except Exception as e:
retry_count += 1
error_msg = f"批次 {i // batch_size} 处理出错: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
if retry_count >= 3:
results.append({
'batch_index': i // batch_size,
'images_processed': len(batch),
'error': error_msg,
'model_used': self.model_name
})
else:
logger.info(f"批次 {i // batch_size} 处理失败等待60秒后重试...")
await asyncio.sleep(60)
pbar.update(1)
return results
except Exception as e:
error_msg = f"图片分析过程中发生错误: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
raise Exception(error_msg)
def save_results_to_txt(self, results: List[Dict], output_dir: str):
"""将分析结果保存到txt文件"""
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
for result in results:
if not result.get('image_paths'):
continue
response_text = result['response']
image_paths = result['image_paths']
img_name_start = Path(image_paths[0]).stem.split('_')[-1]
img_name_end = Path(image_paths[-1]).stem.split('_')[-1]
txt_path = os.path.join(output_dir, f"frame_{img_name_start}_{img_name_end}.txt")
# 保存结果到txt文件
with open(txt_path, 'w', encoding='utf-8') as f:
f.write(response_text.strip())
print(f"已保存分析结果到: {txt_path}")
def load_images(self, image_paths: List[str]) -> List[PIL.Image.Image]:
"""
加载多张图片
Args:
image_paths: 图片路径列表
Returns:
加载后的PIL Image对象列表
"""
images = []
failed_images = []
for img_path in image_paths:
try:
if not os.path.exists(img_path):
logger.error(f"图片文件不存在: {img_path}")
failed_images.append(img_path)
continue
img = PIL.Image.open(img_path)
# 确保图片被完全加载
img.load()
# 转换为RGB模式
if img.mode != 'RGB':
img = img.convert('RGB')
images.append(img)
except Exception as e:
logger.error(f"无法加载图片 {img_path}: {str(e)}")
failed_images.append(img_path)
if failed_images:
logger.warning(f"以下图片加载失败:\n{json.dumps(failed_images, indent=2, ensure_ascii=False)}")
if not images:
raise ValueError("没有成功加载任何图片")
return images

View File

@ -1,94 +1,92 @@
[app]
project_version="0.2.0"
# 如果你没有 OPENAI API Key可以使用 g4f 代替,或者使用国内的 Moonshot API
# If you don't have an OPENAI API Key, you can use g4f instead
video_llm_provider="gemini"
project_version="0.3.2"
# 支持视频理解的大模型提供商
# gemini
# NarratoAPI
# qwen2-vl (待增加)
vision_llm_provider="gemini"
vision_batch_size = 5
vision_analysis_prompt = "你是资深视频内容分析专家,擅长分析视频画面信息,分析下面视频画面内容,只输出客观的画面描述不要给任何总结或评价"
# 支持的提供商 (Supported providers):
# openai
########## Vision Gemini API Key
vision_gemini_api_key = ""
vision_gemini_model_name = "gemini-1.5-flash"
########### Vision NarratoAPI Key
# NarratoAPI 是为了便捷访问不了 Gemini API 的用户, 提供的代理服务
narrato_api_key = ""
narrato_api_url = "https://narratoapi.scsmtech.cn/api/v1"
narrato_vision_model = "gemini-1.5-flash"
narrato_vision_key = ""
narrato_llm_model = "gpt-4o"
narrato_llm_key = ""
# 用于生成文案的大模型支持的提供商 (Supported providers):
# openai (默认)
# moonshot (月之暗面)
# oneapi
# g4f
# azure
# qwen (通义千问)
# gemini
llm_provider="openai"
# 支持多模态视频理解能力的大模型
########## Ollama Settings
# No need to set it unless you want to use your own proxy
ollama_base_url = ""
# Check your available models at https://ollama.com/library
ollama_model_name = ""
text_llm_provider="openai"
########## OpenAI API Key
# Get your API key at https://platform.openai.com/api-keys
openai_api_key = ""
text_openai_api_key = ""
# No need to set it unless you want to use your own proxy
openai_base_url = ""
text_openai_base_url = ""
# Check your available models at https://platform.openai.com/account/limits
openai_model_name = "gpt-4-turbo"
text_openai_model_name = "gpt-4o-mini"
########## Moonshot API Key
# Visit https://platform.moonshot.cn/console/api-keys to get your API key.
moonshot_api_key=""
moonshot_base_url = "https://api.moonshot.cn/v1"
moonshot_model_name = "moonshot-v1-8k"
########## OneAPI API Key
# Visit https://github.com/songquanpeng/one-api to get your API key
oneapi_api_key=""
oneapi_base_url=""
oneapi_model_name=""
text_moonshot_api_key=""
text_moonshot_base_url = "https://api.moonshot.cn/v1"
text_moonshot_model_name = "moonshot-v1-8k"
########## G4F
# Visit https://github.com/xtekky/gpt4free to get more details
# Supported model list: https://github.com/xtekky/gpt4free/blob/main/g4f/models.py
g4f_model_name = "gpt-3.5-turbo"
text_g4f_model_name = "gpt-3.5-turbo"
########## Azure API Key
# Visit https://learn.microsoft.com/zh-cn/azure/ai-services/openai/ to get more details
# API documentation: https://learn.microsoft.com/zh-cn/azure/ai-services/openai/reference
azure_api_key = ""
azure_base_url=""
azure_model_name="gpt-35-turbo" # replace with your model deployment name
azure_api_version = "2024-02-15-preview"
text_azure_api_key = ""
text_azure_base_url=""
text_azure_model_name="gpt-35-turbo" # replace with your model deployment name
text_azure_api_version = "2024-02-15-preview"
########## Gemini API Key
gemini_api_key=""
gemini_model_name = "gemini-1.5-flash"
text_gemini_api_key=""
text_gemini_model_name = "gemini-1.5-flash"
########## Qwen API Key
# Visit https://dashscope.console.aliyun.com/apiKey to get your API key
# Visit below links to get more details
# https://tongyi.aliyun.com/qianwen/
# https://help.aliyun.com/zh/dashscope/developer-reference/model-introduction
qwen_api_key = ""
qwen_model_name = "qwen-max"
text_qwen_api_key = ""
text_qwen_model_name = "qwen-max"
########## DeepSeek API Key
# Visit https://platform.deepseek.com/api_keys to get your API key
deepseek_api_key = ""
deepseek_base_url = "https://api.deepseek.com"
deepseek_model_name = "deepseek-chat"
text_deepseek_api_key = ""
text_deepseek_base_url = "https://api.deepseek.com"
text_deepseek_model_name = "deepseek-chat"
# Subtitle Provider, "whisper"
# If empty, the subtitle will not be generated
# 字幕提供商、可选,支持 whisper 和 faster-whisper-large-v2"whisper"
# 默认为 faster-whisper-large-v2 模型地址https://huggingface.co/guillaumekln/faster-whisper-large-v2
subtitle_provider = "faster-whisper-large-v2"
subtitle_enabled = true
#
# ImageMagick
#
# Once you have installed it, ImageMagick will be automatically detected, except on Windows!
# On Windows, for example "C:\Program Files (x86)\ImageMagick-7.1.1-Q16-HDRI\magick.exe"
# Download from https://imagemagick.org/archive/binaries/ImageMagick-7.1.1-29-Q16-x64-static.exe
# 安装后,将自动检测到 ImageMagickWindows 除外!
# 例如,在 Windows 上 "C:\Program Files (x86)\ImageMagick-7.1.1-Q16-HDRI\magick.exe"
# 下载位置 https://imagemagick.org/archive/binaries/ImageMagick-7.1.1-29-Q16-x64-static.exe
# imagemagick_path = "C:\\Program Files (x86)\\ImageMagick-7.1.1-Q16\\magick.exe"
#
# FFMPEG
#
# 通常情况下ffmpeg 会被自动下载,并且会被自动检测到。
@ -97,12 +95,6 @@
# Install ffmpeg on your system, or set the IMAGEIO_FFMPEG_EXE environment variable.
# 此时你可以手动下载 ffmpeg 并设置 ffmpeg_path下载地址https://www.gyan.dev/ffmpeg/builds/
# Under normal circumstances, ffmpeg is downloaded automatically and detected automatically.
# However, if there is an issue with your environment that prevents automatic downloading, you might encounter the following error:
# RuntimeError: No ffmpeg exe could be found.
# Install ffmpeg on your system, or set the IMAGEIO_FFMPEG_EXE environment variable.
# In such cases, you can manually download ffmpeg and set the ffmpeg_path, download link: https://www.gyan.dev/ffmpeg/builds/
# ffmpeg_path = "C:\\Users\\harry\\Downloads\\ffmpeg.exe"
#########################################################################################
@ -132,7 +124,7 @@
material_directory = ""
# Used for state management of the task
# 用于任务的状态管理
enable_redis = false
redis_host = "localhost"
redis_port = 6379
@ -143,7 +135,6 @@
max_concurrent_tasks = 5
# webui界面是否显示配置项
# webui hide baisc config panel
hide_config = false
@ -161,7 +152,7 @@
# recommended model_size: "large-v3"
model_size="faster-whisper-large-v2"
# if you want to use GPU, set device="cuda"
# 如果要使用 GPU请设置 device=“cuda”
device="CPU"
compute_type="int8"

View File

@ -1,28 +1,34 @@
requests~=2.31.0
moviepy~=2.0.0.dev2
openai~=1.13.3
faster-whisper~=1.0.1
edge_tts~=6.1.15
uvicorn~=0.27.1
fastapi~=0.110.0
fastapi~=0.115.4
tomli~=2.0.1
streamlit~=1.33.0
streamlit~=1.40.0
loguru~=0.7.2
aiohttp~=3.9.3
aiohttp~=3.10.10
urllib3~=2.2.1
pillow~=10.3.0
pydantic~=2.6.3
g4f~=0.3.0.4
dashscope~=1.15.0
google.generativeai>=0.8.2
google.generativeai>=0.8.3
python-multipart~=0.0.9
redis==5.0.3
# if you use pillow~=10.3.0, you will get "PIL.Image' has no attribute 'ANTIALIAS'" error when resize video
# please install opencv-python to fix "PIL.Image' has no attribute 'ANTIALIAS'" error
opencv-python~=4.9.0.80
opencv-python~=4.10.0.84
# for azure speech
# https://techcommunity.microsoft.com/t5/ai-azure-ai-services-blog/9-more-realistic-ai-voices-for-conversations-now-generally/ba-p/4099471
azure-cognitiveservices-speech~=1.37.0
git-changelog~=2.5.2
watchdog==5.0.2
pydub==0.25.1
psutil>=5.9.0
opencv-python~=4.10.0.84
scikit-learn~=1.5.2
google-generativeai~=0.8.3
Pillow>=11.0.0
python-dotenv~=1.0.1
openai~=1.53.0
tqdm>=4.66.6
tenacity>=9.0.0
tiktoken==0.8.0

885
webui.py
View File

@ -1,6 +1,15 @@
import streamlit as st
import os
import sys
from uuid import uuid4
from app.config import config
from webui.components import basic_settings, video_settings, audio_settings, subtitle_settings, script_settings, review_settings
from webui.utils import cache, file_utils
from app.utils import utils
from app.models.schema import VideoClipParams, VideoAspect
from webui.utils.performance import PerformanceMonitor
# 初始化配置 - 必须是第一个 Streamlit 命令
st.set_page_config(
page_title="NarratoAI",
page_icon="📽️",
@ -13,126 +22,34 @@ st.set_page_config(
},
)
import sys
import os
import glob
import json
import time
import datetime
import traceback
from uuid import uuid4
import platform
import streamlit.components.v1 as components
from loguru import logger
from app.models.const import FILE_TYPE_VIDEOS
from app.models.schema import VideoClipParams, VideoAspect, VideoConcatMode
from app.services import task as tm, llm, voice, material
from app.utils import utils
# # 将项目的根目录添加到系统路径中,以允许从项目导入模块
root_dir = os.path.dirname(os.path.realpath(__file__))
if root_dir not in sys.path:
sys.path.append(root_dir)
print("******** sys.path ********")
print(sys.path)
print("*" * 20)
proxy_url_http = config.proxy.get("http", "") or os.getenv("VPN_PROXY_URL", "")
proxy_url_https = config.proxy.get("https", "") or os.getenv("VPN_PROXY_URL", "")
os.environ["HTTP_PROXY"] = proxy_url_http
os.environ["HTTPS_PROXY"] = proxy_url_https
# 设置页面样式
hide_streamlit_style = """
<style>#root > div:nth-child(1) > div > div > div > div > section > div {padding-top: 6px; padding-bottom: 10px; padding-left: 20px; padding-right: 20px;}</style>
"""
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
st.title(f"NarratoAI :sunglasses:📽️")
support_locales = [
"zh-CN",
"zh-HK",
"zh-TW",
"en-US",
]
font_dir = os.path.join(root_dir, "resource", "fonts")
song_dir = os.path.join(root_dir, "resource", "songs")
i18n_dir = os.path.join(root_dir, "webui", "i18n")
config_file = os.path.join(root_dir, "webui", ".streamlit", "webui.toml")
system_locale = utils.get_system_locale()
if 'video_clip_json' not in st.session_state:
st.session_state['video_clip_json'] = []
if 'video_plot' not in st.session_state:
st.session_state['video_plot'] = ''
if 'ui_language' not in st.session_state:
st.session_state['ui_language'] = config.ui.get("language", system_locale)
if 'subclip_videos' not in st.session_state:
st.session_state['subclip_videos'] = {}
def get_all_fonts():
fonts = []
for root, dirs, files in os.walk(font_dir):
for file in files:
if file.endswith(".ttf") or file.endswith(".ttc"):
fonts.append(file)
fonts.sort()
return fonts
def get_all_songs():
songs = []
for root, dirs, files in os.walk(song_dir):
for file in files:
if file.endswith(".mp3"):
songs.append(file)
return songs
def open_task_folder(task_id):
try:
sys = platform.system()
path = os.path.join(root_dir, "storage", "tasks", task_id)
if os.path.exists(path):
if sys == 'Windows':
os.system(f"start {path}")
if sys == 'Darwin':
os.system(f"open {path}")
except Exception as e:
logger.error(e)
def scroll_to_bottom():
js = f"""
<script>
console.log("scroll_to_bottom");
function scroll(dummy_var_to_force_repeat_execution){{
var sections = parent.document.querySelectorAll('section.main');
console.log(sections);
for(let index = 0; index<sections.length; index++) {{
sections[index].scrollTop = sections[index].scrollHeight;
}}
}}
scroll(1);
</script>
"""
st.components.v1.html(js, height=0, width=0)
def init_log():
"""初始化日志配置"""
from loguru import logger
logger.remove()
_lvl = "DEBUG"
def format_record(record):
# 获取日志记录中的文件全径
# 增加更多需要过滤的警告消息
ignore_messages = [
"Examining the path of torch.classes raised",
"torch.cuda.is_available()",
"CUDA initialization"
]
for msg in ignore_messages:
if msg in record["message"]:
return ""
file_path = record["file"].path
# 将绝对路径转换为相对于项目根目录的路径
relative_path = os.path.relpath(file_path, root_dir)
# 更新记录中的文件路径
relative_path = os.path.relpath(file_path, config.root_dir)
record["file"].path = f"./{relative_path}"
# 返回修改后的格式字符串
# 您可以根据需要调整这里的格式
record['message'] = record['message'].replace(root_dir, ".")
record['message'] = record['message'].replace(config.root_dir, ".")
_format = '<green>{time:%Y-%m-%d %H:%M:%S}</> | ' + \
'<level>{level}</> | ' + \
@ -140,671 +57,143 @@ def init_log():
'- <level>{message}</>' + "\n"
return _format
# 优化日志过滤器
def log_filter(record):
ignore_messages = [
"Examining the path of torch.classes raised",
"torch.cuda.is_available()",
"CUDA initialization"
]
return not any(msg in record["message"] for msg in ignore_messages)
logger.add(
sys.stdout,
level=_lvl,
format=format_record,
colorize=True,
filter=log_filter
)
init_log()
locales = utils.load_locales(i18n_dir)
def init_global_state():
"""初始化全局状态"""
if 'video_clip_json' not in st.session_state:
st.session_state['video_clip_json'] = []
if 'video_plot' not in st.session_state:
st.session_state['video_plot'] = ''
if 'ui_language' not in st.session_state:
st.session_state['ui_language'] = config.ui.get("language", utils.get_system_locale())
if 'subclip_videos' not in st.session_state:
st.session_state['subclip_videos'] = {}
def tr(key):
"""翻译函数"""
i18n_dir = os.path.join(os.path.dirname(__file__), "webui", "i18n")
locales = utils.load_locales(i18n_dir)
loc = locales.get(st.session_state['ui_language'], {})
return loc.get("Translation", {}).get(key, key)
def render_generate_button():
"""渲染生成按钮和处理逻辑"""
if st.button(tr("Generate Video"), use_container_width=True, type="primary"):
try:
from app.services import task as tm
import torch
# 重置日志容器和记录
log_container = st.empty()
log_records = []
st.write(tr("Get Help"))
def log_received(msg):
with log_container:
log_records.append(msg)
st.code("\n".join(log_records))
# 基础设置
with st.expander(tr("Basic Settings"), expanded=False):
config_panels = st.columns(3)
left_config_panel = config_panels[0]
middle_config_panel = config_panels[1]
right_config_panel = config_panels[2]
with left_config_panel:
display_languages = []
selected_index = 0
for i, code in enumerate(locales.keys()):
display_languages.append(f"{code} - {locales[code].get('Language')}")
if code == st.session_state['ui_language']:
selected_index = i
from loguru import logger
logger.add(log_received)
selected_language = st.selectbox(tr("Language"), options=display_languages,
index=selected_index)
if selected_language:
code = selected_language.split(" - ")[0].strip()
st.session_state['ui_language'] = code
config.ui['language'] = code
config.save_config()
task_id = st.session_state.get('task_id')
HTTP_PROXY = st.text_input(tr("HTTP_PROXY"), value=proxy_url_http)
HTTPS_PROXY = st.text_input(tr("HTTPs_PROXY"), value=proxy_url_https)
if HTTP_PROXY:
config.proxy["http"] = HTTP_PROXY
if HTTPS_PROXY:
config.proxy["https"] = HTTPS_PROXY
if not task_id:
st.error(tr("请先裁剪视频"))
return
if not st.session_state.get('video_clip_json_path'):
st.error(tr("脚本文件不能为空"))
return
if not st.session_state.get('video_origin_path'):
st.error(tr("视频文件不能为空"))
return
# 视频转录大模型
with middle_config_panel:
video_llm_providers = ['Gemini']
saved_llm_provider = config.app.get("llm_provider", "OpenAI").lower()
saved_llm_provider_index = 0
for i, provider in enumerate(video_llm_providers):
if provider.lower() == saved_llm_provider:
saved_llm_provider_index = i
break
st.toast(tr("生成视频"))
logger.info(tr("开始生成视频"))
video_llm_provider = st.selectbox(tr("Video LLM Provider"), options=video_llm_providers, index=saved_llm_provider_index)
video_llm_provider = video_llm_provider.lower()
config.app["video_llm_provider"] = video_llm_provider
# 获取所有参数
script_params = script_settings.get_script_params()
video_params = video_settings.get_video_params()
audio_params = audio_settings.get_audio_params()
subtitle_params = subtitle_settings.get_subtitle_params()
video_llm_api_key = config.app.get(f"{video_llm_provider}_api_key", "")
video_llm_base_url = config.app.get(f"{video_llm_provider}_base_url", "")
video_llm_model_name = config.app.get(f"{video_llm_provider}_model_name", "")
video_llm_account_id = config.app.get(f"{video_llm_provider}_account_id", "")
st_llm_api_key = st.text_input(tr("Video API Key"), value=video_llm_api_key, type="password")
st_llm_base_url = st.text_input(tr("Video Base Url"), value=video_llm_base_url)
st_llm_model_name = st.text_input(tr("Video Model Name"), value=video_llm_model_name)
if st_llm_api_key:
config.app[f"{video_llm_provider}_api_key"] = st_llm_api_key
if st_llm_base_url:
config.app[f"{video_llm_provider}_base_url"] = st_llm_base_url
if st_llm_model_name:
config.app[f"{video_llm_provider}_model_name"] = st_llm_model_name
# 合并所有参数
all_params = {
**script_params,
**video_params,
**audio_params,
**subtitle_params
}
# 大语言模型
with right_config_panel:
llm_providers = ['Gemini', 'OpenAI', 'Moonshot', 'Azure', 'Qwen', 'Ollama', 'G4f', 'OneAPI', "Cloudflare"]
saved_llm_provider = config.app.get("llm_provider", "OpenAI").lower()
saved_llm_provider_index = 0
for i, provider in enumerate(llm_providers):
if provider.lower() == saved_llm_provider:
saved_llm_provider_index = i
break
# 创建参数对象
params = VideoClipParams(**all_params)
llm_provider = st.selectbox(tr("LLM Provider"), options=llm_providers, index=saved_llm_provider_index)
llm_provider = llm_provider.lower()
config.app["llm_provider"] = llm_provider
llm_api_key = config.app.get(f"{llm_provider}_api_key", "")
llm_base_url = config.app.get(f"{llm_provider}_base_url", "")
llm_model_name = config.app.get(f"{llm_provider}_model_name", "")
llm_account_id = config.app.get(f"{llm_provider}_account_id", "")
st_llm_api_key = st.text_input(tr("API Key"), value=llm_api_key, type="password")
st_llm_base_url = st.text_input(tr("Base Url"), value=llm_base_url)
st_llm_model_name = st.text_input(tr("Model Name"), value=llm_model_name)
if st_llm_api_key:
config.app[f"{llm_provider}_api_key"] = st_llm_api_key
if st_llm_base_url:
config.app[f"{llm_provider}_base_url"] = st_llm_base_url
if st_llm_model_name:
config.app[f"{llm_provider}_model_name"] = st_llm_model_name
if llm_provider == 'cloudflare':
st_llm_account_id = st.text_input(tr("Account ID"), value=llm_account_id)
if st_llm_account_id:
config.app[f"{llm_provider}_account_id"] = st_llm_account_id
panel = st.columns(3)
left_panel = panel[0]
middle_panel = panel[1]
right_panel = panel[2]
params = VideoClipParams()
# 左侧面板
with left_panel:
with st.container(border=True):
st.write(tr("Video Script Configuration"))
# 脚本语言
video_languages = [
(tr("Auto Detect"), ""),
]
for code in ["zh-CN", "en-US", "zh-TW"]:
video_languages.append((code, code))
selected_index = st.selectbox(tr("Script Language"),
index=0,
options=range(len(video_languages)), # 使用索引作为内部选项值
format_func=lambda x: video_languages[x][0] # 显示给用户的是标签
)
params.video_language = video_languages[selected_index][1]
# 脚本路径
suffix = "*.json"
song_dir = utils.script_dir()
files = glob.glob(os.path.join(song_dir, suffix))
script_list = []
for file in files:
script_list.append({
"name": os.path.basename(file),
"size": os.path.getsize(file),
"file": file,
"ctime": os.path.getctime(file) # 获取文件创建时间
})
# 按创建时间降序排序
script_list.sort(key=lambda x: x["ctime"], reverse=True)
# 本文件 下拉框
script_path = [(tr("Auto Generate"), ""), ]
for file in script_list:
display_name = file['file'].replace(root_dir, "")
script_path.append((display_name, file['file']))
selected_script_index = st.selectbox(tr("Script Files"),
index=0,
options=range(len(script_path)), # 使用索引作为内部选项值
format_func=lambda x: script_path[x][0] # 显示给用户的是标签
)
params.video_clip_json_path = script_path[selected_script_index][1]
config.app["video_clip_json_path"] = params.video_clip_json_path
st.session_state['video_clip_json_path'] = params.video_clip_json_path
# 视频文件处理
video_files = []
for suffix in ["*.mp4", "*.mov", "*.avi", "*.mkv"]:
video_files.extend(glob.glob(os.path.join(utils.video_dir(), suffix)))
video_files = video_files[::-1]
video_list = []
for video_file in video_files:
video_list.append({
"name": os.path.basename(video_file),
"size": os.path.getsize(video_file),
"file": video_file,
"ctime": os.path.getctime(video_file) # 获取文件创建时间
})
# 按创建时间降序排序
video_list.sort(key=lambda x: x["ctime"], reverse=True)
video_path = [(tr("None"), ""), (tr("Upload Local Files"), "local")]
for file in video_list:
display_name = file['file'].replace(root_dir, "")
video_path.append((display_name, file['file']))
# 视频文件
selected_video_index = st.selectbox(tr("Video File"),
index=0,
options=range(len(video_path)), # 使用索引作为内部选项值
format_func=lambda x: video_path[x][0] # 显示给用户的是标签
)
params.video_origin_path = video_path[selected_video_index][1]
config.app["video_origin_path"] = params.video_origin_path
st.session_state['video_origin_path'] = params.video_origin_path
# 从本地上传 mp4 文件
if params.video_origin_path == "local":
_supported_types = FILE_TYPE_VIDEOS
uploaded_file = st.file_uploader(
tr("Upload Local Files"),
type=["mp4", "mov", "avi", "flv", "mkv"],
accept_multiple_files=False,
result = tm.start_subclip(
task_id=task_id,
params=params,
subclip_path_videos=st.session_state['subclip_videos']
)
if uploaded_file is not None:
# 构造保存路径
video_file_path = os.path.join(utils.video_dir(), uploaded_file.name)
file_name, file_extension = os.path.splitext(uploaded_file.name)
# 检查文件是否存在,如果存在则添加时间戳
if os.path.exists(video_file_path):
timestamp = time.strftime("%Y%m%d%H%M%S")
file_name_with_timestamp = f"{file_name}_{timestamp}"
video_file_path = os.path.join(utils.video_dir(), file_name_with_timestamp + file_extension)
# 将文件保存到指定目录
with open(video_file_path, "wb") as f:
f.write(uploaded_file.read())
st.success(tr("File Uploaded Successfully"))
time.sleep(1)
st.rerun()
# 视频名称
video_name = st.text_input(tr("Video Name"))
# 剧情内容
video_plot = st.text_area(
tr("Plot Description"),
value=st.session_state['video_plot'],
height=180
)
# 生成视频脚本
if st.session_state['video_clip_json_path']:
generate_button_name = tr("Video Script Load")
else:
generate_button_name = tr("Video Script Generate")
if st.button(generate_button_name, key="auto_generate_script"):
progress_bar = st.progress(0)
status_text = st.empty()
def update_progress(progress: float, message: str = ""):
progress_bar.progress(progress)
if message:
status_text.text(f"{progress}% - {message}")
else:
status_text.text(f"进度: {progress}%")
video_files = result.get("videos", [])
st.success(tr("视生成完成"))
try:
with st.spinner("正在生成脚本..."):
if not video_plot:
st.warning("视频剧情为空; 会极大影响生成效果!")
if params.video_clip_json_path == "" and params.video_origin_path != "":
update_progress(10, "压缩视频中...")
# 使用大模型生成视频脚本
script = llm.generate_script(
video_path=params.video_origin_path,
video_plot=video_plot,
video_name=video_name,
language=params.video_language,
progress_callback=update_progress
)
if script is None:
st.error("生成脚本失败,请检查日志")
st.stop()
else:
update_progress(90)
if video_files:
player_cols = st.columns(len(video_files) * 2 + 1)
for i, url in enumerate(video_files):
player_cols[i * 2 + 1].video(url)
except Exception as e:
logger.error(f"播放视频失败: {e}")
script = utils.clean_model_output(script)
st.session_state['video_clip_json'] = json.loads(script)
else:
# 从本地加载
with open(params.video_clip_json_path, 'r', encoding='utf-8') as f:
update_progress(50)
status_text.text("从本地加载中...")
script = f.read()
script = utils.clean_model_output(script)
st.session_state['video_clip_json'] = json.loads(script)
update_progress(100)
status_text.text("从本地加载成功")
file_utils.open_task_folder(config.root_dir, task_id)
logger.info(tr("视频生成完成"))
time.sleep(0.5) # 给进度条一点时间到达100%
progress_bar.progress(100)
status_text.text("脚本生成完成!")
st.success("视频脚本生成成功!")
except Exception as err:
st.error(f"生成过程中发生错误: {str(err)}")
finally:
time.sleep(2) # 给用户一些时间查看最终状态
progress_bar.empty()
status_text.empty()
finally:
PerformanceMonitor.cleanup_resources()
# 视频脚本
video_clip_json_details = st.text_area(
tr("Video Script"),
value=json.dumps(st.session_state.video_clip_json, indent=2, ensure_ascii=False),
height=180
)
def main():
"""主函数"""
init_log()
init_global_state()
utils.init_resources()
st.title(f"NarratoAI :sunglasses:📽️")
st.write(tr("Get Help"))
# 渲染基础设置面板
basic_settings.render_basic_settings(tr)
# 渲染主面板
panel = st.columns(3)
with panel[0]:
script_settings.render_script_panel(tr)
with panel[1]:
video_settings.render_video_panel(tr)
audio_settings.render_audio_panel(tr)
with panel[2]:
subtitle_settings.render_subtitle_panel(tr)
# 渲染视频审查面板
review_settings.render_review_panel(tr)
# 渲染生成按钮和处理逻辑
render_generate_button()
# 保存脚本
button_columns = st.columns(2)
with button_columns[0]:
if st.button(tr("Save Script"), key="auto_generate_terms", use_container_width=True):
if not video_clip_json_details:
st.error(tr("请输入视频脚本"))
st.stop()
with st.spinner(tr("Save Script")):
script_dir = utils.script_dir()
# 获取当前时间戳,形如 2024-0618-171820
timestamp = datetime.datetime.now().strftime("%Y-%m%d-%H%M%S")
save_path = os.path.join(script_dir, f"{timestamp}.json")
try:
data = utils.add_new_timestamps(json.loads(video_clip_json_details))
except Exception as err:
st.error(f"视频脚本格式错误,请检查脚本是否符合 JSON 格式;{err} \n\n{traceback.format_exc()}")
st.stop()
# 存储为新的 JSON 文件
with open(save_path, 'w', encoding='utf-8') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
# 将data的值存储到 session_state 中,类似缓存
st.session_state['video_clip_json'] = data
st.session_state['video_clip_json_path'] = save_path
# 刷新页面
st.rerun()
# 裁剪视频
with button_columns[1]:
if st.button(tr("Crop Video"), key="auto_crop_video", use_container_width=True):
progress_bar = st.progress(0)
status_text = st.empty()
def update_progress(progress):
progress_bar.progress(progress)
status_text.text(f"剪辑进度: {progress}%")
try:
utils.cut_video(params, update_progress)
time.sleep(0.5) # 给进度条一点时间到达100%
progress_bar.progress(100)
status_text.text("剪辑完成!")
st.success("视频剪辑成功完成!")
except Exception as e:
st.error(f"剪辑过程中发生错误: {str(e)}")
finally:
time.sleep(2) # 给用户一些时间查看最终状态
progress_bar.empty()
status_text.empty()
# 新中间面板
with middle_panel:
with st.container(border=True):
st.write(tr("Video Settings"))
# 视频比例
video_aspect_ratios = [
(tr("Portrait"), VideoAspect.portrait.value),
(tr("Landscape"), VideoAspect.landscape.value),
]
selected_index = st.selectbox(
tr("Video Ratio"),
options=range(len(video_aspect_ratios)), # 使用索引作为内部选项值
format_func=lambda x: video_aspect_ratios[x][0], # 显示给用户的是标签
)
params.video_aspect = VideoAspect(video_aspect_ratios[selected_index][1])
# params.video_clip_duration = st.selectbox(
# tr("Clip Duration"), options=[2, 3, 4, 5, 6, 7, 8, 9, 10], index=1
# )
# params.video_count = st.selectbox(
# tr("Number of Videos Generated Simultaneously"),
# options=[1, 2, 3, 4, 5],
# index=0,
# )
with st.container(border=True):
st.write(tr("Audio Settings"))
# tts_providers = ['edge', 'azure']
# tts_provider = st.selectbox(tr("TTS Provider"), tts_providers)
voices = voice.get_all_azure_voices(filter_locals=support_locales)
friendly_names = {
v: v.replace("Female", tr("Female"))
.replace("Male", tr("Male"))
.replace("Neural", "")
for v in voices
}
saved_voice_name = config.ui.get("voice_name", "")
saved_voice_name_index = 0
if saved_voice_name in friendly_names:
saved_voice_name_index = list(friendly_names.keys()).index(saved_voice_name)
else:
for i, v in enumerate(voices):
if (
v.lower().startswith(st.session_state["ui_language"].lower())
and "V2" not in v
):
saved_voice_name_index = i
break
selected_friendly_name = st.selectbox(
tr("Speech Synthesis"),
options=list(friendly_names.values()),
index=saved_voice_name_index,
)
voice_name = list(friendly_names.keys())[
list(friendly_names.values()).index(selected_friendly_name)
]
params.voice_name = voice_name
config.ui["voice_name"] = voice_name
# 试听语言合成
if st.button(tr("Play Voice")):
play_content = "这是一段试听语言"
if not play_content:
play_content = params.video_script
if not play_content:
play_content = tr("Voice Example")
with st.spinner(tr("Synthesizing Voice")):
temp_dir = utils.storage_dir("temp", create=True)
audio_file = os.path.join(temp_dir, f"tmp-voice-{str(uuid4())}.mp3")
sub_maker = voice.tts(
text=play_content,
voice_name=voice_name,
voice_rate=params.voice_rate,
voice_file=audio_file,
)
# if the voice file generation failed, try again with a default content.
if not sub_maker:
play_content = "This is a example voice. if you hear this, the voice synthesis failed with the original content."
sub_maker = voice.tts(
text=play_content,
voice_name=voice_name,
voice_rate=params.voice_rate,
voice_file=audio_file,
)
if sub_maker and os.path.exists(audio_file):
st.audio(audio_file, format="audio/mp3")
if os.path.exists(audio_file):
os.remove(audio_file)
if voice.is_azure_v2_voice(voice_name):
saved_azure_speech_region = config.azure.get("speech_region", "")
saved_azure_speech_key = config.azure.get("speech_key", "")
azure_speech_region = st.text_input(
tr("Speech Region"), value=saved_azure_speech_region
)
azure_speech_key = st.text_input(
tr("Speech Key"), value=saved_azure_speech_key, type="password"
)
config.azure["speech_region"] = azure_speech_region
config.azure["speech_key"] = azure_speech_key
params.voice_volume = st.selectbox(
tr("Speech Volume"),
options=[0.6, 0.8, 1.0, 1.2, 1.5, 2.0, 3.0, 4.0, 5.0],
index=2,
)
params.voice_rate = st.selectbox(
tr("Speech Rate"),
options=[0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 1.8, 2.0],
index=2,
)
bgm_options = [
(tr("No Background Music"), ""),
(tr("Random Background Music"), "random"),
(tr("Custom Background Music"), "custom"),
]
selected_index = st.selectbox(
tr("Background Music"),
index=1,
options=range(len(bgm_options)), # 使用索引作为内部选项值
format_func=lambda x: bgm_options[x][0], # 显示给用户的是标签
)
# 获取选择的背景音乐类型
params.bgm_type = bgm_options[selected_index][1]
# 根据选择显示或隐藏组件
if params.bgm_type == "custom":
custom_bgm_file = st.text_input(tr("Custom Background Music File"))
if custom_bgm_file and os.path.exists(custom_bgm_file):
params.bgm_file = custom_bgm_file
# st.write(f":red[已选择自定义背景音乐]**{custom_bgm_file}**")
params.bgm_volume = st.selectbox(
tr("Background Music Volume"),
options=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
index=2,
)
# 新侧面板
with right_panel:
with st.container(border=True):
st.write(tr("Subtitle Settings"))
params.subtitle_enabled = st.checkbox(tr("Enable Subtitles"), value=True)
font_names = get_all_fonts()
saved_font_name = config.ui.get("font_name", "")
saved_font_name_index = 0
if saved_font_name in font_names:
saved_font_name_index = font_names.index(saved_font_name)
params.font_name = st.selectbox(
tr("Font"), font_names, index=saved_font_name_index
)
config.ui["font_name"] = params.font_name
subtitle_positions = [
(tr("Top"), "top"),
(tr("Center"), "center"),
(tr("Bottom"), "bottom"),
(tr("Custom"), "custom"),
]
selected_index = st.selectbox(
tr("Position"),
index=2,
options=range(len(subtitle_positions)),
format_func=lambda x: subtitle_positions[x][0],
)
params.subtitle_position = subtitle_positions[selected_index][1]
if params.subtitle_position == "custom":
custom_position = st.text_input(
tr("Custom Position (% from top)"), value="70.0"
)
try:
params.custom_position = float(custom_position)
if params.custom_position < 0 or params.custom_position > 100:
st.error(tr("Please enter a value between 0 and 100"))
except ValueError:
logger.error(f"输入的值无效: {traceback.format_exc()}")
st.error(tr("Please enter a valid number"))
font_cols = st.columns([0.3, 0.7])
with font_cols[0]:
saved_text_fore_color = config.ui.get("text_fore_color", "#FFFFFF")
params.text_fore_color = st.color_picker(
tr("Font Color"), saved_text_fore_color
)
config.ui["text_fore_color"] = params.text_fore_color
with font_cols[1]:
saved_font_size = config.ui.get("font_size", 60)
params.font_size = st.slider(tr("Font Size"), 30, 100, saved_font_size)
config.ui["font_size"] = params.font_size
stroke_cols = st.columns([0.3, 0.7])
with stroke_cols[0]:
params.stroke_color = st.color_picker(tr("Stroke Color"), "#000000")
with stroke_cols[1]:
params.stroke_width = st.slider(tr("Stroke Width"), 0.0, 10.0, 1.5)
# 视频编辑面板
with st.expander(tr("Video Check"), expanded=False):
try:
video_list = st.session_state.video_clip_json
except KeyError as e:
video_list = []
# 计算列数和行数
num_videos = len(video_list)
cols_per_row = 3
rows = (num_videos + cols_per_row - 1) // cols_per_row # 向上取整计算行数
# 使用容器展示视频
for row in range(rows):
cols = st.columns(cols_per_row)
for col in range(cols_per_row):
index = row * cols_per_row + col
if index < num_videos:
with cols[col]:
video_info = video_list[index]
video_path = video_info.get('path')
if video_path is not None:
initial_narration = video_info['narration']
initial_picture = video_info['picture']
initial_timestamp = video_info['timestamp']
with open(video_path, 'rb') as video_file:
video_bytes = video_file.read()
st.video(video_bytes)
# 可编辑的输入框
text_panels = st.columns(2)
with text_panels[0]:
text1 = st.text_area(tr("timestamp"), value=initial_timestamp, height=20,
key=f"timestamp_{index}")
with text_panels[1]:
text2 = st.text_area(tr("Picture description"), value=initial_picture, height=20,
key=f"picture_{index}")
text3 = st.text_area(tr("Narration"), value=initial_narration, height=100,
key=f"narration_{index}")
# 重新生成按钮
if st.button(tr("Rebuild"), key=f"rebuild_{index}"):
# 更新video_list中的对应项
video_list[index]['timestamp'] = text1
video_list[index]['picture'] = text2
video_list[index]['narration'] = text3
for video in video_list:
if 'path' in video:
del video['path']
# 更新session_state以确保更改被保存
st.session_state['video_clip_json'] = utils.to_json(video_list)
# 替换原JSON 文件
with open(params.video_clip_json_path, 'w', encoding='utf-8') as file:
json.dump(video_list, file, ensure_ascii=False, indent=4)
utils.cut_video(params, progress_callback=None)
st.rerun()
# 开始按钮
start_button = st.button(tr("Generate Video"), use_container_width=True, type="primary")
if start_button:
# 重置日志容器和记录
log_container = st.empty()
log_records = []
config.save_config()
task_id = st.session_state.get('task_id')
if st.session_state.get('video_script_json_path') is not None:
params.video_clip_json = st.session_state.get('video_clip_json')
logger.debug(f"当前的脚本文件为:{st.session_state.video_clip_json_path}")
logger.debug(f"当前的视频文件为:{st.session_state.video_origin_path}")
logger.debug(f"裁剪后是视频列表:{st.session_state.subclip_videos}")
if not task_id:
st.error(tr("请先裁剪视频"))
scroll_to_bottom()
st.stop()
if not params.video_clip_json_path:
st.error(tr("脚本文件不能为空"))
scroll_to_bottom()
st.stop()
if not params.video_origin_path:
st.error(tr("视频文件不能为空"))
scroll_to_bottom()
st.stop()
def log_received(msg):
with log_container:
log_records.append(msg)
st.code("\n".join(log_records))
logger.add(log_received)
st.toast(tr("生成视频"))
logger.info(tr("开始生成视频"))
logger.info(utils.to_json(params))
scroll_to_bottom()
result = tm.start_subclip(task_id=task_id, params=params, subclip_path_videos=st.session_state.subclip_videos)
video_files = result.get("videos", [])
st.success(tr("视频生成完成"))
try:
if video_files:
# 将视频播放器居中
player_cols = st.columns(len(video_files) * 2 + 1)
for i, url in enumerate(video_files):
player_cols[i * 2 + 1].video(url)
except Exception as e:
pass
open_task_folder(task_id)
logger.info(tr("视频生成完成"))
scroll_to_bottom()
config.save_config()
if __name__ == "__main__":
main()

22
webui/__init__.py Normal file
View File

@ -0,0 +1,22 @@
"""
NarratoAI WebUI Package
"""
from webui.config.settings import config
from webui.components import (
basic_settings,
video_settings,
audio_settings,
subtitle_settings
)
from webui.utils import cache, file_utils, performance
__all__ = [
'config',
'basic_settings',
'video_settings',
'audio_settings',
'subtitle_settings',
'cache',
'file_utils',
'performance'
]

View File

@ -0,0 +1,15 @@
from .basic_settings import render_basic_settings
from .script_settings import render_script_panel
from .video_settings import render_video_panel
from .audio_settings import render_audio_panel
from .subtitle_settings import render_subtitle_panel
from .review_settings import render_review_panel
__all__ = [
'render_basic_settings',
'render_script_panel',
'render_video_panel',
'render_audio_panel',
'render_subtitle_panel',
'render_review_panel'
]

View File

@ -0,0 +1,198 @@
import streamlit as st
import os
from uuid import uuid4
from app.config import config
from app.services import voice
from app.utils import utils
from webui.utils.cache import get_songs_cache
def render_audio_panel(tr):
"""渲染音频设置面板"""
with st.container(border=True):
st.write(tr("Audio Settings"))
# 渲染TTS设置
render_tts_settings(tr)
# 渲染背景音乐设置
render_bgm_settings(tr)
def render_tts_settings(tr):
"""渲染TTS(文本转语音)设置"""
# 获取支持的语音列表
support_locales = ["zh-CN", "zh-HK", "zh-TW", "en-US"]
voices = voice.get_all_azure_voices(filter_locals=support_locales)
# 创建友好的显示名称
friendly_names = {
v: v.replace("Female", tr("Female"))
.replace("Male", tr("Male"))
.replace("Neural", "")
for v in voices
}
# 获取保存的语音设置
saved_voice_name = config.ui.get("voice_name", "")
saved_voice_name_index = 0
if saved_voice_name in friendly_names:
saved_voice_name_index = list(friendly_names.keys()).index(saved_voice_name)
else:
# 如果没有保存的设置选择与UI语言匹配的第一个语音
for i, v in enumerate(voices):
if (v.lower().startswith(st.session_state["ui_language"].lower())
and "V2" not in v):
saved_voice_name_index = i
break
# 语音选择下拉框
selected_friendly_name = st.selectbox(
tr("Speech Synthesis"),
options=list(friendly_names.values()),
index=saved_voice_name_index,
)
# 获取实际的语音名称
voice_name = list(friendly_names.keys())[
list(friendly_names.values()).index(selected_friendly_name)
]
# 保存设置
config.ui["voice_name"] = voice_name
# Azure V2语音特殊处理
if voice.is_azure_v2_voice(voice_name):
render_azure_v2_settings(tr)
# 语音参数设置
render_voice_parameters(tr)
# 试听按钮
render_voice_preview(tr, voice_name)
def render_azure_v2_settings(tr):
"""渲染Azure V2语音设置"""
saved_azure_speech_region = config.azure.get("speech_region", "")
saved_azure_speech_key = config.azure.get("speech_key", "")
azure_speech_region = st.text_input(
tr("Speech Region"),
value=saved_azure_speech_region
)
azure_speech_key = st.text_input(
tr("Speech Key"),
value=saved_azure_speech_key,
type="password"
)
config.azure["speech_region"] = azure_speech_region
config.azure["speech_key"] = azure_speech_key
def render_voice_parameters(tr):
"""渲染语音参数设置"""
# 音量
voice_volume = st.selectbox(
tr("Speech Volume"),
options=[0.6, 0.8, 1.0, 1.2, 1.5, 2.0, 3.0, 4.0, 5.0],
index=2,
)
st.session_state['voice_volume'] = voice_volume
# 语速
voice_rate = st.selectbox(
tr("Speech Rate"),
options=[0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 1.8, 2.0],
index=2,
)
st.session_state['voice_rate'] = voice_rate
# 音调
voice_pitch = st.selectbox(
tr("Speech Pitch"),
options=[0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 1.8, 2.0],
index=2,
)
st.session_state['voice_pitch'] = voice_pitch
def render_voice_preview(tr, voice_name):
"""渲染语音试听功能"""
if st.button(tr("Play Voice")):
play_content = "感谢关注 NarratoAI有任何问题或建议可以关注微信公众号求助或讨论"
if not play_content:
play_content = st.session_state.get('video_script', '')
if not play_content:
play_content = tr("Voice Example")
with st.spinner(tr("Synthesizing Voice")):
temp_dir = utils.storage_dir("temp", create=True)
audio_file = os.path.join(temp_dir, f"tmp-voice-{str(uuid4())}.mp3")
sub_maker = voice.tts(
text=play_content,
voice_name=voice_name,
voice_rate=st.session_state.get('voice_rate', 1.0),
voice_pitch=st.session_state.get('voice_pitch', 1.0),
voice_file=audio_file,
)
# 如果语音文件生成失败,使用默认内容重试
if not sub_maker:
play_content = "This is a example voice. if you hear this, the voice synthesis failed with the original content."
sub_maker = voice.tts(
text=play_content,
voice_name=voice_name,
voice_rate=st.session_state.get('voice_rate', 1.0),
voice_pitch=st.session_state.get('voice_pitch', 1.0),
voice_file=audio_file,
)
if sub_maker and os.path.exists(audio_file):
st.audio(audio_file, format="audio/mp3")
if os.path.exists(audio_file):
os.remove(audio_file)
def render_bgm_settings(tr):
"""渲染背景音乐设置"""
# 背景音乐选项
bgm_options = [
(tr("No Background Music"), ""),
(tr("Random Background Music"), "random"),
(tr("Custom Background Music"), "custom"),
]
selected_index = st.selectbox(
tr("Background Music"),
index=1,
options=range(len(bgm_options)),
format_func=lambda x: bgm_options[x][0],
)
# 获取选择的背景音乐类型
bgm_type = bgm_options[selected_index][1]
st.session_state['bgm_type'] = bgm_type
# 自定义背景音乐处理
if bgm_type == "custom":
custom_bgm_file = st.text_input(tr("Custom Background Music File"))
if custom_bgm_file and os.path.exists(custom_bgm_file):
st.session_state['bgm_file'] = custom_bgm_file
# 背景音乐音量
bgm_volume = st.selectbox(
tr("Background Music Volume"),
options=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
index=2,
)
st.session_state['bgm_volume'] = bgm_volume
def get_audio_params():
"""获取音频参数"""
return {
'voice_name': config.ui.get("voice_name", ""),
'voice_volume': st.session_state.get('voice_volume', 1.0),
'voice_rate': st.session_state.get('voice_rate', 1.0),
'voice_pitch': st.session_state.get('voice_pitch', 1.0),
'bgm_type': st.session_state.get('bgm_type', 'random'),
'bgm_file': st.session_state.get('bgm_file', ''),
'bgm_volume': st.session_state.get('bgm_volume', 0.2),
}

View File

@ -0,0 +1,236 @@
import streamlit as st
import os
from app.config import config
from app.utils import utils
def render_basic_settings(tr):
"""渲染基础设置面板"""
with st.expander(tr("Basic Settings"), expanded=False):
config_panels = st.columns(3)
left_config_panel = config_panels[0]
middle_config_panel = config_panels[1]
right_config_panel = config_panels[2]
with left_config_panel:
render_language_settings(tr)
render_proxy_settings(tr)
with middle_config_panel:
render_vision_llm_settings(tr) # 视频分析模型设置
with right_config_panel:
render_text_llm_settings(tr) # 文案生成模型设置
def render_language_settings(tr):
st.subheader(tr("Proxy Settings"))
"""渲染语言设置"""
system_locale = utils.get_system_locale()
i18n_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "i18n")
locales = utils.load_locales(i18n_dir)
display_languages = []
selected_index = 0
for i, code in enumerate(locales.keys()):
display_languages.append(f"{code} - {locales[code].get('Language')}")
if code == st.session_state.get('ui_language', system_locale):
selected_index = i
selected_language = st.selectbox(
tr("Language"),
options=display_languages,
index=selected_index
)
if selected_language:
code = selected_language.split(" - ")[0].strip()
st.session_state['ui_language'] = code
config.ui['language'] = code
def render_proxy_settings(tr):
"""渲染代理设置"""
proxy_url_http = config.proxy.get("http", "") or os.getenv("VPN_PROXY_URL", "")
proxy_url_https = config.proxy.get("https", "") or os.getenv("VPN_PROXY_URL", "")
HTTP_PROXY = st.text_input(tr("HTTP_PROXY"), value=proxy_url_http)
HTTPS_PROXY = st.text_input(tr("HTTPs_PROXY"), value=proxy_url_https)
if HTTP_PROXY:
config.proxy["http"] = HTTP_PROXY
os.environ["HTTP_PROXY"] = HTTP_PROXY
if HTTPS_PROXY:
config.proxy["https"] = HTTPS_PROXY
os.environ["HTTPS_PROXY"] = HTTPS_PROXY
def render_vision_llm_settings(tr):
"""渲染视频分析模型设置"""
st.subheader(tr("Vision Model Settings"))
# 视频分析模型提供商选择
vision_providers = ['Gemini', 'NarratoAPI']
saved_vision_provider = config.app.get("vision_llm_provider", "Gemini").lower()
saved_provider_index = 0
for i, provider in enumerate(vision_providers):
if provider.lower() == saved_vision_provider:
saved_provider_index = i
break
vision_provider = st.selectbox(
tr("Vision Model Provider"),
options=vision_providers,
index=saved_provider_index
)
vision_provider = vision_provider.lower()
config.app["vision_llm_provider"] = vision_provider
st.session_state['vision_llm_providers'] = vision_provider
# 获取已保存的视觉模型配置
vision_api_key = config.app.get(f"vision_{vision_provider}_api_key", "")
vision_base_url = config.app.get(f"vision_{vision_provider}_base_url", "")
vision_model_name = config.app.get(f"vision_{vision_provider}_model_name", "")
# 渲染视觉模型配置输入框
st_vision_api_key = st.text_input(tr("Vision API Key"), value=vision_api_key, type="password")
st_vision_base_url = st.text_input(tr("Vision Base URL"), value=vision_base_url)
st_vision_model_name = st.text_input(tr("Vision Model Name"), value=vision_model_name)
# 保存视觉模型配置
if st_vision_api_key:
config.app[f"vision_{vision_provider}_api_key"] = st_vision_api_key
st.session_state[f"vision_{vision_provider}_api_key"] = st_vision_api_key # 用于script_settings.py
if st_vision_base_url:
config.app[f"vision_{vision_provider}_base_url"] = st_vision_base_url
st.session_state[f"vision_{vision_provider}_base_url"] = st_vision_base_url
if st_vision_model_name:
config.app[f"vision_{vision_provider}_model_name"] = st_vision_model_name
st.session_state[f"vision_{vision_provider}_model_name"] = st_vision_model_name
# NarratoAPI 特殊配置
if vision_provider == 'narratoapi':
st.subheader(tr("Narrato Additional Settings"))
# Narrato API 基础配置
narrato_api_key = st.text_input(
tr("Narrato API Key"),
value=config.app.get("narrato_api_key", ""),
type="password",
help="用于访问 Narrato API 的密钥"
)
if narrato_api_key:
config.app["narrato_api_key"] = narrato_api_key
st.session_state['narrato_api_key'] = narrato_api_key
narrato_api_url = st.text_input(
tr("Narrato API URL"),
value=config.app.get("narrato_api_url", "http://127.0.0.1:8000/api/v1/video/analyze")
)
if narrato_api_url:
config.app["narrato_api_url"] = narrato_api_url
st.session_state['narrato_api_url'] = narrato_api_url
# 视频分析模型配置
st.markdown("##### " + tr("Vision Model Settings"))
narrato_vision_model = st.text_input(
tr("Vision Model Name"),
value=config.app.get("narrato_vision_model", "gemini-1.5-flash")
)
narrato_vision_key = st.text_input(
tr("Vision Model API Key"),
value=config.app.get("narrato_vision_key", ""),
type="password",
help="用于视频分析的模型 API Key"
)
if narrato_vision_model:
config.app["narrato_vision_model"] = narrato_vision_model
st.session_state['narrato_vision_model'] = narrato_vision_model
if narrato_vision_key:
config.app["narrato_vision_key"] = narrato_vision_key
st.session_state['narrato_vision_key'] = narrato_vision_key
# 文案生成模型配置
st.markdown("##### " + tr("Text Generation Model Settings"))
narrato_llm_model = st.text_input(
tr("LLM Model Name"),
value=config.app.get("narrato_llm_model", "qwen-plus")
)
narrato_llm_key = st.text_input(
tr("LLM Model API Key"),
value=config.app.get("narrato_llm_key", ""),
type="password",
help="用于文案生成的模型 API Key"
)
if narrato_llm_model:
config.app["narrato_llm_model"] = narrato_llm_model
st.session_state['narrato_llm_model'] = narrato_llm_model
if narrato_llm_key:
config.app["narrato_llm_key"] = narrato_llm_key
st.session_state['narrato_llm_key'] = narrato_llm_key
# 批处理配置
narrato_batch_size = st.number_input(
tr("Batch Size"),
min_value=1,
max_value=50,
value=config.app.get("narrato_batch_size", 10),
help="每批处理的图片数量"
)
if narrato_batch_size:
config.app["narrato_batch_size"] = narrato_batch_size
st.session_state['narrato_batch_size'] = narrato_batch_size
def render_text_llm_settings(tr):
"""渲染文案生成模型设置"""
st.subheader(tr("Text Generation Model Settings"))
# 文案生成模型提供商选择
text_providers = ['OpenAI', 'Gemini', 'Moonshot', 'Azure', 'Qwen', 'Ollama', 'G4f', 'OneAPI', 'Cloudflare']
saved_text_provider = config.app.get("text_llm_provider", "OpenAI").lower()
saved_provider_index = 0
for i, provider in enumerate(text_providers):
if provider.lower() == saved_text_provider:
saved_provider_index = i
break
text_provider = st.selectbox(
tr("Text Model Provider"),
options=text_providers,
index=saved_provider_index
)
text_provider = text_provider.lower()
config.app["text_llm_provider"] = text_provider
# 获取已保存的文本模型配置
text_api_key = config.app.get(f"text_{text_provider}_api_key", "")
text_base_url = config.app.get(f"text_{text_provider}_base_url", "")
text_model_name = config.app.get(f"text_{text_provider}_model_name", "")
# 渲染文本模型配置输入框
st_text_api_key = st.text_input(tr("Text API Key"), value=text_api_key, type="password")
st_text_base_url = st.text_input(tr("Text Base URL"), value=text_base_url)
st_text_model_name = st.text_input(tr("Text Model Name"), value=text_model_name)
# 保存文本模型配置
if st_text_api_key:
config.app[f"text_{text_provider}_api_key"] = st_text_api_key
if st_text_base_url:
config.app[f"text_{text_provider}_base_url"] = st_text_base_url
if st_text_model_name:
config.app[f"text_{text_provider}_model_name"] = st_text_model_name
# Cloudflare 特殊配置
if text_provider == 'cloudflare':
st_account_id = st.text_input(
tr("Account ID"),
value=config.app.get(f"text_{text_provider}_account_id", "")
)
if st_account_id:
config.app[f"text_{text_provider}_account_id"] = st_account_id

View File

@ -0,0 +1,85 @@
import streamlit as st
import os
from loguru import logger
def render_review_panel(tr):
"""渲染视频审查面板"""
with st.expander(tr("Video Check"), expanded=False):
try:
video_list = st.session_state.get('video_clip_json', [])
subclip_videos = st.session_state.get('subclip_videos', {})
except KeyError:
video_list = []
subclip_videos = {}
# 计算列数和行数
num_videos = len(video_list)
cols_per_row = 3
rows = (num_videos + cols_per_row - 1) // cols_per_row # 向上取整计算行数
# 使用容器展示视频
for row in range(rows):
cols = st.columns(cols_per_row)
for col in range(cols_per_row):
index = row * cols_per_row + col
if index < num_videos:
with cols[col]:
render_video_item(tr, video_list, subclip_videos, index)
def render_video_item(tr, video_list, subclip_videos, index):
"""渲染单个视频项"""
video_script = video_list[index]
# 显示时间戳
timestamp = video_script.get('timestamp', '')
st.text_area(
tr("Timestamp"),
value=timestamp,
height=70,
disabled=True,
key=f"timestamp_{index}"
)
# 显示视频播放器
video_path = subclip_videos.get(timestamp)
if video_path and os.path.exists(video_path):
try:
st.video(video_path)
except Exception as e:
logger.error(f"加载视频失败 {video_path}: {e}")
st.error(f"无法加载视频: {os.path.basename(video_path)}")
else:
st.warning(tr("视频文件未找到"))
# 显示画面描述
st.text_area(
tr("Picture Description"),
value=video_script.get('picture', ''),
height=150,
disabled=True,
key=f"picture_{index}"
)
# 显示旁白文本
narration = st.text_area(
tr("Narration"),
value=video_script.get('narration', ''),
height=150,
key=f"narration_{index}"
)
# 保存修改后的旁白文本
if narration != video_script.get('narration', ''):
video_script['narration'] = narration
st.session_state['video_clip_json'] = video_list
# 显示剪辑模式
ost = st.selectbox(
tr("Clip Mode"),
options=range(1, 10),
index=video_script.get('OST', 1) - 1,
key=f"ost_{index}"
)
# 保存修改后的剪辑模式
if ost != video_script.get('OST', 1):
video_script['OST'] = ost
st.session_state['video_clip_json'] = video_list

View File

@ -0,0 +1,652 @@
import os
import glob
import json
import time
import asyncio
import traceback
import requests
import streamlit as st
from loguru import logger
from app.config import config
from app.models.schema import VideoClipParams
from app.utils.script_generator import ScriptProcessor
from app.utils import utils, check_script, vision_analyzer, video_processor
from webui.utils import file_utils
def render_script_panel(tr):
"""渲染脚本配置面板"""
with st.container(border=True):
st.write(tr("Video Script Configuration"))
params = VideoClipParams()
# 渲染脚本文件选择
render_script_file(tr, params)
# 渲染视频文件选择
render_video_file(tr, params)
# 渲染视频主题和提示词
render_video_details(tr)
# 渲染脚本操作按钮
render_script_buttons(tr, params)
def render_script_file(tr, params):
"""渲染脚本文件选择"""
script_list = [(tr("None"), ""), (tr("Auto Generate"), "auto")]
# 获取已有脚本文件
suffix = "*.json"
script_dir = utils.script_dir()
files = glob.glob(os.path.join(script_dir, suffix))
file_list = []
for file in files:
file_list.append({
"name": os.path.basename(file),
"file": file,
"ctime": os.path.getctime(file)
})
file_list.sort(key=lambda x: x["ctime"], reverse=True)
for file in file_list:
display_name = file['file'].replace(config.root_dir, "")
script_list.append((display_name, file['file']))
# 找到保存的脚本文件在列表中的索引
saved_script_path = st.session_state.get('video_clip_json_path', '')
selected_index = 0
for i, (_, path) in enumerate(script_list):
if path == saved_script_path:
selected_index = i
break
selected_script_index = st.selectbox(
tr("Script Files"),
index=selected_index, # 使用找到的索引
options=range(len(script_list)),
format_func=lambda x: script_list[x][0]
)
script_path = script_list[selected_script_index][1]
st.session_state['video_clip_json_path'] = script_path
params.video_clip_json_path = script_path
def render_video_file(tr, params):
"""渲染视频文件选择"""
video_list = [(tr("None"), ""), (tr("Upload Local Files"), "local")]
# 获取已有视频文件
for suffix in ["*.mp4", "*.mov", "*.avi", "*.mkv"]:
video_files = glob.glob(os.path.join(utils.video_dir(), suffix))
for file in video_files:
display_name = file.replace(config.root_dir, "")
video_list.append((display_name, file))
selected_video_index = st.selectbox(
tr("Video File"),
index=0,
options=range(len(video_list)),
format_func=lambda x: video_list[x][0]
)
video_path = video_list[selected_video_index][1]
st.session_state['video_origin_path'] = video_path
params.video_origin_path = video_path
if video_path == "local":
uploaded_file = st.file_uploader(
tr("Upload Local Files"),
type=["mp4", "mov", "avi", "flv", "mkv"],
accept_multiple_files=False,
)
if uploaded_file is not None:
video_file_path = os.path.join(utils.video_dir(), uploaded_file.name)
file_name, file_extension = os.path.splitext(uploaded_file.name)
if os.path.exists(video_file_path):
timestamp = time.strftime("%Y%m%d%H%M%S")
file_name_with_timestamp = f"{file_name}_{timestamp}"
video_file_path = os.path.join(utils.video_dir(), file_name_with_timestamp + file_extension)
with open(video_file_path, "wb") as f:
f.write(uploaded_file.read())
st.success(tr("File Uploaded Successfully"))
st.session_state['video_origin_path'] = video_file_path
params.video_origin_path = video_file_path
time.sleep(1)
st.rerun()
def render_video_details(tr):
"""渲染视频主题和提示词"""
video_theme = st.text_input(tr("Video Theme"))
custom_prompt = st.text_area(
tr("Generation Prompt"),
value=st.session_state.get('video_plot', ''),
help=tr("Custom prompt for LLM, leave empty to use default prompt"),
height=180
)
st.session_state['video_theme'] = video_theme
st.session_state['custom_prompt'] = custom_prompt
return video_theme, custom_prompt
def render_script_buttons(tr, params):
"""渲染脚本操作按钮"""
# 生成/加载按钮
script_path = st.session_state.get('video_clip_json_path', '')
if script_path == "auto":
button_name = tr("Generate Video Script")
elif script_path:
button_name = tr("Load Video Script")
else:
button_name = tr("Please Select Script File")
if st.button(button_name, key="script_action", disabled=not script_path):
if script_path == "auto":
generate_script(tr, params)
else:
load_script(tr, script_path)
# 视频脚本编辑区
video_clip_json_details = st.text_area(
tr("Video Script"),
value=json.dumps(st.session_state.get('video_clip_json', []), indent=2, ensure_ascii=False),
height=180
)
# 操作按钮行
button_cols = st.columns(3)
with button_cols[0]:
if st.button(tr("Check Format"), key="check_format", use_container_width=True):
check_script_format(tr, video_clip_json_details)
with button_cols[1]:
if st.button(tr("Save Script"), key="save_script", use_container_width=True):
save_script(tr, video_clip_json_details)
with button_cols[2]:
script_valid = st.session_state.get('script_format_valid', False)
if st.button(tr("Crop Video"), key="crop_video", disabled=not script_valid, use_container_width=True):
crop_video(tr, params)
def check_script_format(tr, script_content):
"""检查脚本格式"""
try:
result = check_script.check_format(script_content)
if result.get('success'):
st.success(tr("Script format check passed"))
st.session_state['script_format_valid'] = True
else:
st.error(f"{tr('Script format check failed')}: {result.get('message')}")
st.session_state['script_format_valid'] = False
except Exception as e:
st.error(f"{tr('Script format check error')}: {str(e)}")
st.session_state['script_format_valid'] = False
def load_script(tr, script_path):
"""加载脚本文件"""
try:
with open(script_path, 'r', encoding='utf-8') as f:
script = f.read()
script = utils.clean_model_output(script)
st.session_state['video_clip_json'] = json.loads(script)
st.success(tr("Script loaded successfully"))
st.rerun()
except Exception as e:
st.error(f"{tr('Failed to load script')}: {str(e)}")
def generate_script(tr, params):
"""生成视频脚本"""
progress_bar = st.progress(0)
status_text = st.empty()
def update_progress(progress: float, message: str = ""):
progress_bar.progress(progress)
if message:
status_text.text(f"{progress}% - {message}")
else:
status_text.text(f"进度: {progress}%")
try:
with st.spinner("正在生成脚本..."):
if not params.video_origin_path:
st.error("请先选择视频文件")
st.stop()
return
# ===================提取键帧===================
update_progress(10, "正在提取关键帧...")
# 创建临时目录用于存储关键帧
keyframes_dir = os.path.join(utils.temp_dir(), "keyframes")
video_hash = utils.md5(params.video_origin_path + str(os.path.getmtime(params.video_origin_path)))
video_keyframes_dir = os.path.join(keyframes_dir, video_hash)
# 检查是否已经提取过关键帧
keyframe_files = []
if os.path.exists(video_keyframes_dir):
# 获取已有的关键帧文件
for filename in sorted(os.listdir(video_keyframes_dir)):
if filename.endswith('.jpg'):
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
if keyframe_files:
logger.info(f"使用已缓存的关键帧: {video_keyframes_dir}")
st.info(f"使用已缓存的关键帧,如需重新提取请删除目录: {video_keyframes_dir}")
update_progress(20, f"使用已缓存关键帧,共 {len(keyframe_files)}")
# 如果没有缓存的关键帧,则进行提取
if not keyframe_files:
try:
# 确保目录存在
os.makedirs(video_keyframes_dir, exist_ok=True)
# 初始化视频处理器
processor = video_processor.VideoProcessor(params.video_origin_path)
# 处理视频并提取关键帧
processor.process_video(
output_dir=video_keyframes_dir,
skip_seconds=0
)
# 获取所有关键帧文件路径
for filename in sorted(os.listdir(video_keyframes_dir)):
if filename.endswith('.jpg'):
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
if not keyframe_files:
raise Exception("未提取到任何关键帧")
update_progress(20, f"关键帧提取完成,共 {len(keyframe_files)}")
except Exception as e:
# 如果提取失败,清理创建的目录
try:
if os.path.exists(video_keyframes_dir):
import shutil
shutil.rmtree(video_keyframes_dir)
except Exception as cleanup_err:
logger.error(f"清理失败的关键帧目录时出错: {cleanup_err}")
raise Exception(f"关键帧提取失败: {str(e)}")
# 根据不同的 LLM 提供商处理
vision_llm_provider = st.session_state.get('vision_llm_providers').lower()
logger.debug(f"Vision LLM 提供商: {vision_llm_provider}")
if vision_llm_provider == 'gemini':
try:
# ===================初始化视觉分析器===================
update_progress(30, "正在初始化视觉分析器...")
# 从配置中获取 Gemini 相关配置
vision_api_key = st.session_state.get('vision_gemini_api_key')
vision_model = st.session_state.get('vision_gemini_model_name')
vision_base_url = st.session_state.get('vision_gemini_base_url')
if not vision_api_key or not vision_model:
raise ValueError("未配置 Gemini API Key 或者 模型,请在基础设置中配置")
analyzer = vision_analyzer.VisionAnalyzer(
model_name=vision_model,
api_key=vision_api_key
)
update_progress(40, "正在分析关键帧...")
# ===================创建异步事件循环===================
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 执行异步分析
results = loop.run_until_complete(
analyzer.analyze_images(
images=keyframe_files,
prompt=config.app.get('vision_analysis_prompt'),
batch_size=config.app.get("vision_batch_size", 5)
)
)
loop.close()
# ===================处理分析结果===================
update_progress(60, "正在整理分析结果...")
# 合并所有批次的析结果
frame_analysis = ""
for result in results:
if 'error' in result:
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
continue
# 获取当前批次的图片文件
batch_start = result['batch_index'] * config.app.get("vision_batch_size", 5)
batch_end = min(batch_start + config.app.get("vision_batch_size", 5), len(keyframe_files))
batch_files = keyframe_files[batch_start:batch_end]
# 提取首帧和尾帧的时间戳
first_frame = os.path.basename(batch_files[0])
last_frame = os.path.basename(batch_files[-1])
# 从文件名中提取时间信息
first_time = first_frame.split('_')[2].replace('.jpg', '') # 000002
last_time = last_frame.split('_')[2].replace('.jpg', '') # 000005
# 转换为分:秒格式
def format_timestamp(time_str):
seconds = int(time_str)
minutes = seconds // 60
seconds = seconds % 60
return f"{minutes:02d}:{seconds:02d}"
first_timestamp = format_timestamp(first_time)
last_timestamp = format_timestamp(last_time)
# 添加带时间戳的分析结果
frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n"
frame_analysis += result['response']
frame_analysis += "\n"
if not frame_analysis.strip():
raise Exception("未能生成有效的帧分析结果")
# 保存分析结果
analysis_path = os.path.join(utils.temp_dir(), "frame_analysis.txt")
with open(analysis_path, 'w', encoding='utf-8') as f:
f.write(frame_analysis)
update_progress(70, "正在生成脚本...")
# 从配置中获取文本生成相关配置
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
text_api_key = config.app.get(f'text_{text_provider}_api_key')
text_model = config.app.get(f'text_{text_provider}_model_name')
text_base_url = config.app.get(f'text_{text_provider}_base_url')
# 构建帧内容列表
frame_content_list = []
for i, result in enumerate(results):
if 'error' in result:
continue
# 获取当前批次的图片文件
batch_start = result['batch_index'] * config.app.get("vision_batch_size", 5)
batch_end = min(batch_start + config.app.get("vision_batch_size", 5), len(keyframe_files))
batch_files = keyframe_files[batch_start:batch_end]
# 提取首帧和尾帧的时间戳
first_frame = os.path.basename(batch_files[0])
last_frame = os.path.basename(batch_files[-1])
# 从文件名中提取时间信息
first_time = first_frame.split('_')[2].replace('.jpg', '') # 000002
last_time = last_frame.split('_')[2].replace('.jpg', '') # 000005
# 转换为分:秒格式
def format_timestamp(time_str):
seconds = int(time_str)
minutes = seconds // 60
seconds = seconds % 60
return f"{minutes:02d}:{seconds:02d}"
first_timestamp = format_timestamp(first_time)
last_timestamp = format_timestamp(last_time)
# 构建时间戳范围字符串 (MM:SS-MM:SS 格式)
timestamp_range = f"{first_timestamp}-{last_timestamp}"
frame_content = {
"timestamp": timestamp_range, # 使用时间范围格式 "MM:SS-MM:SS"
"picture": result['response'], # 图片分析结果
"narration": "", # 将由 ScriptProcessor 生成
"OST": 2 # 默认值
}
frame_content_list.append(frame_content)
logger.debug(f"添加帧内容: 时间范围={timestamp_range}, 分析结果长度={len(result['response'])}")
if not frame_content_list:
raise Exception("没有有效的帧内容可以处理")
# ===================开始生成文案===================
update_progress(90, "正在生成文案...")
# 校验配置
api_params = {
'batch_size': st.session_state.get('narrato_batch_size', 10),
'use_ai': False,
'start_offset': 0,
'vision_model': vision_model,
'vision_api_key': vision_api_key,
'llm_model': text_model,
'llm_api_key': text_api_key,
'custom_prompt': st.session_state.get('custom_prompt', '')
}
response = requests.post(
f"{config.app.get('narrato_api_url')}/video/config",
params=api_params,
timeout=30,
verify=False
)
custom_prompt = st.session_state.get('custom_prompt', '')
processor = ScriptProcessor(
model_name=text_model,
api_key=text_api_key,
prompt=custom_prompt,
video_theme=st.session_state.get('video_theme', '')
)
# 处理帧内容生成脚本
script_result = processor.process_frames(frame_content_list)
# 将结果转换为JSON字符串
script = json.dumps(script_result, ensure_ascii=False, indent=2)
except Exception as e:
logger.exception(f"Gemini 处理过程中发生错误\n{traceback.format_exc()}")
raise Exception(f"视觉分析失败: {str(e)}")
elif vision_llm_provider == 'narratoapi': # NarratoAPI
try:
# 创建临时目录
temp_dir = utils.temp_dir("narrato")
# 打包关键帧
update_progress(30, "正在打包关键帧...")
zip_path = os.path.join(temp_dir, f"keyframes_{int(time.time())}.zip")
if not file_utils.create_zip(keyframe_files, zip_path):
raise Exception("打包关键帧失败")
# 获取API配置
api_url = st.session_state.get('narrato_api_url')
api_key = st.session_state.get('narrato_api_key')
if not api_key:
raise ValueError("未配置 Narrato API Key请在基础设置中配置")
# 准备API请求
headers = {
'X-API-Key': api_key,
'accept': 'application/json'
}
api_params = {
'batch_size': st.session_state.get('narrato_batch_size', 10),
'use_ai': False,
'start_offset': 0,
'vision_model': st.session_state.get('narrato_vision_model', 'gemini-1.5-flash'),
'vision_api_key': st.session_state.get('narrato_vision_key'),
'llm_model': st.session_state.get('narrato_llm_model', 'qwen-plus'),
'llm_api_key': st.session_state.get('narrato_llm_key'),
'custom_prompt': st.session_state.get('custom_prompt', '')
}
# 发送API请求
logger.info(f"请求NarratoAPI: {api_url}")
update_progress(40, "正在上传文件...")
with open(zip_path, 'rb') as f:
files = {'file': (os.path.basename(zip_path), f, 'application/x-zip-compressed')}
try:
response = requests.post(
f"{api_url}/video/analyze",
headers=headers,
params=api_params,
files=files,
timeout=30 # 设置超时时间
)
response.raise_for_status()
except requests.RequestException as e:
logger.error(f"Narrato API 请求失败:\n{traceback.format_exc()}")
raise Exception(f"API请求失败: {str(e)}")
task_data = response.json()
task_id = task_data["data"].get('task_id')
if not task_id:
raise Exception(f"无效的API响应: {response.text}")
# 轮询任务状态
update_progress(50, "正在等待分析结果...")
retry_count = 0
max_retries = 60 # 最多等待2分钟
while retry_count < max_retries:
try:
status_response = requests.get(
f"{api_url}/video/tasks/{task_id}",
headers=headers,
timeout=10
)
status_response.raise_for_status()
task_status = status_response.json()['data']
if task_status['status'] == 'SUCCESS':
script = task_status['result']['data']
break
elif task_status['status'] in ['FAILURE', 'RETRY']:
raise Exception(f"任务失败: {task_status.get('error')}")
retry_count += 1
time.sleep(2)
except requests.RequestException as e:
logger.warning(f"获取任务状态失败,重试中: {str(e)}")
retry_count += 1
time.sleep(2)
continue
if retry_count >= max_retries:
raise Exception("任务执行超时")
except Exception as e:
logger.exception(f"NarratoAPI 处理过程中发生错误\n{traceback.format_exc()}")
raise Exception(f"NarratoAPI 处理失败: {str(e)}")
finally:
# 清理临时文件
try:
if os.path.exists(zip_path):
os.remove(zip_path)
except Exception as e:
logger.warning(f"清理临时文件失败: {str(e)}")
else:
logger.exception("Vision Model 未启用,请检查配置")
if script is None:
st.error("生成脚本失败,请检查日志")
st.stop()
logger.info(f"脚本生成完成\n{script} \n{type(script)}")
if isinstance(script, list):
st.session_state['video_clip_json'] = script
elif isinstance(script, str):
st.session_state['video_clip_json'] = json.loads(script)
update_progress(90, "脚本生成完成")
time.sleep(0.5)
progress_bar.progress(100)
status_text.text("脚本生成完成!")
st.success("视频脚本生成成功!")
except Exception as err:
st.error(f"生成过程中发生错误: {str(err)}")
logger.exception(f"生成脚本时发生错误\n{traceback.format_exc()}")
finally:
time.sleep(2)
progress_bar.empty()
status_text.empty()
def save_script(tr, video_clip_json_details):
"""保存视频脚本"""
if not video_clip_json_details:
st.error(tr("请输入视频脚本"))
st.stop()
with st.spinner(tr("Save Script")):
script_dir = utils.script_dir()
timestamp = time.strftime("%Y-%m%d-%H%M%S")
save_path = os.path.join(script_dir, f"{timestamp}.json")
try:
data = json.loads(video_clip_json_details)
with open(save_path, 'w', encoding='utf-8') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
st.session_state['video_clip_json'] = data
st.session_state['video_clip_json_path'] = save_path
# 更新配置
config.app["video_clip_json_path"] = save_path
# 显示成功消息
st.success(tr("Script saved successfully"))
# 强制重新加载页面<E9A1B5><E99DA2>更新选择框
time.sleep(0.5) # 给一点时间让用户看到成功消息
st.rerun()
except Exception as err:
st.error(f"{tr('Failed to save script')}: {str(err)}")
st.stop()
def crop_video(tr, params):
"""裁剪视频"""
progress_bar = st.progress(0)
status_text = st.empty()
def update_progress(progress):
progress_bar.progress(progress)
status_text.text(f"剪辑进度: {progress}%")
try:
utils.cut_video(params, update_progress)
time.sleep(0.5)
progress_bar.progress(100)
status_text.text("剪辑完成!")
st.success("视频剪辑成功完成!")
except Exception as e:
st.error(f"剪辑过程中发生错误: {str(e)}")
finally:
time.sleep(2)
progress_bar.empty()
status_text.empty()
def get_script_params():
"""获取脚本参数"""
return {
'video_language': st.session_state.get('video_language', ''),
'video_clip_json_path': st.session_state.get('video_clip_json_path', ''),
'video_origin_path': st.session_state.get('video_origin_path', ''),
'video_name': st.session_state.get('video_name', ''),
'video_plot': st.session_state.get('video_plot', '')
}

View File

@ -0,0 +1,129 @@
import streamlit as st
from app.config import config
from webui.utils.cache import get_fonts_cache
import os
def render_subtitle_panel(tr):
"""渲染字幕设置面板"""
with st.container(border=True):
st.write(tr("Subtitle Settings"))
# 启用字幕选项
enable_subtitles = st.checkbox(tr("Enable Subtitles"), value=True)
st.session_state['subtitle_enabled'] = enable_subtitles
if enable_subtitles:
render_font_settings(tr)
render_position_settings(tr)
render_style_settings(tr)
def render_font_settings(tr):
"""渲染字体设置"""
# 获取字体列表
font_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "resource", "fonts")
font_names = get_fonts_cache(font_dir)
# 获取保存的字体设置
saved_font_name = config.ui.get("font_name", "")
saved_font_name_index = 0
if saved_font_name in font_names:
saved_font_name_index = font_names.index(saved_font_name)
# 字体选择
font_name = st.selectbox(
tr("Font"),
options=font_names,
index=saved_font_name_index
)
config.ui["font_name"] = font_name
st.session_state['font_name'] = font_name
# 字体大小
font_cols = st.columns([0.3, 0.7])
with font_cols[0]:
saved_text_fore_color = config.ui.get("text_fore_color", "#FFFFFF")
text_fore_color = st.color_picker(
tr("Font Color"),
saved_text_fore_color
)
config.ui["text_fore_color"] = text_fore_color
st.session_state['text_fore_color'] = text_fore_color
with font_cols[1]:
saved_font_size = config.ui.get("font_size", 60)
font_size = st.slider(
tr("Font Size"),
min_value=30,
max_value=100,
value=saved_font_size
)
config.ui["font_size"] = font_size
st.session_state['font_size'] = font_size
def render_position_settings(tr):
"""渲染位置设置"""
subtitle_positions = [
(tr("Top"), "top"),
(tr("Center"), "center"),
(tr("Bottom"), "bottom"),
(tr("Custom"), "custom"),
]
selected_index = st.selectbox(
tr("Position"),
index=2,
options=range(len(subtitle_positions)),
format_func=lambda x: subtitle_positions[x][0],
)
subtitle_position = subtitle_positions[selected_index][1]
st.session_state['subtitle_position'] = subtitle_position
# 自定义位置处理
if subtitle_position == "custom":
custom_position = st.text_input(
tr("Custom Position (% from top)"),
value="70.0"
)
try:
custom_position_value = float(custom_position)
if custom_position_value < 0 or custom_position_value > 100:
st.error(tr("Please enter a value between 0 and 100"))
else:
st.session_state['custom_position'] = custom_position_value
except ValueError:
st.error(tr("Please enter a valid number"))
def render_style_settings(tr):
"""渲染样式设置"""
stroke_cols = st.columns([0.3, 0.7])
with stroke_cols[0]:
stroke_color = st.color_picker(
tr("Stroke Color"),
value="#000000"
)
st.session_state['stroke_color'] = stroke_color
with stroke_cols[1]:
stroke_width = st.slider(
tr("Stroke Width"),
min_value=0.0,
max_value=10.0,
value=1.5,
step=0.1
)
st.session_state['stroke_width'] = stroke_width
def get_subtitle_params():
"""获取字幕参数"""
return {
'enabled': st.session_state.get('subtitle_enabled', True),
'font_name': st.session_state.get('font_name', ''),
'font_size': st.session_state.get('font_size', 60),
'text_fore_color': st.session_state.get('text_fore_color', '#FFFFFF'),
'position': st.session_state.get('subtitle_position', 'bottom'),
'custom_position': st.session_state.get('custom_position', 70.0),
'stroke_color': st.session_state.get('stroke_color', '#000000'),
'stroke_width': st.session_state.get('stroke_width', 1.5),
}

View File

@ -0,0 +1,47 @@
import streamlit as st
from app.models.schema import VideoClipParams, VideoAspect
def render_video_panel(tr):
"""渲染视频配置面板"""
with st.container(border=True):
st.write(tr("Video Settings"))
params = VideoClipParams()
render_video_config(tr, params)
def render_video_config(tr, params):
"""渲染视频配置"""
# 视频比例
video_aspect_ratios = [
(tr("Portrait"), VideoAspect.portrait.value),
(tr("Landscape"), VideoAspect.landscape.value),
]
selected_index = st.selectbox(
tr("Video Ratio"),
options=range(len(video_aspect_ratios)),
format_func=lambda x: video_aspect_ratios[x][0],
)
params.video_aspect = VideoAspect(video_aspect_ratios[selected_index][1])
st.session_state['video_aspect'] = params.video_aspect.value
# 视频画质
video_qualities = [
("4K (2160p)", "2160p"),
("2K (1440p)", "1440p"),
("Full HD (1080p)", "1080p"),
("HD (720p)", "720p"),
("SD (480p)", "480p"),
]
quality_index = st.selectbox(
tr("Video Quality"),
options=range(len(video_qualities)),
format_func=lambda x: video_qualities[x][0],
index=2 # 默认选择 1080p
)
st.session_state['video_quality'] = video_qualities[quality_index][1]
def get_video_params():
"""获取视频参数"""
return {
'video_aspect': st.session_state.get('video_aspect', VideoAspect.portrait.value),
'video_quality': st.session_state.get('video_quality', '1080p')
}

167
webui/config/settings.py Normal file
View File

@ -0,0 +1,167 @@
import os
import tomli
from loguru import logger
from typing import Dict, Any, Optional
from dataclasses import dataclass
@dataclass
class WebUIConfig:
"""WebUI配置类"""
# UI配置
ui: Dict[str, Any] = None
# 代理配置
proxy: Dict[str, str] = None
# 应用配置
app: Dict[str, Any] = None
# Azure配置
azure: Dict[str, str] = None
# 项目版本
project_version: str = "0.1.0"
# 项目根目录
root_dir: str = None
# Gemini API Key
gemini_api_key: str = ""
# 每批处理的图片数量
vision_batch_size: int = 5
# 提示词
vision_prompt: str = """..."""
# Narrato API 配置
narrato_api_url: str = "http://127.0.0.1:8000/api/v1/video/analyze"
narrato_api_key: str = ""
narrato_batch_size: int = 10
narrato_vision_model: str = "gemini-1.5-flash"
narrato_llm_model: str = "qwen-plus"
def __post_init__(self):
"""初始化默认值"""
self.ui = self.ui or {}
self.proxy = self.proxy or {}
self.app = self.app or {}
self.azure = self.azure or {}
self.root_dir = self.root_dir or os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
def load_config(config_path: Optional[str] = None) -> WebUIConfig:
"""加载配置文件
Args:
config_path: 配置文件路径如果为None则使用默认路径
Returns:
WebUIConfig: 配置对象
"""
try:
if config_path is None:
config_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
".streamlit",
"webui.toml"
)
# 如果配置文件不存在,使用示例配置
if not os.path.exists(config_path):
example_config = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
"config.example.toml"
)
if os.path.exists(example_config):
config_path = example_config
else:
logger.warning(f"配置文件不存在: {config_path}")
return WebUIConfig()
# 读取配置文件
with open(config_path, "rb") as f:
config_dict = tomli.load(f)
# 创建配置对象
config = WebUIConfig(
ui=config_dict.get("ui", {}),
proxy=config_dict.get("proxy", {}),
app=config_dict.get("app", {}),
azure=config_dict.get("azure", {}),
project_version=config_dict.get("project_version", "0.1.0")
)
return config
except Exception as e:
logger.error(f"加载配置文件失败: {e}")
return WebUIConfig()
def save_config(config: WebUIConfig, config_path: Optional[str] = None) -> bool:
"""保存配置到文件
Args:
config: 配置对象
config_path: 配置文件路径如果为None则使用默认路径
Returns:
bool: 是否保存成功
"""
try:
if config_path is None:
config_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
".streamlit",
"webui.toml"
)
# 确保目录存在
os.makedirs(os.path.dirname(config_path), exist_ok=True)
# 转换为字典
config_dict = {
"ui": config.ui,
"proxy": config.proxy,
"app": config.app,
"azure": config.azure,
"project_version": config.project_version
}
# 保存配置
with open(config_path, "w", encoding="utf-8") as f:
import tomli_w
tomli_w.dump(config_dict, f)
return True
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
return False
def get_config() -> WebUIConfig:
"""获取全局配置对象
Returns:
WebUIConfig: 配置对象
"""
if not hasattr(get_config, "_config"):
get_config._config = load_config()
return get_config._config
def update_config(config_dict: Dict[str, Any]) -> bool:
"""更新配置
Args:
config_dict: 配置字典
Returns:
bool: 是否更新成功
"""
try:
config = get_config()
# 更新配置
if "ui" in config_dict:
config.ui.update(config_dict["ui"])
if "proxy" in config_dict:
config.proxy.update(config_dict["proxy"])
if "app" in config_dict:
config.app.update(config_dict["app"])
if "azure" in config_dict:
config.azure.update(config_dict["azure"])
if "project_version" in config_dict:
config.project_version = config_dict["project_version"]
# 保存配置
return save_config(config)
except Exception as e:
logger.error(f"更新配置失败: {e}")
return False
# 导出全局配置对象
config = get_config()

1
webui/i18n/__init__.py Normal file
View File

@ -0,0 +1 @@
# 空文件,用于标记包

View File

@ -2,20 +2,20 @@
"Language": "简体中文",
"Translation": {
"Video Script Configuration": "**视频脚本配置**",
"Video Script Generate": "生成视频脚本",
"Generate Video Script": "生成视频脚本",
"Video Subject": "视频主题(给定一个关键词,:red[AI自动生成]视频文案)",
"Script Language": "生成视频脚本的语言一般情况AI会自动根据你输入的主题语言输出",
"Script Files": "脚本文件",
"Generate Video Script and Keywords": "点击使用AI根据**主题**生成 【视频文案】 和 【视频关键词】",
"Auto Detect": "自动检测",
"Auto Generate": "自动生成",
"Video Name": "视频名称",
"Video Script": "视频脚本(:blue[①使用AI生成 ②从本机加载]",
"Video Theme": "视频主题",
"Generation Prompt": "自定义提示词",
"Save Script": "保存脚本",
"Crop Video": "裁剪视频",
"Video File": "视频文件(:blue[1⃣支持上传视频文件(限制2G) 2⃣大文件建议直接导入 ./resource/videos 目录]",
"Plot Description": "剧情描述 (:blue[可从 https://www.tvmao.com/ 获取])",
"Generate Video Keywords": "点击使用AI根据**文案**生成【视频关键】",
"Generate Video Keywords": "点击使用AI根据**文案**生成【视频关键<EFBFBD><EFBFBD>】",
"Please Enter the Video Subject": "请先填写视频文案",
"Generating Video Script and Keywords": "AI正在生成视频文案和关键词...",
"Generating Video Keywords": "AI正在生成视频关键词...",
@ -91,6 +91,40 @@
"Picture description": "图片描述",
"Narration": "视频文案",
"Rebuild": "重新生成",
"Video Script Load": "加载视频脚本"
"Load Video Script": "加载视频脚本",
"Speech Pitch": "语调",
"Please Select Script File": "请选择脚本文件",
"Check Format": "脚本格式检查",
"Script Loaded Successfully": "脚本加载成功",
"Script format check passed": "脚本格式检查通过",
"Script format check failed": "脚本格式检查失<E69FA5><E5A4B1>",
"Failed to Load Script": "加载脚本失败",
"Failed to Save Script": "保存脚本失败",
"Script saved successfully": "脚本保存成功",
"Video Script": "视频脚本",
"Video Quality": "视频质量",
"Custom prompt for LLM, leave empty to use default prompt": "自定义提示词,留空则使用默认提示词",
"Basic Settings": "基础设置",
"Proxy Settings": "代理设置",
"Language": "界面语言",
"HTTP_PROXY": "HTTP 代理",
"HTTPs_PROXY": "HTTPS 代理",
"Vision Model Settings": "视频分析模型设置",
"Vision Model Provider": "视频分析模型提供商",
"Vision API Key": "视频分析 API 密钥",
"Vision Base URL": "视频分析接口地址",
"Vision Model Name": "视频分析模型名称",
"Narrato Additional Settings": "Narrato 附加设置",
"Narrato API Key": "Narrato API 密钥",
"Narrato API URL": "Narrato API 地址",
"Text Generation Model Settings": "文案生成模型设置",
"LLM Model Name": "大语言模型名称",
"LLM Model API Key": "大语言模型 API 密钥",
"Batch Size": "批处理大小",
"Text Model Provider": "文案生成模型提供商",
"Text API Key": "文案生成 API 密钥",
"Text Base URL": "文案生成接口地址",
"Text Model Name": "文案生成模型名称",
"Account ID": "账户 ID"
}
}

8
webui/utils/__init__.py Normal file
View File

@ -0,0 +1,8 @@
from .performance import monitor_performance, PerformanceMonitor
from .cache import *
from .file_utils import *
__all__ = [
'monitor_performance',
'PerformanceMonitor'
]

33
webui/utils/cache.py Normal file
View File

@ -0,0 +1,33 @@
import streamlit as st
import os
import glob
from app.utils import utils
def get_fonts_cache(font_dir):
if 'fonts_cache' not in st.session_state:
fonts = []
for root, dirs, files in os.walk(font_dir):
for file in files:
if file.endswith(".ttf") or file.endswith(".ttc"):
fonts.append(file)
fonts.sort()
st.session_state['fonts_cache'] = fonts
return st.session_state['fonts_cache']
def get_video_files_cache():
if 'video_files_cache' not in st.session_state:
video_files = []
for suffix in ["*.mp4", "*.mov", "*.avi", "*.mkv"]:
video_files.extend(glob.glob(os.path.join(utils.video_dir(), suffix)))
st.session_state['video_files_cache'] = video_files[::-1]
return st.session_state['video_files_cache']
def get_songs_cache(song_dir):
if 'songs_cache' not in st.session_state:
songs = []
for root, dirs, files in os.walk(song_dir):
for file in files:
if file.endswith(".mp3"):
songs.append(file)
st.session_state['songs_cache'] = songs
return st.session_state['songs_cache']

230
webui/utils/file_utils.py Normal file
View File

@ -0,0 +1,230 @@
import os
import glob
import time
import platform
import shutil
from uuid import uuid4
from loguru import logger
from app.utils import utils
def open_task_folder(root_dir, task_id):
"""打开任务文件夹
Args:
root_dir: 项目根目录
task_id: 任务ID
"""
try:
sys = platform.system()
path = os.path.join(root_dir, "storage", "tasks", task_id)
if os.path.exists(path):
if sys == 'Windows':
os.system(f"start {path}")
if sys == 'Darwin':
os.system(f"open {path}")
if sys == 'Linux':
os.system(f"xdg-open {path}")
except Exception as e:
logger.error(f"打开任务文件夹失败: {e}")
def cleanup_temp_files(temp_dir, max_age=3600):
"""清理临时文件
Args:
temp_dir: 临时文件目录
max_age: 文件最大保存时间()
"""
if os.path.exists(temp_dir):
for file in os.listdir(temp_dir):
file_path = os.path.join(temp_dir, file)
try:
if os.path.getctime(file_path) < time.time() - max_age:
if os.path.isfile(file_path):
os.remove(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
logger.debug(f"已清理临时文件: {file_path}")
except Exception as e:
logger.error(f"清理临时文件失败: {file_path}, 错误: {e}")
def get_file_list(directory, file_types=None, sort_by='ctime', reverse=True):
"""获取指定目录下的文件列表
Args:
directory: 目录路径
file_types: 文件类型列表 ['.mp4', '.mov']
sort_by: 排序方式支持 'ctime'(创建时间), 'mtime'(修改时间), 'size'(文件大小), 'name'(文件名)
reverse: 是否倒序排序
Returns:
list: 文件信息列表
"""
if not os.path.exists(directory):
return []
files = []
if file_types:
for file_type in file_types:
files.extend(glob.glob(os.path.join(directory, f"*{file_type}")))
else:
files = glob.glob(os.path.join(directory, "*"))
file_list = []
for file_path in files:
try:
file_stat = os.stat(file_path)
file_info = {
"name": os.path.basename(file_path),
"path": file_path,
"size": file_stat.st_size,
"ctime": file_stat.st_ctime,
"mtime": file_stat.st_mtime
}
file_list.append(file_info)
except Exception as e:
logger.error(f"获取文件信息失败: {file_path}, 错误: {e}")
# 排序
if sort_by in ['ctime', 'mtime', 'size', 'name']:
file_list.sort(key=lambda x: x.get(sort_by, ''), reverse=reverse)
return file_list
def save_uploaded_file(uploaded_file, save_dir, allowed_types=None):
"""保存上传的文件
Args:
uploaded_file: StreamlitUploadedFile对象
save_dir: 保存目录
allowed_types: 允许的文件类型列表 ['.mp4', '.mov']
Returns:
str: 保存后的文件路径失败返回None
"""
try:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
file_name, file_extension = os.path.splitext(uploaded_file.name)
# 检查文件类型
if allowed_types and file_extension.lower() not in allowed_types:
logger.error(f"不支持的文件类型: {file_extension}")
return None
# 如果文件已存在,添加时间戳
save_path = os.path.join(save_dir, uploaded_file.name)
if os.path.exists(save_path):
timestamp = time.strftime("%Y%m%d%H%M%S")
new_file_name = f"{file_name}_{timestamp}{file_extension}"
save_path = os.path.join(save_dir, new_file_name)
# 保存文件
with open(save_path, "wb") as f:
f.write(uploaded_file.read())
logger.info(f"文件保存成功: {save_path}")
return save_path
except Exception as e:
logger.error(f"保存上传文件失败: {e}")
return None
def create_temp_file(prefix='tmp', suffix='', directory=None):
"""创建临时文件
Args:
prefix: 文件名前缀
suffix: 文件扩展名
directory: 临时文件目录默认使用系统临时目录
Returns:
str: 临时文件路径
"""
try:
if directory is None:
directory = utils.storage_dir("temp", create=True)
if not os.path.exists(directory):
os.makedirs(directory)
temp_file = os.path.join(directory, f"{prefix}-{str(uuid4())}{suffix}")
return temp_file
except Exception as e:
logger.error(f"创建临时文件失败: {e}")
return None
def get_file_size(file_path, format='MB'):
"""获取文件大小
Args:
file_path: 文件路径
format: 返回格式支持 'B', 'KB', 'MB', 'GB'
Returns:
float: 文件大小
"""
try:
size_bytes = os.path.getsize(file_path)
if format.upper() == 'B':
return size_bytes
elif format.upper() == 'KB':
return size_bytes / 1024
elif format.upper() == 'MB':
return size_bytes / (1024 * 1024)
elif format.upper() == 'GB':
return size_bytes / (1024 * 1024 * 1024)
else:
return size_bytes
except Exception as e:
logger.error(f"获取文件大小失败: {file_path}, 错误: {e}")
return 0
def ensure_directory(directory):
"""确保目录存在,如果不存在则创建
Args:
directory: 目录路径
Returns:
bool: 是否成功
"""
try:
if not os.path.exists(directory):
os.makedirs(directory)
return True
except Exception as e:
logger.error(f"创建目录失败: {directory}, 错误: {e}")
return False
def create_zip(files: list, zip_path: str, base_dir: str = None, folder_name: str = "demo") -> bool:
"""
创建zip文件
Args:
files: 要打包的文件列表
zip_path: zip文件保存路径
base_dir: 基础目录用于保持目录结构
folder_name: zip解压后的文件夹名称默认为frames
Returns:
bool: 是否成功
"""
try:
import zipfile
# 确保目标目录存在
os.makedirs(os.path.dirname(zip_path), exist_ok=True)
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
for file in files:
if not os.path.exists(file):
logger.warning(f"文件不存在,跳过: {file}")
continue
# 计算文件在zip中的路径添加folder_name作为前缀目录
if base_dir:
arcname = os.path.join(folder_name, os.path.relpath(file, base_dir))
else:
arcname = os.path.join(folder_name, os.path.basename(file))
try:
zipf.write(file, arcname)
except Exception as e:
logger.error(f"添加文件到zip失败: {file}, 错误: {e}")
continue
return True
except Exception as e:
logger.error(f"创建zip文件失败: {e}")
return False

View File

@ -0,0 +1,37 @@
import psutil
import os
from loguru import logger
import torch
class PerformanceMonitor:
@staticmethod
def monitor_memory():
process = psutil.Process(os.getpid())
memory_info = process.memory_info()
logger.debug(f"Memory usage: {memory_info.rss / 1024 / 1024:.2f} MB")
if torch.cuda.is_available():
gpu_memory = torch.cuda.memory_allocated() / 1024 / 1024
logger.debug(f"GPU Memory usage: {gpu_memory:.2f} MB")
@staticmethod
def cleanup_resources():
if torch.cuda.is_available():
torch.cuda.empty_cache()
import gc
gc.collect()
PerformanceMonitor.monitor_memory()
def monitor_performance(func):
"""性能监控装饰器"""
def wrapper(*args, **kwargs):
try:
PerformanceMonitor.monitor_memory()
result = func(*args, **kwargs)
return result
finally:
PerformanceMonitor.cleanup_resources()
return wrapper