diff --git a/app/services/subtitle.py b/app/services/subtitle.py
index c792667..f37eb65 100644
--- a/app/services/subtitle.py
+++ b/app/services/subtitle.py
@@ -29,7 +29,7 @@ def create(audio_file, subtitle_file: str = ""):
返回:
无返回值,但会在指定路径生成字幕文件。
"""
- global model
+ global model, device, compute_type
if not model:
model_path = f"{utils.root_dir()}/app/models/faster-whisper-large-v2"
model_bin_file = f"{model_path}/model.bin"
@@ -43,27 +43,45 @@ def create(audio_file, subtitle_file: str = ""):
)
return None
- logger.info(
- f"加载模型: {model_path}, 设备: {device}, 计算类型: {compute_type}"
- )
+ # 尝试使用 CUDA,如果失败则回退到 CPU
try:
+ import torch
+ if torch.cuda.is_available():
+ try:
+ logger.info(f"尝试使用 CUDA 加载模型: {model_path}")
+ model = WhisperModel(
+ model_size_or_path=model_path,
+ device="cuda",
+ compute_type="float16",
+ local_files_only=True
+ )
+ device = "cuda"
+ compute_type = "float16"
+ logger.info("成功使用 CUDA 加载模型")
+ except Exception as e:
+ logger.warning(f"CUDA 加载失败,错误信息: {str(e)}")
+ logger.warning("回退到 CPU 模式")
+ device = "cpu"
+ compute_type = "int8"
+ else:
+ logger.info("未检测到 CUDA,使用 CPU 模式")
+ device = "cpu"
+ compute_type = "int8"
+ except ImportError:
+ logger.warning("未安装 torch,使用 CPU 模式")
+ device = "cpu"
+ compute_type = "int8"
+
+ if device == "cpu":
+ logger.info(f"使用 CPU 加载模型: {model_path}")
model = WhisperModel(
model_size_or_path=model_path,
device=device,
compute_type=compute_type,
local_files_only=True
)
- except Exception as e:
- logger.error(
- f"加载模型失败: {e} \n\n"
- f"********************************************\n"
- f"这可能是由网络问题引起的. \n"
- f"请手动下载模型并将其放入 'app/models' 文件夹中。 \n"
- f"see [README.md FAQ](https://github.com/linyqh/NarratoAI) for more details.\n"
- f"********************************************\n\n"
- f"{traceback.format_exc()}"
- )
- return None
+
+ logger.info(f"模型加载完成,使用设备: {device}, 计算类型: {compute_type}")
logger.info(f"start, output file: {subtitle_file}")
if not subtitle_file:
diff --git a/app/utils/check_script.py b/app/utils/check_script.py
index 623c42a..00e6c0f 100644
--- a/app/utils/check_script.py
+++ b/app/utils/check_script.py
@@ -1,115 +1,81 @@
import json
-from loguru import logger
-import os
-from datetime import timedelta
+from typing import Dict, Any
-def time_to_seconds(time_str):
- parts = list(map(int, time_str.split(':')))
- if len(parts) == 2:
- return timedelta(minutes=parts[0], seconds=parts[1]).total_seconds()
- elif len(parts) == 3:
- return timedelta(hours=parts[0], minutes=parts[1], seconds=parts[2]).total_seconds()
- raise ValueError(f"无法解析时间字符串: {time_str}")
+def check_format(script_content: str) -> Dict[str, Any]:
+ """检查脚本格式
+ Args:
+ script_content: 脚本内容
+ Returns:
+ Dict: {'success': bool, 'message': str}
+ """
+ try:
+ # 检查是否为有效的JSON
+ data = json.loads(script_content)
+
+ # 检查是否为列表
+ if not isinstance(data, list):
+ return {
+ 'success': False,
+ 'message': '脚本必须是JSON数组格式'
+ }
+
+ # 检查每个片段
+ for i, clip in enumerate(data):
+ # 检查必需字段
+ required_fields = ['narration', 'picture', 'timestamp']
+ for field in required_fields:
+ if field not in clip:
+ return {
+ 'success': False,
+ 'message': f'第{i+1}个片段缺少必需字段: {field}'
+ }
+
+ # 检查字段类型
+ if not isinstance(clip['narration'], str):
+ return {
+ 'success': False,
+ 'message': f'第{i+1}个片段的narration必须是字符串'
+ }
+ if not isinstance(clip['picture'], str):
+ return {
+ 'success': False,
+ 'message': f'第{i+1}个片段的picture必须是字符串'
+ }
+ if not isinstance(clip['timestamp'], str):
+ return {
+ 'success': False,
+ 'message': f'第{i+1}个片段的timestamp必须是字符串'
+ }
+
+ # 检查字段内容不能为空
+ if not clip['narration'].strip():
+ return {
+ 'success': False,
+ 'message': f'第{i+1}个片段的narration不能为空'
+ }
+ if not clip['picture'].strip():
+ return {
+ 'success': False,
+ 'message': f'第{i+1}个片段的picture不能为空'
+ }
+ if not clip['timestamp'].strip():
+ return {
+ 'success': False,
+ 'message': f'第{i+1}个片段的timestamp不能为空'
+ }
-def seconds_to_time_str(seconds):
- hours, remainder = divmod(int(seconds), 3600)
- minutes, seconds = divmod(remainder, 60)
- if hours > 0:
- return f"{hours:02d}:{minutes:02d}:{seconds:02d}"
- else:
- return f"{minutes:02d}:{seconds:02d}"
+ return {
+ 'success': True,
+ 'message': '脚本格式检查通过'
+ }
-def adjust_timestamp(start_time, duration):
- start_seconds = time_to_seconds(start_time)
- end_seconds = start_seconds + duration
- return f"{start_time}-{seconds_to_time_str(end_seconds)}"
-
-def estimate_audio_duration(text):
- # 假设平均每个字符需要 0.2 秒
- return len(text) * 0.2
-
-def check_script(data, total_duration):
- errors = []
- time_ranges = []
-
- logger.info("开始检查脚本")
- logger.info(f"视频总时长: {total_duration:.2f} 秒")
- logger.info("=" * 50)
-
- for i, item in enumerate(data, 1):
- logger.info(f"\n检查第 {i} 项:")
-
- # 检查所有必需字段
- required_fields = ['picture', 'timestamp', 'narration', 'OST']
- for field in required_fields:
- if field not in item:
- errors.append(f"第 {i} 项缺少 {field} 字段")
- logger.info(f" - 错误: 缺少 {field} 字段")
- else:
- logger.info(f" - {field}: {item[field]}")
-
- # 检查 OST 相关规则
- if item.get('OST') == False:
- if not item.get('narration'):
- errors.append(f"第 {i} 项 OST 为 false,但 narration 为空")
- logger.info(" - 错误: OST 为 false,但 narration 为空")
- elif len(item['narration']) > 60:
- errors.append(f"第 {i} 项 OST 为 false,但 narration 超过 60 字")
- logger.info(f" - 错误: OST 为 false,但 narration 超过 60 字 (当前: {len(item['narration'])} 字)")
- else:
- logger.info(" - OST 为 false,narration 检查通过")
- elif item.get('OST') == True:
- if "原声播放_" not in item.get('narration'):
- errors.append(f"第 {i} 项 OST 为 true,但 narration 不为空")
- logger.info(" - 错误: OST 为 true,但 narration 不为空")
- else:
- logger.info(" - OST 为 true,narration 检查通过")
-
- # 检查 timestamp
- if 'timestamp' in item:
- start, end = map(time_to_seconds, item['timestamp'].split('-'))
- if any((start < existing_end and end > existing_start) for existing_start, existing_end in time_ranges):
- errors.append(f"第 {i} 项 timestamp '{item['timestamp']}' 与其他时间段重叠")
- logger.info(f" - 错误: timestamp '{item['timestamp']}' 与其他时间段重叠")
- else:
- logger.info(f" - timestamp '{item['timestamp']}' 检查通过")
- time_ranges.append((start, end))
-
- # if end > total_duration:
- # errors.append(f"第 {i} 项 timestamp '{item['timestamp']}' 超过总时长 {total_duration:.2f} 秒")
- # logger.info(f" - 错误: timestamp '{item['timestamp']}' 超过总时长 {total_duration:.2f} 秒")
- # else:
- # logger.info(f" - timestamp 在总时长范围内")
-
- # 处理 narration 字段
- if item.get('OST') == False and item.get('narration'):
- estimated_duration = estimate_audio_duration(item['narration'])
- start_time = item['timestamp'].split('-')[0]
- item['timestamp'] = adjust_timestamp(start_time, estimated_duration)
- logger.info(f" - 已调整 timestamp 为 {item['timestamp']} (估算音频时长: {estimated_duration:.2f} 秒)")
-
- if errors:
- logger.info("检查结果:不通过")
- logger.info("发现以下错误:")
- for error in errors:
- logger.info(f"- {error}")
- else:
- logger.info("检查结果:通过")
- logger.info("所有项目均符合规则要求。")
-
- return errors, data
-
-
-if __name__ == "__main__":
- file_path = "/Users/apple/Desktop/home/NarratoAI/resource/scripts/test004.json"
-
- with open(file_path, 'r', encoding='utf-8') as f:
- data = json.load(f)
-
- total_duration = 280
-
- # check_script(data, total_duration)
-
- from app.utils.utils import add_new_timestamps
- res = add_new_timestamps(data)
- print(json.dumps(res, indent=4, ensure_ascii=False))
+ except json.JSONDecodeError as e:
+ return {
+ 'success': False,
+ 'message': f'JSON格式错误: {str(e)}'
+ }
+ except Exception as e:
+ return {
+ 'success': False,
+ 'message': f'检查过程中发生错误: {str(e)}'
+ }
diff --git a/requirements.txt b/requirements.txt
index a9445b9..f3be823 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -24,3 +24,4 @@ azure-cognitiveservices-speech~=1.37.0
git-changelog~=2.5.2
watchdog==5.0.2
pydub==0.25.1
+psutil>=5.9.0
diff --git a/webui.py b/webui.py
index faae899..d2a02f0 100644
--- a/webui.py
+++ b/webui.py
@@ -1,6 +1,14 @@
import streamlit as st
+import os
+import sys
+from uuid import uuid4
from app.config import config
+from webui.components import basic_settings, video_settings, audio_settings, subtitle_settings, script_settings, review_settings
+from webui.utils import cache, file_utils, performance
+from app.utils import utils
+from app.models.schema import VideoClipParams, VideoAspect
+# 初始化配置 - 必须是第一个 Streamlit 命令
st.set_page_config(
page_title="NarratoAI",
page_icon="📽️",
@@ -13,126 +21,23 @@ st.set_page_config(
},
)
-import sys
-import os
-import glob
-import json
-import time
-import datetime
-import traceback
-from uuid import uuid4
-import platform
-import streamlit.components.v1 as components
-from loguru import logger
-
-from app.models.const import FILE_TYPE_VIDEOS
-from app.models.schema import VideoClipParams, VideoAspect, VideoConcatMode
-from app.services import task as tm, llm, voice, material
-from app.utils import utils
-
-# # 将项目的根目录添加到系统路径中,以允许从项目导入模块
-root_dir = os.path.dirname(os.path.realpath(__file__))
-if root_dir not in sys.path:
- sys.path.append(root_dir)
- print("******** sys.path ********")
- print(sys.path)
- print("*" * 20)
-
-proxy_url_http = config.proxy.get("http", "") or os.getenv("VPN_PROXY_URL", "")
-proxy_url_https = config.proxy.get("https", "") or os.getenv("VPN_PROXY_URL", "")
-os.environ["HTTP_PROXY"] = proxy_url_http
-os.environ["HTTPS_PROXY"] = proxy_url_https
-
+# 设置页面样式
hide_streamlit_style = """
"""
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
-st.title(f"NarratoAI :sunglasses:📽️")
-support_locales = [
- "zh-CN",
- "zh-HK",
- "zh-TW",
- "en-US",
-]
-font_dir = os.path.join(root_dir, "resource", "fonts")
-song_dir = os.path.join(root_dir, "resource", "songs")
-i18n_dir = os.path.join(root_dir, "webui", "i18n")
-config_file = os.path.join(root_dir, "webui", ".streamlit", "webui.toml")
-system_locale = utils.get_system_locale()
-
-if 'video_clip_json' not in st.session_state:
- st.session_state['video_clip_json'] = []
-if 'video_plot' not in st.session_state:
- st.session_state['video_plot'] = ''
-if 'ui_language' not in st.session_state:
- st.session_state['ui_language'] = config.ui.get("language", system_locale)
-if 'subclip_videos' not in st.session_state:
- st.session_state['subclip_videos'] = {}
-
-
-def get_all_fonts():
- fonts = []
- for root, dirs, files in os.walk(font_dir):
- for file in files:
- if file.endswith(".ttf") or file.endswith(".ttc"):
- fonts.append(file)
- fonts.sort()
- return fonts
-
-
-def get_all_songs():
- songs = []
- for root, dirs, files in os.walk(song_dir):
- for file in files:
- if file.endswith(".mp3"):
- songs.append(file)
- return songs
-
-
-def open_task_folder(task_id):
- try:
- sys = platform.system()
- path = os.path.join(root_dir, "storage", "tasks", task_id)
- if os.path.exists(path):
- if sys == 'Windows':
- os.system(f"start {path}")
- if sys == 'Darwin':
- os.system(f"open {path}")
- except Exception as e:
- logger.error(e)
-
-
-def scroll_to_bottom():
- js = f"""
-
- """
- st.components.v1.html(js, height=0, width=0)
-
def init_log():
+ """初始化日志配置"""
+ from loguru import logger
logger.remove()
_lvl = "DEBUG"
def format_record(record):
- # 获取日志记录中的文件全径
file_path = record["file"].path
- # 将绝对路径转换为相对于项目根目录的路径
- relative_path = os.path.relpath(file_path, root_dir)
- # 更新记录中的文件路径
+ relative_path = os.path.relpath(file_path, config.root_dir)
record["file"].path = f"./{relative_path}"
- # 返回修改后的格式字符串
- # 您可以根据需要调整这里的格式
- record['message'] = record['message'].replace(root_dir, ".")
+ record['message'] = record['message'].replace(config.root_dir, ".")
_format = '{time:%Y-%m-%d %H:%M:%S}> | ' + \
'{level}> | ' + \
@@ -147,672 +52,120 @@ def init_log():
colorize=True,
)
-
-init_log()
-
-locales = utils.load_locales(i18n_dir)
-
+def init_global_state():
+ """初始化全局状态"""
+ if 'video_clip_json' not in st.session_state:
+ st.session_state['video_clip_json'] = []
+ if 'video_plot' not in st.session_state:
+ st.session_state['video_plot'] = ''
+ if 'ui_language' not in st.session_state:
+ st.session_state['ui_language'] = config.ui.get("language", utils.get_system_locale())
+ if 'subclip_videos' not in st.session_state:
+ st.session_state['subclip_videos'] = {}
def tr(key):
+ """翻译函数"""
+ i18n_dir = os.path.join(os.path.dirname(__file__), "webui", "i18n")
+ locales = utils.load_locales(i18n_dir)
loc = locales.get(st.session_state['ui_language'], {})
return loc.get("Translation", {}).get(key, key)
+def render_generate_button():
+ """渲染生成按钮和处理逻辑"""
+ if st.button(tr("Generate Video"), use_container_width=True, type="primary"):
+ from app.services import task as tm
+
+ # 重置日志容器和记录
+ log_container = st.empty()
+ log_records = []
-st.write(tr("Get Help"))
+ def log_received(msg):
+ with log_container:
+ log_records.append(msg)
+ st.code("\n".join(log_records))
-# 基础设置
-with st.expander(tr("Basic Settings"), expanded=False):
- config_panels = st.columns(3)
- left_config_panel = config_panels[0]
- middle_config_panel = config_panels[1]
- right_config_panel = config_panels[2]
- with left_config_panel:
- display_languages = []
- selected_index = 0
- for i, code in enumerate(locales.keys()):
- display_languages.append(f"{code} - {locales[code].get('Language')}")
- if code == st.session_state['ui_language']:
- selected_index = i
+ from loguru import logger
+ logger.add(log_received)
- selected_language = st.selectbox(tr("Language"), options=display_languages,
- index=selected_index)
- if selected_language:
- code = selected_language.split(" - ")[0].strip()
- st.session_state['ui_language'] = code
- config.ui['language'] = code
+ config.save_config()
+ task_id = st.session_state.get('task_id')
- HTTP_PROXY = st.text_input(tr("HTTP_PROXY"), value=proxy_url_http)
- HTTPS_PROXY = st.text_input(tr("HTTPs_PROXY"), value=proxy_url_https)
- if HTTP_PROXY:
- config.proxy["http"] = HTTP_PROXY
- if HTTPS_PROXY:
- config.proxy["https"] = HTTPS_PROXY
+ if not task_id:
+ st.error(tr("请先裁剪视频"))
+ return
+ if not st.session_state.get('video_clip_json_path'):
+ st.error(tr("脚本文件不能为空"))
+ return
+ if not st.session_state.get('video_origin_path'):
+ st.error(tr("视频文件不能为空"))
+ return
- # 视频转录大模型
- with middle_config_panel:
- video_llm_providers = ['Gemini']
- saved_llm_provider = config.app.get("llm_provider", "OpenAI").lower()
- saved_llm_provider_index = 0
- for i, provider in enumerate(video_llm_providers):
- if provider.lower() == saved_llm_provider:
- saved_llm_provider_index = i
- break
+ st.toast(tr("生成视频"))
+ logger.info(tr("开始生成视频"))
- video_llm_provider = st.selectbox(tr("Video LLM Provider"), options=video_llm_providers, index=saved_llm_provider_index)
- video_llm_provider = video_llm_provider.lower()
- config.app["video_llm_provider"] = video_llm_provider
+ # 获取所有参数
+ script_params = script_settings.get_script_params()
+ video_params = video_settings.get_video_params()
+ audio_params = audio_settings.get_audio_params()
+ subtitle_params = subtitle_settings.get_subtitle_params()
- video_llm_api_key = config.app.get(f"{video_llm_provider}_api_key", "")
- video_llm_base_url = config.app.get(f"{video_llm_provider}_base_url", "")
- video_llm_model_name = config.app.get(f"{video_llm_provider}_model_name", "")
- video_llm_account_id = config.app.get(f"{video_llm_provider}_account_id", "")
- st_llm_api_key = st.text_input(tr("Video API Key"), value=video_llm_api_key, type="password")
- st_llm_base_url = st.text_input(tr("Video Base Url"), value=video_llm_base_url)
- st_llm_model_name = st.text_input(tr("Video Model Name"), value=video_llm_model_name)
- if st_llm_api_key:
- config.app[f"{video_llm_provider}_api_key"] = st_llm_api_key
- if st_llm_base_url:
- config.app[f"{video_llm_provider}_base_url"] = st_llm_base_url
- if st_llm_model_name:
- config.app[f"{video_llm_provider}_model_name"] = st_llm_model_name
-
- # 大语言模型
- with right_config_panel:
- llm_providers = ['Gemini', 'OpenAI', 'Moonshot', 'Azure', 'Qwen', 'Ollama', 'G4f', 'OneAPI', "Cloudflare"]
- saved_llm_provider = config.app.get("llm_provider", "OpenAI").lower()
- saved_llm_provider_index = 0
- for i, provider in enumerate(llm_providers):
- if provider.lower() == saved_llm_provider:
- saved_llm_provider_index = i
- break
-
- llm_provider = st.selectbox(tr("LLM Provider"), options=llm_providers, index=saved_llm_provider_index)
- llm_provider = llm_provider.lower()
- config.app["llm_provider"] = llm_provider
-
- llm_api_key = config.app.get(f"{llm_provider}_api_key", "")
- llm_base_url = config.app.get(f"{llm_provider}_base_url", "")
- llm_model_name = config.app.get(f"{llm_provider}_model_name", "")
- llm_account_id = config.app.get(f"{llm_provider}_account_id", "")
- st_llm_api_key = st.text_input(tr("API Key"), value=llm_api_key, type="password")
- st_llm_base_url = st.text_input(tr("Base Url"), value=llm_base_url)
- st_llm_model_name = st.text_input(tr("Model Name"), value=llm_model_name)
- if st_llm_api_key:
- config.app[f"{llm_provider}_api_key"] = st_llm_api_key
- if st_llm_base_url:
- config.app[f"{llm_provider}_base_url"] = st_llm_base_url
- if st_llm_model_name:
- config.app[f"{llm_provider}_model_name"] = st_llm_model_name
-
- if llm_provider == 'cloudflare':
- st_llm_account_id = st.text_input(tr("Account ID"), value=llm_account_id)
- if st_llm_account_id:
- config.app[f"{llm_provider}_account_id"] = st_llm_account_id
-
-panel = st.columns(3)
-left_panel = panel[0]
-middle_panel = panel[1]
-right_panel = panel[2]
-
-params = VideoClipParams()
-
-# 左侧面板
-with left_panel:
- with st.container(border=True):
- st.write(tr("Video Script Configuration"))
- # 脚本语言
- video_languages = [
- (tr("Auto Detect"), ""),
- ]
- for code in ["zh-CN", "en-US", "zh-TW"]:
- video_languages.append((code, code))
-
- selected_index = st.selectbox(tr("Script Language"),
- index=0,
- options=range(len(video_languages)), # 使用索引作为内部选项值
- format_func=lambda x: video_languages[x][0] # 显示给用户的是标签
- )
- params.video_language = video_languages[selected_index][1]
-
- # 脚本路径
- suffix = "*.json"
- song_dir = utils.script_dir()
- files = glob.glob(os.path.join(song_dir, suffix))
- script_list = []
- for file in files:
- script_list.append({
- "name": os.path.basename(file),
- "size": os.path.getsize(file),
- "file": file,
- "ctime": os.path.getctime(file) # 获取文件创建时间
- })
-
- # 按创建时间降序排序
- script_list.sort(key=lambda x: x["ctime"], reverse=True)
-
- # 本文件 下拉框
- script_path = [(tr("Auto Generate"), ""), ]
- for file in script_list:
- display_name = file['file'].replace(root_dir, "")
- script_path.append((display_name, file['file']))
- selected_script_index = st.selectbox(tr("Script Files"),
- index=0,
- options=range(len(script_path)), # 使用索引作为内部选项值
- format_func=lambda x: script_path[x][0] # 显示给用户的是标签
- )
- params.video_clip_json_path = script_path[selected_script_index][1]
- config.app["video_clip_json_path"] = params.video_clip_json_path
- st.session_state['video_clip_json_path'] = params.video_clip_json_path
-
- # 视频文件处理
- video_files = []
- for suffix in ["*.mp4", "*.mov", "*.avi", "*.mkv"]:
- video_files.extend(glob.glob(os.path.join(utils.video_dir(), suffix)))
- video_files = video_files[::-1]
-
- video_list = []
- for video_file in video_files:
- video_list.append({
- "name": os.path.basename(video_file),
- "size": os.path.getsize(video_file),
- "file": video_file,
- "ctime": os.path.getctime(video_file) # 获取文件创建时间
- })
- # 按创建时间降序排序
- video_list.sort(key=lambda x: x["ctime"], reverse=True)
- video_path = [(tr("None"), ""), (tr("Upload Local Files"), "local")]
- for file in video_list:
- display_name = file['file'].replace(root_dir, "")
- video_path.append((display_name, file['file']))
-
- # 视频文件
- selected_video_index = st.selectbox(tr("Video File"),
- index=0,
- options=range(len(video_path)), # 使用索引作为内部选项值
- format_func=lambda x: video_path[x][0] # 显示给用户的是标签
- )
- params.video_origin_path = video_path[selected_video_index][1]
- config.app["video_origin_path"] = params.video_origin_path
- st.session_state['video_origin_path'] = params.video_origin_path
-
- # 从本地上传 mp4 文件
- if params.video_origin_path == "local":
- _supported_types = FILE_TYPE_VIDEOS
- uploaded_file = st.file_uploader(
- tr("Upload Local Files"),
- type=["mp4", "mov", "avi", "flv", "mkv"],
- accept_multiple_files=False,
- )
- if uploaded_file is not None:
- # 构造保存路径
- video_file_path = os.path.join(utils.video_dir(), uploaded_file.name)
- file_name, file_extension = os.path.splitext(uploaded_file.name)
- # 检查文件是否存在,如果存在则添加时间戳
- if os.path.exists(video_file_path):
- timestamp = time.strftime("%Y%m%d%H%M%S")
- file_name_with_timestamp = f"{file_name}_{timestamp}"
- video_file_path = os.path.join(utils.video_dir(), file_name_with_timestamp + file_extension)
- # 将文件保存到指定目录
- with open(video_file_path, "wb") as f:
- f.write(uploaded_file.read())
- st.success(tr("File Uploaded Successfully"))
- time.sleep(1)
- st.rerun()
- # 视频名称
- video_name = st.text_input(tr("Video Name"))
- # 剧情内容
- video_plot = st.text_area(
- tr("Plot Description"),
- value=st.session_state['video_plot'],
- height=180
- )
-
- # 生成视频脚本
- if st.session_state['video_clip_json_path']:
- generate_button_name = tr("Video Script Load")
- else:
- generate_button_name = tr("Video Script Generate")
- if st.button(generate_button_name, key="auto_generate_script"):
- progress_bar = st.progress(0)
- status_text = st.empty()
-
- def update_progress(progress: float, message: str = ""):
- progress_bar.progress(progress)
- if message:
- status_text.text(f"{progress}% - {message}")
- else:
- status_text.text(f"进度: {progress}%")
-
- try:
- with st.spinner("正在生成脚本..."):
- if not video_plot:
- st.warning("视频剧情为空; 会极大影响生成效果!")
- if params.video_clip_json_path == "" and params.video_origin_path != "":
- update_progress(10, "压缩视频中...")
- # 使用大模型生成视频脚本
- script = llm.generate_script(
- video_path=params.video_origin_path,
- video_plot=video_plot,
- video_name=video_name,
- language=params.video_language,
- progress_callback=update_progress
- )
- if script is None:
- st.error("生成脚本失败,请检查日志")
- st.stop()
- else:
- update_progress(90)
-
- script = utils.clean_model_output(script)
- st.session_state['video_clip_json'] = json.loads(script)
- else:
- # 从本地加载
- with open(params.video_clip_json_path, 'r', encoding='utf-8') as f:
- update_progress(50)
- status_text.text("从本地加载中...")
- script = f.read()
- script = utils.clean_model_output(script)
- st.session_state['video_clip_json'] = json.loads(script)
- update_progress(100)
- status_text.text("从本地加载成功")
-
- time.sleep(0.5) # 给进度条一点时间到达100%
- progress_bar.progress(100)
- status_text.text("脚本生成完成!")
- st.success("视频脚本生成成功!")
- except Exception as err:
- st.error(f"生成过程中发生错误: {str(err)}")
- finally:
- time.sleep(2) # 给用户一些时间查看最终状态
- progress_bar.empty()
- status_text.empty()
-
- # 视频脚本
- video_clip_json_details = st.text_area(
- tr("Video Script"),
- value=json.dumps(st.session_state.video_clip_json, indent=2, ensure_ascii=False),
- height=180
- )
-
- # 保存脚本
- button_columns = st.columns(2)
- with button_columns[0]:
- if st.button(tr("Save Script"), key="auto_generate_terms", use_container_width=True):
- if not video_clip_json_details:
- st.error(tr("请输入视频脚本"))
- st.stop()
-
- with st.spinner(tr("Save Script")):
- script_dir = utils.script_dir()
- # 获取当前时间戳,形如 2024-0618-171820
- timestamp = datetime.datetime.now().strftime("%Y-%m%d-%H%M%S")
- save_path = os.path.join(script_dir, f"{timestamp}.json")
-
- try:
- data = utils.add_new_timestamps(json.loads(video_clip_json_details))
- except Exception as err:
- st.error(f"视频脚本格式错误,请检查脚本是否符合 JSON 格式;{err} \n\n{traceback.format_exc()}")
- st.stop()
-
- # 存储为新的 JSON 文件
- with open(save_path, 'w', encoding='utf-8') as file:
- json.dump(data, file, ensure_ascii=False, indent=4)
- # 将data的值存储到 session_state 中,类似缓存
- st.session_state['video_clip_json'] = data
- st.session_state['video_clip_json_path'] = save_path
- # 刷新页面
- st.rerun()
-
- # 裁剪视频
- with button_columns[1]:
- if st.button(tr("Crop Video"), key="auto_crop_video", use_container_width=True):
- progress_bar = st.progress(0)
- status_text = st.empty()
-
- def update_progress(progress):
- progress_bar.progress(progress)
- status_text.text(f"剪辑进度: {progress}%")
-
- try:
- utils.cut_video(params, update_progress)
- time.sleep(0.5) # 给进度条一点时间到达100%
- progress_bar.progress(100)
- status_text.text("剪辑完成!")
- st.success("视频剪辑成功完成!")
- except Exception as e:
- st.error(f"剪辑过程中发生错误: {str(e)}")
- finally:
- time.sleep(2) # 给用户一些时间查看最终状态
- progress_bar.empty()
- status_text.empty()
-
-# 新中间面板
-with middle_panel:
- with st.container(border=True):
- st.write(tr("Video Settings"))
-
- # 视频比例
- video_aspect_ratios = [
- (tr("Portrait"), VideoAspect.portrait.value),
- (tr("Landscape"), VideoAspect.landscape.value),
- ]
- selected_index = st.selectbox(
- tr("Video Ratio"),
- options=range(len(video_aspect_ratios)), # 使用索引作为内部选项值
- format_func=lambda x: video_aspect_ratios[x][0], # 显示给用户的是标签
- )
- params.video_aspect = VideoAspect(video_aspect_ratios[selected_index][1])
-
- # params.video_clip_duration = st.selectbox(
- # tr("Clip Duration"), options=[2, 3, 4, 5, 6, 7, 8, 9, 10], index=1
- # )
- # params.video_count = st.selectbox(
- # tr("Number of Videos Generated Simultaneously"),
- # options=[1, 2, 3, 4, 5],
- # index=0,
- # )
- with st.container(border=True):
- st.write(tr("Audio Settings"))
-
- # tts_providers = ['edge', 'azure']
- # tts_provider = st.selectbox(tr("TTS Provider"), tts_providers)
-
- voices = voice.get_all_azure_voices(filter_locals=support_locales)
- friendly_names = {
- v: v.replace("Female", tr("Female"))
- .replace("Male", tr("Male"))
- .replace("Neural", "")
- for v in voices
+ # 合并所有参数
+ all_params = {
+ **script_params,
+ **video_params,
+ **audio_params,
+ **subtitle_params
}
- saved_voice_name = config.ui.get("voice_name", "")
- saved_voice_name_index = 0
- if saved_voice_name in friendly_names:
- saved_voice_name_index = list(friendly_names.keys()).index(saved_voice_name)
- else:
- for i, v in enumerate(voices):
- if (
- v.lower().startswith(st.session_state["ui_language"].lower())
- and "V2" not in v
- ):
- saved_voice_name_index = i
- break
- selected_friendly_name = st.selectbox(
- tr("Speech Synthesis"),
- options=list(friendly_names.values()),
- index=saved_voice_name_index,
+ # 创建参数对象
+ params = VideoClipParams(**all_params)
+
+ result = tm.start_subclip(
+ task_id=task_id,
+ params=params,
+ subclip_path_videos=st.session_state['subclip_videos']
)
- voice_name = list(friendly_names.keys())[
- list(friendly_names.values()).index(selected_friendly_name)
- ]
- params.voice_name = voice_name
- config.ui["voice_name"] = voice_name
+ video_files = result.get("videos", [])
+ st.success(tr("视频生成完成"))
+
+ try:
+ if video_files:
+ player_cols = st.columns(len(video_files) * 2 + 1)
+ for i, url in enumerate(video_files):
+ player_cols[i * 2 + 1].video(url)
+ except Exception as e:
+ logger.error(f"播放视频失败: {e}")
- if voice.is_azure_v2_voice(voice_name):
- saved_azure_speech_region = config.azure.get("speech_region", "")
- saved_azure_speech_key = config.azure.get("speech_key", "")
- azure_speech_region = st.text_input(
- tr("Speech Region"), value=saved_azure_speech_region
- )
- azure_speech_key = st.text_input(
- tr("Speech Key"), value=saved_azure_speech_key, type="password"
- )
- config.azure["speech_region"] = azure_speech_region
- config.azure["speech_key"] = azure_speech_key
+ file_utils.open_task_folder(config.root_dir, task_id)
+ logger.info(tr("视频生成完成"))
- params.voice_volume = st.selectbox(
- tr("Speech Volume"),
- options=[0.6, 0.8, 1.0, 1.2, 1.5, 2.0, 3.0, 4.0, 5.0],
- index=2,
- )
+def main():
+ """主函数"""
+ init_log()
+ init_global_state()
+
+ st.title(f"NarratoAI :sunglasses:📽️")
+ st.write(tr("Get Help"))
+
+ # 渲染基础设置面板
+ basic_settings.render_basic_settings(tr)
+
+ # 渲染主面板
+ panel = st.columns(3)
+ with panel[0]:
+ script_settings.render_script_panel(tr)
+ with panel[1]:
+ video_settings.render_video_panel(tr)
+ audio_settings.render_audio_panel(tr)
+ with panel[2]:
+ subtitle_settings.render_subtitle_panel(tr)
+
+ # 渲染视频审查面板
+ review_settings.render_review_panel(tr)
+
+ # 渲染生成按钮和处理逻辑
+ render_generate_button()
- params.voice_rate = st.selectbox(
- tr("Speech Rate"),
- options=[0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 1.8, 2.0],
- index=2,
- )
-
- params.voice_pitch = st.selectbox(
- tr("Speech Pitch"),
- options=[0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 1.8, 2.0],
- index=2,
- )
-
- # 试听语言合成
- if st.button(tr("Play Voice")):
- play_content = "感谢关注 NarratoAI,有任何问题或建议,可以关注微信公众号,求助或讨论"
- if not play_content:
- play_content = params.video_script
- if not play_content:
- play_content = tr("Voice Example")
- with st.spinner(tr("Synthesizing Voice")):
- temp_dir = utils.storage_dir("temp", create=True)
- audio_file = os.path.join(temp_dir, f"tmp-voice-{str(uuid4())}.mp3")
- sub_maker = voice.tts(
- text=play_content,
- voice_name=voice_name,
- voice_rate=params.voice_rate,
- voice_pitch=params.voice_pitch,
- voice_file=audio_file,
- )
- # 如果语音文件生成失败,请使用默认内容重试。
- if not sub_maker:
- play_content = "This is a example voice. if you hear this, the voice synthesis failed with the original content."
- sub_maker = voice.tts(
- text=play_content,
- voice_name=voice_name,
- voice_rate=params.voice_rate,
- voice_pitch=params.voice_pitch,
- voice_file=audio_file,
- )
-
- if sub_maker and os.path.exists(audio_file):
- st.audio(audio_file, format="audio/mp3")
- if os.path.exists(audio_file):
- os.remove(audio_file)
-
- bgm_options = [
- (tr("No Background Music"), ""),
- (tr("Random Background Music"), "random"),
- (tr("Custom Background Music"), "custom"),
- ]
- selected_index = st.selectbox(
- tr("Background Music"),
- index=1,
- options=range(len(bgm_options)), # 使用索引作为内部选项值
- format_func=lambda x: bgm_options[x][0], # 显示给用户的是标签
- )
- # 获取选择的背景音乐类型
- params.bgm_type = bgm_options[selected_index][1]
-
- # 根据选择显示或隐藏组件
- if params.bgm_type == "custom":
- custom_bgm_file = st.text_input(tr("Custom Background Music File"))
- if custom_bgm_file and os.path.exists(custom_bgm_file):
- params.bgm_file = custom_bgm_file
- # st.write(f":red[已选择自定义背景音乐]:**{custom_bgm_file}**")
- params.bgm_volume = st.selectbox(
- tr("Background Music Volume"),
- options=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
- index=2,
- )
-
-# 新侧面板
-with right_panel:
- with st.container(border=True):
- st.write(tr("Subtitle Settings"))
- params.subtitle_enabled = st.checkbox(tr("Enable Subtitles"), value=True)
- font_names = get_all_fonts()
- saved_font_name = config.ui.get("font_name", "")
- saved_font_name_index = 0
- if saved_font_name in font_names:
- saved_font_name_index = font_names.index(saved_font_name)
- params.font_name = st.selectbox(
- tr("Font"), font_names, index=saved_font_name_index
- )
- config.ui["font_name"] = params.font_name
-
- subtitle_positions = [
- (tr("Top"), "top"),
- (tr("Center"), "center"),
- (tr("Bottom"), "bottom"),
- (tr("Custom"), "custom"),
- ]
- selected_index = st.selectbox(
- tr("Position"),
- index=2,
- options=range(len(subtitle_positions)),
- format_func=lambda x: subtitle_positions[x][0],
- )
- params.subtitle_position = subtitle_positions[selected_index][1]
-
- if params.subtitle_position == "custom":
- custom_position = st.text_input(
- tr("Custom Position (% from top)"), value="70.0"
- )
- try:
- params.custom_position = float(custom_position)
- if params.custom_position < 0 or params.custom_position > 100:
- st.error(tr("Please enter a value between 0 and 100"))
- except ValueError:
- logger.error(f"输入的值无效: {traceback.format_exc()}")
- st.error(tr("Please enter a valid number"))
-
- font_cols = st.columns([0.3, 0.7])
- with font_cols[0]:
- saved_text_fore_color = config.ui.get("text_fore_color", "#FFFFFF")
- params.text_fore_color = st.color_picker(
- tr("Font Color"), saved_text_fore_color
- )
- config.ui["text_fore_color"] = params.text_fore_color
-
- with font_cols[1]:
- saved_font_size = config.ui.get("font_size", 60)
- params.font_size = st.slider(tr("Font Size"), 30, 100, saved_font_size)
- config.ui["font_size"] = params.font_size
-
- stroke_cols = st.columns([0.3, 0.7])
- with stroke_cols[0]:
- params.stroke_color = st.color_picker(tr("Stroke Color"), "#000000")
- with stroke_cols[1]:
- params.stroke_width = st.slider(tr("Stroke Width"), 0.0, 10.0, 1.5)
-
-# 视频编辑面板
-with st.expander(tr("Video Check"), expanded=False):
- try:
- video_list = st.session_state.video_clip_json
- except KeyError as e:
- video_list = []
-
- # 计算列数和行数
- num_videos = len(video_list)
- cols_per_row = 3
- rows = (num_videos + cols_per_row - 1) // cols_per_row # 向上取整计算行数
-
- # 使用容器展示视频
- for row in range(rows):
- cols = st.columns(cols_per_row)
- for col in range(cols_per_row):
- index = row * cols_per_row + col
- if index < num_videos:
- with cols[col]:
- video_info = video_list[index]
- video_path = video_info.get('path')
- if video_path is not None:
- initial_narration = video_info['narration']
- initial_picture = video_info['picture']
- initial_timestamp = video_info['timestamp']
-
- with open(video_path, 'rb') as video_file:
- video_bytes = video_file.read()
- st.video(video_bytes)
-
- # 可编辑的输入框
- text_panels = st.columns(2)
- with text_panels[0]:
- text1 = st.text_area(tr("timestamp"), value=initial_timestamp, height=20,
- key=f"timestamp_{index}")
- with text_panels[1]:
- text2 = st.text_area(tr("Picture description"), value=initial_picture, height=20,
- key=f"picture_{index}")
- text3 = st.text_area(tr("Narration"), value=initial_narration, height=100,
- key=f"narration_{index}")
-
- # 重新生成按钮
- if st.button(tr("Rebuild"), key=f"rebuild_{index}"):
- # 更新video_list中的对应项
- video_list[index]['timestamp'] = text1
- video_list[index]['picture'] = text2
- video_list[index]['narration'] = text3
-
- for video in video_list:
- if 'path' in video:
- del video['path']
- # 更新session_state以确保更改被保存
- st.session_state['video_clip_json'] = utils.to_json(video_list)
- # 替换原JSON 文件
- with open(params.video_clip_json_path, 'w', encoding='utf-8') as file:
- json.dump(video_list, file, ensure_ascii=False, indent=4)
- utils.cut_video(params, progress_callback=None)
- st.rerun()
-
-# 开始按钮
-start_button = st.button(tr("Generate Video"), use_container_width=True, type="primary")
-if start_button:
- # 重置日志容器和记录
- log_container = st.empty()
- log_records = []
-
- config.save_config()
- task_id = st.session_state.get('task_id')
- if st.session_state.get('video_script_json_path') is not None:
- params.video_clip_json = st.session_state.get('video_clip_json')
-
- logger.debug(f"当前的脚本文件为:{st.session_state.video_clip_json_path}")
- logger.debug(f"当前的视频文件为:{st.session_state.video_origin_path}")
- logger.debug(f"裁剪后是视频列表:{st.session_state.subclip_videos}")
-
- if not task_id:
- st.error(tr("请先裁剪视频"))
- scroll_to_bottom()
- st.stop()
- if not params.video_clip_json_path:
- st.error(tr("脚本文件不能为空"))
- scroll_to_bottom()
- st.stop()
- if not params.video_origin_path:
- st.error(tr("视频文件不能为空"))
- scroll_to_bottom()
- st.stop()
-
- def log_received(msg):
- with log_container:
- log_records.append(msg)
- st.code("\n".join(log_records))
-
- logger.add(log_received)
-
- st.toast(tr("生成视频"))
- logger.info(tr("开始生成视频"))
- logger.info(utils.to_json(params))
- scroll_to_bottom()
-
- result = tm.start_subclip(task_id=task_id, params=params, subclip_path_videos=st.session_state.subclip_videos)
-
- video_files = result.get("videos", [])
- st.success(tr("视频生成完成"))
- try:
- if video_files:
- # 将视频播放器居中
- player_cols = st.columns(len(video_files) * 2 + 1)
- for i, url in enumerate(video_files):
- player_cols[i * 2 + 1].video(url)
- except Exception as e:
- pass
-
- open_task_folder(task_id)
- logger.info(tr("视频生成完成"))
- scroll_to_bottom()
-
-config.save_config()
+if __name__ == "__main__":
+ main()
diff --git a/webui/__init__.py b/webui/__init__.py
new file mode 100644
index 0000000..3c0a334
--- /dev/null
+++ b/webui/__init__.py
@@ -0,0 +1,22 @@
+"""
+NarratoAI WebUI Package
+"""
+from webui.config.settings import config
+from webui.components import (
+ basic_settings,
+ video_settings,
+ audio_settings,
+ subtitle_settings
+)
+from webui.utils import cache, file_utils, performance
+
+__all__ = [
+ 'config',
+ 'basic_settings',
+ 'video_settings',
+ 'audio_settings',
+ 'subtitle_settings',
+ 'cache',
+ 'file_utils',
+ 'performance'
+]
\ No newline at end of file
diff --git a/webui/components/__init__.py b/webui/components/__init__.py
new file mode 100644
index 0000000..6aafcd7
--- /dev/null
+++ b/webui/components/__init__.py
@@ -0,0 +1,15 @@
+from .basic_settings import render_basic_settings
+from .script_settings import render_script_panel
+from .video_settings import render_video_panel
+from .audio_settings import render_audio_panel
+from .subtitle_settings import render_subtitle_panel
+from .review_settings import render_review_panel
+
+__all__ = [
+ 'render_basic_settings',
+ 'render_script_panel',
+ 'render_video_panel',
+ 'render_audio_panel',
+ 'render_subtitle_panel',
+ 'render_review_panel'
+]
\ No newline at end of file
diff --git a/webui/components/audio_settings.py b/webui/components/audio_settings.py
new file mode 100644
index 0000000..a189f65
--- /dev/null
+++ b/webui/components/audio_settings.py
@@ -0,0 +1,198 @@
+import streamlit as st
+import os
+from uuid import uuid4
+from app.config import config
+from app.services import voice
+from app.utils import utils
+from webui.utils.cache import get_songs_cache
+
+def render_audio_panel(tr):
+ """渲染音频设置面板"""
+ with st.container(border=True):
+ st.write(tr("Audio Settings"))
+
+ # 渲染TTS设置
+ render_tts_settings(tr)
+
+ # 渲染背景音乐设置
+ render_bgm_settings(tr)
+
+def render_tts_settings(tr):
+ """渲染TTS(文本转语音)设置"""
+ # 获取支持的语音列表
+ support_locales = ["zh-CN", "zh-HK", "zh-TW", "en-US"]
+ voices = voice.get_all_azure_voices(filter_locals=support_locales)
+
+ # 创建友好的显示名称
+ friendly_names = {
+ v: v.replace("Female", tr("Female"))
+ .replace("Male", tr("Male"))
+ .replace("Neural", "")
+ for v in voices
+ }
+
+ # 获取保存的语音设置
+ saved_voice_name = config.ui.get("voice_name", "")
+ saved_voice_name_index = 0
+
+ if saved_voice_name in friendly_names:
+ saved_voice_name_index = list(friendly_names.keys()).index(saved_voice_name)
+ else:
+ # 如果没有保存的设置,选择与UI语言匹配的第一个语音
+ for i, v in enumerate(voices):
+ if (v.lower().startswith(st.session_state["ui_language"].lower())
+ and "V2" not in v):
+ saved_voice_name_index = i
+ break
+
+ # 语音选择下拉框
+ selected_friendly_name = st.selectbox(
+ tr("Speech Synthesis"),
+ options=list(friendly_names.values()),
+ index=saved_voice_name_index,
+ )
+
+ # 获取实际的语音名称
+ voice_name = list(friendly_names.keys())[
+ list(friendly_names.values()).index(selected_friendly_name)
+ ]
+
+ # 保存设置
+ config.ui["voice_name"] = voice_name
+
+ # Azure V2语音特殊处理
+ if voice.is_azure_v2_voice(voice_name):
+ render_azure_v2_settings(tr)
+
+ # 语音参数设置
+ render_voice_parameters(tr)
+
+ # 试听按钮
+ render_voice_preview(tr, voice_name)
+
+def render_azure_v2_settings(tr):
+ """渲染Azure V2语音设置"""
+ saved_azure_speech_region = config.azure.get("speech_region", "")
+ saved_azure_speech_key = config.azure.get("speech_key", "")
+
+ azure_speech_region = st.text_input(
+ tr("Speech Region"),
+ value=saved_azure_speech_region
+ )
+ azure_speech_key = st.text_input(
+ tr("Speech Key"),
+ value=saved_azure_speech_key,
+ type="password"
+ )
+
+ config.azure["speech_region"] = azure_speech_region
+ config.azure["speech_key"] = azure_speech_key
+
+def render_voice_parameters(tr):
+ """渲染语音参数设置"""
+ # 音量
+ voice_volume = st.selectbox(
+ tr("Speech Volume"),
+ options=[0.6, 0.8, 1.0, 1.2, 1.5, 2.0, 3.0, 4.0, 5.0],
+ index=2,
+ )
+ st.session_state['voice_volume'] = voice_volume
+
+ # 语速
+ voice_rate = st.selectbox(
+ tr("Speech Rate"),
+ options=[0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 1.8, 2.0],
+ index=2,
+ )
+ st.session_state['voice_rate'] = voice_rate
+
+ # 音调
+ voice_pitch = st.selectbox(
+ tr("Speech Pitch"),
+ options=[0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 1.8, 2.0],
+ index=2,
+ )
+ st.session_state['voice_pitch'] = voice_pitch
+
+def render_voice_preview(tr, voice_name):
+ """渲染语音试听功能"""
+ if st.button(tr("Play Voice")):
+ play_content = "感谢关注 NarratoAI,有任何问题或建议,可以关注微信公众号,求助或讨论"
+ if not play_content:
+ play_content = st.session_state.get('video_script', '')
+ if not play_content:
+ play_content = tr("Voice Example")
+
+ with st.spinner(tr("Synthesizing Voice")):
+ temp_dir = utils.storage_dir("temp", create=True)
+ audio_file = os.path.join(temp_dir, f"tmp-voice-{str(uuid4())}.mp3")
+
+ sub_maker = voice.tts(
+ text=play_content,
+ voice_name=voice_name,
+ voice_rate=st.session_state.get('voice_rate', 1.0),
+ voice_pitch=st.session_state.get('voice_pitch', 1.0),
+ voice_file=audio_file,
+ )
+
+ # 如果语音文件生成失败,使用默认内容重试
+ if not sub_maker:
+ play_content = "This is a example voice. if you hear this, the voice synthesis failed with the original content."
+ sub_maker = voice.tts(
+ text=play_content,
+ voice_name=voice_name,
+ voice_rate=st.session_state.get('voice_rate', 1.0),
+ voice_pitch=st.session_state.get('voice_pitch', 1.0),
+ voice_file=audio_file,
+ )
+
+ if sub_maker and os.path.exists(audio_file):
+ st.audio(audio_file, format="audio/mp3")
+ if os.path.exists(audio_file):
+ os.remove(audio_file)
+
+def render_bgm_settings(tr):
+ """渲染背景音乐设置"""
+ # 背景音乐选项
+ bgm_options = [
+ (tr("No Background Music"), ""),
+ (tr("Random Background Music"), "random"),
+ (tr("Custom Background Music"), "custom"),
+ ]
+
+ selected_index = st.selectbox(
+ tr("Background Music"),
+ index=1,
+ options=range(len(bgm_options)),
+ format_func=lambda x: bgm_options[x][0],
+ )
+
+ # 获取选择的背景音乐类型
+ bgm_type = bgm_options[selected_index][1]
+ st.session_state['bgm_type'] = bgm_type
+
+ # 自定义背景音乐处理
+ if bgm_type == "custom":
+ custom_bgm_file = st.text_input(tr("Custom Background Music File"))
+ if custom_bgm_file and os.path.exists(custom_bgm_file):
+ st.session_state['bgm_file'] = custom_bgm_file
+
+ # 背景音乐音量
+ bgm_volume = st.selectbox(
+ tr("Background Music Volume"),
+ options=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
+ index=2,
+ )
+ st.session_state['bgm_volume'] = bgm_volume
+
+def get_audio_params():
+ """获取音频参数"""
+ return {
+ 'voice_name': config.ui.get("voice_name", ""),
+ 'voice_volume': st.session_state.get('voice_volume', 1.0),
+ 'voice_rate': st.session_state.get('voice_rate', 1.0),
+ 'voice_pitch': st.session_state.get('voice_pitch', 1.0),
+ 'bgm_type': st.session_state.get('bgm_type', 'random'),
+ 'bgm_file': st.session_state.get('bgm_file', ''),
+ 'bgm_volume': st.session_state.get('bgm_volume', 0.2),
+ }
\ No newline at end of file
diff --git a/webui/components/basic_settings.py b/webui/components/basic_settings.py
new file mode 100644
index 0000000..f8dd65d
--- /dev/null
+++ b/webui/components/basic_settings.py
@@ -0,0 +1,142 @@
+import streamlit as st
+import os
+from app.config import config
+from app.utils import utils
+
+def render_basic_settings(tr):
+ """渲染基础设置面板"""
+ with st.expander(tr("Basic Settings"), expanded=False):
+ config_panels = st.columns(3)
+ left_config_panel = config_panels[0]
+ middle_config_panel = config_panels[1]
+ right_config_panel = config_panels[2]
+
+ with left_config_panel:
+ render_language_settings(tr)
+ render_proxy_settings(tr)
+
+ with middle_config_panel:
+ render_video_llm_settings(tr)
+
+ with right_config_panel:
+ render_llm_settings(tr)
+
+def render_language_settings(tr):
+ """渲染语言设置"""
+ system_locale = utils.get_system_locale()
+ i18n_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "i18n")
+ locales = utils.load_locales(i18n_dir)
+
+ display_languages = []
+ selected_index = 0
+ for i, code in enumerate(locales.keys()):
+ display_languages.append(f"{code} - {locales[code].get('Language')}")
+ if code == st.session_state.get('ui_language', system_locale):
+ selected_index = i
+
+ selected_language = st.selectbox(
+ tr("Language"),
+ options=display_languages,
+ index=selected_index
+ )
+
+ if selected_language:
+ code = selected_language.split(" - ")[0].strip()
+ st.session_state['ui_language'] = code
+ config.ui['language'] = code
+
+def render_proxy_settings(tr):
+ """渲染代理设置"""
+ proxy_url_http = config.proxy.get("http", "") or os.getenv("VPN_PROXY_URL", "")
+ proxy_url_https = config.proxy.get("https", "") or os.getenv("VPN_PROXY_URL", "")
+
+ HTTP_PROXY = st.text_input(tr("HTTP_PROXY"), value=proxy_url_http)
+ HTTPS_PROXY = st.text_input(tr("HTTPs_PROXY"), value=proxy_url_https)
+
+ if HTTP_PROXY:
+ config.proxy["http"] = HTTP_PROXY
+ os.environ["HTTP_PROXY"] = HTTP_PROXY
+ if HTTPS_PROXY:
+ config.proxy["https"] = HTTPS_PROXY
+ os.environ["HTTPS_PROXY"] = HTTPS_PROXY
+
+def render_video_llm_settings(tr):
+ """渲染视频LLM设置"""
+ video_llm_providers = ['Gemini', 'NarratoAPI']
+ saved_llm_provider = config.app.get("video_llm_provider", "OpenAI").lower()
+ saved_llm_provider_index = 0
+
+ for i, provider in enumerate(video_llm_providers):
+ if provider.lower() == saved_llm_provider:
+ saved_llm_provider_index = i
+ break
+
+ video_llm_provider = st.selectbox(
+ tr("Video LLM Provider"),
+ options=video_llm_providers,
+ index=saved_llm_provider_index
+ )
+ video_llm_provider = video_llm_provider.lower()
+ config.app["video_llm_provider"] = video_llm_provider
+
+ # 获取已保存的配置
+ video_llm_api_key = config.app.get(f"{video_llm_provider}_api_key", "")
+ video_llm_base_url = config.app.get(f"{video_llm_provider}_base_url", "")
+ video_llm_model_name = config.app.get(f"{video_llm_provider}_model_name", "")
+
+ # 渲染输入框
+ st_llm_api_key = st.text_input(tr("Video API Key"), value=video_llm_api_key, type="password")
+ st_llm_base_url = st.text_input(tr("Video Base Url"), value=video_llm_base_url)
+ st_llm_model_name = st.text_input(tr("Video Model Name"), value=video_llm_model_name)
+
+ # 保存配置
+ if st_llm_api_key:
+ config.app[f"{video_llm_provider}_api_key"] = st_llm_api_key
+ if st_llm_base_url:
+ config.app[f"{video_llm_provider}_base_url"] = st_llm_base_url
+ if st_llm_model_name:
+ config.app[f"{video_llm_provider}_model_name"] = st_llm_model_name
+
+def render_llm_settings(tr):
+ """渲染LLM设置"""
+ llm_providers = ['Gemini', 'OpenAI', 'Moonshot', 'Azure', 'Qwen', 'Ollama', 'G4f', 'OneAPI', "Cloudflare"]
+ saved_llm_provider = config.app.get("llm_provider", "OpenAI").lower()
+ saved_llm_provider_index = 0
+
+ for i, provider in enumerate(llm_providers):
+ if provider.lower() == saved_llm_provider:
+ saved_llm_provider_index = i
+ break
+
+ llm_provider = st.selectbox(
+ tr("LLM Provider"),
+ options=llm_providers,
+ index=saved_llm_provider_index
+ )
+ llm_provider = llm_provider.lower()
+ config.app["llm_provider"] = llm_provider
+
+ # 获取已保存的配置
+ llm_api_key = config.app.get(f"{llm_provider}_api_key", "")
+ llm_base_url = config.app.get(f"{llm_provider}_base_url", "")
+ llm_model_name = config.app.get(f"{llm_provider}_model_name", "")
+ llm_account_id = config.app.get(f"{llm_provider}_account_id", "")
+
+ # 渲染输入框
+ st_llm_api_key = st.text_input(tr("API Key"), value=llm_api_key, type="password")
+ st_llm_base_url = st.text_input(tr("Base Url"), value=llm_base_url)
+ st_llm_model_name = st.text_input(tr("Model Name"), value=llm_model_name)
+
+ # 保存配置
+ if st_llm_api_key:
+ config.app[f"{llm_provider}_api_key"] = st_llm_api_key
+ if st_llm_base_url:
+ config.app[f"{llm_provider}_base_url"] = st_llm_base_url
+ if st_llm_model_name:
+ config.app[f"{llm_provider}_model_name"] = st_llm_model_name
+
+ # Cloudflare 特殊处理
+ if llm_provider == 'cloudflare':
+ st_llm_account_id = st.text_input(tr("Account ID"), value=llm_account_id)
+ if st_llm_account_id:
+ config.app[f"{llm_provider}_account_id"] = st_llm_account_id
\ No newline at end of file
diff --git a/webui/components/review_settings.py b/webui/components/review_settings.py
new file mode 100644
index 0000000..43f3844
--- /dev/null
+++ b/webui/components/review_settings.py
@@ -0,0 +1,65 @@
+import streamlit as st
+import os
+from loguru import logger
+
+def render_review_panel(tr):
+ """渲染视频审查面板"""
+ with st.expander(tr("Video Check"), expanded=False):
+ try:
+ video_list = st.session_state.get('video_clip_json', [])
+ except KeyError:
+ video_list = []
+
+ # 计算列数和行数
+ num_videos = len(video_list)
+ cols_per_row = 3
+ rows = (num_videos + cols_per_row - 1) // cols_per_row # 向上取整计算行数
+
+ # 使用容器展示视频
+ for row in range(rows):
+ cols = st.columns(cols_per_row)
+ for col in range(cols_per_row):
+ index = row * cols_per_row + col
+ if index < num_videos:
+ with cols[col]:
+ render_video_item(tr, video_list, index)
+
+def render_video_item(tr, video_list, index):
+ """渲染单个视频项"""
+ video_info = video_list[index]
+ video_path = video_info.get('path')
+ if video_path is not None and os.path.exists(video_path):
+ initial_narration = video_info.get('narration', '')
+ initial_picture = video_info.get('picture', '')
+ initial_timestamp = video_info.get('timestamp', '')
+
+ # 显示视频
+ with open(video_path, 'rb') as video_file:
+ video_bytes = video_file.read()
+ st.video(video_bytes)
+
+ # 显示信息(只读)
+ text_panels = st.columns(2)
+ with text_panels[0]:
+ st.text_area(
+ tr("timestamp"),
+ value=initial_timestamp,
+ height=20,
+ key=f"timestamp_{index}",
+ disabled=True
+ )
+ with text_panels[1]:
+ st.text_area(
+ tr("Picture description"),
+ value=initial_picture,
+ height=20,
+ key=f"picture_{index}",
+ disabled=True
+ )
+ st.text_area(
+ tr("Narration"),
+ value=initial_narration,
+ height=100,
+ key=f"narration_{index}",
+ disabled=True
+ )
\ No newline at end of file
diff --git a/webui/components/script_settings.py b/webui/components/script_settings.py
new file mode 100644
index 0000000..cc217d8
--- /dev/null
+++ b/webui/components/script_settings.py
@@ -0,0 +1,314 @@
+import streamlit as st
+import os
+import glob
+import json
+import time
+from app.config import config
+from app.models.schema import VideoClipParams
+from app.services import llm
+from app.utils import utils, check_script
+from loguru import logger
+from webui.utils import file_utils
+
+def render_script_panel(tr):
+ """渲染脚本配置面板"""
+ with st.container(border=True):
+ st.write(tr("Video Script Configuration"))
+ params = VideoClipParams()
+
+ # 渲染脚本文件选择
+ render_script_file(tr, params)
+
+ # 渲染视频文件选择
+ render_video_file(tr, params)
+
+ # 渲染视频主题和提示词
+ render_video_details(tr)
+
+ # 渲染脚本操作按钮
+ render_script_buttons(tr, params)
+
+def render_script_file(tr, params):
+ """渲染脚本文件选择"""
+ script_list = [(tr("None"), ""), (tr("Auto Generate"), "auto")]
+
+ # 获取已有脚本文件
+ suffix = "*.json"
+ script_dir = utils.script_dir()
+ files = glob.glob(os.path.join(script_dir, suffix))
+ file_list = []
+
+ for file in files:
+ file_list.append({
+ "name": os.path.basename(file),
+ "file": file,
+ "ctime": os.path.getctime(file)
+ })
+
+ file_list.sort(key=lambda x: x["ctime"], reverse=True)
+ for file in file_list:
+ display_name = file['file'].replace(config.root_dir, "")
+ script_list.append((display_name, file['file']))
+
+ # 找到保存的脚本文件在列表中的索引
+ saved_script_path = st.session_state.get('video_clip_json_path', '')
+ selected_index = 0
+ for i, (_, path) in enumerate(script_list):
+ if path == saved_script_path:
+ selected_index = i
+ break
+
+ selected_script_index = st.selectbox(
+ tr("Script Files"),
+ index=selected_index, # 使用找到的索引
+ options=range(len(script_list)),
+ format_func=lambda x: script_list[x][0]
+ )
+
+ script_path = script_list[selected_script_index][1]
+ st.session_state['video_clip_json_path'] = script_path
+ params.video_clip_json_path = script_path
+
+def render_video_file(tr, params):
+ """渲染视频文件选择"""
+ video_list = [(tr("None"), ""), (tr("Upload Local Files"), "local")]
+
+ # 获取已有视频文件
+ for suffix in ["*.mp4", "*.mov", "*.avi", "*.mkv"]:
+ video_files = glob.glob(os.path.join(utils.video_dir(), suffix))
+ for file in video_files:
+ display_name = file.replace(config.root_dir, "")
+ video_list.append((display_name, file))
+
+ selected_video_index = st.selectbox(
+ tr("Video File"),
+ index=0,
+ options=range(len(video_list)),
+ format_func=lambda x: video_list[x][0]
+ )
+
+ video_path = video_list[selected_video_index][1]
+ st.session_state['video_origin_path'] = video_path
+ params.video_origin_path = video_path
+
+ if video_path == "local":
+ uploaded_file = st.file_uploader(
+ tr("Upload Local Files"),
+ type=["mp4", "mov", "avi", "flv", "mkv"],
+ accept_multiple_files=False,
+ )
+
+ if uploaded_file is not None:
+ video_file_path = os.path.join(utils.video_dir(), uploaded_file.name)
+ file_name, file_extension = os.path.splitext(uploaded_file.name)
+
+ if os.path.exists(video_file_path):
+ timestamp = time.strftime("%Y%m%d%H%M%S")
+ file_name_with_timestamp = f"{file_name}_{timestamp}"
+ video_file_path = os.path.join(utils.video_dir(), file_name_with_timestamp + file_extension)
+
+ with open(video_file_path, "wb") as f:
+ f.write(uploaded_file.read())
+ st.success(tr("File Uploaded Successfully"))
+ st.session_state['video_origin_path'] = video_file_path
+ params.video_origin_path = video_file_path
+ time.sleep(1)
+ st.rerun()
+
+def render_video_details(tr):
+ """渲染视频主题和提示词"""
+ video_theme = st.text_input(tr("Video Theme"))
+ prompt = st.text_area(
+ tr("Generation Prompt"),
+ value=st.session_state.get('video_plot', ''),
+ help=tr("Custom prompt for LLM, leave empty to use default prompt"),
+ height=180
+ )
+ st.session_state['video_name'] = video_theme
+ st.session_state['video_plot'] = prompt
+ return video_theme, prompt
+
+def render_script_buttons(tr, params):
+ """渲染脚本操作按钮"""
+ # 生成/加载按钮
+ script_path = st.session_state.get('video_clip_json_path', '')
+ if script_path == "auto":
+ button_name = tr("Generate Video Script")
+ elif script_path:
+ button_name = tr("Load Video Script")
+ else:
+ button_name = tr("Please Select Script File")
+
+ if st.button(button_name, key="script_action", disabled=not script_path):
+ if script_path == "auto":
+ generate_script(tr, params)
+ else:
+ load_script(tr, script_path)
+
+ # 视频脚本编辑区
+ video_clip_json_details = st.text_area(
+ tr("Video Script"),
+ value=json.dumps(st.session_state.get('video_clip_json', []), indent=2, ensure_ascii=False),
+ height=180
+ )
+
+ # 操作按钮行
+ button_cols = st.columns(3)
+ with button_cols[0]:
+ if st.button(tr("Check Format"), key="check_format", use_container_width=True):
+ check_script_format(tr, video_clip_json_details)
+
+ with button_cols[1]:
+ if st.button(tr("Save Script"), key="save_script", use_container_width=True):
+ save_script(tr, video_clip_json_details)
+
+ with button_cols[2]:
+ script_valid = st.session_state.get('script_format_valid', False)
+ if st.button(tr("Crop Video"), key="crop_video", disabled=not script_valid, use_container_width=True):
+ crop_video(tr, params)
+
+def check_script_format(tr, script_content):
+ """检查脚本格式"""
+ try:
+ result = check_script.check_format(script_content)
+ if result.get('success'):
+ st.success(tr("Script format check passed"))
+ st.session_state['script_format_valid'] = True
+ else:
+ st.error(f"{tr('Script format check failed')}: {result.get('message')}")
+ st.session_state['script_format_valid'] = False
+ except Exception as e:
+ st.error(f"{tr('Script format check error')}: {str(e)}")
+ st.session_state['script_format_valid'] = False
+
+def load_script(tr, script_path):
+ """加载脚本文件"""
+ try:
+ with open(script_path, 'r', encoding='utf-8') as f:
+ script = f.read()
+ script = utils.clean_model_output(script)
+ st.session_state['video_clip_json'] = json.loads(script)
+ st.success(tr("Script loaded successfully"))
+ st.rerun()
+ except Exception as e:
+ st.error(f"{tr('Failed to load script')}: {str(e)}")
+
+def generate_script(tr, params):
+ """生成视频脚本"""
+ progress_bar = st.progress(0)
+ status_text = st.empty()
+
+ def update_progress(progress: float, message: str = ""):
+ progress_bar.progress(progress)
+ if message:
+ status_text.text(f"{progress}% - {message}")
+ else:
+ status_text.text(f"进度: {progress}%")
+
+ try:
+ with st.spinner("正在生成脚本..."):
+ if not st.session_state.get('video_plot'):
+ st.warning("视频剧情为空; 会极大影响生成效果!")
+
+ if params.video_clip_json_path == "" and params.video_origin_path != "":
+ update_progress(10, "压缩视频中...")
+ script = llm.generate_script(
+ video_path=params.video_origin_path,
+ video_plot=st.session_state.get('video_plot', ''),
+ video_name=st.session_state.get('video_name', ''),
+ language=params.video_language,
+ progress_callback=update_progress
+ )
+ if script is None:
+ st.error("生成脚本失败,请检查日志")
+ st.stop()
+ else:
+ update_progress(90)
+
+ script = utils.clean_model_output(script)
+ st.session_state['video_clip_json'] = json.loads(script)
+ else:
+ # 从本地加载
+ with open(params.video_clip_json_path, 'r', encoding='utf-8') as f:
+ update_progress(50)
+ status_text.text("从本地加载中...")
+ script = f.read()
+ script = utils.clean_model_output(script)
+ st.session_state['video_clip_json'] = json.loads(script)
+ update_progress(100)
+ status_text.text("从本地加载成功")
+
+ time.sleep(0.5)
+ progress_bar.progress(100)
+ status_text.text("脚本生成完成!")
+ st.success("视频脚本生成成功!")
+ except Exception as err:
+ st.error(f"生成过程中发生错误: {str(err)}")
+ finally:
+ time.sleep(2)
+ progress_bar.empty()
+ status_text.empty()
+
+def save_script(tr, video_clip_json_details):
+ """保存视频脚本"""
+ if not video_clip_json_details:
+ st.error(tr("请输入视频脚本"))
+ st.stop()
+
+ with st.spinner(tr("Save Script")):
+ script_dir = utils.script_dir()
+ timestamp = time.strftime("%Y-%m%d-%H%M%S")
+ save_path = os.path.join(script_dir, f"{timestamp}.json")
+
+ try:
+ data = json.loads(video_clip_json_details)
+ with open(save_path, 'w', encoding='utf-8') as file:
+ json.dump(data, file, ensure_ascii=False, indent=4)
+ st.session_state['video_clip_json'] = data
+ st.session_state['video_clip_json_path'] = save_path
+
+ # 更新配置
+ config.app["video_clip_json_path"] = save_path
+
+ # 显示成功消息
+ st.success(tr("Script saved successfully"))
+
+ # 强制重新加载页面以更新选择框
+ time.sleep(0.5) # 给一点时间让用户看到成功消息
+ st.rerun()
+
+ except Exception as err:
+ st.error(f"{tr('Failed to save script')}: {str(err)}")
+ st.stop()
+
+def crop_video(tr, params):
+ """裁剪视频"""
+ progress_bar = st.progress(0)
+ status_text = st.empty()
+
+ def update_progress(progress):
+ progress_bar.progress(progress)
+ status_text.text(f"剪辑进度: {progress}%")
+
+ try:
+ utils.cut_video(params, update_progress)
+ time.sleep(0.5)
+ progress_bar.progress(100)
+ status_text.text("剪辑完成!")
+ st.success("视频剪辑成功完成!")
+ except Exception as e:
+ st.error(f"剪辑过程中发生错误: {str(e)}")
+ finally:
+ time.sleep(2)
+ progress_bar.empty()
+ status_text.empty()
+
+def get_script_params():
+ """获取脚本参数"""
+ return {
+ 'video_language': st.session_state.get('video_language', ''),
+ 'video_clip_json_path': st.session_state.get('video_clip_json_path', ''),
+ 'video_origin_path': st.session_state.get('video_origin_path', ''),
+ 'video_name': st.session_state.get('video_name', ''),
+ 'video_plot': st.session_state.get('video_plot', '')
+ }
\ No newline at end of file
diff --git a/webui/components/subtitle_settings.py b/webui/components/subtitle_settings.py
new file mode 100644
index 0000000..9b94e3c
--- /dev/null
+++ b/webui/components/subtitle_settings.py
@@ -0,0 +1,129 @@
+import streamlit as st
+from app.config import config
+from webui.utils.cache import get_fonts_cache
+import os
+
+def render_subtitle_panel(tr):
+ """渲染字幕设置面板"""
+ with st.container(border=True):
+ st.write(tr("Subtitle Settings"))
+
+ # 启用字幕选项
+ enable_subtitles = st.checkbox(tr("Enable Subtitles"), value=True)
+ st.session_state['subtitle_enabled'] = enable_subtitles
+
+ if enable_subtitles:
+ render_font_settings(tr)
+ render_position_settings(tr)
+ render_style_settings(tr)
+
+def render_font_settings(tr):
+ """渲染字体设置"""
+ # 获取字体列表
+ font_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "resource", "fonts")
+ font_names = get_fonts_cache(font_dir)
+
+ # 获取保存的字体设置
+ saved_font_name = config.ui.get("font_name", "")
+ saved_font_name_index = 0
+ if saved_font_name in font_names:
+ saved_font_name_index = font_names.index(saved_font_name)
+
+ # 字体选择
+ font_name = st.selectbox(
+ tr("Font"),
+ options=font_names,
+ index=saved_font_name_index
+ )
+ config.ui["font_name"] = font_name
+ st.session_state['font_name'] = font_name
+
+ # 字体大小
+ font_cols = st.columns([0.3, 0.7])
+ with font_cols[0]:
+ saved_text_fore_color = config.ui.get("text_fore_color", "#FFFFFF")
+ text_fore_color = st.color_picker(
+ tr("Font Color"),
+ saved_text_fore_color
+ )
+ config.ui["text_fore_color"] = text_fore_color
+ st.session_state['text_fore_color'] = text_fore_color
+
+ with font_cols[1]:
+ saved_font_size = config.ui.get("font_size", 60)
+ font_size = st.slider(
+ tr("Font Size"),
+ min_value=30,
+ max_value=100,
+ value=saved_font_size
+ )
+ config.ui["font_size"] = font_size
+ st.session_state['font_size'] = font_size
+
+def render_position_settings(tr):
+ """渲染位置设置"""
+ subtitle_positions = [
+ (tr("Top"), "top"),
+ (tr("Center"), "center"),
+ (tr("Bottom"), "bottom"),
+ (tr("Custom"), "custom"),
+ ]
+
+ selected_index = st.selectbox(
+ tr("Position"),
+ index=2,
+ options=range(len(subtitle_positions)),
+ format_func=lambda x: subtitle_positions[x][0],
+ )
+
+ subtitle_position = subtitle_positions[selected_index][1]
+ st.session_state['subtitle_position'] = subtitle_position
+
+ # 自定义位置处理
+ if subtitle_position == "custom":
+ custom_position = st.text_input(
+ tr("Custom Position (% from top)"),
+ value="70.0"
+ )
+ try:
+ custom_position_value = float(custom_position)
+ if custom_position_value < 0 or custom_position_value > 100:
+ st.error(tr("Please enter a value between 0 and 100"))
+ else:
+ st.session_state['custom_position'] = custom_position_value
+ except ValueError:
+ st.error(tr("Please enter a valid number"))
+
+def render_style_settings(tr):
+ """渲染样式设置"""
+ stroke_cols = st.columns([0.3, 0.7])
+
+ with stroke_cols[0]:
+ stroke_color = st.color_picker(
+ tr("Stroke Color"),
+ value="#000000"
+ )
+ st.session_state['stroke_color'] = stroke_color
+
+ with stroke_cols[1]:
+ stroke_width = st.slider(
+ tr("Stroke Width"),
+ min_value=0.0,
+ max_value=10.0,
+ value=1.5,
+ step=0.1
+ )
+ st.session_state['stroke_width'] = stroke_width
+
+def get_subtitle_params():
+ """获取字幕参数"""
+ return {
+ 'enabled': st.session_state.get('subtitle_enabled', True),
+ 'font_name': st.session_state.get('font_name', ''),
+ 'font_size': st.session_state.get('font_size', 60),
+ 'text_fore_color': st.session_state.get('text_fore_color', '#FFFFFF'),
+ 'position': st.session_state.get('subtitle_position', 'bottom'),
+ 'custom_position': st.session_state.get('custom_position', 70.0),
+ 'stroke_color': st.session_state.get('stroke_color', '#000000'),
+ 'stroke_width': st.session_state.get('stroke_width', 1.5),
+ }
\ No newline at end of file
diff --git a/webui/components/video_settings.py b/webui/components/video_settings.py
new file mode 100644
index 0000000..7942bee
--- /dev/null
+++ b/webui/components/video_settings.py
@@ -0,0 +1,47 @@
+import streamlit as st
+from app.models.schema import VideoClipParams, VideoAspect
+
+def render_video_panel(tr):
+ """渲染视频配置面板"""
+ with st.container(border=True):
+ st.write(tr("Video Settings"))
+ params = VideoClipParams()
+ render_video_config(tr, params)
+
+def render_video_config(tr, params):
+ """渲染视频配置"""
+ # 视频比例
+ video_aspect_ratios = [
+ (tr("Portrait"), VideoAspect.portrait.value),
+ (tr("Landscape"), VideoAspect.landscape.value),
+ ]
+ selected_index = st.selectbox(
+ tr("Video Ratio"),
+ options=range(len(video_aspect_ratios)),
+ format_func=lambda x: video_aspect_ratios[x][0],
+ )
+ params.video_aspect = VideoAspect(video_aspect_ratios[selected_index][1])
+ st.session_state['video_aspect'] = params.video_aspect.value
+
+ # 视频画质
+ video_qualities = [
+ ("4K (2160p)", "2160p"),
+ ("2K (1440p)", "1440p"),
+ ("Full HD (1080p)", "1080p"),
+ ("HD (720p)", "720p"),
+ ("SD (480p)", "480p"),
+ ]
+ quality_index = st.selectbox(
+ tr("Video Quality"),
+ options=range(len(video_qualities)),
+ format_func=lambda x: video_qualities[x][0],
+ index=2 # 默认选择 1080p
+ )
+ st.session_state['video_quality'] = video_qualities[quality_index][1]
+
+def get_video_params():
+ """获取视频参数"""
+ return {
+ 'video_aspect': st.session_state.get('video_aspect', VideoAspect.portrait.value),
+ 'video_quality': st.session_state.get('video_quality', '1080p')
+ }
\ No newline at end of file
diff --git a/webui/config/settings.py b/webui/config/settings.py
new file mode 100644
index 0000000..6ad4db3
--- /dev/null
+++ b/webui/config/settings.py
@@ -0,0 +1,155 @@
+import os
+import tomli
+from loguru import logger
+from typing import Dict, Any, Optional
+from dataclasses import dataclass
+
+@dataclass
+class WebUIConfig:
+ """WebUI配置类"""
+ # UI配置
+ ui: Dict[str, Any] = None
+ # 代理配置
+ proxy: Dict[str, str] = None
+ # 应用配置
+ app: Dict[str, Any] = None
+ # Azure配置
+ azure: Dict[str, str] = None
+ # 项目版本
+ project_version: str = "0.1.0"
+ # 项目根目录
+ root_dir: str = None
+
+ def __post_init__(self):
+ """初始化默认值"""
+ self.ui = self.ui or {}
+ self.proxy = self.proxy or {}
+ self.app = self.app or {}
+ self.azure = self.azure or {}
+ self.root_dir = self.root_dir or os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
+
+def load_config(config_path: Optional[str] = None) -> WebUIConfig:
+ """加载配置文件
+ Args:
+ config_path: 配置文件路径,如果为None则使用默认路径
+ Returns:
+ WebUIConfig: 配置对象
+ """
+ try:
+ if config_path is None:
+ config_path = os.path.join(
+ os.path.dirname(os.path.dirname(__file__)),
+ ".streamlit",
+ "webui.toml"
+ )
+
+ # 如果配置文件不存在,使用示例配置
+ if not os.path.exists(config_path):
+ example_config = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
+ "config.example.toml"
+ )
+ if os.path.exists(example_config):
+ config_path = example_config
+ else:
+ logger.warning(f"配置文件不存在: {config_path}")
+ return WebUIConfig()
+
+ # 读取配置文件
+ with open(config_path, "rb") as f:
+ config_dict = tomli.load(f)
+
+ # 创建配置对象
+ config = WebUIConfig(
+ ui=config_dict.get("ui", {}),
+ proxy=config_dict.get("proxy", {}),
+ app=config_dict.get("app", {}),
+ azure=config_dict.get("azure", {}),
+ project_version=config_dict.get("project_version", "0.1.0")
+ )
+
+ return config
+
+ except Exception as e:
+ logger.error(f"加载配置文件失败: {e}")
+ return WebUIConfig()
+
+def save_config(config: WebUIConfig, config_path: Optional[str] = None) -> bool:
+ """保存配置到文件
+ Args:
+ config: 配置对象
+ config_path: 配置文件路径,如果为None则使用默认路径
+ Returns:
+ bool: 是否保存成功
+ """
+ try:
+ if config_path is None:
+ config_path = os.path.join(
+ os.path.dirname(os.path.dirname(__file__)),
+ ".streamlit",
+ "webui.toml"
+ )
+
+ # 确保目录存在
+ os.makedirs(os.path.dirname(config_path), exist_ok=True)
+
+ # 转换为字典
+ config_dict = {
+ "ui": config.ui,
+ "proxy": config.proxy,
+ "app": config.app,
+ "azure": config.azure,
+ "project_version": config.project_version
+ }
+
+ # 保存配置
+ with open(config_path, "w", encoding="utf-8") as f:
+ import tomli_w
+ tomli_w.dump(config_dict, f)
+
+ return True
+
+ except Exception as e:
+ logger.error(f"保存配置文件失败: {e}")
+ return False
+
+def get_config() -> WebUIConfig:
+ """获取全局配置对象
+ Returns:
+ WebUIConfig: 配置对象
+ """
+ if not hasattr(get_config, "_config"):
+ get_config._config = load_config()
+ return get_config._config
+
+def update_config(config_dict: Dict[str, Any]) -> bool:
+ """更新配置
+ Args:
+ config_dict: 配置字典
+ Returns:
+ bool: 是否更新成功
+ """
+ try:
+ config = get_config()
+
+ # 更新配置
+ if "ui" in config_dict:
+ config.ui.update(config_dict["ui"])
+ if "proxy" in config_dict:
+ config.proxy.update(config_dict["proxy"])
+ if "app" in config_dict:
+ config.app.update(config_dict["app"])
+ if "azure" in config_dict:
+ config.azure.update(config_dict["azure"])
+ if "project_version" in config_dict:
+ config.project_version = config_dict["project_version"]
+
+ # 保存配置
+ return save_config(config)
+
+ except Exception as e:
+ logger.error(f"更新配置失败: {e}")
+ return False
+
+# 导出全局配置对象
+config = get_config()
\ No newline at end of file
diff --git a/webui/i18n/__init__.py b/webui/i18n/__init__.py
new file mode 100644
index 0000000..0f05c76
--- /dev/null
+++ b/webui/i18n/__init__.py
@@ -0,0 +1 @@
+# 空文件,用于标记包
\ No newline at end of file
diff --git a/webui/i18n/zh.json b/webui/i18n/zh.json
index f1bc6b2..0481306 100644
--- a/webui/i18n/zh.json
+++ b/webui/i18n/zh.json
@@ -2,15 +2,15 @@
"Language": "简体中文",
"Translation": {
"Video Script Configuration": "**视频脚本配置**",
- "Video Script Generate": "生成视频脚本",
+ "Generate Video Script": "生成视频脚本",
"Video Subject": "视频主题(给定一个关键词,:red[AI自动生成]视频文案)",
"Script Language": "生成视频脚本的语言(一般情况AI会自动根据你输入的主题语言输出)",
"Script Files": "脚本文件",
"Generate Video Script and Keywords": "点击使用AI根据**主题**生成 【视频文案】 和 【视频关键词】",
"Auto Detect": "自动检测",
"Auto Generate": "自动生成",
- "Video Name": "视频名称",
- "Video Script": "视频脚本(:blue[①使用AI生成 ②从本机加载])",
+ "Video Theme": "视频主题",
+ "Generation Prompt": "自定义提示词",
"Save Script": "保存脚本",
"Crop Video": "裁剪视频",
"Video File": "视频文件(:blue[1️⃣支持上传视频文件(限制2G) 2️⃣大文件建议直接导入 ./resource/videos 目录])",
@@ -91,7 +91,18 @@
"Picture description": "图片描述",
"Narration": "视频文案",
"Rebuild": "重新生成",
- "Video Script Load": "加载视频脚本",
- "Speech Pitch": "语调"
+ "Load Video Script": "加载视频脚本",
+ "Speech Pitch": "语调",
+ "Please Select Script File": "请选择脚本文件",
+ "Check Format": "脚本格式检查",
+ "Script Loaded Successfully": "脚本加载成功",
+ "Script format check passed": "脚本格式检查通过",
+ "Script format check failed": "脚本格式检查失败",
+ "Failed to Load Script": "加载脚本失败",
+ "Failed to Save Script": "保存脚本失败",
+ "Script saved successfully": "脚本保存成功",
+ "Video Script": "视频脚本",
+ "Video Quality": "视频质量",
+ "Custom prompt for LLM, leave empty to use default prompt": "自定义提示词,留空则使用默认提示词"
}
}
\ No newline at end of file
diff --git a/webui/utils/__init__.py b/webui/utils/__init__.py
new file mode 100644
index 0000000..b6a1870
--- /dev/null
+++ b/webui/utils/__init__.py
@@ -0,0 +1,20 @@
+from .cache import get_fonts_cache, get_video_files_cache, get_songs_cache
+from .file_utils import (
+ open_task_folder, cleanup_temp_files, get_file_list,
+ save_uploaded_file, create_temp_file, get_file_size, ensure_directory
+)
+from .performance import monitor_performance
+
+__all__ = [
+ 'get_fonts_cache',
+ 'get_video_files_cache',
+ 'get_songs_cache',
+ 'open_task_folder',
+ 'cleanup_temp_files',
+ 'get_file_list',
+ 'save_uploaded_file',
+ 'create_temp_file',
+ 'get_file_size',
+ 'ensure_directory',
+ 'monitor_performance'
+]
\ No newline at end of file
diff --git a/webui/utils/cache.py b/webui/utils/cache.py
new file mode 100644
index 0000000..6cc3b05
--- /dev/null
+++ b/webui/utils/cache.py
@@ -0,0 +1,33 @@
+import streamlit as st
+import os
+import glob
+from app.utils import utils
+
+def get_fonts_cache(font_dir):
+ if 'fonts_cache' not in st.session_state:
+ fonts = []
+ for root, dirs, files in os.walk(font_dir):
+ for file in files:
+ if file.endswith(".ttf") or file.endswith(".ttc"):
+ fonts.append(file)
+ fonts.sort()
+ st.session_state['fonts_cache'] = fonts
+ return st.session_state['fonts_cache']
+
+def get_video_files_cache():
+ if 'video_files_cache' not in st.session_state:
+ video_files = []
+ for suffix in ["*.mp4", "*.mov", "*.avi", "*.mkv"]:
+ video_files.extend(glob.glob(os.path.join(utils.video_dir(), suffix)))
+ st.session_state['video_files_cache'] = video_files[::-1]
+ return st.session_state['video_files_cache']
+
+def get_songs_cache(song_dir):
+ if 'songs_cache' not in st.session_state:
+ songs = []
+ for root, dirs, files in os.walk(song_dir):
+ for file in files:
+ if file.endswith(".mp3"):
+ songs.append(file)
+ st.session_state['songs_cache'] = songs
+ return st.session_state['songs_cache']
\ No newline at end of file
diff --git a/webui/utils/file_utils.py b/webui/utils/file_utils.py
new file mode 100644
index 0000000..458efa6
--- /dev/null
+++ b/webui/utils/file_utils.py
@@ -0,0 +1,189 @@
+import os
+import glob
+import time
+import platform
+import shutil
+from uuid import uuid4
+from loguru import logger
+from app.utils import utils
+
+def open_task_folder(root_dir, task_id):
+ """打开任务文件夹
+ Args:
+ root_dir: 项目根目录
+ task_id: 任务ID
+ """
+ try:
+ sys = platform.system()
+ path = os.path.join(root_dir, "storage", "tasks", task_id)
+ if os.path.exists(path):
+ if sys == 'Windows':
+ os.system(f"start {path}")
+ if sys == 'Darwin':
+ os.system(f"open {path}")
+ if sys == 'Linux':
+ os.system(f"xdg-open {path}")
+ except Exception as e:
+ logger.error(f"打开任务文件夹失败: {e}")
+
+def cleanup_temp_files(temp_dir, max_age=3600):
+ """清理临时文件
+ Args:
+ temp_dir: 临时文件目录
+ max_age: 文件最大保存时间(秒)
+ """
+ if os.path.exists(temp_dir):
+ for file in os.listdir(temp_dir):
+ file_path = os.path.join(temp_dir, file)
+ try:
+ if os.path.getctime(file_path) < time.time() - max_age:
+ if os.path.isfile(file_path):
+ os.remove(file_path)
+ elif os.path.isdir(file_path):
+ shutil.rmtree(file_path)
+ logger.debug(f"已清理临时文件: {file_path}")
+ except Exception as e:
+ logger.error(f"清理临时文件失败: {file_path}, 错误: {e}")
+
+def get_file_list(directory, file_types=None, sort_by='ctime', reverse=True):
+ """获取指定目录下的文件列表
+ Args:
+ directory: 目录路径
+ file_types: 文件类型列表,如 ['.mp4', '.mov']
+ sort_by: 排序方式,支持 'ctime'(创建时间), 'mtime'(修改时间), 'size'(文件大小), 'name'(文件名)
+ reverse: 是否倒序排序
+ Returns:
+ list: 文件信息列表
+ """
+ if not os.path.exists(directory):
+ return []
+
+ files = []
+ if file_types:
+ for file_type in file_types:
+ files.extend(glob.glob(os.path.join(directory, f"*{file_type}")))
+ else:
+ files = glob.glob(os.path.join(directory, "*"))
+
+ file_list = []
+ for file_path in files:
+ try:
+ file_stat = os.stat(file_path)
+ file_info = {
+ "name": os.path.basename(file_path),
+ "path": file_path,
+ "size": file_stat.st_size,
+ "ctime": file_stat.st_ctime,
+ "mtime": file_stat.st_mtime
+ }
+ file_list.append(file_info)
+ except Exception as e:
+ logger.error(f"获取文件信息失败: {file_path}, 错误: {e}")
+
+ # 排序
+ if sort_by in ['ctime', 'mtime', 'size', 'name']:
+ file_list.sort(key=lambda x: x.get(sort_by, ''), reverse=reverse)
+
+ return file_list
+
+def save_uploaded_file(uploaded_file, save_dir, allowed_types=None):
+ """保存上传的文件
+ Args:
+ uploaded_file: StreamlitUploadedFile对象
+ save_dir: 保存目录
+ allowed_types: 允许的文件类型列表,如 ['.mp4', '.mov']
+ Returns:
+ str: 保存后的文件路径,失败返回None
+ """
+ try:
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+
+ file_name, file_extension = os.path.splitext(uploaded_file.name)
+
+ # 检查文件类型
+ if allowed_types and file_extension.lower() not in allowed_types:
+ logger.error(f"不支持的文件类型: {file_extension}")
+ return None
+
+ # 如果文件已存在,添加时间戳
+ save_path = os.path.join(save_dir, uploaded_file.name)
+ if os.path.exists(save_path):
+ timestamp = time.strftime("%Y%m%d%H%M%S")
+ new_file_name = f"{file_name}_{timestamp}{file_extension}"
+ save_path = os.path.join(save_dir, new_file_name)
+
+ # 保存文件
+ with open(save_path, "wb") as f:
+ f.write(uploaded_file.read())
+
+ logger.info(f"文件保存成功: {save_path}")
+ return save_path
+
+ except Exception as e:
+ logger.error(f"保存上传文件失败: {e}")
+ return None
+
+def create_temp_file(prefix='tmp', suffix='', directory=None):
+ """创建临时文件
+ Args:
+ prefix: 文件名前缀
+ suffix: 文件扩展名
+ directory: 临时文件目录,默认使用系统临时目录
+ Returns:
+ str: 临时文件路径
+ """
+ try:
+ if directory is None:
+ directory = utils.storage_dir("temp", create=True)
+
+ if not os.path.exists(directory):
+ os.makedirs(directory)
+
+ temp_file = os.path.join(directory, f"{prefix}-{str(uuid4())}{suffix}")
+ return temp_file
+
+ except Exception as e:
+ logger.error(f"创建临时文件失败: {e}")
+ return None
+
+def get_file_size(file_path, format='MB'):
+ """获取文件大小
+ Args:
+ file_path: 文件路径
+ format: 返回格式,支持 'B', 'KB', 'MB', 'GB'
+ Returns:
+ float: 文件大小
+ """
+ try:
+ size_bytes = os.path.getsize(file_path)
+
+ if format.upper() == 'B':
+ return size_bytes
+ elif format.upper() == 'KB':
+ return size_bytes / 1024
+ elif format.upper() == 'MB':
+ return size_bytes / (1024 * 1024)
+ elif format.upper() == 'GB':
+ return size_bytes / (1024 * 1024 * 1024)
+ else:
+ return size_bytes
+
+ except Exception as e:
+ logger.error(f"获取文件大小失败: {file_path}, 错误: {e}")
+ return 0
+
+def ensure_directory(directory):
+ """确保目录存在,如果不存在则创建
+ Args:
+ directory: 目录路径
+ Returns:
+ bool: 是否成功
+ """
+ try:
+ if not os.path.exists(directory):
+ os.makedirs(directory)
+ return True
+ except Exception as e:
+ logger.error(f"创建目录失败: {directory}, 错误: {e}")
+ return False
\ No newline at end of file
diff --git a/webui/utils/performance.py b/webui/utils/performance.py
new file mode 100644
index 0000000..76b7dcb
--- /dev/null
+++ b/webui/utils/performance.py
@@ -0,0 +1,24 @@
+import time
+from loguru import logger
+
+try:
+ import psutil
+ ENABLE_PERFORMANCE_MONITORING = True
+except ImportError:
+ ENABLE_PERFORMANCE_MONITORING = False
+ logger.warning("psutil not installed. Performance monitoring is disabled.")
+
+def monitor_performance():
+ if not ENABLE_PERFORMANCE_MONITORING:
+ return {'execution_time': 0, 'memory_usage': 0}
+
+ start_time = time.time()
+ try:
+ memory_usage = psutil.Process().memory_info().rss / 1024 / 1024 # MB
+ except:
+ memory_usage = 0
+
+ return {
+ 'execution_time': time.time() - start_time,
+ 'memory_usage': memory_usage
+ }
\ No newline at end of file