webui 代码重构;

This commit is contained in:
linyqh 2024-11-09 02:26:39 +08:00
parent 242f8d5355
commit bb18a754fe
19 changed files with 1592 additions and 889 deletions

View File

@ -29,7 +29,7 @@ def create(audio_file, subtitle_file: str = ""):
返回:
无返回值但会在指定路径生成字幕文件
"""
global model
global model, device, compute_type
if not model:
model_path = f"{utils.root_dir()}/app/models/faster-whisper-large-v2"
model_bin_file = f"{model_path}/model.bin"
@ -43,27 +43,45 @@ def create(audio_file, subtitle_file: str = ""):
)
return None
logger.info(
f"加载模型: {model_path}, 设备: {device}, 计算类型: {compute_type}"
)
# 尝试使用 CUDA如果失败则回退到 CPU
try:
import torch
if torch.cuda.is_available():
try:
logger.info(f"尝试使用 CUDA 加载模型: {model_path}")
model = WhisperModel(
model_size_or_path=model_path,
device="cuda",
compute_type="float16",
local_files_only=True
)
device = "cuda"
compute_type = "float16"
logger.info("成功使用 CUDA 加载模型")
except Exception as e:
logger.warning(f"CUDA 加载失败,错误信息: {str(e)}")
logger.warning("回退到 CPU 模式")
device = "cpu"
compute_type = "int8"
else:
logger.info("未检测到 CUDA使用 CPU 模式")
device = "cpu"
compute_type = "int8"
except ImportError:
logger.warning("未安装 torch使用 CPU 模式")
device = "cpu"
compute_type = "int8"
if device == "cpu":
logger.info(f"使用 CPU 加载模型: {model_path}")
model = WhisperModel(
model_size_or_path=model_path,
device=device,
compute_type=compute_type,
local_files_only=True
)
except Exception as e:
logger.error(
f"加载模型失败: {e} \n\n"
f"********************************************\n"
f"这可能是由网络问题引起的. \n"
f"请手动下载模型并将其放入 'app/models' 文件夹中。 \n"
f"see [README.md FAQ](https://github.com/linyqh/NarratoAI) for more details.\n"
f"********************************************\n\n"
f"{traceback.format_exc()}"
)
return None
logger.info(f"模型加载完成,使用设备: {device}, 计算类型: {compute_type}")
logger.info(f"start, output file: {subtitle_file}")
if not subtitle_file:

View File

@ -1,115 +1,81 @@
import json
from loguru import logger
import os
from datetime import timedelta
from typing import Dict, Any
def time_to_seconds(time_str):
parts = list(map(int, time_str.split(':')))
if len(parts) == 2:
return timedelta(minutes=parts[0], seconds=parts[1]).total_seconds()
elif len(parts) == 3:
return timedelta(hours=parts[0], minutes=parts[1], seconds=parts[2]).total_seconds()
raise ValueError(f"无法解析时间字符串: {time_str}")
def check_format(script_content: str) -> Dict[str, Any]:
"""检查脚本格式
Args:
script_content: 脚本内容
Returns:
Dict: {'success': bool, 'message': str}
"""
try:
# 检查是否为有效的JSON
data = json.loads(script_content)
# 检查是否为列表
if not isinstance(data, list):
return {
'success': False,
'message': '脚本必须是JSON数组格式'
}
# 检查每个片段
for i, clip in enumerate(data):
# 检查必需字段
required_fields = ['narration', 'picture', 'timestamp']
for field in required_fields:
if field not in clip:
return {
'success': False,
'message': f'{i+1}个片段缺少必需字段: {field}'
}
# 检查字段类型
if not isinstance(clip['narration'], str):
return {
'success': False,
'message': f'{i+1}个片段的narration必须是字符串'
}
if not isinstance(clip['picture'], str):
return {
'success': False,
'message': f'{i+1}个片段的picture必须是字符串'
}
if not isinstance(clip['timestamp'], str):
return {
'success': False,
'message': f'{i+1}个片段的timestamp必须是字符串'
}
# 检查字段内容不能为空
if not clip['narration'].strip():
return {
'success': False,
'message': f'{i+1}个片段的narration不能为空'
}
if not clip['picture'].strip():
return {
'success': False,
'message': f'{i+1}个片段的picture不能为空'
}
if not clip['timestamp'].strip():
return {
'success': False,
'message': f'{i+1}个片段的timestamp不能为空'
}
def seconds_to_time_str(seconds):
hours, remainder = divmod(int(seconds), 3600)
minutes, seconds = divmod(remainder, 60)
if hours > 0:
return f"{hours:02d}:{minutes:02d}:{seconds:02d}"
else:
return f"{minutes:02d}:{seconds:02d}"
return {
'success': True,
'message': '脚本格式检查通过'
}
def adjust_timestamp(start_time, duration):
start_seconds = time_to_seconds(start_time)
end_seconds = start_seconds + duration
return f"{start_time}-{seconds_to_time_str(end_seconds)}"
def estimate_audio_duration(text):
# 假设平均每个字符需要 0.2 秒
return len(text) * 0.2
def check_script(data, total_duration):
errors = []
time_ranges = []
logger.info("开始检查脚本")
logger.info(f"视频总时长: {total_duration:.2f}")
logger.info("=" * 50)
for i, item in enumerate(data, 1):
logger.info(f"\n检查第 {i} 项:")
# 检查所有必需字段
required_fields = ['picture', 'timestamp', 'narration', 'OST']
for field in required_fields:
if field not in item:
errors.append(f"{i} 项缺少 {field} 字段")
logger.info(f" - 错误: 缺少 {field} 字段")
else:
logger.info(f" - {field}: {item[field]}")
# 检查 OST 相关规则
if item.get('OST') == False:
if not item.get('narration'):
errors.append(f"{i} 项 OST 为 false但 narration 为空")
logger.info(" - 错误: OST 为 false但 narration 为空")
elif len(item['narration']) > 60:
errors.append(f"{i} 项 OST 为 false但 narration 超过 60 字")
logger.info(f" - 错误: OST 为 false但 narration 超过 60 字 (当前: {len(item['narration'])} 字)")
else:
logger.info(" - OST 为 falsenarration 检查通过")
elif item.get('OST') == True:
if "原声播放_" not in item.get('narration'):
errors.append(f"{i} 项 OST 为 true但 narration 不为空")
logger.info(" - 错误: OST 为 true但 narration 不为空")
else:
logger.info(" - OST 为 truenarration 检查通过")
# 检查 timestamp
if 'timestamp' in item:
start, end = map(time_to_seconds, item['timestamp'].split('-'))
if any((start < existing_end and end > existing_start) for existing_start, existing_end in time_ranges):
errors.append(f"{i} 项 timestamp '{item['timestamp']}' 与其他时间段重叠")
logger.info(f" - 错误: timestamp '{item['timestamp']}' 与其他时间段重叠")
else:
logger.info(f" - timestamp '{item['timestamp']}' 检查通过")
time_ranges.append((start, end))
# if end > total_duration:
# errors.append(f"第 {i} 项 timestamp '{item['timestamp']}' 超过总时长 {total_duration:.2f} 秒")
# logger.info(f" - 错误: timestamp '{item['timestamp']}' 超过总时长 {total_duration:.2f} 秒")
# else:
# logger.info(f" - timestamp 在总时长范围内")
# 处理 narration 字段
if item.get('OST') == False and item.get('narration'):
estimated_duration = estimate_audio_duration(item['narration'])
start_time = item['timestamp'].split('-')[0]
item['timestamp'] = adjust_timestamp(start_time, estimated_duration)
logger.info(f" - 已调整 timestamp 为 {item['timestamp']} (估算音频时长: {estimated_duration:.2f} 秒)")
if errors:
logger.info("检查结果:不通过")
logger.info("发现以下错误:")
for error in errors:
logger.info(f"- {error}")
else:
logger.info("检查结果:通过")
logger.info("所有项目均符合规则要求。")
return errors, data
if __name__ == "__main__":
file_path = "/Users/apple/Desktop/home/NarratoAI/resource/scripts/test004.json"
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
total_duration = 280
# check_script(data, total_duration)
from app.utils.utils import add_new_timestamps
res = add_new_timestamps(data)
print(json.dumps(res, indent=4, ensure_ascii=False))
except json.JSONDecodeError as e:
return {
'success': False,
'message': f'JSON格式错误: {str(e)}'
}
except Exception as e:
return {
'success': False,
'message': f'检查过程中发生错误: {str(e)}'
}

View File

@ -24,3 +24,4 @@ azure-cognitiveservices-speech~=1.37.0
git-changelog~=2.5.2
watchdog==5.0.2
pydub==0.25.1
psutil>=5.9.0

869
webui.py
View File

@ -1,6 +1,14 @@
import streamlit as st
import os
import sys
from uuid import uuid4
from app.config import config
from webui.components import basic_settings, video_settings, audio_settings, subtitle_settings, script_settings, review_settings
from webui.utils import cache, file_utils, performance
from app.utils import utils
from app.models.schema import VideoClipParams, VideoAspect
# 初始化配置 - 必须是第一个 Streamlit 命令
st.set_page_config(
page_title="NarratoAI",
page_icon="📽️",
@ -13,126 +21,23 @@ st.set_page_config(
},
)
import sys
import os
import glob
import json
import time
import datetime
import traceback
from uuid import uuid4
import platform
import streamlit.components.v1 as components
from loguru import logger
from app.models.const import FILE_TYPE_VIDEOS
from app.models.schema import VideoClipParams, VideoAspect, VideoConcatMode
from app.services import task as tm, llm, voice, material
from app.utils import utils
# # 将项目的根目录添加到系统路径中,以允许从项目导入模块
root_dir = os.path.dirname(os.path.realpath(__file__))
if root_dir not in sys.path:
sys.path.append(root_dir)
print("******** sys.path ********")
print(sys.path)
print("*" * 20)
proxy_url_http = config.proxy.get("http", "") or os.getenv("VPN_PROXY_URL", "")
proxy_url_https = config.proxy.get("https", "") or os.getenv("VPN_PROXY_URL", "")
os.environ["HTTP_PROXY"] = proxy_url_http
os.environ["HTTPS_PROXY"] = proxy_url_https
# 设置页面样式
hide_streamlit_style = """
<style>#root > div:nth-child(1) > div > div > div > div > section > div {padding-top: 6px; padding-bottom: 10px; padding-left: 20px; padding-right: 20px;}</style>
"""
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
st.title(f"NarratoAI :sunglasses:📽️")
support_locales = [
"zh-CN",
"zh-HK",
"zh-TW",
"en-US",
]
font_dir = os.path.join(root_dir, "resource", "fonts")
song_dir = os.path.join(root_dir, "resource", "songs")
i18n_dir = os.path.join(root_dir, "webui", "i18n")
config_file = os.path.join(root_dir, "webui", ".streamlit", "webui.toml")
system_locale = utils.get_system_locale()
if 'video_clip_json' not in st.session_state:
st.session_state['video_clip_json'] = []
if 'video_plot' not in st.session_state:
st.session_state['video_plot'] = ''
if 'ui_language' not in st.session_state:
st.session_state['ui_language'] = config.ui.get("language", system_locale)
if 'subclip_videos' not in st.session_state:
st.session_state['subclip_videos'] = {}
def get_all_fonts():
fonts = []
for root, dirs, files in os.walk(font_dir):
for file in files:
if file.endswith(".ttf") or file.endswith(".ttc"):
fonts.append(file)
fonts.sort()
return fonts
def get_all_songs():
songs = []
for root, dirs, files in os.walk(song_dir):
for file in files:
if file.endswith(".mp3"):
songs.append(file)
return songs
def open_task_folder(task_id):
try:
sys = platform.system()
path = os.path.join(root_dir, "storage", "tasks", task_id)
if os.path.exists(path):
if sys == 'Windows':
os.system(f"start {path}")
if sys == 'Darwin':
os.system(f"open {path}")
except Exception as e:
logger.error(e)
def scroll_to_bottom():
js = f"""
<script>
console.log("scroll_to_bottom");
function scroll(dummy_var_to_force_repeat_execution){{
var sections = parent.document.querySelectorAll('section.main');
console.log(sections);
for(let index = 0; index<sections.length; index++) {{
sections[index].scrollTop = sections[index].scrollHeight;
}}
}}
scroll(1);
</script>
"""
st.components.v1.html(js, height=0, width=0)
def init_log():
"""初始化日志配置"""
from loguru import logger
logger.remove()
_lvl = "DEBUG"
def format_record(record):
# 获取日志记录中的文件全径
file_path = record["file"].path
# 将绝对路径转换为相对于项目根目录的路径
relative_path = os.path.relpath(file_path, root_dir)
# 更新记录中的文件路径
relative_path = os.path.relpath(file_path, config.root_dir)
record["file"].path = f"./{relative_path}"
# 返回修改后的格式字符串
# 您可以根据需要调整这里的格式
record['message'] = record['message'].replace(root_dir, ".")
record['message'] = record['message'].replace(config.root_dir, ".")
_format = '<green>{time:%Y-%m-%d %H:%M:%S}</> | ' + \
'<level>{level}</> | ' + \
@ -147,672 +52,120 @@ def init_log():
colorize=True,
)
init_log()
locales = utils.load_locales(i18n_dir)
def init_global_state():
"""初始化全局状态"""
if 'video_clip_json' not in st.session_state:
st.session_state['video_clip_json'] = []
if 'video_plot' not in st.session_state:
st.session_state['video_plot'] = ''
if 'ui_language' not in st.session_state:
st.session_state['ui_language'] = config.ui.get("language", utils.get_system_locale())
if 'subclip_videos' not in st.session_state:
st.session_state['subclip_videos'] = {}
def tr(key):
"""翻译函数"""
i18n_dir = os.path.join(os.path.dirname(__file__), "webui", "i18n")
locales = utils.load_locales(i18n_dir)
loc = locales.get(st.session_state['ui_language'], {})
return loc.get("Translation", {}).get(key, key)
def render_generate_button():
"""渲染生成按钮和处理逻辑"""
if st.button(tr("Generate Video"), use_container_width=True, type="primary"):
from app.services import task as tm
# 重置日志容器和记录
log_container = st.empty()
log_records = []
st.write(tr("Get Help"))
def log_received(msg):
with log_container:
log_records.append(msg)
st.code("\n".join(log_records))
# 基础设置
with st.expander(tr("Basic Settings"), expanded=False):
config_panels = st.columns(3)
left_config_panel = config_panels[0]
middle_config_panel = config_panels[1]
right_config_panel = config_panels[2]
with left_config_panel:
display_languages = []
selected_index = 0
for i, code in enumerate(locales.keys()):
display_languages.append(f"{code} - {locales[code].get('Language')}")
if code == st.session_state['ui_language']:
selected_index = i
from loguru import logger
logger.add(log_received)
selected_language = st.selectbox(tr("Language"), options=display_languages,
index=selected_index)
if selected_language:
code = selected_language.split(" - ")[0].strip()
st.session_state['ui_language'] = code
config.ui['language'] = code
config.save_config()
task_id = st.session_state.get('task_id')
HTTP_PROXY = st.text_input(tr("HTTP_PROXY"), value=proxy_url_http)
HTTPS_PROXY = st.text_input(tr("HTTPs_PROXY"), value=proxy_url_https)
if HTTP_PROXY:
config.proxy["http"] = HTTP_PROXY
if HTTPS_PROXY:
config.proxy["https"] = HTTPS_PROXY
if not task_id:
st.error(tr("请先裁剪视频"))
return
if not st.session_state.get('video_clip_json_path'):
st.error(tr("脚本文件不能为空"))
return
if not st.session_state.get('video_origin_path'):
st.error(tr("视频文件不能为空"))
return
# 视频转录大模型
with middle_config_panel:
video_llm_providers = ['Gemini']
saved_llm_provider = config.app.get("llm_provider", "OpenAI").lower()
saved_llm_provider_index = 0
for i, provider in enumerate(video_llm_providers):
if provider.lower() == saved_llm_provider:
saved_llm_provider_index = i
break
st.toast(tr("生成视频"))
logger.info(tr("开始生成视频"))
video_llm_provider = st.selectbox(tr("Video LLM Provider"), options=video_llm_providers, index=saved_llm_provider_index)
video_llm_provider = video_llm_provider.lower()
config.app["video_llm_provider"] = video_llm_provider
# 获取所有参数
script_params = script_settings.get_script_params()
video_params = video_settings.get_video_params()
audio_params = audio_settings.get_audio_params()
subtitle_params = subtitle_settings.get_subtitle_params()
video_llm_api_key = config.app.get(f"{video_llm_provider}_api_key", "")
video_llm_base_url = config.app.get(f"{video_llm_provider}_base_url", "")
video_llm_model_name = config.app.get(f"{video_llm_provider}_model_name", "")
video_llm_account_id = config.app.get(f"{video_llm_provider}_account_id", "")
st_llm_api_key = st.text_input(tr("Video API Key"), value=video_llm_api_key, type="password")
st_llm_base_url = st.text_input(tr("Video Base Url"), value=video_llm_base_url)
st_llm_model_name = st.text_input(tr("Video Model Name"), value=video_llm_model_name)
if st_llm_api_key:
config.app[f"{video_llm_provider}_api_key"] = st_llm_api_key
if st_llm_base_url:
config.app[f"{video_llm_provider}_base_url"] = st_llm_base_url
if st_llm_model_name:
config.app[f"{video_llm_provider}_model_name"] = st_llm_model_name
# 大语言模型
with right_config_panel:
llm_providers = ['Gemini', 'OpenAI', 'Moonshot', 'Azure', 'Qwen', 'Ollama', 'G4f', 'OneAPI', "Cloudflare"]
saved_llm_provider = config.app.get("llm_provider", "OpenAI").lower()
saved_llm_provider_index = 0
for i, provider in enumerate(llm_providers):
if provider.lower() == saved_llm_provider:
saved_llm_provider_index = i
break
llm_provider = st.selectbox(tr("LLM Provider"), options=llm_providers, index=saved_llm_provider_index)
llm_provider = llm_provider.lower()
config.app["llm_provider"] = llm_provider
llm_api_key = config.app.get(f"{llm_provider}_api_key", "")
llm_base_url = config.app.get(f"{llm_provider}_base_url", "")
llm_model_name = config.app.get(f"{llm_provider}_model_name", "")
llm_account_id = config.app.get(f"{llm_provider}_account_id", "")
st_llm_api_key = st.text_input(tr("API Key"), value=llm_api_key, type="password")
st_llm_base_url = st.text_input(tr("Base Url"), value=llm_base_url)
st_llm_model_name = st.text_input(tr("Model Name"), value=llm_model_name)
if st_llm_api_key:
config.app[f"{llm_provider}_api_key"] = st_llm_api_key
if st_llm_base_url:
config.app[f"{llm_provider}_base_url"] = st_llm_base_url
if st_llm_model_name:
config.app[f"{llm_provider}_model_name"] = st_llm_model_name
if llm_provider == 'cloudflare':
st_llm_account_id = st.text_input(tr("Account ID"), value=llm_account_id)
if st_llm_account_id:
config.app[f"{llm_provider}_account_id"] = st_llm_account_id
panel = st.columns(3)
left_panel = panel[0]
middle_panel = panel[1]
right_panel = panel[2]
params = VideoClipParams()
# 左侧面板
with left_panel:
with st.container(border=True):
st.write(tr("Video Script Configuration"))
# 脚本语言
video_languages = [
(tr("Auto Detect"), ""),
]
for code in ["zh-CN", "en-US", "zh-TW"]:
video_languages.append((code, code))
selected_index = st.selectbox(tr("Script Language"),
index=0,
options=range(len(video_languages)), # 使用索引作为内部选项值
format_func=lambda x: video_languages[x][0] # 显示给用户的是标签
)
params.video_language = video_languages[selected_index][1]
# 脚本路径
suffix = "*.json"
song_dir = utils.script_dir()
files = glob.glob(os.path.join(song_dir, suffix))
script_list = []
for file in files:
script_list.append({
"name": os.path.basename(file),
"size": os.path.getsize(file),
"file": file,
"ctime": os.path.getctime(file) # 获取文件创建时间
})
# 按创建时间降序排序
script_list.sort(key=lambda x: x["ctime"], reverse=True)
# 本文件 下拉框
script_path = [(tr("Auto Generate"), ""), ]
for file in script_list:
display_name = file['file'].replace(root_dir, "")
script_path.append((display_name, file['file']))
selected_script_index = st.selectbox(tr("Script Files"),
index=0,
options=range(len(script_path)), # 使用索引作为内部选项值
format_func=lambda x: script_path[x][0] # 显示给用户的是标签
)
params.video_clip_json_path = script_path[selected_script_index][1]
config.app["video_clip_json_path"] = params.video_clip_json_path
st.session_state['video_clip_json_path'] = params.video_clip_json_path
# 视频文件处理
video_files = []
for suffix in ["*.mp4", "*.mov", "*.avi", "*.mkv"]:
video_files.extend(glob.glob(os.path.join(utils.video_dir(), suffix)))
video_files = video_files[::-1]
video_list = []
for video_file in video_files:
video_list.append({
"name": os.path.basename(video_file),
"size": os.path.getsize(video_file),
"file": video_file,
"ctime": os.path.getctime(video_file) # 获取文件创建时间
})
# 按创建时间降序排序
video_list.sort(key=lambda x: x["ctime"], reverse=True)
video_path = [(tr("None"), ""), (tr("Upload Local Files"), "local")]
for file in video_list:
display_name = file['file'].replace(root_dir, "")
video_path.append((display_name, file['file']))
# 视频文件
selected_video_index = st.selectbox(tr("Video File"),
index=0,
options=range(len(video_path)), # 使用索引作为内部选项值
format_func=lambda x: video_path[x][0] # 显示给用户的是标签
)
params.video_origin_path = video_path[selected_video_index][1]
config.app["video_origin_path"] = params.video_origin_path
st.session_state['video_origin_path'] = params.video_origin_path
# 从本地上传 mp4 文件
if params.video_origin_path == "local":
_supported_types = FILE_TYPE_VIDEOS
uploaded_file = st.file_uploader(
tr("Upload Local Files"),
type=["mp4", "mov", "avi", "flv", "mkv"],
accept_multiple_files=False,
)
if uploaded_file is not None:
# 构造保存路径
video_file_path = os.path.join(utils.video_dir(), uploaded_file.name)
file_name, file_extension = os.path.splitext(uploaded_file.name)
# 检查文件是否存在,如果存在则添加时间戳
if os.path.exists(video_file_path):
timestamp = time.strftime("%Y%m%d%H%M%S")
file_name_with_timestamp = f"{file_name}_{timestamp}"
video_file_path = os.path.join(utils.video_dir(), file_name_with_timestamp + file_extension)
# 将文件保存到指定目录
with open(video_file_path, "wb") as f:
f.write(uploaded_file.read())
st.success(tr("File Uploaded Successfully"))
time.sleep(1)
st.rerun()
# 视频名称
video_name = st.text_input(tr("Video Name"))
# 剧情内容
video_plot = st.text_area(
tr("Plot Description"),
value=st.session_state['video_plot'],
height=180
)
# 生成视频脚本
if st.session_state['video_clip_json_path']:
generate_button_name = tr("Video Script Load")
else:
generate_button_name = tr("Video Script Generate")
if st.button(generate_button_name, key="auto_generate_script"):
progress_bar = st.progress(0)
status_text = st.empty()
def update_progress(progress: float, message: str = ""):
progress_bar.progress(progress)
if message:
status_text.text(f"{progress}% - {message}")
else:
status_text.text(f"进度: {progress}%")
try:
with st.spinner("正在生成脚本..."):
if not video_plot:
st.warning("视频剧情为空; 会极大影响生成效果!")
if params.video_clip_json_path == "" and params.video_origin_path != "":
update_progress(10, "压缩视频中...")
# 使用大模型生成视频脚本
script = llm.generate_script(
video_path=params.video_origin_path,
video_plot=video_plot,
video_name=video_name,
language=params.video_language,
progress_callback=update_progress
)
if script is None:
st.error("生成脚本失败,请检查日志")
st.stop()
else:
update_progress(90)
script = utils.clean_model_output(script)
st.session_state['video_clip_json'] = json.loads(script)
else:
# 从本地加载
with open(params.video_clip_json_path, 'r', encoding='utf-8') as f:
update_progress(50)
status_text.text("从本地加载中...")
script = f.read()
script = utils.clean_model_output(script)
st.session_state['video_clip_json'] = json.loads(script)
update_progress(100)
status_text.text("从本地加载成功")
time.sleep(0.5) # 给进度条一点时间到达100%
progress_bar.progress(100)
status_text.text("脚本生成完成!")
st.success("视频脚本生成成功!")
except Exception as err:
st.error(f"生成过程中发生错误: {str(err)}")
finally:
time.sleep(2) # 给用户一些时间查看最终状态
progress_bar.empty()
status_text.empty()
# 视频脚本
video_clip_json_details = st.text_area(
tr("Video Script"),
value=json.dumps(st.session_state.video_clip_json, indent=2, ensure_ascii=False),
height=180
)
# 保存脚本
button_columns = st.columns(2)
with button_columns[0]:
if st.button(tr("Save Script"), key="auto_generate_terms", use_container_width=True):
if not video_clip_json_details:
st.error(tr("请输入视频脚本"))
st.stop()
with st.spinner(tr("Save Script")):
script_dir = utils.script_dir()
# 获取当前时间戳,形如 2024-0618-171820
timestamp = datetime.datetime.now().strftime("%Y-%m%d-%H%M%S")
save_path = os.path.join(script_dir, f"{timestamp}.json")
try:
data = utils.add_new_timestamps(json.loads(video_clip_json_details))
except Exception as err:
st.error(f"视频脚本格式错误,请检查脚本是否符合 JSON 格式;{err} \n\n{traceback.format_exc()}")
st.stop()
# 存储为新的 JSON 文件
with open(save_path, 'w', encoding='utf-8') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
# 将data的值存储到 session_state 中,类似缓存
st.session_state['video_clip_json'] = data
st.session_state['video_clip_json_path'] = save_path
# 刷新页面
st.rerun()
# 裁剪视频
with button_columns[1]:
if st.button(tr("Crop Video"), key="auto_crop_video", use_container_width=True):
progress_bar = st.progress(0)
status_text = st.empty()
def update_progress(progress):
progress_bar.progress(progress)
status_text.text(f"剪辑进度: {progress}%")
try:
utils.cut_video(params, update_progress)
time.sleep(0.5) # 给进度条一点时间到达100%
progress_bar.progress(100)
status_text.text("剪辑完成!")
st.success("视频剪辑成功完成!")
except Exception as e:
st.error(f"剪辑过程中发生错误: {str(e)}")
finally:
time.sleep(2) # 给用户一些时间查看最终状态
progress_bar.empty()
status_text.empty()
# 新中间面板
with middle_panel:
with st.container(border=True):
st.write(tr("Video Settings"))
# 视频比例
video_aspect_ratios = [
(tr("Portrait"), VideoAspect.portrait.value),
(tr("Landscape"), VideoAspect.landscape.value),
]
selected_index = st.selectbox(
tr("Video Ratio"),
options=range(len(video_aspect_ratios)), # 使用索引作为内部选项值
format_func=lambda x: video_aspect_ratios[x][0], # 显示给用户的是标签
)
params.video_aspect = VideoAspect(video_aspect_ratios[selected_index][1])
# params.video_clip_duration = st.selectbox(
# tr("Clip Duration"), options=[2, 3, 4, 5, 6, 7, 8, 9, 10], index=1
# )
# params.video_count = st.selectbox(
# tr("Number of Videos Generated Simultaneously"),
# options=[1, 2, 3, 4, 5],
# index=0,
# )
with st.container(border=True):
st.write(tr("Audio Settings"))
# tts_providers = ['edge', 'azure']
# tts_provider = st.selectbox(tr("TTS Provider"), tts_providers)
voices = voice.get_all_azure_voices(filter_locals=support_locales)
friendly_names = {
v: v.replace("Female", tr("Female"))
.replace("Male", tr("Male"))
.replace("Neural", "")
for v in voices
# 合并所有参数
all_params = {
**script_params,
**video_params,
**audio_params,
**subtitle_params
}
saved_voice_name = config.ui.get("voice_name", "")
saved_voice_name_index = 0
if saved_voice_name in friendly_names:
saved_voice_name_index = list(friendly_names.keys()).index(saved_voice_name)
else:
for i, v in enumerate(voices):
if (
v.lower().startswith(st.session_state["ui_language"].lower())
and "V2" not in v
):
saved_voice_name_index = i
break
selected_friendly_name = st.selectbox(
tr("Speech Synthesis"),
options=list(friendly_names.values()),
index=saved_voice_name_index,
# 创建参数对象
params = VideoClipParams(**all_params)
result = tm.start_subclip(
task_id=task_id,
params=params,
subclip_path_videos=st.session_state['subclip_videos']
)
voice_name = list(friendly_names.keys())[
list(friendly_names.values()).index(selected_friendly_name)
]
params.voice_name = voice_name
config.ui["voice_name"] = voice_name
video_files = result.get("videos", [])
st.success(tr("视频生成完成"))
try:
if video_files:
player_cols = st.columns(len(video_files) * 2 + 1)
for i, url in enumerate(video_files):
player_cols[i * 2 + 1].video(url)
except Exception as e:
logger.error(f"播放视频失败: {e}")
if voice.is_azure_v2_voice(voice_name):
saved_azure_speech_region = config.azure.get("speech_region", "")
saved_azure_speech_key = config.azure.get("speech_key", "")
azure_speech_region = st.text_input(
tr("Speech Region"), value=saved_azure_speech_region
)
azure_speech_key = st.text_input(
tr("Speech Key"), value=saved_azure_speech_key, type="password"
)
config.azure["speech_region"] = azure_speech_region
config.azure["speech_key"] = azure_speech_key
file_utils.open_task_folder(config.root_dir, task_id)
logger.info(tr("视频生成完成"))
params.voice_volume = st.selectbox(
tr("Speech Volume"),
options=[0.6, 0.8, 1.0, 1.2, 1.5, 2.0, 3.0, 4.0, 5.0],
index=2,
)
def main():
"""主函数"""
init_log()
init_global_state()
st.title(f"NarratoAI :sunglasses:📽️")
st.write(tr("Get Help"))
# 渲染基础设置面板
basic_settings.render_basic_settings(tr)
# 渲染主面板
panel = st.columns(3)
with panel[0]:
script_settings.render_script_panel(tr)
with panel[1]:
video_settings.render_video_panel(tr)
audio_settings.render_audio_panel(tr)
with panel[2]:
subtitle_settings.render_subtitle_panel(tr)
# 渲染视频审查面板
review_settings.render_review_panel(tr)
# 渲染生成按钮和处理逻辑
render_generate_button()
params.voice_rate = st.selectbox(
tr("Speech Rate"),
options=[0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 1.8, 2.0],
index=2,
)
params.voice_pitch = st.selectbox(
tr("Speech Pitch"),
options=[0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 1.8, 2.0],
index=2,
)
# 试听语言合成
if st.button(tr("Play Voice")):
play_content = "感谢关注 NarratoAI有任何问题或建议可以关注微信公众号求助或讨论"
if not play_content:
play_content = params.video_script
if not play_content:
play_content = tr("Voice Example")
with st.spinner(tr("Synthesizing Voice")):
temp_dir = utils.storage_dir("temp", create=True)
audio_file = os.path.join(temp_dir, f"tmp-voice-{str(uuid4())}.mp3")
sub_maker = voice.tts(
text=play_content,
voice_name=voice_name,
voice_rate=params.voice_rate,
voice_pitch=params.voice_pitch,
voice_file=audio_file,
)
# 如果语音文件生成失败,请使用默认内容重试。
if not sub_maker:
play_content = "This is a example voice. if you hear this, the voice synthesis failed with the original content."
sub_maker = voice.tts(
text=play_content,
voice_name=voice_name,
voice_rate=params.voice_rate,
voice_pitch=params.voice_pitch,
voice_file=audio_file,
)
if sub_maker and os.path.exists(audio_file):
st.audio(audio_file, format="audio/mp3")
if os.path.exists(audio_file):
os.remove(audio_file)
bgm_options = [
(tr("No Background Music"), ""),
(tr("Random Background Music"), "random"),
(tr("Custom Background Music"), "custom"),
]
selected_index = st.selectbox(
tr("Background Music"),
index=1,
options=range(len(bgm_options)), # 使用索引作为内部选项值
format_func=lambda x: bgm_options[x][0], # 显示给用户的是标签
)
# 获取选择的背景音乐类型
params.bgm_type = bgm_options[selected_index][1]
# 根据选择显示或隐藏组件
if params.bgm_type == "custom":
custom_bgm_file = st.text_input(tr("Custom Background Music File"))
if custom_bgm_file and os.path.exists(custom_bgm_file):
params.bgm_file = custom_bgm_file
# st.write(f":red[已选择自定义背景音乐]**{custom_bgm_file}**")
params.bgm_volume = st.selectbox(
tr("Background Music Volume"),
options=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
index=2,
)
# 新侧面板
with right_panel:
with st.container(border=True):
st.write(tr("Subtitle Settings"))
params.subtitle_enabled = st.checkbox(tr("Enable Subtitles"), value=True)
font_names = get_all_fonts()
saved_font_name = config.ui.get("font_name", "")
saved_font_name_index = 0
if saved_font_name in font_names:
saved_font_name_index = font_names.index(saved_font_name)
params.font_name = st.selectbox(
tr("Font"), font_names, index=saved_font_name_index
)
config.ui["font_name"] = params.font_name
subtitle_positions = [
(tr("Top"), "top"),
(tr("Center"), "center"),
(tr("Bottom"), "bottom"),
(tr("Custom"), "custom"),
]
selected_index = st.selectbox(
tr("Position"),
index=2,
options=range(len(subtitle_positions)),
format_func=lambda x: subtitle_positions[x][0],
)
params.subtitle_position = subtitle_positions[selected_index][1]
if params.subtitle_position == "custom":
custom_position = st.text_input(
tr("Custom Position (% from top)"), value="70.0"
)
try:
params.custom_position = float(custom_position)
if params.custom_position < 0 or params.custom_position > 100:
st.error(tr("Please enter a value between 0 and 100"))
except ValueError:
logger.error(f"输入的值无效: {traceback.format_exc()}")
st.error(tr("Please enter a valid number"))
font_cols = st.columns([0.3, 0.7])
with font_cols[0]:
saved_text_fore_color = config.ui.get("text_fore_color", "#FFFFFF")
params.text_fore_color = st.color_picker(
tr("Font Color"), saved_text_fore_color
)
config.ui["text_fore_color"] = params.text_fore_color
with font_cols[1]:
saved_font_size = config.ui.get("font_size", 60)
params.font_size = st.slider(tr("Font Size"), 30, 100, saved_font_size)
config.ui["font_size"] = params.font_size
stroke_cols = st.columns([0.3, 0.7])
with stroke_cols[0]:
params.stroke_color = st.color_picker(tr("Stroke Color"), "#000000")
with stroke_cols[1]:
params.stroke_width = st.slider(tr("Stroke Width"), 0.0, 10.0, 1.5)
# 视频编辑面板
with st.expander(tr("Video Check"), expanded=False):
try:
video_list = st.session_state.video_clip_json
except KeyError as e:
video_list = []
# 计算列数和行数
num_videos = len(video_list)
cols_per_row = 3
rows = (num_videos + cols_per_row - 1) // cols_per_row # 向上取整计算行数
# 使用容器展示视频
for row in range(rows):
cols = st.columns(cols_per_row)
for col in range(cols_per_row):
index = row * cols_per_row + col
if index < num_videos:
with cols[col]:
video_info = video_list[index]
video_path = video_info.get('path')
if video_path is not None:
initial_narration = video_info['narration']
initial_picture = video_info['picture']
initial_timestamp = video_info['timestamp']
with open(video_path, 'rb') as video_file:
video_bytes = video_file.read()
st.video(video_bytes)
# 可编辑的输入框
text_panels = st.columns(2)
with text_panels[0]:
text1 = st.text_area(tr("timestamp"), value=initial_timestamp, height=20,
key=f"timestamp_{index}")
with text_panels[1]:
text2 = st.text_area(tr("Picture description"), value=initial_picture, height=20,
key=f"picture_{index}")
text3 = st.text_area(tr("Narration"), value=initial_narration, height=100,
key=f"narration_{index}")
# 重新生成按钮
if st.button(tr("Rebuild"), key=f"rebuild_{index}"):
# 更新video_list中的对应项
video_list[index]['timestamp'] = text1
video_list[index]['picture'] = text2
video_list[index]['narration'] = text3
for video in video_list:
if 'path' in video:
del video['path']
# 更新session_state以确保更改被保存
st.session_state['video_clip_json'] = utils.to_json(video_list)
# 替换原JSON 文件
with open(params.video_clip_json_path, 'w', encoding='utf-8') as file:
json.dump(video_list, file, ensure_ascii=False, indent=4)
utils.cut_video(params, progress_callback=None)
st.rerun()
# 开始按钮
start_button = st.button(tr("Generate Video"), use_container_width=True, type="primary")
if start_button:
# 重置日志容器和记录
log_container = st.empty()
log_records = []
config.save_config()
task_id = st.session_state.get('task_id')
if st.session_state.get('video_script_json_path') is not None:
params.video_clip_json = st.session_state.get('video_clip_json')
logger.debug(f"当前的脚本文件为:{st.session_state.video_clip_json_path}")
logger.debug(f"当前的视频文件为:{st.session_state.video_origin_path}")
logger.debug(f"裁剪后是视频列表:{st.session_state.subclip_videos}")
if not task_id:
st.error(tr("请先裁剪视频"))
scroll_to_bottom()
st.stop()
if not params.video_clip_json_path:
st.error(tr("脚本文件不能为空"))
scroll_to_bottom()
st.stop()
if not params.video_origin_path:
st.error(tr("视频文件不能为空"))
scroll_to_bottom()
st.stop()
def log_received(msg):
with log_container:
log_records.append(msg)
st.code("\n".join(log_records))
logger.add(log_received)
st.toast(tr("生成视频"))
logger.info(tr("开始生成视频"))
logger.info(utils.to_json(params))
scroll_to_bottom()
result = tm.start_subclip(task_id=task_id, params=params, subclip_path_videos=st.session_state.subclip_videos)
video_files = result.get("videos", [])
st.success(tr("视频生成完成"))
try:
if video_files:
# 将视频播放器居中
player_cols = st.columns(len(video_files) * 2 + 1)
for i, url in enumerate(video_files):
player_cols[i * 2 + 1].video(url)
except Exception as e:
pass
open_task_folder(task_id)
logger.info(tr("视频生成完成"))
scroll_to_bottom()
config.save_config()
if __name__ == "__main__":
main()

22
webui/__init__.py Normal file
View File

@ -0,0 +1,22 @@
"""
NarratoAI WebUI Package
"""
from webui.config.settings import config
from webui.components import (
basic_settings,
video_settings,
audio_settings,
subtitle_settings
)
from webui.utils import cache, file_utils, performance
__all__ = [
'config',
'basic_settings',
'video_settings',
'audio_settings',
'subtitle_settings',
'cache',
'file_utils',
'performance'
]

View File

@ -0,0 +1,15 @@
from .basic_settings import render_basic_settings
from .script_settings import render_script_panel
from .video_settings import render_video_panel
from .audio_settings import render_audio_panel
from .subtitle_settings import render_subtitle_panel
from .review_settings import render_review_panel
__all__ = [
'render_basic_settings',
'render_script_panel',
'render_video_panel',
'render_audio_panel',
'render_subtitle_panel',
'render_review_panel'
]

View File

@ -0,0 +1,198 @@
import streamlit as st
import os
from uuid import uuid4
from app.config import config
from app.services import voice
from app.utils import utils
from webui.utils.cache import get_songs_cache
def render_audio_panel(tr):
"""渲染音频设置面板"""
with st.container(border=True):
st.write(tr("Audio Settings"))
# 渲染TTS设置
render_tts_settings(tr)
# 渲染背景音乐设置
render_bgm_settings(tr)
def render_tts_settings(tr):
"""渲染TTS(文本转语音)设置"""
# 获取支持的语音列表
support_locales = ["zh-CN", "zh-HK", "zh-TW", "en-US"]
voices = voice.get_all_azure_voices(filter_locals=support_locales)
# 创建友好的显示名称
friendly_names = {
v: v.replace("Female", tr("Female"))
.replace("Male", tr("Male"))
.replace("Neural", "")
for v in voices
}
# 获取保存的语音设置
saved_voice_name = config.ui.get("voice_name", "")
saved_voice_name_index = 0
if saved_voice_name in friendly_names:
saved_voice_name_index = list(friendly_names.keys()).index(saved_voice_name)
else:
# 如果没有保存的设置选择与UI语言匹配的第一个语音
for i, v in enumerate(voices):
if (v.lower().startswith(st.session_state["ui_language"].lower())
and "V2" not in v):
saved_voice_name_index = i
break
# 语音选择下拉框
selected_friendly_name = st.selectbox(
tr("Speech Synthesis"),
options=list(friendly_names.values()),
index=saved_voice_name_index,
)
# 获取实际的语音名称
voice_name = list(friendly_names.keys())[
list(friendly_names.values()).index(selected_friendly_name)
]
# 保存设置
config.ui["voice_name"] = voice_name
# Azure V2语音特殊处理
if voice.is_azure_v2_voice(voice_name):
render_azure_v2_settings(tr)
# 语音参数设置
render_voice_parameters(tr)
# 试听按钮
render_voice_preview(tr, voice_name)
def render_azure_v2_settings(tr):
"""渲染Azure V2语音设置"""
saved_azure_speech_region = config.azure.get("speech_region", "")
saved_azure_speech_key = config.azure.get("speech_key", "")
azure_speech_region = st.text_input(
tr("Speech Region"),
value=saved_azure_speech_region
)
azure_speech_key = st.text_input(
tr("Speech Key"),
value=saved_azure_speech_key,
type="password"
)
config.azure["speech_region"] = azure_speech_region
config.azure["speech_key"] = azure_speech_key
def render_voice_parameters(tr):
"""渲染语音参数设置"""
# 音量
voice_volume = st.selectbox(
tr("Speech Volume"),
options=[0.6, 0.8, 1.0, 1.2, 1.5, 2.0, 3.0, 4.0, 5.0],
index=2,
)
st.session_state['voice_volume'] = voice_volume
# 语速
voice_rate = st.selectbox(
tr("Speech Rate"),
options=[0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 1.8, 2.0],
index=2,
)
st.session_state['voice_rate'] = voice_rate
# 音调
voice_pitch = st.selectbox(
tr("Speech Pitch"),
options=[0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 1.8, 2.0],
index=2,
)
st.session_state['voice_pitch'] = voice_pitch
def render_voice_preview(tr, voice_name):
"""渲染语音试听功能"""
if st.button(tr("Play Voice")):
play_content = "感谢关注 NarratoAI有任何问题或建议可以关注微信公众号求助或讨论"
if not play_content:
play_content = st.session_state.get('video_script', '')
if not play_content:
play_content = tr("Voice Example")
with st.spinner(tr("Synthesizing Voice")):
temp_dir = utils.storage_dir("temp", create=True)
audio_file = os.path.join(temp_dir, f"tmp-voice-{str(uuid4())}.mp3")
sub_maker = voice.tts(
text=play_content,
voice_name=voice_name,
voice_rate=st.session_state.get('voice_rate', 1.0),
voice_pitch=st.session_state.get('voice_pitch', 1.0),
voice_file=audio_file,
)
# 如果语音文件生成失败,使用默认内容重试
if not sub_maker:
play_content = "This is a example voice. if you hear this, the voice synthesis failed with the original content."
sub_maker = voice.tts(
text=play_content,
voice_name=voice_name,
voice_rate=st.session_state.get('voice_rate', 1.0),
voice_pitch=st.session_state.get('voice_pitch', 1.0),
voice_file=audio_file,
)
if sub_maker and os.path.exists(audio_file):
st.audio(audio_file, format="audio/mp3")
if os.path.exists(audio_file):
os.remove(audio_file)
def render_bgm_settings(tr):
"""渲染背景音乐设置"""
# 背景音乐选项
bgm_options = [
(tr("No Background Music"), ""),
(tr("Random Background Music"), "random"),
(tr("Custom Background Music"), "custom"),
]
selected_index = st.selectbox(
tr("Background Music"),
index=1,
options=range(len(bgm_options)),
format_func=lambda x: bgm_options[x][0],
)
# 获取选择的背景音乐类型
bgm_type = bgm_options[selected_index][1]
st.session_state['bgm_type'] = bgm_type
# 自定义背景音乐处理
if bgm_type == "custom":
custom_bgm_file = st.text_input(tr("Custom Background Music File"))
if custom_bgm_file and os.path.exists(custom_bgm_file):
st.session_state['bgm_file'] = custom_bgm_file
# 背景音乐音量
bgm_volume = st.selectbox(
tr("Background Music Volume"),
options=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
index=2,
)
st.session_state['bgm_volume'] = bgm_volume
def get_audio_params():
"""获取音频参数"""
return {
'voice_name': config.ui.get("voice_name", ""),
'voice_volume': st.session_state.get('voice_volume', 1.0),
'voice_rate': st.session_state.get('voice_rate', 1.0),
'voice_pitch': st.session_state.get('voice_pitch', 1.0),
'bgm_type': st.session_state.get('bgm_type', 'random'),
'bgm_file': st.session_state.get('bgm_file', ''),
'bgm_volume': st.session_state.get('bgm_volume', 0.2),
}

View File

@ -0,0 +1,142 @@
import streamlit as st
import os
from app.config import config
from app.utils import utils
def render_basic_settings(tr):
"""渲染基础设置面板"""
with st.expander(tr("Basic Settings"), expanded=False):
config_panels = st.columns(3)
left_config_panel = config_panels[0]
middle_config_panel = config_panels[1]
right_config_panel = config_panels[2]
with left_config_panel:
render_language_settings(tr)
render_proxy_settings(tr)
with middle_config_panel:
render_video_llm_settings(tr)
with right_config_panel:
render_llm_settings(tr)
def render_language_settings(tr):
"""渲染语言设置"""
system_locale = utils.get_system_locale()
i18n_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "i18n")
locales = utils.load_locales(i18n_dir)
display_languages = []
selected_index = 0
for i, code in enumerate(locales.keys()):
display_languages.append(f"{code} - {locales[code].get('Language')}")
if code == st.session_state.get('ui_language', system_locale):
selected_index = i
selected_language = st.selectbox(
tr("Language"),
options=display_languages,
index=selected_index
)
if selected_language:
code = selected_language.split(" - ")[0].strip()
st.session_state['ui_language'] = code
config.ui['language'] = code
def render_proxy_settings(tr):
"""渲染代理设置"""
proxy_url_http = config.proxy.get("http", "") or os.getenv("VPN_PROXY_URL", "")
proxy_url_https = config.proxy.get("https", "") or os.getenv("VPN_PROXY_URL", "")
HTTP_PROXY = st.text_input(tr("HTTP_PROXY"), value=proxy_url_http)
HTTPS_PROXY = st.text_input(tr("HTTPs_PROXY"), value=proxy_url_https)
if HTTP_PROXY:
config.proxy["http"] = HTTP_PROXY
os.environ["HTTP_PROXY"] = HTTP_PROXY
if HTTPS_PROXY:
config.proxy["https"] = HTTPS_PROXY
os.environ["HTTPS_PROXY"] = HTTPS_PROXY
def render_video_llm_settings(tr):
"""渲染视频LLM设置"""
video_llm_providers = ['Gemini', 'NarratoAPI']
saved_llm_provider = config.app.get("video_llm_provider", "OpenAI").lower()
saved_llm_provider_index = 0
for i, provider in enumerate(video_llm_providers):
if provider.lower() == saved_llm_provider:
saved_llm_provider_index = i
break
video_llm_provider = st.selectbox(
tr("Video LLM Provider"),
options=video_llm_providers,
index=saved_llm_provider_index
)
video_llm_provider = video_llm_provider.lower()
config.app["video_llm_provider"] = video_llm_provider
# 获取已保存的配置
video_llm_api_key = config.app.get(f"{video_llm_provider}_api_key", "")
video_llm_base_url = config.app.get(f"{video_llm_provider}_base_url", "")
video_llm_model_name = config.app.get(f"{video_llm_provider}_model_name", "")
# 渲染输入框
st_llm_api_key = st.text_input(tr("Video API Key"), value=video_llm_api_key, type="password")
st_llm_base_url = st.text_input(tr("Video Base Url"), value=video_llm_base_url)
st_llm_model_name = st.text_input(tr("Video Model Name"), value=video_llm_model_name)
# 保存配置
if st_llm_api_key:
config.app[f"{video_llm_provider}_api_key"] = st_llm_api_key
if st_llm_base_url:
config.app[f"{video_llm_provider}_base_url"] = st_llm_base_url
if st_llm_model_name:
config.app[f"{video_llm_provider}_model_name"] = st_llm_model_name
def render_llm_settings(tr):
"""渲染LLM设置"""
llm_providers = ['Gemini', 'OpenAI', 'Moonshot', 'Azure', 'Qwen', 'Ollama', 'G4f', 'OneAPI', "Cloudflare"]
saved_llm_provider = config.app.get("llm_provider", "OpenAI").lower()
saved_llm_provider_index = 0
for i, provider in enumerate(llm_providers):
if provider.lower() == saved_llm_provider:
saved_llm_provider_index = i
break
llm_provider = st.selectbox(
tr("LLM Provider"),
options=llm_providers,
index=saved_llm_provider_index
)
llm_provider = llm_provider.lower()
config.app["llm_provider"] = llm_provider
# 获取已保存的配置
llm_api_key = config.app.get(f"{llm_provider}_api_key", "")
llm_base_url = config.app.get(f"{llm_provider}_base_url", "")
llm_model_name = config.app.get(f"{llm_provider}_model_name", "")
llm_account_id = config.app.get(f"{llm_provider}_account_id", "")
# 渲染输入框
st_llm_api_key = st.text_input(tr("API Key"), value=llm_api_key, type="password")
st_llm_base_url = st.text_input(tr("Base Url"), value=llm_base_url)
st_llm_model_name = st.text_input(tr("Model Name"), value=llm_model_name)
# 保存配置
if st_llm_api_key:
config.app[f"{llm_provider}_api_key"] = st_llm_api_key
if st_llm_base_url:
config.app[f"{llm_provider}_base_url"] = st_llm_base_url
if st_llm_model_name:
config.app[f"{llm_provider}_model_name"] = st_llm_model_name
# Cloudflare 特殊处理
if llm_provider == 'cloudflare':
st_llm_account_id = st.text_input(tr("Account ID"), value=llm_account_id)
if st_llm_account_id:
config.app[f"{llm_provider}_account_id"] = st_llm_account_id

View File

@ -0,0 +1,65 @@
import streamlit as st
import os
from loguru import logger
def render_review_panel(tr):
"""渲染视频审查面板"""
with st.expander(tr("Video Check"), expanded=False):
try:
video_list = st.session_state.get('video_clip_json', [])
except KeyError:
video_list = []
# 计算列数和行数
num_videos = len(video_list)
cols_per_row = 3
rows = (num_videos + cols_per_row - 1) // cols_per_row # 向上取整计算行数
# 使用容器展示视频
for row in range(rows):
cols = st.columns(cols_per_row)
for col in range(cols_per_row):
index = row * cols_per_row + col
if index < num_videos:
with cols[col]:
render_video_item(tr, video_list, index)
def render_video_item(tr, video_list, index):
"""渲染单个视频项"""
video_info = video_list[index]
video_path = video_info.get('path')
if video_path is not None and os.path.exists(video_path):
initial_narration = video_info.get('narration', '')
initial_picture = video_info.get('picture', '')
initial_timestamp = video_info.get('timestamp', '')
# 显示视频
with open(video_path, 'rb') as video_file:
video_bytes = video_file.read()
st.video(video_bytes)
# 显示信息(只读)
text_panels = st.columns(2)
with text_panels[0]:
st.text_area(
tr("timestamp"),
value=initial_timestamp,
height=20,
key=f"timestamp_{index}",
disabled=True
)
with text_panels[1]:
st.text_area(
tr("Picture description"),
value=initial_picture,
height=20,
key=f"picture_{index}",
disabled=True
)
st.text_area(
tr("Narration"),
value=initial_narration,
height=100,
key=f"narration_{index}",
disabled=True
)

View File

@ -0,0 +1,314 @@
import streamlit as st
import os
import glob
import json
import time
from app.config import config
from app.models.schema import VideoClipParams
from app.services import llm
from app.utils import utils, check_script
from loguru import logger
from webui.utils import file_utils
def render_script_panel(tr):
"""渲染脚本配置面板"""
with st.container(border=True):
st.write(tr("Video Script Configuration"))
params = VideoClipParams()
# 渲染脚本文件选择
render_script_file(tr, params)
# 渲染视频文件选择
render_video_file(tr, params)
# 渲染视频主题和提示词
render_video_details(tr)
# 渲染脚本操作按钮
render_script_buttons(tr, params)
def render_script_file(tr, params):
"""渲染脚本文件选择"""
script_list = [(tr("None"), ""), (tr("Auto Generate"), "auto")]
# 获取已有脚本文件
suffix = "*.json"
script_dir = utils.script_dir()
files = glob.glob(os.path.join(script_dir, suffix))
file_list = []
for file in files:
file_list.append({
"name": os.path.basename(file),
"file": file,
"ctime": os.path.getctime(file)
})
file_list.sort(key=lambda x: x["ctime"], reverse=True)
for file in file_list:
display_name = file['file'].replace(config.root_dir, "")
script_list.append((display_name, file['file']))
# 找到保存的脚本文件在列表中的索引
saved_script_path = st.session_state.get('video_clip_json_path', '')
selected_index = 0
for i, (_, path) in enumerate(script_list):
if path == saved_script_path:
selected_index = i
break
selected_script_index = st.selectbox(
tr("Script Files"),
index=selected_index, # 使用找到的索引
options=range(len(script_list)),
format_func=lambda x: script_list[x][0]
)
script_path = script_list[selected_script_index][1]
st.session_state['video_clip_json_path'] = script_path
params.video_clip_json_path = script_path
def render_video_file(tr, params):
"""渲染视频文件选择"""
video_list = [(tr("None"), ""), (tr("Upload Local Files"), "local")]
# 获取已有视频文件
for suffix in ["*.mp4", "*.mov", "*.avi", "*.mkv"]:
video_files = glob.glob(os.path.join(utils.video_dir(), suffix))
for file in video_files:
display_name = file.replace(config.root_dir, "")
video_list.append((display_name, file))
selected_video_index = st.selectbox(
tr("Video File"),
index=0,
options=range(len(video_list)),
format_func=lambda x: video_list[x][0]
)
video_path = video_list[selected_video_index][1]
st.session_state['video_origin_path'] = video_path
params.video_origin_path = video_path
if video_path == "local":
uploaded_file = st.file_uploader(
tr("Upload Local Files"),
type=["mp4", "mov", "avi", "flv", "mkv"],
accept_multiple_files=False,
)
if uploaded_file is not None:
video_file_path = os.path.join(utils.video_dir(), uploaded_file.name)
file_name, file_extension = os.path.splitext(uploaded_file.name)
if os.path.exists(video_file_path):
timestamp = time.strftime("%Y%m%d%H%M%S")
file_name_with_timestamp = f"{file_name}_{timestamp}"
video_file_path = os.path.join(utils.video_dir(), file_name_with_timestamp + file_extension)
with open(video_file_path, "wb") as f:
f.write(uploaded_file.read())
st.success(tr("File Uploaded Successfully"))
st.session_state['video_origin_path'] = video_file_path
params.video_origin_path = video_file_path
time.sleep(1)
st.rerun()
def render_video_details(tr):
"""渲染视频主题和提示词"""
video_theme = st.text_input(tr("Video Theme"))
prompt = st.text_area(
tr("Generation Prompt"),
value=st.session_state.get('video_plot', ''),
help=tr("Custom prompt for LLM, leave empty to use default prompt"),
height=180
)
st.session_state['video_name'] = video_theme
st.session_state['video_plot'] = prompt
return video_theme, prompt
def render_script_buttons(tr, params):
"""渲染脚本操作按钮"""
# 生成/加载按钮
script_path = st.session_state.get('video_clip_json_path', '')
if script_path == "auto":
button_name = tr("Generate Video Script")
elif script_path:
button_name = tr("Load Video Script")
else:
button_name = tr("Please Select Script File")
if st.button(button_name, key="script_action", disabled=not script_path):
if script_path == "auto":
generate_script(tr, params)
else:
load_script(tr, script_path)
# 视频脚本编辑区
video_clip_json_details = st.text_area(
tr("Video Script"),
value=json.dumps(st.session_state.get('video_clip_json', []), indent=2, ensure_ascii=False),
height=180
)
# 操作按钮行
button_cols = st.columns(3)
with button_cols[0]:
if st.button(tr("Check Format"), key="check_format", use_container_width=True):
check_script_format(tr, video_clip_json_details)
with button_cols[1]:
if st.button(tr("Save Script"), key="save_script", use_container_width=True):
save_script(tr, video_clip_json_details)
with button_cols[2]:
script_valid = st.session_state.get('script_format_valid', False)
if st.button(tr("Crop Video"), key="crop_video", disabled=not script_valid, use_container_width=True):
crop_video(tr, params)
def check_script_format(tr, script_content):
"""检查脚本格式"""
try:
result = check_script.check_format(script_content)
if result.get('success'):
st.success(tr("Script format check passed"))
st.session_state['script_format_valid'] = True
else:
st.error(f"{tr('Script format check failed')}: {result.get('message')}")
st.session_state['script_format_valid'] = False
except Exception as e:
st.error(f"{tr('Script format check error')}: {str(e)}")
st.session_state['script_format_valid'] = False
def load_script(tr, script_path):
"""加载脚本文件"""
try:
with open(script_path, 'r', encoding='utf-8') as f:
script = f.read()
script = utils.clean_model_output(script)
st.session_state['video_clip_json'] = json.loads(script)
st.success(tr("Script loaded successfully"))
st.rerun()
except Exception as e:
st.error(f"{tr('Failed to load script')}: {str(e)}")
def generate_script(tr, params):
"""生成视频脚本"""
progress_bar = st.progress(0)
status_text = st.empty()
def update_progress(progress: float, message: str = ""):
progress_bar.progress(progress)
if message:
status_text.text(f"{progress}% - {message}")
else:
status_text.text(f"进度: {progress}%")
try:
with st.spinner("正在生成脚本..."):
if not st.session_state.get('video_plot'):
st.warning("视频剧情为空; 会极大影响生成效果!")
if params.video_clip_json_path == "" and params.video_origin_path != "":
update_progress(10, "压缩视频中...")
script = llm.generate_script(
video_path=params.video_origin_path,
video_plot=st.session_state.get('video_plot', ''),
video_name=st.session_state.get('video_name', ''),
language=params.video_language,
progress_callback=update_progress
)
if script is None:
st.error("生成脚本失败,请检查日志")
st.stop()
else:
update_progress(90)
script = utils.clean_model_output(script)
st.session_state['video_clip_json'] = json.loads(script)
else:
# 从本地加载
with open(params.video_clip_json_path, 'r', encoding='utf-8') as f:
update_progress(50)
status_text.text("从本地加载中...")
script = f.read()
script = utils.clean_model_output(script)
st.session_state['video_clip_json'] = json.loads(script)
update_progress(100)
status_text.text("从本地加载成功")
time.sleep(0.5)
progress_bar.progress(100)
status_text.text("脚本生成完成!")
st.success("视频脚本生成成功!")
except Exception as err:
st.error(f"生成过程中发生错误: {str(err)}")
finally:
time.sleep(2)
progress_bar.empty()
status_text.empty()
def save_script(tr, video_clip_json_details):
"""保存视频脚本"""
if not video_clip_json_details:
st.error(tr("请输入视频脚本"))
st.stop()
with st.spinner(tr("Save Script")):
script_dir = utils.script_dir()
timestamp = time.strftime("%Y-%m%d-%H%M%S")
save_path = os.path.join(script_dir, f"{timestamp}.json")
try:
data = json.loads(video_clip_json_details)
with open(save_path, 'w', encoding='utf-8') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
st.session_state['video_clip_json'] = data
st.session_state['video_clip_json_path'] = save_path
# 更新配置
config.app["video_clip_json_path"] = save_path
# 显示成功消息
st.success(tr("Script saved successfully"))
# 强制重新加载页面以更新选择框
time.sleep(0.5) # 给一点时间让用户看到成功消息
st.rerun()
except Exception as err:
st.error(f"{tr('Failed to save script')}: {str(err)}")
st.stop()
def crop_video(tr, params):
"""裁剪视频"""
progress_bar = st.progress(0)
status_text = st.empty()
def update_progress(progress):
progress_bar.progress(progress)
status_text.text(f"剪辑进度: {progress}%")
try:
utils.cut_video(params, update_progress)
time.sleep(0.5)
progress_bar.progress(100)
status_text.text("剪辑完成!")
st.success("视频剪辑成功完成!")
except Exception as e:
st.error(f"剪辑过程中发生错误: {str(e)}")
finally:
time.sleep(2)
progress_bar.empty()
status_text.empty()
def get_script_params():
"""获取脚本参数"""
return {
'video_language': st.session_state.get('video_language', ''),
'video_clip_json_path': st.session_state.get('video_clip_json_path', ''),
'video_origin_path': st.session_state.get('video_origin_path', ''),
'video_name': st.session_state.get('video_name', ''),
'video_plot': st.session_state.get('video_plot', '')
}

View File

@ -0,0 +1,129 @@
import streamlit as st
from app.config import config
from webui.utils.cache import get_fonts_cache
import os
def render_subtitle_panel(tr):
"""渲染字幕设置面板"""
with st.container(border=True):
st.write(tr("Subtitle Settings"))
# 启用字幕选项
enable_subtitles = st.checkbox(tr("Enable Subtitles"), value=True)
st.session_state['subtitle_enabled'] = enable_subtitles
if enable_subtitles:
render_font_settings(tr)
render_position_settings(tr)
render_style_settings(tr)
def render_font_settings(tr):
"""渲染字体设置"""
# 获取字体列表
font_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "resource", "fonts")
font_names = get_fonts_cache(font_dir)
# 获取保存的字体设置
saved_font_name = config.ui.get("font_name", "")
saved_font_name_index = 0
if saved_font_name in font_names:
saved_font_name_index = font_names.index(saved_font_name)
# 字体选择
font_name = st.selectbox(
tr("Font"),
options=font_names,
index=saved_font_name_index
)
config.ui["font_name"] = font_name
st.session_state['font_name'] = font_name
# 字体大小
font_cols = st.columns([0.3, 0.7])
with font_cols[0]:
saved_text_fore_color = config.ui.get("text_fore_color", "#FFFFFF")
text_fore_color = st.color_picker(
tr("Font Color"),
saved_text_fore_color
)
config.ui["text_fore_color"] = text_fore_color
st.session_state['text_fore_color'] = text_fore_color
with font_cols[1]:
saved_font_size = config.ui.get("font_size", 60)
font_size = st.slider(
tr("Font Size"),
min_value=30,
max_value=100,
value=saved_font_size
)
config.ui["font_size"] = font_size
st.session_state['font_size'] = font_size
def render_position_settings(tr):
"""渲染位置设置"""
subtitle_positions = [
(tr("Top"), "top"),
(tr("Center"), "center"),
(tr("Bottom"), "bottom"),
(tr("Custom"), "custom"),
]
selected_index = st.selectbox(
tr("Position"),
index=2,
options=range(len(subtitle_positions)),
format_func=lambda x: subtitle_positions[x][0],
)
subtitle_position = subtitle_positions[selected_index][1]
st.session_state['subtitle_position'] = subtitle_position
# 自定义位置处理
if subtitle_position == "custom":
custom_position = st.text_input(
tr("Custom Position (% from top)"),
value="70.0"
)
try:
custom_position_value = float(custom_position)
if custom_position_value < 0 or custom_position_value > 100:
st.error(tr("Please enter a value between 0 and 100"))
else:
st.session_state['custom_position'] = custom_position_value
except ValueError:
st.error(tr("Please enter a valid number"))
def render_style_settings(tr):
"""渲染样式设置"""
stroke_cols = st.columns([0.3, 0.7])
with stroke_cols[0]:
stroke_color = st.color_picker(
tr("Stroke Color"),
value="#000000"
)
st.session_state['stroke_color'] = stroke_color
with stroke_cols[1]:
stroke_width = st.slider(
tr("Stroke Width"),
min_value=0.0,
max_value=10.0,
value=1.5,
step=0.1
)
st.session_state['stroke_width'] = stroke_width
def get_subtitle_params():
"""获取字幕参数"""
return {
'enabled': st.session_state.get('subtitle_enabled', True),
'font_name': st.session_state.get('font_name', ''),
'font_size': st.session_state.get('font_size', 60),
'text_fore_color': st.session_state.get('text_fore_color', '#FFFFFF'),
'position': st.session_state.get('subtitle_position', 'bottom'),
'custom_position': st.session_state.get('custom_position', 70.0),
'stroke_color': st.session_state.get('stroke_color', '#000000'),
'stroke_width': st.session_state.get('stroke_width', 1.5),
}

View File

@ -0,0 +1,47 @@
import streamlit as st
from app.models.schema import VideoClipParams, VideoAspect
def render_video_panel(tr):
"""渲染视频配置面板"""
with st.container(border=True):
st.write(tr("Video Settings"))
params = VideoClipParams()
render_video_config(tr, params)
def render_video_config(tr, params):
"""渲染视频配置"""
# 视频比例
video_aspect_ratios = [
(tr("Portrait"), VideoAspect.portrait.value),
(tr("Landscape"), VideoAspect.landscape.value),
]
selected_index = st.selectbox(
tr("Video Ratio"),
options=range(len(video_aspect_ratios)),
format_func=lambda x: video_aspect_ratios[x][0],
)
params.video_aspect = VideoAspect(video_aspect_ratios[selected_index][1])
st.session_state['video_aspect'] = params.video_aspect.value
# 视频画质
video_qualities = [
("4K (2160p)", "2160p"),
("2K (1440p)", "1440p"),
("Full HD (1080p)", "1080p"),
("HD (720p)", "720p"),
("SD (480p)", "480p"),
]
quality_index = st.selectbox(
tr("Video Quality"),
options=range(len(video_qualities)),
format_func=lambda x: video_qualities[x][0],
index=2 # 默认选择 1080p
)
st.session_state['video_quality'] = video_qualities[quality_index][1]
def get_video_params():
"""获取视频参数"""
return {
'video_aspect': st.session_state.get('video_aspect', VideoAspect.portrait.value),
'video_quality': st.session_state.get('video_quality', '1080p')
}

155
webui/config/settings.py Normal file
View File

@ -0,0 +1,155 @@
import os
import tomli
from loguru import logger
from typing import Dict, Any, Optional
from dataclasses import dataclass
@dataclass
class WebUIConfig:
"""WebUI配置类"""
# UI配置
ui: Dict[str, Any] = None
# 代理配置
proxy: Dict[str, str] = None
# 应用配置
app: Dict[str, Any] = None
# Azure配置
azure: Dict[str, str] = None
# 项目版本
project_version: str = "0.1.0"
# 项目根目录
root_dir: str = None
def __post_init__(self):
"""初始化默认值"""
self.ui = self.ui or {}
self.proxy = self.proxy or {}
self.app = self.app or {}
self.azure = self.azure or {}
self.root_dir = self.root_dir or os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
def load_config(config_path: Optional[str] = None) -> WebUIConfig:
"""加载配置文件
Args:
config_path: 配置文件路径如果为None则使用默认路径
Returns:
WebUIConfig: 配置对象
"""
try:
if config_path is None:
config_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
".streamlit",
"webui.toml"
)
# 如果配置文件不存在,使用示例配置
if not os.path.exists(config_path):
example_config = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
"config.example.toml"
)
if os.path.exists(example_config):
config_path = example_config
else:
logger.warning(f"配置文件不存在: {config_path}")
return WebUIConfig()
# 读取配置文件
with open(config_path, "rb") as f:
config_dict = tomli.load(f)
# 创建配置对象
config = WebUIConfig(
ui=config_dict.get("ui", {}),
proxy=config_dict.get("proxy", {}),
app=config_dict.get("app", {}),
azure=config_dict.get("azure", {}),
project_version=config_dict.get("project_version", "0.1.0")
)
return config
except Exception as e:
logger.error(f"加载配置文件失败: {e}")
return WebUIConfig()
def save_config(config: WebUIConfig, config_path: Optional[str] = None) -> bool:
"""保存配置到文件
Args:
config: 配置对象
config_path: 配置文件路径如果为None则使用默认路径
Returns:
bool: 是否保存成功
"""
try:
if config_path is None:
config_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
".streamlit",
"webui.toml"
)
# 确保目录存在
os.makedirs(os.path.dirname(config_path), exist_ok=True)
# 转换为字典
config_dict = {
"ui": config.ui,
"proxy": config.proxy,
"app": config.app,
"azure": config.azure,
"project_version": config.project_version
}
# 保存配置
with open(config_path, "w", encoding="utf-8") as f:
import tomli_w
tomli_w.dump(config_dict, f)
return True
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
return False
def get_config() -> WebUIConfig:
"""获取全局配置对象
Returns:
WebUIConfig: 配置对象
"""
if not hasattr(get_config, "_config"):
get_config._config = load_config()
return get_config._config
def update_config(config_dict: Dict[str, Any]) -> bool:
"""更新配置
Args:
config_dict: 配置字典
Returns:
bool: 是否更新成功
"""
try:
config = get_config()
# 更新配置
if "ui" in config_dict:
config.ui.update(config_dict["ui"])
if "proxy" in config_dict:
config.proxy.update(config_dict["proxy"])
if "app" in config_dict:
config.app.update(config_dict["app"])
if "azure" in config_dict:
config.azure.update(config_dict["azure"])
if "project_version" in config_dict:
config.project_version = config_dict["project_version"]
# 保存配置
return save_config(config)
except Exception as e:
logger.error(f"更新配置失败: {e}")
return False
# 导出全局配置对象
config = get_config()

1
webui/i18n/__init__.py Normal file
View File

@ -0,0 +1 @@
# 空文件,用于标记包

View File

@ -2,15 +2,15 @@
"Language": "简体中文",
"Translation": {
"Video Script Configuration": "**视频脚本配置**",
"Video Script Generate": "生成视频脚本",
"Generate Video Script": "生成视频脚本",
"Video Subject": "视频主题(给定一个关键词,:red[AI自动生成]视频文案)",
"Script Language": "生成视频脚本的语言一般情况AI会自动根据你输入的主题语言输出",
"Script Files": "脚本文件",
"Generate Video Script and Keywords": "点击使用AI根据**主题**生成 【视频文案】 和 【视频关键词】",
"Auto Detect": "自动检测",
"Auto Generate": "自动生成",
"Video Name": "视频名称",
"Video Script": "视频脚本(:blue[①使用AI生成 ②从本机加载]",
"Video Theme": "视频主题",
"Generation Prompt": "自定义提示词",
"Save Script": "保存脚本",
"Crop Video": "裁剪视频",
"Video File": "视频文件(:blue[1⃣支持上传视频文件(限制2G) 2⃣大文件建议直接导入 ./resource/videos 目录]",
@ -91,7 +91,18 @@
"Picture description": "图片描述",
"Narration": "视频文案",
"Rebuild": "重新生成",
"Video Script Load": "加载视频脚本",
"Speech Pitch": "语调"
"Load Video Script": "加载视频脚本",
"Speech Pitch": "语调",
"Please Select Script File": "请选择脚本文件",
"Check Format": "脚本格式检查",
"Script Loaded Successfully": "脚本加载成功",
"Script format check passed": "脚本格式检查通过",
"Script format check failed": "脚本格式检查失败",
"Failed to Load Script": "加载脚本失败",
"Failed to Save Script": "保存脚本失败",
"Script saved successfully": "脚本保存成功",
"Video Script": "视频脚本",
"Video Quality": "视频质量",
"Custom prompt for LLM, leave empty to use default prompt": "自定义提示词,留空则使用默认提示词"
}
}

20
webui/utils/__init__.py Normal file
View File

@ -0,0 +1,20 @@
from .cache import get_fonts_cache, get_video_files_cache, get_songs_cache
from .file_utils import (
open_task_folder, cleanup_temp_files, get_file_list,
save_uploaded_file, create_temp_file, get_file_size, ensure_directory
)
from .performance import monitor_performance
__all__ = [
'get_fonts_cache',
'get_video_files_cache',
'get_songs_cache',
'open_task_folder',
'cleanup_temp_files',
'get_file_list',
'save_uploaded_file',
'create_temp_file',
'get_file_size',
'ensure_directory',
'monitor_performance'
]

33
webui/utils/cache.py Normal file
View File

@ -0,0 +1,33 @@
import streamlit as st
import os
import glob
from app.utils import utils
def get_fonts_cache(font_dir):
if 'fonts_cache' not in st.session_state:
fonts = []
for root, dirs, files in os.walk(font_dir):
for file in files:
if file.endswith(".ttf") or file.endswith(".ttc"):
fonts.append(file)
fonts.sort()
st.session_state['fonts_cache'] = fonts
return st.session_state['fonts_cache']
def get_video_files_cache():
if 'video_files_cache' not in st.session_state:
video_files = []
for suffix in ["*.mp4", "*.mov", "*.avi", "*.mkv"]:
video_files.extend(glob.glob(os.path.join(utils.video_dir(), suffix)))
st.session_state['video_files_cache'] = video_files[::-1]
return st.session_state['video_files_cache']
def get_songs_cache(song_dir):
if 'songs_cache' not in st.session_state:
songs = []
for root, dirs, files in os.walk(song_dir):
for file in files:
if file.endswith(".mp3"):
songs.append(file)
st.session_state['songs_cache'] = songs
return st.session_state['songs_cache']

189
webui/utils/file_utils.py Normal file
View File

@ -0,0 +1,189 @@
import os
import glob
import time
import platform
import shutil
from uuid import uuid4
from loguru import logger
from app.utils import utils
def open_task_folder(root_dir, task_id):
"""打开任务文件夹
Args:
root_dir: 项目根目录
task_id: 任务ID
"""
try:
sys = platform.system()
path = os.path.join(root_dir, "storage", "tasks", task_id)
if os.path.exists(path):
if sys == 'Windows':
os.system(f"start {path}")
if sys == 'Darwin':
os.system(f"open {path}")
if sys == 'Linux':
os.system(f"xdg-open {path}")
except Exception as e:
logger.error(f"打开任务文件夹失败: {e}")
def cleanup_temp_files(temp_dir, max_age=3600):
"""清理临时文件
Args:
temp_dir: 临时文件目录
max_age: 文件最大保存时间()
"""
if os.path.exists(temp_dir):
for file in os.listdir(temp_dir):
file_path = os.path.join(temp_dir, file)
try:
if os.path.getctime(file_path) < time.time() - max_age:
if os.path.isfile(file_path):
os.remove(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
logger.debug(f"已清理临时文件: {file_path}")
except Exception as e:
logger.error(f"清理临时文件失败: {file_path}, 错误: {e}")
def get_file_list(directory, file_types=None, sort_by='ctime', reverse=True):
"""获取指定目录下的文件列表
Args:
directory: 目录路径
file_types: 文件类型列表 ['.mp4', '.mov']
sort_by: 排序方式支持 'ctime'(创建时间), 'mtime'(修改时间), 'size'(文件大小), 'name'(文件名)
reverse: 是否倒序排序
Returns:
list: 文件信息列表
"""
if not os.path.exists(directory):
return []
files = []
if file_types:
for file_type in file_types:
files.extend(glob.glob(os.path.join(directory, f"*{file_type}")))
else:
files = glob.glob(os.path.join(directory, "*"))
file_list = []
for file_path in files:
try:
file_stat = os.stat(file_path)
file_info = {
"name": os.path.basename(file_path),
"path": file_path,
"size": file_stat.st_size,
"ctime": file_stat.st_ctime,
"mtime": file_stat.st_mtime
}
file_list.append(file_info)
except Exception as e:
logger.error(f"获取文件信息失败: {file_path}, 错误: {e}")
# 排序
if sort_by in ['ctime', 'mtime', 'size', 'name']:
file_list.sort(key=lambda x: x.get(sort_by, ''), reverse=reverse)
return file_list
def save_uploaded_file(uploaded_file, save_dir, allowed_types=None):
"""保存上传的文件
Args:
uploaded_file: StreamlitUploadedFile对象
save_dir: 保存目录
allowed_types: 允许的文件类型列表 ['.mp4', '.mov']
Returns:
str: 保存后的文件路径失败返回None
"""
try:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
file_name, file_extension = os.path.splitext(uploaded_file.name)
# 检查文件类型
if allowed_types and file_extension.lower() not in allowed_types:
logger.error(f"不支持的文件类型: {file_extension}")
return None
# 如果文件已存在,添加时间戳
save_path = os.path.join(save_dir, uploaded_file.name)
if os.path.exists(save_path):
timestamp = time.strftime("%Y%m%d%H%M%S")
new_file_name = f"{file_name}_{timestamp}{file_extension}"
save_path = os.path.join(save_dir, new_file_name)
# 保存文件
with open(save_path, "wb") as f:
f.write(uploaded_file.read())
logger.info(f"文件保存成功: {save_path}")
return save_path
except Exception as e:
logger.error(f"保存上传文件失败: {e}")
return None
def create_temp_file(prefix='tmp', suffix='', directory=None):
"""创建临时文件
Args:
prefix: 文件名前缀
suffix: 文件扩展名
directory: 临时文件目录默认使用系统临时目录
Returns:
str: 临时文件路径
"""
try:
if directory is None:
directory = utils.storage_dir("temp", create=True)
if not os.path.exists(directory):
os.makedirs(directory)
temp_file = os.path.join(directory, f"{prefix}-{str(uuid4())}{suffix}")
return temp_file
except Exception as e:
logger.error(f"创建临时文件失败: {e}")
return None
def get_file_size(file_path, format='MB'):
"""获取文件大小
Args:
file_path: 文件路径
format: 返回格式支持 'B', 'KB', 'MB', 'GB'
Returns:
float: 文件大小
"""
try:
size_bytes = os.path.getsize(file_path)
if format.upper() == 'B':
return size_bytes
elif format.upper() == 'KB':
return size_bytes / 1024
elif format.upper() == 'MB':
return size_bytes / (1024 * 1024)
elif format.upper() == 'GB':
return size_bytes / (1024 * 1024 * 1024)
else:
return size_bytes
except Exception as e:
logger.error(f"获取文件大小失败: {file_path}, 错误: {e}")
return 0
def ensure_directory(directory):
"""确保目录存在,如果不存在则创建
Args:
directory: 目录路径
Returns:
bool: 是否成功
"""
try:
if not os.path.exists(directory):
os.makedirs(directory)
return True
except Exception as e:
logger.error(f"创建目录失败: {directory}, 错误: {e}")
return False

View File

@ -0,0 +1,24 @@
import time
from loguru import logger
try:
import psutil
ENABLE_PERFORMANCE_MONITORING = True
except ImportError:
ENABLE_PERFORMANCE_MONITORING = False
logger.warning("psutil not installed. Performance monitoring is disabled.")
def monitor_performance():
if not ENABLE_PERFORMANCE_MONITORING:
return {'execution_time': 0, 'memory_usage': 0}
start_time = time.time()
try:
memory_usage = psutil.Process().memory_info().rss / 1024 / 1024 # MB
except:
memory_usage = 0
return {
'execution_time': time.time() - start_time,
'memory_usage': memory_usage
}