mirror of
https://github.com/linyqh/NarratoAI.git
synced 2025-12-11 18:42:49 +00:00
webui 代码重构;
This commit is contained in:
parent
242f8d5355
commit
bb18a754fe
@ -29,7 +29,7 @@ def create(audio_file, subtitle_file: str = ""):
|
||||
返回:
|
||||
无返回值,但会在指定路径生成字幕文件。
|
||||
"""
|
||||
global model
|
||||
global model, device, compute_type
|
||||
if not model:
|
||||
model_path = f"{utils.root_dir()}/app/models/faster-whisper-large-v2"
|
||||
model_bin_file = f"{model_path}/model.bin"
|
||||
@ -43,27 +43,45 @@ def create(audio_file, subtitle_file: str = ""):
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"加载模型: {model_path}, 设备: {device}, 计算类型: {compute_type}"
|
||||
)
|
||||
# 尝试使用 CUDA,如果失败则回退到 CPU
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
logger.info(f"尝试使用 CUDA 加载模型: {model_path}")
|
||||
model = WhisperModel(
|
||||
model_size_or_path=model_path,
|
||||
device="cuda",
|
||||
compute_type="float16",
|
||||
local_files_only=True
|
||||
)
|
||||
device = "cuda"
|
||||
compute_type = "float16"
|
||||
logger.info("成功使用 CUDA 加载模型")
|
||||
except Exception as e:
|
||||
logger.warning(f"CUDA 加载失败,错误信息: {str(e)}")
|
||||
logger.warning("回退到 CPU 模式")
|
||||
device = "cpu"
|
||||
compute_type = "int8"
|
||||
else:
|
||||
logger.info("未检测到 CUDA,使用 CPU 模式")
|
||||
device = "cpu"
|
||||
compute_type = "int8"
|
||||
except ImportError:
|
||||
logger.warning("未安装 torch,使用 CPU 模式")
|
||||
device = "cpu"
|
||||
compute_type = "int8"
|
||||
|
||||
if device == "cpu":
|
||||
logger.info(f"使用 CPU 加载模型: {model_path}")
|
||||
model = WhisperModel(
|
||||
model_size_or_path=model_path,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
local_files_only=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"加载模型失败: {e} \n\n"
|
||||
f"********************************************\n"
|
||||
f"这可能是由网络问题引起的. \n"
|
||||
f"请手动下载模型并将其放入 'app/models' 文件夹中。 \n"
|
||||
f"see [README.md FAQ](https://github.com/linyqh/NarratoAI) for more details.\n"
|
||||
f"********************************************\n\n"
|
||||
f"{traceback.format_exc()}"
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(f"模型加载完成,使用设备: {device}, 计算类型: {compute_type}")
|
||||
|
||||
logger.info(f"start, output file: {subtitle_file}")
|
||||
if not subtitle_file:
|
||||
|
||||
@ -1,115 +1,81 @@
|
||||
import json
|
||||
from loguru import logger
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from typing import Dict, Any
|
||||
|
||||
def time_to_seconds(time_str):
|
||||
parts = list(map(int, time_str.split(':')))
|
||||
if len(parts) == 2:
|
||||
return timedelta(minutes=parts[0], seconds=parts[1]).total_seconds()
|
||||
elif len(parts) == 3:
|
||||
return timedelta(hours=parts[0], minutes=parts[1], seconds=parts[2]).total_seconds()
|
||||
raise ValueError(f"无法解析时间字符串: {time_str}")
|
||||
def check_format(script_content: str) -> Dict[str, Any]:
|
||||
"""检查脚本格式
|
||||
Args:
|
||||
script_content: 脚本内容
|
||||
Returns:
|
||||
Dict: {'success': bool, 'message': str}
|
||||
"""
|
||||
try:
|
||||
# 检查是否为有效的JSON
|
||||
data = json.loads(script_content)
|
||||
|
||||
# 检查是否为列表
|
||||
if not isinstance(data, list):
|
||||
return {
|
||||
'success': False,
|
||||
'message': '脚本必须是JSON数组格式'
|
||||
}
|
||||
|
||||
# 检查每个片段
|
||||
for i, clip in enumerate(data):
|
||||
# 检查必需字段
|
||||
required_fields = ['narration', 'picture', 'timestamp']
|
||||
for field in required_fields:
|
||||
if field not in clip:
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'第{i+1}个片段缺少必需字段: {field}'
|
||||
}
|
||||
|
||||
# 检查字段类型
|
||||
if not isinstance(clip['narration'], str):
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'第{i+1}个片段的narration必须是字符串'
|
||||
}
|
||||
if not isinstance(clip['picture'], str):
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'第{i+1}个片段的picture必须是字符串'
|
||||
}
|
||||
if not isinstance(clip['timestamp'], str):
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'第{i+1}个片段的timestamp必须是字符串'
|
||||
}
|
||||
|
||||
# 检查字段内容不能为空
|
||||
if not clip['narration'].strip():
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'第{i+1}个片段的narration不能为空'
|
||||
}
|
||||
if not clip['picture'].strip():
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'第{i+1}个片段的picture不能为空'
|
||||
}
|
||||
if not clip['timestamp'].strip():
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'第{i+1}个片段的timestamp不能为空'
|
||||
}
|
||||
|
||||
def seconds_to_time_str(seconds):
|
||||
hours, remainder = divmod(int(seconds), 3600)
|
||||
minutes, seconds = divmod(remainder, 60)
|
||||
if hours > 0:
|
||||
return f"{hours:02d}:{minutes:02d}:{seconds:02d}"
|
||||
else:
|
||||
return f"{minutes:02d}:{seconds:02d}"
|
||||
return {
|
||||
'success': True,
|
||||
'message': '脚本格式检查通过'
|
||||
}
|
||||
|
||||
def adjust_timestamp(start_time, duration):
|
||||
start_seconds = time_to_seconds(start_time)
|
||||
end_seconds = start_seconds + duration
|
||||
return f"{start_time}-{seconds_to_time_str(end_seconds)}"
|
||||
|
||||
def estimate_audio_duration(text):
|
||||
# 假设平均每个字符需要 0.2 秒
|
||||
return len(text) * 0.2
|
||||
|
||||
def check_script(data, total_duration):
|
||||
errors = []
|
||||
time_ranges = []
|
||||
|
||||
logger.info("开始检查脚本")
|
||||
logger.info(f"视频总时长: {total_duration:.2f} 秒")
|
||||
logger.info("=" * 50)
|
||||
|
||||
for i, item in enumerate(data, 1):
|
||||
logger.info(f"\n检查第 {i} 项:")
|
||||
|
||||
# 检查所有必需字段
|
||||
required_fields = ['picture', 'timestamp', 'narration', 'OST']
|
||||
for field in required_fields:
|
||||
if field not in item:
|
||||
errors.append(f"第 {i} 项缺少 {field} 字段")
|
||||
logger.info(f" - 错误: 缺少 {field} 字段")
|
||||
else:
|
||||
logger.info(f" - {field}: {item[field]}")
|
||||
|
||||
# 检查 OST 相关规则
|
||||
if item.get('OST') == False:
|
||||
if not item.get('narration'):
|
||||
errors.append(f"第 {i} 项 OST 为 false,但 narration 为空")
|
||||
logger.info(" - 错误: OST 为 false,但 narration 为空")
|
||||
elif len(item['narration']) > 60:
|
||||
errors.append(f"第 {i} 项 OST 为 false,但 narration 超过 60 字")
|
||||
logger.info(f" - 错误: OST 为 false,但 narration 超过 60 字 (当前: {len(item['narration'])} 字)")
|
||||
else:
|
||||
logger.info(" - OST 为 false,narration 检查通过")
|
||||
elif item.get('OST') == True:
|
||||
if "原声播放_" not in item.get('narration'):
|
||||
errors.append(f"第 {i} 项 OST 为 true,但 narration 不为空")
|
||||
logger.info(" - 错误: OST 为 true,但 narration 不为空")
|
||||
else:
|
||||
logger.info(" - OST 为 true,narration 检查通过")
|
||||
|
||||
# 检查 timestamp
|
||||
if 'timestamp' in item:
|
||||
start, end = map(time_to_seconds, item['timestamp'].split('-'))
|
||||
if any((start < existing_end and end > existing_start) for existing_start, existing_end in time_ranges):
|
||||
errors.append(f"第 {i} 项 timestamp '{item['timestamp']}' 与其他时间段重叠")
|
||||
logger.info(f" - 错误: timestamp '{item['timestamp']}' 与其他时间段重叠")
|
||||
else:
|
||||
logger.info(f" - timestamp '{item['timestamp']}' 检查通过")
|
||||
time_ranges.append((start, end))
|
||||
|
||||
# if end > total_duration:
|
||||
# errors.append(f"第 {i} 项 timestamp '{item['timestamp']}' 超过总时长 {total_duration:.2f} 秒")
|
||||
# logger.info(f" - 错误: timestamp '{item['timestamp']}' 超过总时长 {total_duration:.2f} 秒")
|
||||
# else:
|
||||
# logger.info(f" - timestamp 在总时长范围内")
|
||||
|
||||
# 处理 narration 字段
|
||||
if item.get('OST') == False and item.get('narration'):
|
||||
estimated_duration = estimate_audio_duration(item['narration'])
|
||||
start_time = item['timestamp'].split('-')[0]
|
||||
item['timestamp'] = adjust_timestamp(start_time, estimated_duration)
|
||||
logger.info(f" - 已调整 timestamp 为 {item['timestamp']} (估算音频时长: {estimated_duration:.2f} 秒)")
|
||||
|
||||
if errors:
|
||||
logger.info("检查结果:不通过")
|
||||
logger.info("发现以下错误:")
|
||||
for error in errors:
|
||||
logger.info(f"- {error}")
|
||||
else:
|
||||
logger.info("检查结果:通过")
|
||||
logger.info("所有项目均符合规则要求。")
|
||||
|
||||
return errors, data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
file_path = "/Users/apple/Desktop/home/NarratoAI/resource/scripts/test004.json"
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
total_duration = 280
|
||||
|
||||
# check_script(data, total_duration)
|
||||
|
||||
from app.utils.utils import add_new_timestamps
|
||||
res = add_new_timestamps(data)
|
||||
print(json.dumps(res, indent=4, ensure_ascii=False))
|
||||
except json.JSONDecodeError as e:
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'JSON格式错误: {str(e)}'
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'检查过程中发生错误: {str(e)}'
|
||||
}
|
||||
|
||||
@ -24,3 +24,4 @@ azure-cognitiveservices-speech~=1.37.0
|
||||
git-changelog~=2.5.2
|
||||
watchdog==5.0.2
|
||||
pydub==0.25.1
|
||||
psutil>=5.9.0
|
||||
|
||||
869
webui.py
869
webui.py
@ -1,6 +1,14 @@
|
||||
import streamlit as st
|
||||
import os
|
||||
import sys
|
||||
from uuid import uuid4
|
||||
from app.config import config
|
||||
from webui.components import basic_settings, video_settings, audio_settings, subtitle_settings, script_settings, review_settings
|
||||
from webui.utils import cache, file_utils, performance
|
||||
from app.utils import utils
|
||||
from app.models.schema import VideoClipParams, VideoAspect
|
||||
|
||||
# 初始化配置 - 必须是第一个 Streamlit 命令
|
||||
st.set_page_config(
|
||||
page_title="NarratoAI",
|
||||
page_icon="📽️",
|
||||
@ -13,126 +21,23 @@ st.set_page_config(
|
||||
},
|
||||
)
|
||||
|
||||
import sys
|
||||
import os
|
||||
import glob
|
||||
import json
|
||||
import time
|
||||
import datetime
|
||||
import traceback
|
||||
from uuid import uuid4
|
||||
import platform
|
||||
import streamlit.components.v1 as components
|
||||
from loguru import logger
|
||||
|
||||
from app.models.const import FILE_TYPE_VIDEOS
|
||||
from app.models.schema import VideoClipParams, VideoAspect, VideoConcatMode
|
||||
from app.services import task as tm, llm, voice, material
|
||||
from app.utils import utils
|
||||
|
||||
# # 将项目的根目录添加到系统路径中,以允许从项目导入模块
|
||||
root_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
if root_dir not in sys.path:
|
||||
sys.path.append(root_dir)
|
||||
print("******** sys.path ********")
|
||||
print(sys.path)
|
||||
print("*" * 20)
|
||||
|
||||
proxy_url_http = config.proxy.get("http", "") or os.getenv("VPN_PROXY_URL", "")
|
||||
proxy_url_https = config.proxy.get("https", "") or os.getenv("VPN_PROXY_URL", "")
|
||||
os.environ["HTTP_PROXY"] = proxy_url_http
|
||||
os.environ["HTTPS_PROXY"] = proxy_url_https
|
||||
|
||||
# 设置页面样式
|
||||
hide_streamlit_style = """
|
||||
<style>#root > div:nth-child(1) > div > div > div > div > section > div {padding-top: 6px; padding-bottom: 10px; padding-left: 20px; padding-right: 20px;}</style>
|
||||
"""
|
||||
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
|
||||
st.title(f"NarratoAI :sunglasses:📽️")
|
||||
support_locales = [
|
||||
"zh-CN",
|
||||
"zh-HK",
|
||||
"zh-TW",
|
||||
"en-US",
|
||||
]
|
||||
font_dir = os.path.join(root_dir, "resource", "fonts")
|
||||
song_dir = os.path.join(root_dir, "resource", "songs")
|
||||
i18n_dir = os.path.join(root_dir, "webui", "i18n")
|
||||
config_file = os.path.join(root_dir, "webui", ".streamlit", "webui.toml")
|
||||
system_locale = utils.get_system_locale()
|
||||
|
||||
if 'video_clip_json' not in st.session_state:
|
||||
st.session_state['video_clip_json'] = []
|
||||
if 'video_plot' not in st.session_state:
|
||||
st.session_state['video_plot'] = ''
|
||||
if 'ui_language' not in st.session_state:
|
||||
st.session_state['ui_language'] = config.ui.get("language", system_locale)
|
||||
if 'subclip_videos' not in st.session_state:
|
||||
st.session_state['subclip_videos'] = {}
|
||||
|
||||
|
||||
def get_all_fonts():
|
||||
fonts = []
|
||||
for root, dirs, files in os.walk(font_dir):
|
||||
for file in files:
|
||||
if file.endswith(".ttf") or file.endswith(".ttc"):
|
||||
fonts.append(file)
|
||||
fonts.sort()
|
||||
return fonts
|
||||
|
||||
|
||||
def get_all_songs():
|
||||
songs = []
|
||||
for root, dirs, files in os.walk(song_dir):
|
||||
for file in files:
|
||||
if file.endswith(".mp3"):
|
||||
songs.append(file)
|
||||
return songs
|
||||
|
||||
|
||||
def open_task_folder(task_id):
|
||||
try:
|
||||
sys = platform.system()
|
||||
path = os.path.join(root_dir, "storage", "tasks", task_id)
|
||||
if os.path.exists(path):
|
||||
if sys == 'Windows':
|
||||
os.system(f"start {path}")
|
||||
if sys == 'Darwin':
|
||||
os.system(f"open {path}")
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
|
||||
def scroll_to_bottom():
|
||||
js = f"""
|
||||
<script>
|
||||
console.log("scroll_to_bottom");
|
||||
function scroll(dummy_var_to_force_repeat_execution){{
|
||||
var sections = parent.document.querySelectorAll('section.main');
|
||||
console.log(sections);
|
||||
for(let index = 0; index<sections.length; index++) {{
|
||||
sections[index].scrollTop = sections[index].scrollHeight;
|
||||
}}
|
||||
}}
|
||||
scroll(1);
|
||||
</script>
|
||||
"""
|
||||
st.components.v1.html(js, height=0, width=0)
|
||||
|
||||
|
||||
def init_log():
|
||||
"""初始化日志配置"""
|
||||
from loguru import logger
|
||||
logger.remove()
|
||||
_lvl = "DEBUG"
|
||||
|
||||
def format_record(record):
|
||||
# 获取日志记录中的文件全径
|
||||
file_path = record["file"].path
|
||||
# 将绝对路径转换为相对于项目根目录的路径
|
||||
relative_path = os.path.relpath(file_path, root_dir)
|
||||
# 更新记录中的文件路径
|
||||
relative_path = os.path.relpath(file_path, config.root_dir)
|
||||
record["file"].path = f"./{relative_path}"
|
||||
# 返回修改后的格式字符串
|
||||
# 您可以根据需要调整这里的格式
|
||||
record['message'] = record['message'].replace(root_dir, ".")
|
||||
record['message'] = record['message'].replace(config.root_dir, ".")
|
||||
|
||||
_format = '<green>{time:%Y-%m-%d %H:%M:%S}</> | ' + \
|
||||
'<level>{level}</> | ' + \
|
||||
@ -147,672 +52,120 @@ def init_log():
|
||||
colorize=True,
|
||||
)
|
||||
|
||||
|
||||
init_log()
|
||||
|
||||
locales = utils.load_locales(i18n_dir)
|
||||
|
||||
def init_global_state():
|
||||
"""初始化全局状态"""
|
||||
if 'video_clip_json' not in st.session_state:
|
||||
st.session_state['video_clip_json'] = []
|
||||
if 'video_plot' not in st.session_state:
|
||||
st.session_state['video_plot'] = ''
|
||||
if 'ui_language' not in st.session_state:
|
||||
st.session_state['ui_language'] = config.ui.get("language", utils.get_system_locale())
|
||||
if 'subclip_videos' not in st.session_state:
|
||||
st.session_state['subclip_videos'] = {}
|
||||
|
||||
def tr(key):
|
||||
"""翻译函数"""
|
||||
i18n_dir = os.path.join(os.path.dirname(__file__), "webui", "i18n")
|
||||
locales = utils.load_locales(i18n_dir)
|
||||
loc = locales.get(st.session_state['ui_language'], {})
|
||||
return loc.get("Translation", {}).get(key, key)
|
||||
|
||||
def render_generate_button():
|
||||
"""渲染生成按钮和处理逻辑"""
|
||||
if st.button(tr("Generate Video"), use_container_width=True, type="primary"):
|
||||
from app.services import task as tm
|
||||
|
||||
# 重置日志容器和记录
|
||||
log_container = st.empty()
|
||||
log_records = []
|
||||
|
||||
st.write(tr("Get Help"))
|
||||
def log_received(msg):
|
||||
with log_container:
|
||||
log_records.append(msg)
|
||||
st.code("\n".join(log_records))
|
||||
|
||||
# 基础设置
|
||||
with st.expander(tr("Basic Settings"), expanded=False):
|
||||
config_panels = st.columns(3)
|
||||
left_config_panel = config_panels[0]
|
||||
middle_config_panel = config_panels[1]
|
||||
right_config_panel = config_panels[2]
|
||||
with left_config_panel:
|
||||
display_languages = []
|
||||
selected_index = 0
|
||||
for i, code in enumerate(locales.keys()):
|
||||
display_languages.append(f"{code} - {locales[code].get('Language')}")
|
||||
if code == st.session_state['ui_language']:
|
||||
selected_index = i
|
||||
from loguru import logger
|
||||
logger.add(log_received)
|
||||
|
||||
selected_language = st.selectbox(tr("Language"), options=display_languages,
|
||||
index=selected_index)
|
||||
if selected_language:
|
||||
code = selected_language.split(" - ")[0].strip()
|
||||
st.session_state['ui_language'] = code
|
||||
config.ui['language'] = code
|
||||
config.save_config()
|
||||
task_id = st.session_state.get('task_id')
|
||||
|
||||
HTTP_PROXY = st.text_input(tr("HTTP_PROXY"), value=proxy_url_http)
|
||||
HTTPS_PROXY = st.text_input(tr("HTTPs_PROXY"), value=proxy_url_https)
|
||||
if HTTP_PROXY:
|
||||
config.proxy["http"] = HTTP_PROXY
|
||||
if HTTPS_PROXY:
|
||||
config.proxy["https"] = HTTPS_PROXY
|
||||
if not task_id:
|
||||
st.error(tr("请先裁剪视频"))
|
||||
return
|
||||
if not st.session_state.get('video_clip_json_path'):
|
||||
st.error(tr("脚本文件不能为空"))
|
||||
return
|
||||
if not st.session_state.get('video_origin_path'):
|
||||
st.error(tr("视频文件不能为空"))
|
||||
return
|
||||
|
||||
# 视频转录大模型
|
||||
with middle_config_panel:
|
||||
video_llm_providers = ['Gemini']
|
||||
saved_llm_provider = config.app.get("llm_provider", "OpenAI").lower()
|
||||
saved_llm_provider_index = 0
|
||||
for i, provider in enumerate(video_llm_providers):
|
||||
if provider.lower() == saved_llm_provider:
|
||||
saved_llm_provider_index = i
|
||||
break
|
||||
st.toast(tr("生成视频"))
|
||||
logger.info(tr("开始生成视频"))
|
||||
|
||||
video_llm_provider = st.selectbox(tr("Video LLM Provider"), options=video_llm_providers, index=saved_llm_provider_index)
|
||||
video_llm_provider = video_llm_provider.lower()
|
||||
config.app["video_llm_provider"] = video_llm_provider
|
||||
# 获取所有参数
|
||||
script_params = script_settings.get_script_params()
|
||||
video_params = video_settings.get_video_params()
|
||||
audio_params = audio_settings.get_audio_params()
|
||||
subtitle_params = subtitle_settings.get_subtitle_params()
|
||||
|
||||
video_llm_api_key = config.app.get(f"{video_llm_provider}_api_key", "")
|
||||
video_llm_base_url = config.app.get(f"{video_llm_provider}_base_url", "")
|
||||
video_llm_model_name = config.app.get(f"{video_llm_provider}_model_name", "")
|
||||
video_llm_account_id = config.app.get(f"{video_llm_provider}_account_id", "")
|
||||
st_llm_api_key = st.text_input(tr("Video API Key"), value=video_llm_api_key, type="password")
|
||||
st_llm_base_url = st.text_input(tr("Video Base Url"), value=video_llm_base_url)
|
||||
st_llm_model_name = st.text_input(tr("Video Model Name"), value=video_llm_model_name)
|
||||
if st_llm_api_key:
|
||||
config.app[f"{video_llm_provider}_api_key"] = st_llm_api_key
|
||||
if st_llm_base_url:
|
||||
config.app[f"{video_llm_provider}_base_url"] = st_llm_base_url
|
||||
if st_llm_model_name:
|
||||
config.app[f"{video_llm_provider}_model_name"] = st_llm_model_name
|
||||
|
||||
# 大语言模型
|
||||
with right_config_panel:
|
||||
llm_providers = ['Gemini', 'OpenAI', 'Moonshot', 'Azure', 'Qwen', 'Ollama', 'G4f', 'OneAPI', "Cloudflare"]
|
||||
saved_llm_provider = config.app.get("llm_provider", "OpenAI").lower()
|
||||
saved_llm_provider_index = 0
|
||||
for i, provider in enumerate(llm_providers):
|
||||
if provider.lower() == saved_llm_provider:
|
||||
saved_llm_provider_index = i
|
||||
break
|
||||
|
||||
llm_provider = st.selectbox(tr("LLM Provider"), options=llm_providers, index=saved_llm_provider_index)
|
||||
llm_provider = llm_provider.lower()
|
||||
config.app["llm_provider"] = llm_provider
|
||||
|
||||
llm_api_key = config.app.get(f"{llm_provider}_api_key", "")
|
||||
llm_base_url = config.app.get(f"{llm_provider}_base_url", "")
|
||||
llm_model_name = config.app.get(f"{llm_provider}_model_name", "")
|
||||
llm_account_id = config.app.get(f"{llm_provider}_account_id", "")
|
||||
st_llm_api_key = st.text_input(tr("API Key"), value=llm_api_key, type="password")
|
||||
st_llm_base_url = st.text_input(tr("Base Url"), value=llm_base_url)
|
||||
st_llm_model_name = st.text_input(tr("Model Name"), value=llm_model_name)
|
||||
if st_llm_api_key:
|
||||
config.app[f"{llm_provider}_api_key"] = st_llm_api_key
|
||||
if st_llm_base_url:
|
||||
config.app[f"{llm_provider}_base_url"] = st_llm_base_url
|
||||
if st_llm_model_name:
|
||||
config.app[f"{llm_provider}_model_name"] = st_llm_model_name
|
||||
|
||||
if llm_provider == 'cloudflare':
|
||||
st_llm_account_id = st.text_input(tr("Account ID"), value=llm_account_id)
|
||||
if st_llm_account_id:
|
||||
config.app[f"{llm_provider}_account_id"] = st_llm_account_id
|
||||
|
||||
panel = st.columns(3)
|
||||
left_panel = panel[0]
|
||||
middle_panel = panel[1]
|
||||
right_panel = panel[2]
|
||||
|
||||
params = VideoClipParams()
|
||||
|
||||
# 左侧面板
|
||||
with left_panel:
|
||||
with st.container(border=True):
|
||||
st.write(tr("Video Script Configuration"))
|
||||
# 脚本语言
|
||||
video_languages = [
|
||||
(tr("Auto Detect"), ""),
|
||||
]
|
||||
for code in ["zh-CN", "en-US", "zh-TW"]:
|
||||
video_languages.append((code, code))
|
||||
|
||||
selected_index = st.selectbox(tr("Script Language"),
|
||||
index=0,
|
||||
options=range(len(video_languages)), # 使用索引作为内部选项值
|
||||
format_func=lambda x: video_languages[x][0] # 显示给用户的是标签
|
||||
)
|
||||
params.video_language = video_languages[selected_index][1]
|
||||
|
||||
# 脚本路径
|
||||
suffix = "*.json"
|
||||
song_dir = utils.script_dir()
|
||||
files = glob.glob(os.path.join(song_dir, suffix))
|
||||
script_list = []
|
||||
for file in files:
|
||||
script_list.append({
|
||||
"name": os.path.basename(file),
|
||||
"size": os.path.getsize(file),
|
||||
"file": file,
|
||||
"ctime": os.path.getctime(file) # 获取文件创建时间
|
||||
})
|
||||
|
||||
# 按创建时间降序排序
|
||||
script_list.sort(key=lambda x: x["ctime"], reverse=True)
|
||||
|
||||
# 本文件 下拉框
|
||||
script_path = [(tr("Auto Generate"), ""), ]
|
||||
for file in script_list:
|
||||
display_name = file['file'].replace(root_dir, "")
|
||||
script_path.append((display_name, file['file']))
|
||||
selected_script_index = st.selectbox(tr("Script Files"),
|
||||
index=0,
|
||||
options=range(len(script_path)), # 使用索引作为内部选项值
|
||||
format_func=lambda x: script_path[x][0] # 显示给用户的是标签
|
||||
)
|
||||
params.video_clip_json_path = script_path[selected_script_index][1]
|
||||
config.app["video_clip_json_path"] = params.video_clip_json_path
|
||||
st.session_state['video_clip_json_path'] = params.video_clip_json_path
|
||||
|
||||
# 视频文件处理
|
||||
video_files = []
|
||||
for suffix in ["*.mp4", "*.mov", "*.avi", "*.mkv"]:
|
||||
video_files.extend(glob.glob(os.path.join(utils.video_dir(), suffix)))
|
||||
video_files = video_files[::-1]
|
||||
|
||||
video_list = []
|
||||
for video_file in video_files:
|
||||
video_list.append({
|
||||
"name": os.path.basename(video_file),
|
||||
"size": os.path.getsize(video_file),
|
||||
"file": video_file,
|
||||
"ctime": os.path.getctime(video_file) # 获取文件创建时间
|
||||
})
|
||||
# 按创建时间降序排序
|
||||
video_list.sort(key=lambda x: x["ctime"], reverse=True)
|
||||
video_path = [(tr("None"), ""), (tr("Upload Local Files"), "local")]
|
||||
for file in video_list:
|
||||
display_name = file['file'].replace(root_dir, "")
|
||||
video_path.append((display_name, file['file']))
|
||||
|
||||
# 视频文件
|
||||
selected_video_index = st.selectbox(tr("Video File"),
|
||||
index=0,
|
||||
options=range(len(video_path)), # 使用索引作为内部选项值
|
||||
format_func=lambda x: video_path[x][0] # 显示给用户的是标签
|
||||
)
|
||||
params.video_origin_path = video_path[selected_video_index][1]
|
||||
config.app["video_origin_path"] = params.video_origin_path
|
||||
st.session_state['video_origin_path'] = params.video_origin_path
|
||||
|
||||
# 从本地上传 mp4 文件
|
||||
if params.video_origin_path == "local":
|
||||
_supported_types = FILE_TYPE_VIDEOS
|
||||
uploaded_file = st.file_uploader(
|
||||
tr("Upload Local Files"),
|
||||
type=["mp4", "mov", "avi", "flv", "mkv"],
|
||||
accept_multiple_files=False,
|
||||
)
|
||||
if uploaded_file is not None:
|
||||
# 构造保存路径
|
||||
video_file_path = os.path.join(utils.video_dir(), uploaded_file.name)
|
||||
file_name, file_extension = os.path.splitext(uploaded_file.name)
|
||||
# 检查文件是否存在,如果存在则添加时间戳
|
||||
if os.path.exists(video_file_path):
|
||||
timestamp = time.strftime("%Y%m%d%H%M%S")
|
||||
file_name_with_timestamp = f"{file_name}_{timestamp}"
|
||||
video_file_path = os.path.join(utils.video_dir(), file_name_with_timestamp + file_extension)
|
||||
# 将文件保存到指定目录
|
||||
with open(video_file_path, "wb") as f:
|
||||
f.write(uploaded_file.read())
|
||||
st.success(tr("File Uploaded Successfully"))
|
||||
time.sleep(1)
|
||||
st.rerun()
|
||||
# 视频名称
|
||||
video_name = st.text_input(tr("Video Name"))
|
||||
# 剧情内容
|
||||
video_plot = st.text_area(
|
||||
tr("Plot Description"),
|
||||
value=st.session_state['video_plot'],
|
||||
height=180
|
||||
)
|
||||
|
||||
# 生成视频脚本
|
||||
if st.session_state['video_clip_json_path']:
|
||||
generate_button_name = tr("Video Script Load")
|
||||
else:
|
||||
generate_button_name = tr("Video Script Generate")
|
||||
if st.button(generate_button_name, key="auto_generate_script"):
|
||||
progress_bar = st.progress(0)
|
||||
status_text = st.empty()
|
||||
|
||||
def update_progress(progress: float, message: str = ""):
|
||||
progress_bar.progress(progress)
|
||||
if message:
|
||||
status_text.text(f"{progress}% - {message}")
|
||||
else:
|
||||
status_text.text(f"进度: {progress}%")
|
||||
|
||||
try:
|
||||
with st.spinner("正在生成脚本..."):
|
||||
if not video_plot:
|
||||
st.warning("视频剧情为空; 会极大影响生成效果!")
|
||||
if params.video_clip_json_path == "" and params.video_origin_path != "":
|
||||
update_progress(10, "压缩视频中...")
|
||||
# 使用大模型生成视频脚本
|
||||
script = llm.generate_script(
|
||||
video_path=params.video_origin_path,
|
||||
video_plot=video_plot,
|
||||
video_name=video_name,
|
||||
language=params.video_language,
|
||||
progress_callback=update_progress
|
||||
)
|
||||
if script is None:
|
||||
st.error("生成脚本失败,请检查日志")
|
||||
st.stop()
|
||||
else:
|
||||
update_progress(90)
|
||||
|
||||
script = utils.clean_model_output(script)
|
||||
st.session_state['video_clip_json'] = json.loads(script)
|
||||
else:
|
||||
# 从本地加载
|
||||
with open(params.video_clip_json_path, 'r', encoding='utf-8') as f:
|
||||
update_progress(50)
|
||||
status_text.text("从本地加载中...")
|
||||
script = f.read()
|
||||
script = utils.clean_model_output(script)
|
||||
st.session_state['video_clip_json'] = json.loads(script)
|
||||
update_progress(100)
|
||||
status_text.text("从本地加载成功")
|
||||
|
||||
time.sleep(0.5) # 给进度条一点时间到达100%
|
||||
progress_bar.progress(100)
|
||||
status_text.text("脚本生成完成!")
|
||||
st.success("视频脚本生成成功!")
|
||||
except Exception as err:
|
||||
st.error(f"生成过程中发生错误: {str(err)}")
|
||||
finally:
|
||||
time.sleep(2) # 给用户一些时间查看最终状态
|
||||
progress_bar.empty()
|
||||
status_text.empty()
|
||||
|
||||
# 视频脚本
|
||||
video_clip_json_details = st.text_area(
|
||||
tr("Video Script"),
|
||||
value=json.dumps(st.session_state.video_clip_json, indent=2, ensure_ascii=False),
|
||||
height=180
|
||||
)
|
||||
|
||||
# 保存脚本
|
||||
button_columns = st.columns(2)
|
||||
with button_columns[0]:
|
||||
if st.button(tr("Save Script"), key="auto_generate_terms", use_container_width=True):
|
||||
if not video_clip_json_details:
|
||||
st.error(tr("请输入视频脚本"))
|
||||
st.stop()
|
||||
|
||||
with st.spinner(tr("Save Script")):
|
||||
script_dir = utils.script_dir()
|
||||
# 获取当前时间戳,形如 2024-0618-171820
|
||||
timestamp = datetime.datetime.now().strftime("%Y-%m%d-%H%M%S")
|
||||
save_path = os.path.join(script_dir, f"{timestamp}.json")
|
||||
|
||||
try:
|
||||
data = utils.add_new_timestamps(json.loads(video_clip_json_details))
|
||||
except Exception as err:
|
||||
st.error(f"视频脚本格式错误,请检查脚本是否符合 JSON 格式;{err} \n\n{traceback.format_exc()}")
|
||||
st.stop()
|
||||
|
||||
# 存储为新的 JSON 文件
|
||||
with open(save_path, 'w', encoding='utf-8') as file:
|
||||
json.dump(data, file, ensure_ascii=False, indent=4)
|
||||
# 将data的值存储到 session_state 中,类似缓存
|
||||
st.session_state['video_clip_json'] = data
|
||||
st.session_state['video_clip_json_path'] = save_path
|
||||
# 刷新页面
|
||||
st.rerun()
|
||||
|
||||
# 裁剪视频
|
||||
with button_columns[1]:
|
||||
if st.button(tr("Crop Video"), key="auto_crop_video", use_container_width=True):
|
||||
progress_bar = st.progress(0)
|
||||
status_text = st.empty()
|
||||
|
||||
def update_progress(progress):
|
||||
progress_bar.progress(progress)
|
||||
status_text.text(f"剪辑进度: {progress}%")
|
||||
|
||||
try:
|
||||
utils.cut_video(params, update_progress)
|
||||
time.sleep(0.5) # 给进度条一点时间到达100%
|
||||
progress_bar.progress(100)
|
||||
status_text.text("剪辑完成!")
|
||||
st.success("视频剪辑成功完成!")
|
||||
except Exception as e:
|
||||
st.error(f"剪辑过程中发生错误: {str(e)}")
|
||||
finally:
|
||||
time.sleep(2) # 给用户一些时间查看最终状态
|
||||
progress_bar.empty()
|
||||
status_text.empty()
|
||||
|
||||
# 新中间面板
|
||||
with middle_panel:
|
||||
with st.container(border=True):
|
||||
st.write(tr("Video Settings"))
|
||||
|
||||
# 视频比例
|
||||
video_aspect_ratios = [
|
||||
(tr("Portrait"), VideoAspect.portrait.value),
|
||||
(tr("Landscape"), VideoAspect.landscape.value),
|
||||
]
|
||||
selected_index = st.selectbox(
|
||||
tr("Video Ratio"),
|
||||
options=range(len(video_aspect_ratios)), # 使用索引作为内部选项值
|
||||
format_func=lambda x: video_aspect_ratios[x][0], # 显示给用户的是标签
|
||||
)
|
||||
params.video_aspect = VideoAspect(video_aspect_ratios[selected_index][1])
|
||||
|
||||
# params.video_clip_duration = st.selectbox(
|
||||
# tr("Clip Duration"), options=[2, 3, 4, 5, 6, 7, 8, 9, 10], index=1
|
||||
# )
|
||||
# params.video_count = st.selectbox(
|
||||
# tr("Number of Videos Generated Simultaneously"),
|
||||
# options=[1, 2, 3, 4, 5],
|
||||
# index=0,
|
||||
# )
|
||||
with st.container(border=True):
|
||||
st.write(tr("Audio Settings"))
|
||||
|
||||
# tts_providers = ['edge', 'azure']
|
||||
# tts_provider = st.selectbox(tr("TTS Provider"), tts_providers)
|
||||
|
||||
voices = voice.get_all_azure_voices(filter_locals=support_locales)
|
||||
friendly_names = {
|
||||
v: v.replace("Female", tr("Female"))
|
||||
.replace("Male", tr("Male"))
|
||||
.replace("Neural", "")
|
||||
for v in voices
|
||||
# 合并所有参数
|
||||
all_params = {
|
||||
**script_params,
|
||||
**video_params,
|
||||
**audio_params,
|
||||
**subtitle_params
|
||||
}
|
||||
saved_voice_name = config.ui.get("voice_name", "")
|
||||
saved_voice_name_index = 0
|
||||
if saved_voice_name in friendly_names:
|
||||
saved_voice_name_index = list(friendly_names.keys()).index(saved_voice_name)
|
||||
else:
|
||||
for i, v in enumerate(voices):
|
||||
if (
|
||||
v.lower().startswith(st.session_state["ui_language"].lower())
|
||||
and "V2" not in v
|
||||
):
|
||||
saved_voice_name_index = i
|
||||
break
|
||||
|
||||
selected_friendly_name = st.selectbox(
|
||||
tr("Speech Synthesis"),
|
||||
options=list(friendly_names.values()),
|
||||
index=saved_voice_name_index,
|
||||
# 创建参数对象
|
||||
params = VideoClipParams(**all_params)
|
||||
|
||||
result = tm.start_subclip(
|
||||
task_id=task_id,
|
||||
params=params,
|
||||
subclip_path_videos=st.session_state['subclip_videos']
|
||||
)
|
||||
|
||||
voice_name = list(friendly_names.keys())[
|
||||
list(friendly_names.values()).index(selected_friendly_name)
|
||||
]
|
||||
params.voice_name = voice_name
|
||||
config.ui["voice_name"] = voice_name
|
||||
video_files = result.get("videos", [])
|
||||
st.success(tr("视频生成完成"))
|
||||
|
||||
try:
|
||||
if video_files:
|
||||
player_cols = st.columns(len(video_files) * 2 + 1)
|
||||
for i, url in enumerate(video_files):
|
||||
player_cols[i * 2 + 1].video(url)
|
||||
except Exception as e:
|
||||
logger.error(f"播放视频失败: {e}")
|
||||
|
||||
if voice.is_azure_v2_voice(voice_name):
|
||||
saved_azure_speech_region = config.azure.get("speech_region", "")
|
||||
saved_azure_speech_key = config.azure.get("speech_key", "")
|
||||
azure_speech_region = st.text_input(
|
||||
tr("Speech Region"), value=saved_azure_speech_region
|
||||
)
|
||||
azure_speech_key = st.text_input(
|
||||
tr("Speech Key"), value=saved_azure_speech_key, type="password"
|
||||
)
|
||||
config.azure["speech_region"] = azure_speech_region
|
||||
config.azure["speech_key"] = azure_speech_key
|
||||
file_utils.open_task_folder(config.root_dir, task_id)
|
||||
logger.info(tr("视频生成完成"))
|
||||
|
||||
params.voice_volume = st.selectbox(
|
||||
tr("Speech Volume"),
|
||||
options=[0.6, 0.8, 1.0, 1.2, 1.5, 2.0, 3.0, 4.0, 5.0],
|
||||
index=2,
|
||||
)
|
||||
def main():
|
||||
"""主函数"""
|
||||
init_log()
|
||||
init_global_state()
|
||||
|
||||
st.title(f"NarratoAI :sunglasses:📽️")
|
||||
st.write(tr("Get Help"))
|
||||
|
||||
# 渲染基础设置面板
|
||||
basic_settings.render_basic_settings(tr)
|
||||
|
||||
# 渲染主面板
|
||||
panel = st.columns(3)
|
||||
with panel[0]:
|
||||
script_settings.render_script_panel(tr)
|
||||
with panel[1]:
|
||||
video_settings.render_video_panel(tr)
|
||||
audio_settings.render_audio_panel(tr)
|
||||
with panel[2]:
|
||||
subtitle_settings.render_subtitle_panel(tr)
|
||||
|
||||
# 渲染视频审查面板
|
||||
review_settings.render_review_panel(tr)
|
||||
|
||||
# 渲染生成按钮和处理逻辑
|
||||
render_generate_button()
|
||||
|
||||
params.voice_rate = st.selectbox(
|
||||
tr("Speech Rate"),
|
||||
options=[0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 1.8, 2.0],
|
||||
index=2,
|
||||
)
|
||||
|
||||
params.voice_pitch = st.selectbox(
|
||||
tr("Speech Pitch"),
|
||||
options=[0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 1.8, 2.0],
|
||||
index=2,
|
||||
)
|
||||
|
||||
# 试听语言合成
|
||||
if st.button(tr("Play Voice")):
|
||||
play_content = "感谢关注 NarratoAI,有任何问题或建议,可以关注微信公众号,求助或讨论"
|
||||
if not play_content:
|
||||
play_content = params.video_script
|
||||
if not play_content:
|
||||
play_content = tr("Voice Example")
|
||||
with st.spinner(tr("Synthesizing Voice")):
|
||||
temp_dir = utils.storage_dir("temp", create=True)
|
||||
audio_file = os.path.join(temp_dir, f"tmp-voice-{str(uuid4())}.mp3")
|
||||
sub_maker = voice.tts(
|
||||
text=play_content,
|
||||
voice_name=voice_name,
|
||||
voice_rate=params.voice_rate,
|
||||
voice_pitch=params.voice_pitch,
|
||||
voice_file=audio_file,
|
||||
)
|
||||
# 如果语音文件生成失败,请使用默认内容重试。
|
||||
if not sub_maker:
|
||||
play_content = "This is a example voice. if you hear this, the voice synthesis failed with the original content."
|
||||
sub_maker = voice.tts(
|
||||
text=play_content,
|
||||
voice_name=voice_name,
|
||||
voice_rate=params.voice_rate,
|
||||
voice_pitch=params.voice_pitch,
|
||||
voice_file=audio_file,
|
||||
)
|
||||
|
||||
if sub_maker and os.path.exists(audio_file):
|
||||
st.audio(audio_file, format="audio/mp3")
|
||||
if os.path.exists(audio_file):
|
||||
os.remove(audio_file)
|
||||
|
||||
bgm_options = [
|
||||
(tr("No Background Music"), ""),
|
||||
(tr("Random Background Music"), "random"),
|
||||
(tr("Custom Background Music"), "custom"),
|
||||
]
|
||||
selected_index = st.selectbox(
|
||||
tr("Background Music"),
|
||||
index=1,
|
||||
options=range(len(bgm_options)), # 使用索引作为内部选项值
|
||||
format_func=lambda x: bgm_options[x][0], # 显示给用户的是标签
|
||||
)
|
||||
# 获取选择的背景音乐类型
|
||||
params.bgm_type = bgm_options[selected_index][1]
|
||||
|
||||
# 根据选择显示或隐藏组件
|
||||
if params.bgm_type == "custom":
|
||||
custom_bgm_file = st.text_input(tr("Custom Background Music File"))
|
||||
if custom_bgm_file and os.path.exists(custom_bgm_file):
|
||||
params.bgm_file = custom_bgm_file
|
||||
# st.write(f":red[已选择自定义背景音乐]:**{custom_bgm_file}**")
|
||||
params.bgm_volume = st.selectbox(
|
||||
tr("Background Music Volume"),
|
||||
options=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
|
||||
index=2,
|
||||
)
|
||||
|
||||
# 新侧面板
|
||||
with right_panel:
|
||||
with st.container(border=True):
|
||||
st.write(tr("Subtitle Settings"))
|
||||
params.subtitle_enabled = st.checkbox(tr("Enable Subtitles"), value=True)
|
||||
font_names = get_all_fonts()
|
||||
saved_font_name = config.ui.get("font_name", "")
|
||||
saved_font_name_index = 0
|
||||
if saved_font_name in font_names:
|
||||
saved_font_name_index = font_names.index(saved_font_name)
|
||||
params.font_name = st.selectbox(
|
||||
tr("Font"), font_names, index=saved_font_name_index
|
||||
)
|
||||
config.ui["font_name"] = params.font_name
|
||||
|
||||
subtitle_positions = [
|
||||
(tr("Top"), "top"),
|
||||
(tr("Center"), "center"),
|
||||
(tr("Bottom"), "bottom"),
|
||||
(tr("Custom"), "custom"),
|
||||
]
|
||||
selected_index = st.selectbox(
|
||||
tr("Position"),
|
||||
index=2,
|
||||
options=range(len(subtitle_positions)),
|
||||
format_func=lambda x: subtitle_positions[x][0],
|
||||
)
|
||||
params.subtitle_position = subtitle_positions[selected_index][1]
|
||||
|
||||
if params.subtitle_position == "custom":
|
||||
custom_position = st.text_input(
|
||||
tr("Custom Position (% from top)"), value="70.0"
|
||||
)
|
||||
try:
|
||||
params.custom_position = float(custom_position)
|
||||
if params.custom_position < 0 or params.custom_position > 100:
|
||||
st.error(tr("Please enter a value between 0 and 100"))
|
||||
except ValueError:
|
||||
logger.error(f"输入的值无效: {traceback.format_exc()}")
|
||||
st.error(tr("Please enter a valid number"))
|
||||
|
||||
font_cols = st.columns([0.3, 0.7])
|
||||
with font_cols[0]:
|
||||
saved_text_fore_color = config.ui.get("text_fore_color", "#FFFFFF")
|
||||
params.text_fore_color = st.color_picker(
|
||||
tr("Font Color"), saved_text_fore_color
|
||||
)
|
||||
config.ui["text_fore_color"] = params.text_fore_color
|
||||
|
||||
with font_cols[1]:
|
||||
saved_font_size = config.ui.get("font_size", 60)
|
||||
params.font_size = st.slider(tr("Font Size"), 30, 100, saved_font_size)
|
||||
config.ui["font_size"] = params.font_size
|
||||
|
||||
stroke_cols = st.columns([0.3, 0.7])
|
||||
with stroke_cols[0]:
|
||||
params.stroke_color = st.color_picker(tr("Stroke Color"), "#000000")
|
||||
with stroke_cols[1]:
|
||||
params.stroke_width = st.slider(tr("Stroke Width"), 0.0, 10.0, 1.5)
|
||||
|
||||
# 视频编辑面板
|
||||
with st.expander(tr("Video Check"), expanded=False):
|
||||
try:
|
||||
video_list = st.session_state.video_clip_json
|
||||
except KeyError as e:
|
||||
video_list = []
|
||||
|
||||
# 计算列数和行数
|
||||
num_videos = len(video_list)
|
||||
cols_per_row = 3
|
||||
rows = (num_videos + cols_per_row - 1) // cols_per_row # 向上取整计算行数
|
||||
|
||||
# 使用容器展示视频
|
||||
for row in range(rows):
|
||||
cols = st.columns(cols_per_row)
|
||||
for col in range(cols_per_row):
|
||||
index = row * cols_per_row + col
|
||||
if index < num_videos:
|
||||
with cols[col]:
|
||||
video_info = video_list[index]
|
||||
video_path = video_info.get('path')
|
||||
if video_path is not None:
|
||||
initial_narration = video_info['narration']
|
||||
initial_picture = video_info['picture']
|
||||
initial_timestamp = video_info['timestamp']
|
||||
|
||||
with open(video_path, 'rb') as video_file:
|
||||
video_bytes = video_file.read()
|
||||
st.video(video_bytes)
|
||||
|
||||
# 可编辑的输入框
|
||||
text_panels = st.columns(2)
|
||||
with text_panels[0]:
|
||||
text1 = st.text_area(tr("timestamp"), value=initial_timestamp, height=20,
|
||||
key=f"timestamp_{index}")
|
||||
with text_panels[1]:
|
||||
text2 = st.text_area(tr("Picture description"), value=initial_picture, height=20,
|
||||
key=f"picture_{index}")
|
||||
text3 = st.text_area(tr("Narration"), value=initial_narration, height=100,
|
||||
key=f"narration_{index}")
|
||||
|
||||
# 重新生成按钮
|
||||
if st.button(tr("Rebuild"), key=f"rebuild_{index}"):
|
||||
# 更新video_list中的对应项
|
||||
video_list[index]['timestamp'] = text1
|
||||
video_list[index]['picture'] = text2
|
||||
video_list[index]['narration'] = text3
|
||||
|
||||
for video in video_list:
|
||||
if 'path' in video:
|
||||
del video['path']
|
||||
# 更新session_state以确保更改被保存
|
||||
st.session_state['video_clip_json'] = utils.to_json(video_list)
|
||||
# 替换原JSON 文件
|
||||
with open(params.video_clip_json_path, 'w', encoding='utf-8') as file:
|
||||
json.dump(video_list, file, ensure_ascii=False, indent=4)
|
||||
utils.cut_video(params, progress_callback=None)
|
||||
st.rerun()
|
||||
|
||||
# 开始按钮
|
||||
start_button = st.button(tr("Generate Video"), use_container_width=True, type="primary")
|
||||
if start_button:
|
||||
# 重置日志容器和记录
|
||||
log_container = st.empty()
|
||||
log_records = []
|
||||
|
||||
config.save_config()
|
||||
task_id = st.session_state.get('task_id')
|
||||
if st.session_state.get('video_script_json_path') is not None:
|
||||
params.video_clip_json = st.session_state.get('video_clip_json')
|
||||
|
||||
logger.debug(f"当前的脚本文件为:{st.session_state.video_clip_json_path}")
|
||||
logger.debug(f"当前的视频文件为:{st.session_state.video_origin_path}")
|
||||
logger.debug(f"裁剪后是视频列表:{st.session_state.subclip_videos}")
|
||||
|
||||
if not task_id:
|
||||
st.error(tr("请先裁剪视频"))
|
||||
scroll_to_bottom()
|
||||
st.stop()
|
||||
if not params.video_clip_json_path:
|
||||
st.error(tr("脚本文件不能为空"))
|
||||
scroll_to_bottom()
|
||||
st.stop()
|
||||
if not params.video_origin_path:
|
||||
st.error(tr("视频文件不能为空"))
|
||||
scroll_to_bottom()
|
||||
st.stop()
|
||||
|
||||
def log_received(msg):
|
||||
with log_container:
|
||||
log_records.append(msg)
|
||||
st.code("\n".join(log_records))
|
||||
|
||||
logger.add(log_received)
|
||||
|
||||
st.toast(tr("生成视频"))
|
||||
logger.info(tr("开始生成视频"))
|
||||
logger.info(utils.to_json(params))
|
||||
scroll_to_bottom()
|
||||
|
||||
result = tm.start_subclip(task_id=task_id, params=params, subclip_path_videos=st.session_state.subclip_videos)
|
||||
|
||||
video_files = result.get("videos", [])
|
||||
st.success(tr("视频生成完成"))
|
||||
try:
|
||||
if video_files:
|
||||
# 将视频播放器居中
|
||||
player_cols = st.columns(len(video_files) * 2 + 1)
|
||||
for i, url in enumerate(video_files):
|
||||
player_cols[i * 2 + 1].video(url)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
open_task_folder(task_id)
|
||||
logger.info(tr("视频生成完成"))
|
||||
scroll_to_bottom()
|
||||
|
||||
config.save_config()
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
22
webui/__init__.py
Normal file
22
webui/__init__.py
Normal file
@ -0,0 +1,22 @@
|
||||
"""
|
||||
NarratoAI WebUI Package
|
||||
"""
|
||||
from webui.config.settings import config
|
||||
from webui.components import (
|
||||
basic_settings,
|
||||
video_settings,
|
||||
audio_settings,
|
||||
subtitle_settings
|
||||
)
|
||||
from webui.utils import cache, file_utils, performance
|
||||
|
||||
__all__ = [
|
||||
'config',
|
||||
'basic_settings',
|
||||
'video_settings',
|
||||
'audio_settings',
|
||||
'subtitle_settings',
|
||||
'cache',
|
||||
'file_utils',
|
||||
'performance'
|
||||
]
|
||||
15
webui/components/__init__.py
Normal file
15
webui/components/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
from .basic_settings import render_basic_settings
|
||||
from .script_settings import render_script_panel
|
||||
from .video_settings import render_video_panel
|
||||
from .audio_settings import render_audio_panel
|
||||
from .subtitle_settings import render_subtitle_panel
|
||||
from .review_settings import render_review_panel
|
||||
|
||||
__all__ = [
|
||||
'render_basic_settings',
|
||||
'render_script_panel',
|
||||
'render_video_panel',
|
||||
'render_audio_panel',
|
||||
'render_subtitle_panel',
|
||||
'render_review_panel'
|
||||
]
|
||||
198
webui/components/audio_settings.py
Normal file
198
webui/components/audio_settings.py
Normal file
@ -0,0 +1,198 @@
|
||||
import streamlit as st
|
||||
import os
|
||||
from uuid import uuid4
|
||||
from app.config import config
|
||||
from app.services import voice
|
||||
from app.utils import utils
|
||||
from webui.utils.cache import get_songs_cache
|
||||
|
||||
def render_audio_panel(tr):
|
||||
"""渲染音频设置面板"""
|
||||
with st.container(border=True):
|
||||
st.write(tr("Audio Settings"))
|
||||
|
||||
# 渲染TTS设置
|
||||
render_tts_settings(tr)
|
||||
|
||||
# 渲染背景音乐设置
|
||||
render_bgm_settings(tr)
|
||||
|
||||
def render_tts_settings(tr):
|
||||
"""渲染TTS(文本转语音)设置"""
|
||||
# 获取支持的语音列表
|
||||
support_locales = ["zh-CN", "zh-HK", "zh-TW", "en-US"]
|
||||
voices = voice.get_all_azure_voices(filter_locals=support_locales)
|
||||
|
||||
# 创建友好的显示名称
|
||||
friendly_names = {
|
||||
v: v.replace("Female", tr("Female"))
|
||||
.replace("Male", tr("Male"))
|
||||
.replace("Neural", "")
|
||||
for v in voices
|
||||
}
|
||||
|
||||
# 获取保存的语音设置
|
||||
saved_voice_name = config.ui.get("voice_name", "")
|
||||
saved_voice_name_index = 0
|
||||
|
||||
if saved_voice_name in friendly_names:
|
||||
saved_voice_name_index = list(friendly_names.keys()).index(saved_voice_name)
|
||||
else:
|
||||
# 如果没有保存的设置,选择与UI语言匹配的第一个语音
|
||||
for i, v in enumerate(voices):
|
||||
if (v.lower().startswith(st.session_state["ui_language"].lower())
|
||||
and "V2" not in v):
|
||||
saved_voice_name_index = i
|
||||
break
|
||||
|
||||
# 语音选择下拉框
|
||||
selected_friendly_name = st.selectbox(
|
||||
tr("Speech Synthesis"),
|
||||
options=list(friendly_names.values()),
|
||||
index=saved_voice_name_index,
|
||||
)
|
||||
|
||||
# 获取实际的语音名称
|
||||
voice_name = list(friendly_names.keys())[
|
||||
list(friendly_names.values()).index(selected_friendly_name)
|
||||
]
|
||||
|
||||
# 保存设置
|
||||
config.ui["voice_name"] = voice_name
|
||||
|
||||
# Azure V2语音特殊处理
|
||||
if voice.is_azure_v2_voice(voice_name):
|
||||
render_azure_v2_settings(tr)
|
||||
|
||||
# 语音参数设置
|
||||
render_voice_parameters(tr)
|
||||
|
||||
# 试听按钮
|
||||
render_voice_preview(tr, voice_name)
|
||||
|
||||
def render_azure_v2_settings(tr):
|
||||
"""渲染Azure V2语音设置"""
|
||||
saved_azure_speech_region = config.azure.get("speech_region", "")
|
||||
saved_azure_speech_key = config.azure.get("speech_key", "")
|
||||
|
||||
azure_speech_region = st.text_input(
|
||||
tr("Speech Region"),
|
||||
value=saved_azure_speech_region
|
||||
)
|
||||
azure_speech_key = st.text_input(
|
||||
tr("Speech Key"),
|
||||
value=saved_azure_speech_key,
|
||||
type="password"
|
||||
)
|
||||
|
||||
config.azure["speech_region"] = azure_speech_region
|
||||
config.azure["speech_key"] = azure_speech_key
|
||||
|
||||
def render_voice_parameters(tr):
|
||||
"""渲染语音参数设置"""
|
||||
# 音量
|
||||
voice_volume = st.selectbox(
|
||||
tr("Speech Volume"),
|
||||
options=[0.6, 0.8, 1.0, 1.2, 1.5, 2.0, 3.0, 4.0, 5.0],
|
||||
index=2,
|
||||
)
|
||||
st.session_state['voice_volume'] = voice_volume
|
||||
|
||||
# 语速
|
||||
voice_rate = st.selectbox(
|
||||
tr("Speech Rate"),
|
||||
options=[0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 1.8, 2.0],
|
||||
index=2,
|
||||
)
|
||||
st.session_state['voice_rate'] = voice_rate
|
||||
|
||||
# 音调
|
||||
voice_pitch = st.selectbox(
|
||||
tr("Speech Pitch"),
|
||||
options=[0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 1.8, 2.0],
|
||||
index=2,
|
||||
)
|
||||
st.session_state['voice_pitch'] = voice_pitch
|
||||
|
||||
def render_voice_preview(tr, voice_name):
|
||||
"""渲染语音试听功能"""
|
||||
if st.button(tr("Play Voice")):
|
||||
play_content = "感谢关注 NarratoAI,有任何问题或建议,可以关注微信公众号,求助或讨论"
|
||||
if not play_content:
|
||||
play_content = st.session_state.get('video_script', '')
|
||||
if not play_content:
|
||||
play_content = tr("Voice Example")
|
||||
|
||||
with st.spinner(tr("Synthesizing Voice")):
|
||||
temp_dir = utils.storage_dir("temp", create=True)
|
||||
audio_file = os.path.join(temp_dir, f"tmp-voice-{str(uuid4())}.mp3")
|
||||
|
||||
sub_maker = voice.tts(
|
||||
text=play_content,
|
||||
voice_name=voice_name,
|
||||
voice_rate=st.session_state.get('voice_rate', 1.0),
|
||||
voice_pitch=st.session_state.get('voice_pitch', 1.0),
|
||||
voice_file=audio_file,
|
||||
)
|
||||
|
||||
# 如果语音文件生成失败,使用默认内容重试
|
||||
if not sub_maker:
|
||||
play_content = "This is a example voice. if you hear this, the voice synthesis failed with the original content."
|
||||
sub_maker = voice.tts(
|
||||
text=play_content,
|
||||
voice_name=voice_name,
|
||||
voice_rate=st.session_state.get('voice_rate', 1.0),
|
||||
voice_pitch=st.session_state.get('voice_pitch', 1.0),
|
||||
voice_file=audio_file,
|
||||
)
|
||||
|
||||
if sub_maker and os.path.exists(audio_file):
|
||||
st.audio(audio_file, format="audio/mp3")
|
||||
if os.path.exists(audio_file):
|
||||
os.remove(audio_file)
|
||||
|
||||
def render_bgm_settings(tr):
|
||||
"""渲染背景音乐设置"""
|
||||
# 背景音乐选项
|
||||
bgm_options = [
|
||||
(tr("No Background Music"), ""),
|
||||
(tr("Random Background Music"), "random"),
|
||||
(tr("Custom Background Music"), "custom"),
|
||||
]
|
||||
|
||||
selected_index = st.selectbox(
|
||||
tr("Background Music"),
|
||||
index=1,
|
||||
options=range(len(bgm_options)),
|
||||
format_func=lambda x: bgm_options[x][0],
|
||||
)
|
||||
|
||||
# 获取选择的背景音乐类型
|
||||
bgm_type = bgm_options[selected_index][1]
|
||||
st.session_state['bgm_type'] = bgm_type
|
||||
|
||||
# 自定义背景音乐处理
|
||||
if bgm_type == "custom":
|
||||
custom_bgm_file = st.text_input(tr("Custom Background Music File"))
|
||||
if custom_bgm_file and os.path.exists(custom_bgm_file):
|
||||
st.session_state['bgm_file'] = custom_bgm_file
|
||||
|
||||
# 背景音乐音量
|
||||
bgm_volume = st.selectbox(
|
||||
tr("Background Music Volume"),
|
||||
options=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
|
||||
index=2,
|
||||
)
|
||||
st.session_state['bgm_volume'] = bgm_volume
|
||||
|
||||
def get_audio_params():
|
||||
"""获取音频参数"""
|
||||
return {
|
||||
'voice_name': config.ui.get("voice_name", ""),
|
||||
'voice_volume': st.session_state.get('voice_volume', 1.0),
|
||||
'voice_rate': st.session_state.get('voice_rate', 1.0),
|
||||
'voice_pitch': st.session_state.get('voice_pitch', 1.0),
|
||||
'bgm_type': st.session_state.get('bgm_type', 'random'),
|
||||
'bgm_file': st.session_state.get('bgm_file', ''),
|
||||
'bgm_volume': st.session_state.get('bgm_volume', 0.2),
|
||||
}
|
||||
142
webui/components/basic_settings.py
Normal file
142
webui/components/basic_settings.py
Normal file
@ -0,0 +1,142 @@
|
||||
import streamlit as st
|
||||
import os
|
||||
from app.config import config
|
||||
from app.utils import utils
|
||||
|
||||
def render_basic_settings(tr):
|
||||
"""渲染基础设置面板"""
|
||||
with st.expander(tr("Basic Settings"), expanded=False):
|
||||
config_panels = st.columns(3)
|
||||
left_config_panel = config_panels[0]
|
||||
middle_config_panel = config_panels[1]
|
||||
right_config_panel = config_panels[2]
|
||||
|
||||
with left_config_panel:
|
||||
render_language_settings(tr)
|
||||
render_proxy_settings(tr)
|
||||
|
||||
with middle_config_panel:
|
||||
render_video_llm_settings(tr)
|
||||
|
||||
with right_config_panel:
|
||||
render_llm_settings(tr)
|
||||
|
||||
def render_language_settings(tr):
|
||||
"""渲染语言设置"""
|
||||
system_locale = utils.get_system_locale()
|
||||
i18n_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "i18n")
|
||||
locales = utils.load_locales(i18n_dir)
|
||||
|
||||
display_languages = []
|
||||
selected_index = 0
|
||||
for i, code in enumerate(locales.keys()):
|
||||
display_languages.append(f"{code} - {locales[code].get('Language')}")
|
||||
if code == st.session_state.get('ui_language', system_locale):
|
||||
selected_index = i
|
||||
|
||||
selected_language = st.selectbox(
|
||||
tr("Language"),
|
||||
options=display_languages,
|
||||
index=selected_index
|
||||
)
|
||||
|
||||
if selected_language:
|
||||
code = selected_language.split(" - ")[0].strip()
|
||||
st.session_state['ui_language'] = code
|
||||
config.ui['language'] = code
|
||||
|
||||
def render_proxy_settings(tr):
|
||||
"""渲染代理设置"""
|
||||
proxy_url_http = config.proxy.get("http", "") or os.getenv("VPN_PROXY_URL", "")
|
||||
proxy_url_https = config.proxy.get("https", "") or os.getenv("VPN_PROXY_URL", "")
|
||||
|
||||
HTTP_PROXY = st.text_input(tr("HTTP_PROXY"), value=proxy_url_http)
|
||||
HTTPS_PROXY = st.text_input(tr("HTTPs_PROXY"), value=proxy_url_https)
|
||||
|
||||
if HTTP_PROXY:
|
||||
config.proxy["http"] = HTTP_PROXY
|
||||
os.environ["HTTP_PROXY"] = HTTP_PROXY
|
||||
if HTTPS_PROXY:
|
||||
config.proxy["https"] = HTTPS_PROXY
|
||||
os.environ["HTTPS_PROXY"] = HTTPS_PROXY
|
||||
|
||||
def render_video_llm_settings(tr):
|
||||
"""渲染视频LLM设置"""
|
||||
video_llm_providers = ['Gemini', 'NarratoAPI']
|
||||
saved_llm_provider = config.app.get("video_llm_provider", "OpenAI").lower()
|
||||
saved_llm_provider_index = 0
|
||||
|
||||
for i, provider in enumerate(video_llm_providers):
|
||||
if provider.lower() == saved_llm_provider:
|
||||
saved_llm_provider_index = i
|
||||
break
|
||||
|
||||
video_llm_provider = st.selectbox(
|
||||
tr("Video LLM Provider"),
|
||||
options=video_llm_providers,
|
||||
index=saved_llm_provider_index
|
||||
)
|
||||
video_llm_provider = video_llm_provider.lower()
|
||||
config.app["video_llm_provider"] = video_llm_provider
|
||||
|
||||
# 获取已保存的配置
|
||||
video_llm_api_key = config.app.get(f"{video_llm_provider}_api_key", "")
|
||||
video_llm_base_url = config.app.get(f"{video_llm_provider}_base_url", "")
|
||||
video_llm_model_name = config.app.get(f"{video_llm_provider}_model_name", "")
|
||||
|
||||
# 渲染输入框
|
||||
st_llm_api_key = st.text_input(tr("Video API Key"), value=video_llm_api_key, type="password")
|
||||
st_llm_base_url = st.text_input(tr("Video Base Url"), value=video_llm_base_url)
|
||||
st_llm_model_name = st.text_input(tr("Video Model Name"), value=video_llm_model_name)
|
||||
|
||||
# 保存配置
|
||||
if st_llm_api_key:
|
||||
config.app[f"{video_llm_provider}_api_key"] = st_llm_api_key
|
||||
if st_llm_base_url:
|
||||
config.app[f"{video_llm_provider}_base_url"] = st_llm_base_url
|
||||
if st_llm_model_name:
|
||||
config.app[f"{video_llm_provider}_model_name"] = st_llm_model_name
|
||||
|
||||
def render_llm_settings(tr):
|
||||
"""渲染LLM设置"""
|
||||
llm_providers = ['Gemini', 'OpenAI', 'Moonshot', 'Azure', 'Qwen', 'Ollama', 'G4f', 'OneAPI', "Cloudflare"]
|
||||
saved_llm_provider = config.app.get("llm_provider", "OpenAI").lower()
|
||||
saved_llm_provider_index = 0
|
||||
|
||||
for i, provider in enumerate(llm_providers):
|
||||
if provider.lower() == saved_llm_provider:
|
||||
saved_llm_provider_index = i
|
||||
break
|
||||
|
||||
llm_provider = st.selectbox(
|
||||
tr("LLM Provider"),
|
||||
options=llm_providers,
|
||||
index=saved_llm_provider_index
|
||||
)
|
||||
llm_provider = llm_provider.lower()
|
||||
config.app["llm_provider"] = llm_provider
|
||||
|
||||
# 获取已保存的配置
|
||||
llm_api_key = config.app.get(f"{llm_provider}_api_key", "")
|
||||
llm_base_url = config.app.get(f"{llm_provider}_base_url", "")
|
||||
llm_model_name = config.app.get(f"{llm_provider}_model_name", "")
|
||||
llm_account_id = config.app.get(f"{llm_provider}_account_id", "")
|
||||
|
||||
# 渲染输入框
|
||||
st_llm_api_key = st.text_input(tr("API Key"), value=llm_api_key, type="password")
|
||||
st_llm_base_url = st.text_input(tr("Base Url"), value=llm_base_url)
|
||||
st_llm_model_name = st.text_input(tr("Model Name"), value=llm_model_name)
|
||||
|
||||
# 保存配置
|
||||
if st_llm_api_key:
|
||||
config.app[f"{llm_provider}_api_key"] = st_llm_api_key
|
||||
if st_llm_base_url:
|
||||
config.app[f"{llm_provider}_base_url"] = st_llm_base_url
|
||||
if st_llm_model_name:
|
||||
config.app[f"{llm_provider}_model_name"] = st_llm_model_name
|
||||
|
||||
# Cloudflare 特殊处理
|
||||
if llm_provider == 'cloudflare':
|
||||
st_llm_account_id = st.text_input(tr("Account ID"), value=llm_account_id)
|
||||
if st_llm_account_id:
|
||||
config.app[f"{llm_provider}_account_id"] = st_llm_account_id
|
||||
65
webui/components/review_settings.py
Normal file
65
webui/components/review_settings.py
Normal file
@ -0,0 +1,65 @@
|
||||
import streamlit as st
|
||||
import os
|
||||
from loguru import logger
|
||||
|
||||
def render_review_panel(tr):
|
||||
"""渲染视频审查面板"""
|
||||
with st.expander(tr("Video Check"), expanded=False):
|
||||
try:
|
||||
video_list = st.session_state.get('video_clip_json', [])
|
||||
except KeyError:
|
||||
video_list = []
|
||||
|
||||
# 计算列数和行数
|
||||
num_videos = len(video_list)
|
||||
cols_per_row = 3
|
||||
rows = (num_videos + cols_per_row - 1) // cols_per_row # 向上取整计算行数
|
||||
|
||||
# 使用容器展示视频
|
||||
for row in range(rows):
|
||||
cols = st.columns(cols_per_row)
|
||||
for col in range(cols_per_row):
|
||||
index = row * cols_per_row + col
|
||||
if index < num_videos:
|
||||
with cols[col]:
|
||||
render_video_item(tr, video_list, index)
|
||||
|
||||
def render_video_item(tr, video_list, index):
|
||||
"""渲染单个视频项"""
|
||||
video_info = video_list[index]
|
||||
video_path = video_info.get('path')
|
||||
if video_path is not None and os.path.exists(video_path):
|
||||
initial_narration = video_info.get('narration', '')
|
||||
initial_picture = video_info.get('picture', '')
|
||||
initial_timestamp = video_info.get('timestamp', '')
|
||||
|
||||
# 显示视频
|
||||
with open(video_path, 'rb') as video_file:
|
||||
video_bytes = video_file.read()
|
||||
st.video(video_bytes)
|
||||
|
||||
# 显示信息(只读)
|
||||
text_panels = st.columns(2)
|
||||
with text_panels[0]:
|
||||
st.text_area(
|
||||
tr("timestamp"),
|
||||
value=initial_timestamp,
|
||||
height=20,
|
||||
key=f"timestamp_{index}",
|
||||
disabled=True
|
||||
)
|
||||
with text_panels[1]:
|
||||
st.text_area(
|
||||
tr("Picture description"),
|
||||
value=initial_picture,
|
||||
height=20,
|
||||
key=f"picture_{index}",
|
||||
disabled=True
|
||||
)
|
||||
st.text_area(
|
||||
tr("Narration"),
|
||||
value=initial_narration,
|
||||
height=100,
|
||||
key=f"narration_{index}",
|
||||
disabled=True
|
||||
)
|
||||
314
webui/components/script_settings.py
Normal file
314
webui/components/script_settings.py
Normal file
@ -0,0 +1,314 @@
|
||||
import streamlit as st
|
||||
import os
|
||||
import glob
|
||||
import json
|
||||
import time
|
||||
from app.config import config
|
||||
from app.models.schema import VideoClipParams
|
||||
from app.services import llm
|
||||
from app.utils import utils, check_script
|
||||
from loguru import logger
|
||||
from webui.utils import file_utils
|
||||
|
||||
def render_script_panel(tr):
|
||||
"""渲染脚本配置面板"""
|
||||
with st.container(border=True):
|
||||
st.write(tr("Video Script Configuration"))
|
||||
params = VideoClipParams()
|
||||
|
||||
# 渲染脚本文件选择
|
||||
render_script_file(tr, params)
|
||||
|
||||
# 渲染视频文件选择
|
||||
render_video_file(tr, params)
|
||||
|
||||
# 渲染视频主题和提示词
|
||||
render_video_details(tr)
|
||||
|
||||
# 渲染脚本操作按钮
|
||||
render_script_buttons(tr, params)
|
||||
|
||||
def render_script_file(tr, params):
|
||||
"""渲染脚本文件选择"""
|
||||
script_list = [(tr("None"), ""), (tr("Auto Generate"), "auto")]
|
||||
|
||||
# 获取已有脚本文件
|
||||
suffix = "*.json"
|
||||
script_dir = utils.script_dir()
|
||||
files = glob.glob(os.path.join(script_dir, suffix))
|
||||
file_list = []
|
||||
|
||||
for file in files:
|
||||
file_list.append({
|
||||
"name": os.path.basename(file),
|
||||
"file": file,
|
||||
"ctime": os.path.getctime(file)
|
||||
})
|
||||
|
||||
file_list.sort(key=lambda x: x["ctime"], reverse=True)
|
||||
for file in file_list:
|
||||
display_name = file['file'].replace(config.root_dir, "")
|
||||
script_list.append((display_name, file['file']))
|
||||
|
||||
# 找到保存的脚本文件在列表中的索引
|
||||
saved_script_path = st.session_state.get('video_clip_json_path', '')
|
||||
selected_index = 0
|
||||
for i, (_, path) in enumerate(script_list):
|
||||
if path == saved_script_path:
|
||||
selected_index = i
|
||||
break
|
||||
|
||||
selected_script_index = st.selectbox(
|
||||
tr("Script Files"),
|
||||
index=selected_index, # 使用找到的索引
|
||||
options=range(len(script_list)),
|
||||
format_func=lambda x: script_list[x][0]
|
||||
)
|
||||
|
||||
script_path = script_list[selected_script_index][1]
|
||||
st.session_state['video_clip_json_path'] = script_path
|
||||
params.video_clip_json_path = script_path
|
||||
|
||||
def render_video_file(tr, params):
|
||||
"""渲染视频文件选择"""
|
||||
video_list = [(tr("None"), ""), (tr("Upload Local Files"), "local")]
|
||||
|
||||
# 获取已有视频文件
|
||||
for suffix in ["*.mp4", "*.mov", "*.avi", "*.mkv"]:
|
||||
video_files = glob.glob(os.path.join(utils.video_dir(), suffix))
|
||||
for file in video_files:
|
||||
display_name = file.replace(config.root_dir, "")
|
||||
video_list.append((display_name, file))
|
||||
|
||||
selected_video_index = st.selectbox(
|
||||
tr("Video File"),
|
||||
index=0,
|
||||
options=range(len(video_list)),
|
||||
format_func=lambda x: video_list[x][0]
|
||||
)
|
||||
|
||||
video_path = video_list[selected_video_index][1]
|
||||
st.session_state['video_origin_path'] = video_path
|
||||
params.video_origin_path = video_path
|
||||
|
||||
if video_path == "local":
|
||||
uploaded_file = st.file_uploader(
|
||||
tr("Upload Local Files"),
|
||||
type=["mp4", "mov", "avi", "flv", "mkv"],
|
||||
accept_multiple_files=False,
|
||||
)
|
||||
|
||||
if uploaded_file is not None:
|
||||
video_file_path = os.path.join(utils.video_dir(), uploaded_file.name)
|
||||
file_name, file_extension = os.path.splitext(uploaded_file.name)
|
||||
|
||||
if os.path.exists(video_file_path):
|
||||
timestamp = time.strftime("%Y%m%d%H%M%S")
|
||||
file_name_with_timestamp = f"{file_name}_{timestamp}"
|
||||
video_file_path = os.path.join(utils.video_dir(), file_name_with_timestamp + file_extension)
|
||||
|
||||
with open(video_file_path, "wb") as f:
|
||||
f.write(uploaded_file.read())
|
||||
st.success(tr("File Uploaded Successfully"))
|
||||
st.session_state['video_origin_path'] = video_file_path
|
||||
params.video_origin_path = video_file_path
|
||||
time.sleep(1)
|
||||
st.rerun()
|
||||
|
||||
def render_video_details(tr):
|
||||
"""渲染视频主题和提示词"""
|
||||
video_theme = st.text_input(tr("Video Theme"))
|
||||
prompt = st.text_area(
|
||||
tr("Generation Prompt"),
|
||||
value=st.session_state.get('video_plot', ''),
|
||||
help=tr("Custom prompt for LLM, leave empty to use default prompt"),
|
||||
height=180
|
||||
)
|
||||
st.session_state['video_name'] = video_theme
|
||||
st.session_state['video_plot'] = prompt
|
||||
return video_theme, prompt
|
||||
|
||||
def render_script_buttons(tr, params):
|
||||
"""渲染脚本操作按钮"""
|
||||
# 生成/加载按钮
|
||||
script_path = st.session_state.get('video_clip_json_path', '')
|
||||
if script_path == "auto":
|
||||
button_name = tr("Generate Video Script")
|
||||
elif script_path:
|
||||
button_name = tr("Load Video Script")
|
||||
else:
|
||||
button_name = tr("Please Select Script File")
|
||||
|
||||
if st.button(button_name, key="script_action", disabled=not script_path):
|
||||
if script_path == "auto":
|
||||
generate_script(tr, params)
|
||||
else:
|
||||
load_script(tr, script_path)
|
||||
|
||||
# 视频脚本编辑区
|
||||
video_clip_json_details = st.text_area(
|
||||
tr("Video Script"),
|
||||
value=json.dumps(st.session_state.get('video_clip_json', []), indent=2, ensure_ascii=False),
|
||||
height=180
|
||||
)
|
||||
|
||||
# 操作按钮行
|
||||
button_cols = st.columns(3)
|
||||
with button_cols[0]:
|
||||
if st.button(tr("Check Format"), key="check_format", use_container_width=True):
|
||||
check_script_format(tr, video_clip_json_details)
|
||||
|
||||
with button_cols[1]:
|
||||
if st.button(tr("Save Script"), key="save_script", use_container_width=True):
|
||||
save_script(tr, video_clip_json_details)
|
||||
|
||||
with button_cols[2]:
|
||||
script_valid = st.session_state.get('script_format_valid', False)
|
||||
if st.button(tr("Crop Video"), key="crop_video", disabled=not script_valid, use_container_width=True):
|
||||
crop_video(tr, params)
|
||||
|
||||
def check_script_format(tr, script_content):
|
||||
"""检查脚本格式"""
|
||||
try:
|
||||
result = check_script.check_format(script_content)
|
||||
if result.get('success'):
|
||||
st.success(tr("Script format check passed"))
|
||||
st.session_state['script_format_valid'] = True
|
||||
else:
|
||||
st.error(f"{tr('Script format check failed')}: {result.get('message')}")
|
||||
st.session_state['script_format_valid'] = False
|
||||
except Exception as e:
|
||||
st.error(f"{tr('Script format check error')}: {str(e)}")
|
||||
st.session_state['script_format_valid'] = False
|
||||
|
||||
def load_script(tr, script_path):
|
||||
"""加载脚本文件"""
|
||||
try:
|
||||
with open(script_path, 'r', encoding='utf-8') as f:
|
||||
script = f.read()
|
||||
script = utils.clean_model_output(script)
|
||||
st.session_state['video_clip_json'] = json.loads(script)
|
||||
st.success(tr("Script loaded successfully"))
|
||||
st.rerun()
|
||||
except Exception as e:
|
||||
st.error(f"{tr('Failed to load script')}: {str(e)}")
|
||||
|
||||
def generate_script(tr, params):
|
||||
"""生成视频脚本"""
|
||||
progress_bar = st.progress(0)
|
||||
status_text = st.empty()
|
||||
|
||||
def update_progress(progress: float, message: str = ""):
|
||||
progress_bar.progress(progress)
|
||||
if message:
|
||||
status_text.text(f"{progress}% - {message}")
|
||||
else:
|
||||
status_text.text(f"进度: {progress}%")
|
||||
|
||||
try:
|
||||
with st.spinner("正在生成脚本..."):
|
||||
if not st.session_state.get('video_plot'):
|
||||
st.warning("视频剧情为空; 会极大影响生成效果!")
|
||||
|
||||
if params.video_clip_json_path == "" and params.video_origin_path != "":
|
||||
update_progress(10, "压缩视频中...")
|
||||
script = llm.generate_script(
|
||||
video_path=params.video_origin_path,
|
||||
video_plot=st.session_state.get('video_plot', ''),
|
||||
video_name=st.session_state.get('video_name', ''),
|
||||
language=params.video_language,
|
||||
progress_callback=update_progress
|
||||
)
|
||||
if script is None:
|
||||
st.error("生成脚本失败,请检查日志")
|
||||
st.stop()
|
||||
else:
|
||||
update_progress(90)
|
||||
|
||||
script = utils.clean_model_output(script)
|
||||
st.session_state['video_clip_json'] = json.loads(script)
|
||||
else:
|
||||
# 从本地加载
|
||||
with open(params.video_clip_json_path, 'r', encoding='utf-8') as f:
|
||||
update_progress(50)
|
||||
status_text.text("从本地加载中...")
|
||||
script = f.read()
|
||||
script = utils.clean_model_output(script)
|
||||
st.session_state['video_clip_json'] = json.loads(script)
|
||||
update_progress(100)
|
||||
status_text.text("从本地加载成功")
|
||||
|
||||
time.sleep(0.5)
|
||||
progress_bar.progress(100)
|
||||
status_text.text("脚本生成完成!")
|
||||
st.success("视频脚本生成成功!")
|
||||
except Exception as err:
|
||||
st.error(f"生成过程中发生错误: {str(err)}")
|
||||
finally:
|
||||
time.sleep(2)
|
||||
progress_bar.empty()
|
||||
status_text.empty()
|
||||
|
||||
def save_script(tr, video_clip_json_details):
|
||||
"""保存视频脚本"""
|
||||
if not video_clip_json_details:
|
||||
st.error(tr("请输入视频脚本"))
|
||||
st.stop()
|
||||
|
||||
with st.spinner(tr("Save Script")):
|
||||
script_dir = utils.script_dir()
|
||||
timestamp = time.strftime("%Y-%m%d-%H%M%S")
|
||||
save_path = os.path.join(script_dir, f"{timestamp}.json")
|
||||
|
||||
try:
|
||||
data = json.loads(video_clip_json_details)
|
||||
with open(save_path, 'w', encoding='utf-8') as file:
|
||||
json.dump(data, file, ensure_ascii=False, indent=4)
|
||||
st.session_state['video_clip_json'] = data
|
||||
st.session_state['video_clip_json_path'] = save_path
|
||||
|
||||
# 更新配置
|
||||
config.app["video_clip_json_path"] = save_path
|
||||
|
||||
# 显示成功消息
|
||||
st.success(tr("Script saved successfully"))
|
||||
|
||||
# 强制重新加载页面以更新选择框
|
||||
time.sleep(0.5) # 给一点时间让用户看到成功消息
|
||||
st.rerun()
|
||||
|
||||
except Exception as err:
|
||||
st.error(f"{tr('Failed to save script')}: {str(err)}")
|
||||
st.stop()
|
||||
|
||||
def crop_video(tr, params):
|
||||
"""裁剪视频"""
|
||||
progress_bar = st.progress(0)
|
||||
status_text = st.empty()
|
||||
|
||||
def update_progress(progress):
|
||||
progress_bar.progress(progress)
|
||||
status_text.text(f"剪辑进度: {progress}%")
|
||||
|
||||
try:
|
||||
utils.cut_video(params, update_progress)
|
||||
time.sleep(0.5)
|
||||
progress_bar.progress(100)
|
||||
status_text.text("剪辑完成!")
|
||||
st.success("视频剪辑成功完成!")
|
||||
except Exception as e:
|
||||
st.error(f"剪辑过程中发生错误: {str(e)}")
|
||||
finally:
|
||||
time.sleep(2)
|
||||
progress_bar.empty()
|
||||
status_text.empty()
|
||||
|
||||
def get_script_params():
|
||||
"""获取脚本参数"""
|
||||
return {
|
||||
'video_language': st.session_state.get('video_language', ''),
|
||||
'video_clip_json_path': st.session_state.get('video_clip_json_path', ''),
|
||||
'video_origin_path': st.session_state.get('video_origin_path', ''),
|
||||
'video_name': st.session_state.get('video_name', ''),
|
||||
'video_plot': st.session_state.get('video_plot', '')
|
||||
}
|
||||
129
webui/components/subtitle_settings.py
Normal file
129
webui/components/subtitle_settings.py
Normal file
@ -0,0 +1,129 @@
|
||||
import streamlit as st
|
||||
from app.config import config
|
||||
from webui.utils.cache import get_fonts_cache
|
||||
import os
|
||||
|
||||
def render_subtitle_panel(tr):
|
||||
"""渲染字幕设置面板"""
|
||||
with st.container(border=True):
|
||||
st.write(tr("Subtitle Settings"))
|
||||
|
||||
# 启用字幕选项
|
||||
enable_subtitles = st.checkbox(tr("Enable Subtitles"), value=True)
|
||||
st.session_state['subtitle_enabled'] = enable_subtitles
|
||||
|
||||
if enable_subtitles:
|
||||
render_font_settings(tr)
|
||||
render_position_settings(tr)
|
||||
render_style_settings(tr)
|
||||
|
||||
def render_font_settings(tr):
|
||||
"""渲染字体设置"""
|
||||
# 获取字体列表
|
||||
font_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "resource", "fonts")
|
||||
font_names = get_fonts_cache(font_dir)
|
||||
|
||||
# 获取保存的字体设置
|
||||
saved_font_name = config.ui.get("font_name", "")
|
||||
saved_font_name_index = 0
|
||||
if saved_font_name in font_names:
|
||||
saved_font_name_index = font_names.index(saved_font_name)
|
||||
|
||||
# 字体选择
|
||||
font_name = st.selectbox(
|
||||
tr("Font"),
|
||||
options=font_names,
|
||||
index=saved_font_name_index
|
||||
)
|
||||
config.ui["font_name"] = font_name
|
||||
st.session_state['font_name'] = font_name
|
||||
|
||||
# 字体大小
|
||||
font_cols = st.columns([0.3, 0.7])
|
||||
with font_cols[0]:
|
||||
saved_text_fore_color = config.ui.get("text_fore_color", "#FFFFFF")
|
||||
text_fore_color = st.color_picker(
|
||||
tr("Font Color"),
|
||||
saved_text_fore_color
|
||||
)
|
||||
config.ui["text_fore_color"] = text_fore_color
|
||||
st.session_state['text_fore_color'] = text_fore_color
|
||||
|
||||
with font_cols[1]:
|
||||
saved_font_size = config.ui.get("font_size", 60)
|
||||
font_size = st.slider(
|
||||
tr("Font Size"),
|
||||
min_value=30,
|
||||
max_value=100,
|
||||
value=saved_font_size
|
||||
)
|
||||
config.ui["font_size"] = font_size
|
||||
st.session_state['font_size'] = font_size
|
||||
|
||||
def render_position_settings(tr):
|
||||
"""渲染位置设置"""
|
||||
subtitle_positions = [
|
||||
(tr("Top"), "top"),
|
||||
(tr("Center"), "center"),
|
||||
(tr("Bottom"), "bottom"),
|
||||
(tr("Custom"), "custom"),
|
||||
]
|
||||
|
||||
selected_index = st.selectbox(
|
||||
tr("Position"),
|
||||
index=2,
|
||||
options=range(len(subtitle_positions)),
|
||||
format_func=lambda x: subtitle_positions[x][0],
|
||||
)
|
||||
|
||||
subtitle_position = subtitle_positions[selected_index][1]
|
||||
st.session_state['subtitle_position'] = subtitle_position
|
||||
|
||||
# 自定义位置处理
|
||||
if subtitle_position == "custom":
|
||||
custom_position = st.text_input(
|
||||
tr("Custom Position (% from top)"),
|
||||
value="70.0"
|
||||
)
|
||||
try:
|
||||
custom_position_value = float(custom_position)
|
||||
if custom_position_value < 0 or custom_position_value > 100:
|
||||
st.error(tr("Please enter a value between 0 and 100"))
|
||||
else:
|
||||
st.session_state['custom_position'] = custom_position_value
|
||||
except ValueError:
|
||||
st.error(tr("Please enter a valid number"))
|
||||
|
||||
def render_style_settings(tr):
|
||||
"""渲染样式设置"""
|
||||
stroke_cols = st.columns([0.3, 0.7])
|
||||
|
||||
with stroke_cols[0]:
|
||||
stroke_color = st.color_picker(
|
||||
tr("Stroke Color"),
|
||||
value="#000000"
|
||||
)
|
||||
st.session_state['stroke_color'] = stroke_color
|
||||
|
||||
with stroke_cols[1]:
|
||||
stroke_width = st.slider(
|
||||
tr("Stroke Width"),
|
||||
min_value=0.0,
|
||||
max_value=10.0,
|
||||
value=1.5,
|
||||
step=0.1
|
||||
)
|
||||
st.session_state['stroke_width'] = stroke_width
|
||||
|
||||
def get_subtitle_params():
|
||||
"""获取字幕参数"""
|
||||
return {
|
||||
'enabled': st.session_state.get('subtitle_enabled', True),
|
||||
'font_name': st.session_state.get('font_name', ''),
|
||||
'font_size': st.session_state.get('font_size', 60),
|
||||
'text_fore_color': st.session_state.get('text_fore_color', '#FFFFFF'),
|
||||
'position': st.session_state.get('subtitle_position', 'bottom'),
|
||||
'custom_position': st.session_state.get('custom_position', 70.0),
|
||||
'stroke_color': st.session_state.get('stroke_color', '#000000'),
|
||||
'stroke_width': st.session_state.get('stroke_width', 1.5),
|
||||
}
|
||||
47
webui/components/video_settings.py
Normal file
47
webui/components/video_settings.py
Normal file
@ -0,0 +1,47 @@
|
||||
import streamlit as st
|
||||
from app.models.schema import VideoClipParams, VideoAspect
|
||||
|
||||
def render_video_panel(tr):
|
||||
"""渲染视频配置面板"""
|
||||
with st.container(border=True):
|
||||
st.write(tr("Video Settings"))
|
||||
params = VideoClipParams()
|
||||
render_video_config(tr, params)
|
||||
|
||||
def render_video_config(tr, params):
|
||||
"""渲染视频配置"""
|
||||
# 视频比例
|
||||
video_aspect_ratios = [
|
||||
(tr("Portrait"), VideoAspect.portrait.value),
|
||||
(tr("Landscape"), VideoAspect.landscape.value),
|
||||
]
|
||||
selected_index = st.selectbox(
|
||||
tr("Video Ratio"),
|
||||
options=range(len(video_aspect_ratios)),
|
||||
format_func=lambda x: video_aspect_ratios[x][0],
|
||||
)
|
||||
params.video_aspect = VideoAspect(video_aspect_ratios[selected_index][1])
|
||||
st.session_state['video_aspect'] = params.video_aspect.value
|
||||
|
||||
# 视频画质
|
||||
video_qualities = [
|
||||
("4K (2160p)", "2160p"),
|
||||
("2K (1440p)", "1440p"),
|
||||
("Full HD (1080p)", "1080p"),
|
||||
("HD (720p)", "720p"),
|
||||
("SD (480p)", "480p"),
|
||||
]
|
||||
quality_index = st.selectbox(
|
||||
tr("Video Quality"),
|
||||
options=range(len(video_qualities)),
|
||||
format_func=lambda x: video_qualities[x][0],
|
||||
index=2 # 默认选择 1080p
|
||||
)
|
||||
st.session_state['video_quality'] = video_qualities[quality_index][1]
|
||||
|
||||
def get_video_params():
|
||||
"""获取视频参数"""
|
||||
return {
|
||||
'video_aspect': st.session_state.get('video_aspect', VideoAspect.portrait.value),
|
||||
'video_quality': st.session_state.get('video_quality', '1080p')
|
||||
}
|
||||
155
webui/config/settings.py
Normal file
155
webui/config/settings.py
Normal file
@ -0,0 +1,155 @@
|
||||
import os
|
||||
import tomli
|
||||
from loguru import logger
|
||||
from typing import Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class WebUIConfig:
|
||||
"""WebUI配置类"""
|
||||
# UI配置
|
||||
ui: Dict[str, Any] = None
|
||||
# 代理配置
|
||||
proxy: Dict[str, str] = None
|
||||
# 应用配置
|
||||
app: Dict[str, Any] = None
|
||||
# Azure配置
|
||||
azure: Dict[str, str] = None
|
||||
# 项目版本
|
||||
project_version: str = "0.1.0"
|
||||
# 项目根目录
|
||||
root_dir: str = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""初始化默认值"""
|
||||
self.ui = self.ui or {}
|
||||
self.proxy = self.proxy or {}
|
||||
self.app = self.app or {}
|
||||
self.azure = self.azure or {}
|
||||
self.root_dir = self.root_dir or os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
def load_config(config_path: Optional[str] = None) -> WebUIConfig:
|
||||
"""加载配置文件
|
||||
Args:
|
||||
config_path: 配置文件路径,如果为None则使用默认路径
|
||||
Returns:
|
||||
WebUIConfig: 配置对象
|
||||
"""
|
||||
try:
|
||||
if config_path is None:
|
||||
config_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
".streamlit",
|
||||
"webui.toml"
|
||||
)
|
||||
|
||||
# 如果配置文件不存在,使用示例配置
|
||||
if not os.path.exists(config_path):
|
||||
example_config = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
||||
"config.example.toml"
|
||||
)
|
||||
if os.path.exists(example_config):
|
||||
config_path = example_config
|
||||
else:
|
||||
logger.warning(f"配置文件不存在: {config_path}")
|
||||
return WebUIConfig()
|
||||
|
||||
# 读取配置文件
|
||||
with open(config_path, "rb") as f:
|
||||
config_dict = tomli.load(f)
|
||||
|
||||
# 创建配置对象
|
||||
config = WebUIConfig(
|
||||
ui=config_dict.get("ui", {}),
|
||||
proxy=config_dict.get("proxy", {}),
|
||||
app=config_dict.get("app", {}),
|
||||
azure=config_dict.get("azure", {}),
|
||||
project_version=config_dict.get("project_version", "0.1.0")
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载配置文件失败: {e}")
|
||||
return WebUIConfig()
|
||||
|
||||
def save_config(config: WebUIConfig, config_path: Optional[str] = None) -> bool:
|
||||
"""保存配置到文件
|
||||
Args:
|
||||
config: 配置对象
|
||||
config_path: 配置文件路径,如果为None则使用默认路径
|
||||
Returns:
|
||||
bool: 是否保存成功
|
||||
"""
|
||||
try:
|
||||
if config_path is None:
|
||||
config_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
".streamlit",
|
||||
"webui.toml"
|
||||
)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(config_path), exist_ok=True)
|
||||
|
||||
# 转换为字典
|
||||
config_dict = {
|
||||
"ui": config.ui,
|
||||
"proxy": config.proxy,
|
||||
"app": config.app,
|
||||
"azure": config.azure,
|
||||
"project_version": config.project_version
|
||||
}
|
||||
|
||||
# 保存配置
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
import tomli_w
|
||||
tomli_w.dump(config_dict, f)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存配置文件失败: {e}")
|
||||
return False
|
||||
|
||||
def get_config() -> WebUIConfig:
|
||||
"""获取全局配置对象
|
||||
Returns:
|
||||
WebUIConfig: 配置对象
|
||||
"""
|
||||
if not hasattr(get_config, "_config"):
|
||||
get_config._config = load_config()
|
||||
return get_config._config
|
||||
|
||||
def update_config(config_dict: Dict[str, Any]) -> bool:
|
||||
"""更新配置
|
||||
Args:
|
||||
config_dict: 配置字典
|
||||
Returns:
|
||||
bool: 是否更新成功
|
||||
"""
|
||||
try:
|
||||
config = get_config()
|
||||
|
||||
# 更新配置
|
||||
if "ui" in config_dict:
|
||||
config.ui.update(config_dict["ui"])
|
||||
if "proxy" in config_dict:
|
||||
config.proxy.update(config_dict["proxy"])
|
||||
if "app" in config_dict:
|
||||
config.app.update(config_dict["app"])
|
||||
if "azure" in config_dict:
|
||||
config.azure.update(config_dict["azure"])
|
||||
if "project_version" in config_dict:
|
||||
config.project_version = config_dict["project_version"]
|
||||
|
||||
# 保存配置
|
||||
return save_config(config)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新配置失败: {e}")
|
||||
return False
|
||||
|
||||
# 导出全局配置对象
|
||||
config = get_config()
|
||||
1
webui/i18n/__init__.py
Normal file
1
webui/i18n/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# 空文件,用于标记包
|
||||
@ -2,15 +2,15 @@
|
||||
"Language": "简体中文",
|
||||
"Translation": {
|
||||
"Video Script Configuration": "**视频脚本配置**",
|
||||
"Video Script Generate": "生成视频脚本",
|
||||
"Generate Video Script": "生成视频脚本",
|
||||
"Video Subject": "视频主题(给定一个关键词,:red[AI自动生成]视频文案)",
|
||||
"Script Language": "生成视频脚本的语言(一般情况AI会自动根据你输入的主题语言输出)",
|
||||
"Script Files": "脚本文件",
|
||||
"Generate Video Script and Keywords": "点击使用AI根据**主题**生成 【视频文案】 和 【视频关键词】",
|
||||
"Auto Detect": "自动检测",
|
||||
"Auto Generate": "自动生成",
|
||||
"Video Name": "视频名称",
|
||||
"Video Script": "视频脚本(:blue[①使用AI生成 ②从本机加载])",
|
||||
"Video Theme": "视频主题",
|
||||
"Generation Prompt": "自定义提示词",
|
||||
"Save Script": "保存脚本",
|
||||
"Crop Video": "裁剪视频",
|
||||
"Video File": "视频文件(:blue[1️⃣支持上传视频文件(限制2G) 2️⃣大文件建议直接导入 ./resource/videos 目录])",
|
||||
@ -91,7 +91,18 @@
|
||||
"Picture description": "图片描述",
|
||||
"Narration": "视频文案",
|
||||
"Rebuild": "重新生成",
|
||||
"Video Script Load": "加载视频脚本",
|
||||
"Speech Pitch": "语调"
|
||||
"Load Video Script": "加载视频脚本",
|
||||
"Speech Pitch": "语调",
|
||||
"Please Select Script File": "请选择脚本文件",
|
||||
"Check Format": "脚本格式检查",
|
||||
"Script Loaded Successfully": "脚本加载成功",
|
||||
"Script format check passed": "脚本格式检查通过",
|
||||
"Script format check failed": "脚本格式检查失败",
|
||||
"Failed to Load Script": "加载脚本失败",
|
||||
"Failed to Save Script": "保存脚本失败",
|
||||
"Script saved successfully": "脚本保存成功",
|
||||
"Video Script": "视频脚本",
|
||||
"Video Quality": "视频质量",
|
||||
"Custom prompt for LLM, leave empty to use default prompt": "自定义提示词,留空则使用默认提示词"
|
||||
}
|
||||
}
|
||||
20
webui/utils/__init__.py
Normal file
20
webui/utils/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
from .cache import get_fonts_cache, get_video_files_cache, get_songs_cache
|
||||
from .file_utils import (
|
||||
open_task_folder, cleanup_temp_files, get_file_list,
|
||||
save_uploaded_file, create_temp_file, get_file_size, ensure_directory
|
||||
)
|
||||
from .performance import monitor_performance
|
||||
|
||||
__all__ = [
|
||||
'get_fonts_cache',
|
||||
'get_video_files_cache',
|
||||
'get_songs_cache',
|
||||
'open_task_folder',
|
||||
'cleanup_temp_files',
|
||||
'get_file_list',
|
||||
'save_uploaded_file',
|
||||
'create_temp_file',
|
||||
'get_file_size',
|
||||
'ensure_directory',
|
||||
'monitor_performance'
|
||||
]
|
||||
33
webui/utils/cache.py
Normal file
33
webui/utils/cache.py
Normal file
@ -0,0 +1,33 @@
|
||||
import streamlit as st
|
||||
import os
|
||||
import glob
|
||||
from app.utils import utils
|
||||
|
||||
def get_fonts_cache(font_dir):
|
||||
if 'fonts_cache' not in st.session_state:
|
||||
fonts = []
|
||||
for root, dirs, files in os.walk(font_dir):
|
||||
for file in files:
|
||||
if file.endswith(".ttf") or file.endswith(".ttc"):
|
||||
fonts.append(file)
|
||||
fonts.sort()
|
||||
st.session_state['fonts_cache'] = fonts
|
||||
return st.session_state['fonts_cache']
|
||||
|
||||
def get_video_files_cache():
|
||||
if 'video_files_cache' not in st.session_state:
|
||||
video_files = []
|
||||
for suffix in ["*.mp4", "*.mov", "*.avi", "*.mkv"]:
|
||||
video_files.extend(glob.glob(os.path.join(utils.video_dir(), suffix)))
|
||||
st.session_state['video_files_cache'] = video_files[::-1]
|
||||
return st.session_state['video_files_cache']
|
||||
|
||||
def get_songs_cache(song_dir):
|
||||
if 'songs_cache' not in st.session_state:
|
||||
songs = []
|
||||
for root, dirs, files in os.walk(song_dir):
|
||||
for file in files:
|
||||
if file.endswith(".mp3"):
|
||||
songs.append(file)
|
||||
st.session_state['songs_cache'] = songs
|
||||
return st.session_state['songs_cache']
|
||||
189
webui/utils/file_utils.py
Normal file
189
webui/utils/file_utils.py
Normal file
@ -0,0 +1,189 @@
|
||||
import os
|
||||
import glob
|
||||
import time
|
||||
import platform
|
||||
import shutil
|
||||
from uuid import uuid4
|
||||
from loguru import logger
|
||||
from app.utils import utils
|
||||
|
||||
def open_task_folder(root_dir, task_id):
|
||||
"""打开任务文件夹
|
||||
Args:
|
||||
root_dir: 项目根目录
|
||||
task_id: 任务ID
|
||||
"""
|
||||
try:
|
||||
sys = platform.system()
|
||||
path = os.path.join(root_dir, "storage", "tasks", task_id)
|
||||
if os.path.exists(path):
|
||||
if sys == 'Windows':
|
||||
os.system(f"start {path}")
|
||||
if sys == 'Darwin':
|
||||
os.system(f"open {path}")
|
||||
if sys == 'Linux':
|
||||
os.system(f"xdg-open {path}")
|
||||
except Exception as e:
|
||||
logger.error(f"打开任务文件夹失败: {e}")
|
||||
|
||||
def cleanup_temp_files(temp_dir, max_age=3600):
|
||||
"""清理临时文件
|
||||
Args:
|
||||
temp_dir: 临时文件目录
|
||||
max_age: 文件最大保存时间(秒)
|
||||
"""
|
||||
if os.path.exists(temp_dir):
|
||||
for file in os.listdir(temp_dir):
|
||||
file_path = os.path.join(temp_dir, file)
|
||||
try:
|
||||
if os.path.getctime(file_path) < time.time() - max_age:
|
||||
if os.path.isfile(file_path):
|
||||
os.remove(file_path)
|
||||
elif os.path.isdir(file_path):
|
||||
shutil.rmtree(file_path)
|
||||
logger.debug(f"已清理临时文件: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"清理临时文件失败: {file_path}, 错误: {e}")
|
||||
|
||||
def get_file_list(directory, file_types=None, sort_by='ctime', reverse=True):
|
||||
"""获取指定目录下的文件列表
|
||||
Args:
|
||||
directory: 目录路径
|
||||
file_types: 文件类型列表,如 ['.mp4', '.mov']
|
||||
sort_by: 排序方式,支持 'ctime'(创建时间), 'mtime'(修改时间), 'size'(文件大小), 'name'(文件名)
|
||||
reverse: 是否倒序排序
|
||||
Returns:
|
||||
list: 文件信息列表
|
||||
"""
|
||||
if not os.path.exists(directory):
|
||||
return []
|
||||
|
||||
files = []
|
||||
if file_types:
|
||||
for file_type in file_types:
|
||||
files.extend(glob.glob(os.path.join(directory, f"*{file_type}")))
|
||||
else:
|
||||
files = glob.glob(os.path.join(directory, "*"))
|
||||
|
||||
file_list = []
|
||||
for file_path in files:
|
||||
try:
|
||||
file_stat = os.stat(file_path)
|
||||
file_info = {
|
||||
"name": os.path.basename(file_path),
|
||||
"path": file_path,
|
||||
"size": file_stat.st_size,
|
||||
"ctime": file_stat.st_ctime,
|
||||
"mtime": file_stat.st_mtime
|
||||
}
|
||||
file_list.append(file_info)
|
||||
except Exception as e:
|
||||
logger.error(f"获取文件信息失败: {file_path}, 错误: {e}")
|
||||
|
||||
# 排序
|
||||
if sort_by in ['ctime', 'mtime', 'size', 'name']:
|
||||
file_list.sort(key=lambda x: x.get(sort_by, ''), reverse=reverse)
|
||||
|
||||
return file_list
|
||||
|
||||
def save_uploaded_file(uploaded_file, save_dir, allowed_types=None):
|
||||
"""保存上传的文件
|
||||
Args:
|
||||
uploaded_file: StreamlitUploadedFile对象
|
||||
save_dir: 保存目录
|
||||
allowed_types: 允许的文件类型列表,如 ['.mp4', '.mov']
|
||||
Returns:
|
||||
str: 保存后的文件路径,失败返回None
|
||||
"""
|
||||
try:
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
|
||||
file_name, file_extension = os.path.splitext(uploaded_file.name)
|
||||
|
||||
# 检查文件类型
|
||||
if allowed_types and file_extension.lower() not in allowed_types:
|
||||
logger.error(f"不支持的文件类型: {file_extension}")
|
||||
return None
|
||||
|
||||
# 如果文件已存在,添加时间戳
|
||||
save_path = os.path.join(save_dir, uploaded_file.name)
|
||||
if os.path.exists(save_path):
|
||||
timestamp = time.strftime("%Y%m%d%H%M%S")
|
||||
new_file_name = f"{file_name}_{timestamp}{file_extension}"
|
||||
save_path = os.path.join(save_dir, new_file_name)
|
||||
|
||||
# 保存文件
|
||||
with open(save_path, "wb") as f:
|
||||
f.write(uploaded_file.read())
|
||||
|
||||
logger.info(f"文件保存成功: {save_path}")
|
||||
return save_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存上传文件失败: {e}")
|
||||
return None
|
||||
|
||||
def create_temp_file(prefix='tmp', suffix='', directory=None):
|
||||
"""创建临时文件
|
||||
Args:
|
||||
prefix: 文件名前缀
|
||||
suffix: 文件扩展名
|
||||
directory: 临时文件目录,默认使用系统临时目录
|
||||
Returns:
|
||||
str: 临时文件路径
|
||||
"""
|
||||
try:
|
||||
if directory is None:
|
||||
directory = utils.storage_dir("temp", create=True)
|
||||
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
|
||||
temp_file = os.path.join(directory, f"{prefix}-{str(uuid4())}{suffix}")
|
||||
return temp_file
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建临时文件失败: {e}")
|
||||
return None
|
||||
|
||||
def get_file_size(file_path, format='MB'):
|
||||
"""获取文件大小
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
format: 返回格式,支持 'B', 'KB', 'MB', 'GB'
|
||||
Returns:
|
||||
float: 文件大小
|
||||
"""
|
||||
try:
|
||||
size_bytes = os.path.getsize(file_path)
|
||||
|
||||
if format.upper() == 'B':
|
||||
return size_bytes
|
||||
elif format.upper() == 'KB':
|
||||
return size_bytes / 1024
|
||||
elif format.upper() == 'MB':
|
||||
return size_bytes / (1024 * 1024)
|
||||
elif format.upper() == 'GB':
|
||||
return size_bytes / (1024 * 1024 * 1024)
|
||||
else:
|
||||
return size_bytes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取文件大小失败: {file_path}, 错误: {e}")
|
||||
return 0
|
||||
|
||||
def ensure_directory(directory):
|
||||
"""确保目录存在,如果不存在则创建
|
||||
Args:
|
||||
directory: 目录路径
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
try:
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"创建目录失败: {directory}, 错误: {e}")
|
||||
return False
|
||||
24
webui/utils/performance.py
Normal file
24
webui/utils/performance.py
Normal file
@ -0,0 +1,24 @@
|
||||
import time
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
import psutil
|
||||
ENABLE_PERFORMANCE_MONITORING = True
|
||||
except ImportError:
|
||||
ENABLE_PERFORMANCE_MONITORING = False
|
||||
logger.warning("psutil not installed. Performance monitoring is disabled.")
|
||||
|
||||
def monitor_performance():
|
||||
if not ENABLE_PERFORMANCE_MONITORING:
|
||||
return {'execution_time': 0, 'memory_usage': 0}
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
memory_usage = psutil.Process().memory_info().rss / 1024 / 1024 # MB
|
||||
except:
|
||||
memory_usage = 0
|
||||
|
||||
return {
|
||||
'execution_time': time.time() - start_time,
|
||||
'memory_usage': memory_usage
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user