diff --git a/app/services/subtitle.py b/app/services/subtitle.py index c792667..f37eb65 100644 --- a/app/services/subtitle.py +++ b/app/services/subtitle.py @@ -29,7 +29,7 @@ def create(audio_file, subtitle_file: str = ""): 返回: 无返回值,但会在指定路径生成字幕文件。 """ - global model + global model, device, compute_type if not model: model_path = f"{utils.root_dir()}/app/models/faster-whisper-large-v2" model_bin_file = f"{model_path}/model.bin" @@ -43,27 +43,45 @@ def create(audio_file, subtitle_file: str = ""): ) return None - logger.info( - f"加载模型: {model_path}, 设备: {device}, 计算类型: {compute_type}" - ) + # 尝试使用 CUDA,如果失败则回退到 CPU try: + import torch + if torch.cuda.is_available(): + try: + logger.info(f"尝试使用 CUDA 加载模型: {model_path}") + model = WhisperModel( + model_size_or_path=model_path, + device="cuda", + compute_type="float16", + local_files_only=True + ) + device = "cuda" + compute_type = "float16" + logger.info("成功使用 CUDA 加载模型") + except Exception as e: + logger.warning(f"CUDA 加载失败,错误信息: {str(e)}") + logger.warning("回退到 CPU 模式") + device = "cpu" + compute_type = "int8" + else: + logger.info("未检测到 CUDA,使用 CPU 模式") + device = "cpu" + compute_type = "int8" + except ImportError: + logger.warning("未安装 torch,使用 CPU 模式") + device = "cpu" + compute_type = "int8" + + if device == "cpu": + logger.info(f"使用 CPU 加载模型: {model_path}") model = WhisperModel( model_size_or_path=model_path, device=device, compute_type=compute_type, local_files_only=True ) - except Exception as e: - logger.error( - f"加载模型失败: {e} \n\n" - f"********************************************\n" - f"这可能是由网络问题引起的. \n" - f"请手动下载模型并将其放入 'app/models' 文件夹中。 \n" - f"see [README.md FAQ](https://github.com/linyqh/NarratoAI) for more details.\n" - f"********************************************\n\n" - f"{traceback.format_exc()}" - ) - return None + + logger.info(f"模型加载完成,使用设备: {device}, 计算类型: {compute_type}") logger.info(f"start, output file: {subtitle_file}") if not subtitle_file: diff --git a/app/utils/check_script.py b/app/utils/check_script.py index 623c42a..00e6c0f 100644 --- a/app/utils/check_script.py +++ b/app/utils/check_script.py @@ -1,115 +1,81 @@ import json -from loguru import logger -import os -from datetime import timedelta +from typing import Dict, Any -def time_to_seconds(time_str): - parts = list(map(int, time_str.split(':'))) - if len(parts) == 2: - return timedelta(minutes=parts[0], seconds=parts[1]).total_seconds() - elif len(parts) == 3: - return timedelta(hours=parts[0], minutes=parts[1], seconds=parts[2]).total_seconds() - raise ValueError(f"无法解析时间字符串: {time_str}") +def check_format(script_content: str) -> Dict[str, Any]: + """检查脚本格式 + Args: + script_content: 脚本内容 + Returns: + Dict: {'success': bool, 'message': str} + """ + try: + # 检查是否为有效的JSON + data = json.loads(script_content) + + # 检查是否为列表 + if not isinstance(data, list): + return { + 'success': False, + 'message': '脚本必须是JSON数组格式' + } + + # 检查每个片段 + for i, clip in enumerate(data): + # 检查必需字段 + required_fields = ['narration', 'picture', 'timestamp'] + for field in required_fields: + if field not in clip: + return { + 'success': False, + 'message': f'第{i+1}个片段缺少必需字段: {field}' + } + + # 检查字段类型 + if not isinstance(clip['narration'], str): + return { + 'success': False, + 'message': f'第{i+1}个片段的narration必须是字符串' + } + if not isinstance(clip['picture'], str): + return { + 'success': False, + 'message': f'第{i+1}个片段的picture必须是字符串' + } + if not isinstance(clip['timestamp'], str): + return { + 'success': False, + 'message': f'第{i+1}个片段的timestamp必须是字符串' + } + + # 检查字段内容不能为空 + if not clip['narration'].strip(): + return { + 'success': False, + 'message': f'第{i+1}个片段的narration不能为空' + } + if not clip['picture'].strip(): + return { + 'success': False, + 'message': f'第{i+1}个片段的picture不能为空' + } + if not clip['timestamp'].strip(): + return { + 'success': False, + 'message': f'第{i+1}个片段的timestamp不能为空' + } -def seconds_to_time_str(seconds): - hours, remainder = divmod(int(seconds), 3600) - minutes, seconds = divmod(remainder, 60) - if hours > 0: - return f"{hours:02d}:{minutes:02d}:{seconds:02d}" - else: - return f"{minutes:02d}:{seconds:02d}" + return { + 'success': True, + 'message': '脚本格式检查通过' + } -def adjust_timestamp(start_time, duration): - start_seconds = time_to_seconds(start_time) - end_seconds = start_seconds + duration - return f"{start_time}-{seconds_to_time_str(end_seconds)}" - -def estimate_audio_duration(text): - # 假设平均每个字符需要 0.2 秒 - return len(text) * 0.2 - -def check_script(data, total_duration): - errors = [] - time_ranges = [] - - logger.info("开始检查脚本") - logger.info(f"视频总时长: {total_duration:.2f} 秒") - logger.info("=" * 50) - - for i, item in enumerate(data, 1): - logger.info(f"\n检查第 {i} 项:") - - # 检查所有必需字段 - required_fields = ['picture', 'timestamp', 'narration', 'OST'] - for field in required_fields: - if field not in item: - errors.append(f"第 {i} 项缺少 {field} 字段") - logger.info(f" - 错误: 缺少 {field} 字段") - else: - logger.info(f" - {field}: {item[field]}") - - # 检查 OST 相关规则 - if item.get('OST') == False: - if not item.get('narration'): - errors.append(f"第 {i} 项 OST 为 false,但 narration 为空") - logger.info(" - 错误: OST 为 false,但 narration 为空") - elif len(item['narration']) > 60: - errors.append(f"第 {i} 项 OST 为 false,但 narration 超过 60 字") - logger.info(f" - 错误: OST 为 false,但 narration 超过 60 字 (当前: {len(item['narration'])} 字)") - else: - logger.info(" - OST 为 false,narration 检查通过") - elif item.get('OST') == True: - if "原声播放_" not in item.get('narration'): - errors.append(f"第 {i} 项 OST 为 true,但 narration 不为空") - logger.info(" - 错误: OST 为 true,但 narration 不为空") - else: - logger.info(" - OST 为 true,narration 检查通过") - - # 检查 timestamp - if 'timestamp' in item: - start, end = map(time_to_seconds, item['timestamp'].split('-')) - if any((start < existing_end and end > existing_start) for existing_start, existing_end in time_ranges): - errors.append(f"第 {i} 项 timestamp '{item['timestamp']}' 与其他时间段重叠") - logger.info(f" - 错误: timestamp '{item['timestamp']}' 与其他时间段重叠") - else: - logger.info(f" - timestamp '{item['timestamp']}' 检查通过") - time_ranges.append((start, end)) - - # if end > total_duration: - # errors.append(f"第 {i} 项 timestamp '{item['timestamp']}' 超过总时长 {total_duration:.2f} 秒") - # logger.info(f" - 错误: timestamp '{item['timestamp']}' 超过总时长 {total_duration:.2f} 秒") - # else: - # logger.info(f" - timestamp 在总时长范围内") - - # 处理 narration 字段 - if item.get('OST') == False and item.get('narration'): - estimated_duration = estimate_audio_duration(item['narration']) - start_time = item['timestamp'].split('-')[0] - item['timestamp'] = adjust_timestamp(start_time, estimated_duration) - logger.info(f" - 已调整 timestamp 为 {item['timestamp']} (估算音频时长: {estimated_duration:.2f} 秒)") - - if errors: - logger.info("检查结果:不通过") - logger.info("发现以下错误:") - for error in errors: - logger.info(f"- {error}") - else: - logger.info("检查结果:通过") - logger.info("所有项目均符合规则要求。") - - return errors, data - - -if __name__ == "__main__": - file_path = "/Users/apple/Desktop/home/NarratoAI/resource/scripts/test004.json" - - with open(file_path, 'r', encoding='utf-8') as f: - data = json.load(f) - - total_duration = 280 - - # check_script(data, total_duration) - - from app.utils.utils import add_new_timestamps - res = add_new_timestamps(data) - print(json.dumps(res, indent=4, ensure_ascii=False)) + except json.JSONDecodeError as e: + return { + 'success': False, + 'message': f'JSON格式错误: {str(e)}' + } + except Exception as e: + return { + 'success': False, + 'message': f'检查过程中发生错误: {str(e)}' + } diff --git a/requirements.txt b/requirements.txt index a9445b9..f3be823 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,3 +24,4 @@ azure-cognitiveservices-speech~=1.37.0 git-changelog~=2.5.2 watchdog==5.0.2 pydub==0.25.1 +psutil>=5.9.0 diff --git a/webui.py b/webui.py index faae899..d2a02f0 100644 --- a/webui.py +++ b/webui.py @@ -1,6 +1,14 @@ import streamlit as st +import os +import sys +from uuid import uuid4 from app.config import config +from webui.components import basic_settings, video_settings, audio_settings, subtitle_settings, script_settings, review_settings +from webui.utils import cache, file_utils, performance +from app.utils import utils +from app.models.schema import VideoClipParams, VideoAspect +# 初始化配置 - 必须是第一个 Streamlit 命令 st.set_page_config( page_title="NarratoAI", page_icon="📽️", @@ -13,126 +21,23 @@ st.set_page_config( }, ) -import sys -import os -import glob -import json -import time -import datetime -import traceback -from uuid import uuid4 -import platform -import streamlit.components.v1 as components -from loguru import logger - -from app.models.const import FILE_TYPE_VIDEOS -from app.models.schema import VideoClipParams, VideoAspect, VideoConcatMode -from app.services import task as tm, llm, voice, material -from app.utils import utils - -# # 将项目的根目录添加到系统路径中,以允许从项目导入模块 -root_dir = os.path.dirname(os.path.realpath(__file__)) -if root_dir not in sys.path: - sys.path.append(root_dir) - print("******** sys.path ********") - print(sys.path) - print("*" * 20) - -proxy_url_http = config.proxy.get("http", "") or os.getenv("VPN_PROXY_URL", "") -proxy_url_https = config.proxy.get("https", "") or os.getenv("VPN_PROXY_URL", "") -os.environ["HTTP_PROXY"] = proxy_url_http -os.environ["HTTPS_PROXY"] = proxy_url_https - +# 设置页面样式 hide_streamlit_style = """ """ st.markdown(hide_streamlit_style, unsafe_allow_html=True) -st.title(f"NarratoAI :sunglasses:📽️") -support_locales = [ - "zh-CN", - "zh-HK", - "zh-TW", - "en-US", -] -font_dir = os.path.join(root_dir, "resource", "fonts") -song_dir = os.path.join(root_dir, "resource", "songs") -i18n_dir = os.path.join(root_dir, "webui", "i18n") -config_file = os.path.join(root_dir, "webui", ".streamlit", "webui.toml") -system_locale = utils.get_system_locale() - -if 'video_clip_json' not in st.session_state: - st.session_state['video_clip_json'] = [] -if 'video_plot' not in st.session_state: - st.session_state['video_plot'] = '' -if 'ui_language' not in st.session_state: - st.session_state['ui_language'] = config.ui.get("language", system_locale) -if 'subclip_videos' not in st.session_state: - st.session_state['subclip_videos'] = {} - - -def get_all_fonts(): - fonts = [] - for root, dirs, files in os.walk(font_dir): - for file in files: - if file.endswith(".ttf") or file.endswith(".ttc"): - fonts.append(file) - fonts.sort() - return fonts - - -def get_all_songs(): - songs = [] - for root, dirs, files in os.walk(song_dir): - for file in files: - if file.endswith(".mp3"): - songs.append(file) - return songs - - -def open_task_folder(task_id): - try: - sys = platform.system() - path = os.path.join(root_dir, "storage", "tasks", task_id) - if os.path.exists(path): - if sys == 'Windows': - os.system(f"start {path}") - if sys == 'Darwin': - os.system(f"open {path}") - except Exception as e: - logger.error(e) - - -def scroll_to_bottom(): - js = f""" - - """ - st.components.v1.html(js, height=0, width=0) - def init_log(): + """初始化日志配置""" + from loguru import logger logger.remove() _lvl = "DEBUG" def format_record(record): - # 获取日志记录中的文件全径 file_path = record["file"].path - # 将绝对路径转换为相对于项目根目录的路径 - relative_path = os.path.relpath(file_path, root_dir) - # 更新记录中的文件路径 + relative_path = os.path.relpath(file_path, config.root_dir) record["file"].path = f"./{relative_path}" - # 返回修改后的格式字符串 - # 您可以根据需要调整这里的格式 - record['message'] = record['message'].replace(root_dir, ".") + record['message'] = record['message'].replace(config.root_dir, ".") _format = '{time:%Y-%m-%d %H:%M:%S} | ' + \ '{level} | ' + \ @@ -147,672 +52,120 @@ def init_log(): colorize=True, ) - -init_log() - -locales = utils.load_locales(i18n_dir) - +def init_global_state(): + """初始化全局状态""" + if 'video_clip_json' not in st.session_state: + st.session_state['video_clip_json'] = [] + if 'video_plot' not in st.session_state: + st.session_state['video_plot'] = '' + if 'ui_language' not in st.session_state: + st.session_state['ui_language'] = config.ui.get("language", utils.get_system_locale()) + if 'subclip_videos' not in st.session_state: + st.session_state['subclip_videos'] = {} def tr(key): + """翻译函数""" + i18n_dir = os.path.join(os.path.dirname(__file__), "webui", "i18n") + locales = utils.load_locales(i18n_dir) loc = locales.get(st.session_state['ui_language'], {}) return loc.get("Translation", {}).get(key, key) +def render_generate_button(): + """渲染生成按钮和处理逻辑""" + if st.button(tr("Generate Video"), use_container_width=True, type="primary"): + from app.services import task as tm + + # 重置日志容器和记录 + log_container = st.empty() + log_records = [] -st.write(tr("Get Help")) + def log_received(msg): + with log_container: + log_records.append(msg) + st.code("\n".join(log_records)) -# 基础设置 -with st.expander(tr("Basic Settings"), expanded=False): - config_panels = st.columns(3) - left_config_panel = config_panels[0] - middle_config_panel = config_panels[1] - right_config_panel = config_panels[2] - with left_config_panel: - display_languages = [] - selected_index = 0 - for i, code in enumerate(locales.keys()): - display_languages.append(f"{code} - {locales[code].get('Language')}") - if code == st.session_state['ui_language']: - selected_index = i + from loguru import logger + logger.add(log_received) - selected_language = st.selectbox(tr("Language"), options=display_languages, - index=selected_index) - if selected_language: - code = selected_language.split(" - ")[0].strip() - st.session_state['ui_language'] = code - config.ui['language'] = code + config.save_config() + task_id = st.session_state.get('task_id') - HTTP_PROXY = st.text_input(tr("HTTP_PROXY"), value=proxy_url_http) - HTTPS_PROXY = st.text_input(tr("HTTPs_PROXY"), value=proxy_url_https) - if HTTP_PROXY: - config.proxy["http"] = HTTP_PROXY - if HTTPS_PROXY: - config.proxy["https"] = HTTPS_PROXY + if not task_id: + st.error(tr("请先裁剪视频")) + return + if not st.session_state.get('video_clip_json_path'): + st.error(tr("脚本文件不能为空")) + return + if not st.session_state.get('video_origin_path'): + st.error(tr("视频文件不能为空")) + return - # 视频转录大模型 - with middle_config_panel: - video_llm_providers = ['Gemini'] - saved_llm_provider = config.app.get("llm_provider", "OpenAI").lower() - saved_llm_provider_index = 0 - for i, provider in enumerate(video_llm_providers): - if provider.lower() == saved_llm_provider: - saved_llm_provider_index = i - break + st.toast(tr("生成视频")) + logger.info(tr("开始生成视频")) - video_llm_provider = st.selectbox(tr("Video LLM Provider"), options=video_llm_providers, index=saved_llm_provider_index) - video_llm_provider = video_llm_provider.lower() - config.app["video_llm_provider"] = video_llm_provider + # 获取所有参数 + script_params = script_settings.get_script_params() + video_params = video_settings.get_video_params() + audio_params = audio_settings.get_audio_params() + subtitle_params = subtitle_settings.get_subtitle_params() - video_llm_api_key = config.app.get(f"{video_llm_provider}_api_key", "") - video_llm_base_url = config.app.get(f"{video_llm_provider}_base_url", "") - video_llm_model_name = config.app.get(f"{video_llm_provider}_model_name", "") - video_llm_account_id = config.app.get(f"{video_llm_provider}_account_id", "") - st_llm_api_key = st.text_input(tr("Video API Key"), value=video_llm_api_key, type="password") - st_llm_base_url = st.text_input(tr("Video Base Url"), value=video_llm_base_url) - st_llm_model_name = st.text_input(tr("Video Model Name"), value=video_llm_model_name) - if st_llm_api_key: - config.app[f"{video_llm_provider}_api_key"] = st_llm_api_key - if st_llm_base_url: - config.app[f"{video_llm_provider}_base_url"] = st_llm_base_url - if st_llm_model_name: - config.app[f"{video_llm_provider}_model_name"] = st_llm_model_name - - # 大语言模型 - with right_config_panel: - llm_providers = ['Gemini', 'OpenAI', 'Moonshot', 'Azure', 'Qwen', 'Ollama', 'G4f', 'OneAPI', "Cloudflare"] - saved_llm_provider = config.app.get("llm_provider", "OpenAI").lower() - saved_llm_provider_index = 0 - for i, provider in enumerate(llm_providers): - if provider.lower() == saved_llm_provider: - saved_llm_provider_index = i - break - - llm_provider = st.selectbox(tr("LLM Provider"), options=llm_providers, index=saved_llm_provider_index) - llm_provider = llm_provider.lower() - config.app["llm_provider"] = llm_provider - - llm_api_key = config.app.get(f"{llm_provider}_api_key", "") - llm_base_url = config.app.get(f"{llm_provider}_base_url", "") - llm_model_name = config.app.get(f"{llm_provider}_model_name", "") - llm_account_id = config.app.get(f"{llm_provider}_account_id", "") - st_llm_api_key = st.text_input(tr("API Key"), value=llm_api_key, type="password") - st_llm_base_url = st.text_input(tr("Base Url"), value=llm_base_url) - st_llm_model_name = st.text_input(tr("Model Name"), value=llm_model_name) - if st_llm_api_key: - config.app[f"{llm_provider}_api_key"] = st_llm_api_key - if st_llm_base_url: - config.app[f"{llm_provider}_base_url"] = st_llm_base_url - if st_llm_model_name: - config.app[f"{llm_provider}_model_name"] = st_llm_model_name - - if llm_provider == 'cloudflare': - st_llm_account_id = st.text_input(tr("Account ID"), value=llm_account_id) - if st_llm_account_id: - config.app[f"{llm_provider}_account_id"] = st_llm_account_id - -panel = st.columns(3) -left_panel = panel[0] -middle_panel = panel[1] -right_panel = panel[2] - -params = VideoClipParams() - -# 左侧面板 -with left_panel: - with st.container(border=True): - st.write(tr("Video Script Configuration")) - # 脚本语言 - video_languages = [ - (tr("Auto Detect"), ""), - ] - for code in ["zh-CN", "en-US", "zh-TW"]: - video_languages.append((code, code)) - - selected_index = st.selectbox(tr("Script Language"), - index=0, - options=range(len(video_languages)), # 使用索引作为内部选项值 - format_func=lambda x: video_languages[x][0] # 显示给用户的是标签 - ) - params.video_language = video_languages[selected_index][1] - - # 脚本路径 - suffix = "*.json" - song_dir = utils.script_dir() - files = glob.glob(os.path.join(song_dir, suffix)) - script_list = [] - for file in files: - script_list.append({ - "name": os.path.basename(file), - "size": os.path.getsize(file), - "file": file, - "ctime": os.path.getctime(file) # 获取文件创建时间 - }) - - # 按创建时间降序排序 - script_list.sort(key=lambda x: x["ctime"], reverse=True) - - # 本文件 下拉框 - script_path = [(tr("Auto Generate"), ""), ] - for file in script_list: - display_name = file['file'].replace(root_dir, "") - script_path.append((display_name, file['file'])) - selected_script_index = st.selectbox(tr("Script Files"), - index=0, - options=range(len(script_path)), # 使用索引作为内部选项值 - format_func=lambda x: script_path[x][0] # 显示给用户的是标签 - ) - params.video_clip_json_path = script_path[selected_script_index][1] - config.app["video_clip_json_path"] = params.video_clip_json_path - st.session_state['video_clip_json_path'] = params.video_clip_json_path - - # 视频文件处理 - video_files = [] - for suffix in ["*.mp4", "*.mov", "*.avi", "*.mkv"]: - video_files.extend(glob.glob(os.path.join(utils.video_dir(), suffix))) - video_files = video_files[::-1] - - video_list = [] - for video_file in video_files: - video_list.append({ - "name": os.path.basename(video_file), - "size": os.path.getsize(video_file), - "file": video_file, - "ctime": os.path.getctime(video_file) # 获取文件创建时间 - }) - # 按创建时间降序排序 - video_list.sort(key=lambda x: x["ctime"], reverse=True) - video_path = [(tr("None"), ""), (tr("Upload Local Files"), "local")] - for file in video_list: - display_name = file['file'].replace(root_dir, "") - video_path.append((display_name, file['file'])) - - # 视频文件 - selected_video_index = st.selectbox(tr("Video File"), - index=0, - options=range(len(video_path)), # 使用索引作为内部选项值 - format_func=lambda x: video_path[x][0] # 显示给用户的是标签 - ) - params.video_origin_path = video_path[selected_video_index][1] - config.app["video_origin_path"] = params.video_origin_path - st.session_state['video_origin_path'] = params.video_origin_path - - # 从本地上传 mp4 文件 - if params.video_origin_path == "local": - _supported_types = FILE_TYPE_VIDEOS - uploaded_file = st.file_uploader( - tr("Upload Local Files"), - type=["mp4", "mov", "avi", "flv", "mkv"], - accept_multiple_files=False, - ) - if uploaded_file is not None: - # 构造保存路径 - video_file_path = os.path.join(utils.video_dir(), uploaded_file.name) - file_name, file_extension = os.path.splitext(uploaded_file.name) - # 检查文件是否存在,如果存在则添加时间戳 - if os.path.exists(video_file_path): - timestamp = time.strftime("%Y%m%d%H%M%S") - file_name_with_timestamp = f"{file_name}_{timestamp}" - video_file_path = os.path.join(utils.video_dir(), file_name_with_timestamp + file_extension) - # 将文件保存到指定目录 - with open(video_file_path, "wb") as f: - f.write(uploaded_file.read()) - st.success(tr("File Uploaded Successfully")) - time.sleep(1) - st.rerun() - # 视频名称 - video_name = st.text_input(tr("Video Name")) - # 剧情内容 - video_plot = st.text_area( - tr("Plot Description"), - value=st.session_state['video_plot'], - height=180 - ) - - # 生成视频脚本 - if st.session_state['video_clip_json_path']: - generate_button_name = tr("Video Script Load") - else: - generate_button_name = tr("Video Script Generate") - if st.button(generate_button_name, key="auto_generate_script"): - progress_bar = st.progress(0) - status_text = st.empty() - - def update_progress(progress: float, message: str = ""): - progress_bar.progress(progress) - if message: - status_text.text(f"{progress}% - {message}") - else: - status_text.text(f"进度: {progress}%") - - try: - with st.spinner("正在生成脚本..."): - if not video_plot: - st.warning("视频剧情为空; 会极大影响生成效果!") - if params.video_clip_json_path == "" and params.video_origin_path != "": - update_progress(10, "压缩视频中...") - # 使用大模型生成视频脚本 - script = llm.generate_script( - video_path=params.video_origin_path, - video_plot=video_plot, - video_name=video_name, - language=params.video_language, - progress_callback=update_progress - ) - if script is None: - st.error("生成脚本失败,请检查日志") - st.stop() - else: - update_progress(90) - - script = utils.clean_model_output(script) - st.session_state['video_clip_json'] = json.loads(script) - else: - # 从本地加载 - with open(params.video_clip_json_path, 'r', encoding='utf-8') as f: - update_progress(50) - status_text.text("从本地加载中...") - script = f.read() - script = utils.clean_model_output(script) - st.session_state['video_clip_json'] = json.loads(script) - update_progress(100) - status_text.text("从本地加载成功") - - time.sleep(0.5) # 给进度条一点时间到达100% - progress_bar.progress(100) - status_text.text("脚本生成完成!") - st.success("视频脚本生成成功!") - except Exception as err: - st.error(f"生成过程中发生错误: {str(err)}") - finally: - time.sleep(2) # 给用户一些时间查看最终状态 - progress_bar.empty() - status_text.empty() - - # 视频脚本 - video_clip_json_details = st.text_area( - tr("Video Script"), - value=json.dumps(st.session_state.video_clip_json, indent=2, ensure_ascii=False), - height=180 - ) - - # 保存脚本 - button_columns = st.columns(2) - with button_columns[0]: - if st.button(tr("Save Script"), key="auto_generate_terms", use_container_width=True): - if not video_clip_json_details: - st.error(tr("请输入视频脚本")) - st.stop() - - with st.spinner(tr("Save Script")): - script_dir = utils.script_dir() - # 获取当前时间戳,形如 2024-0618-171820 - timestamp = datetime.datetime.now().strftime("%Y-%m%d-%H%M%S") - save_path = os.path.join(script_dir, f"{timestamp}.json") - - try: - data = utils.add_new_timestamps(json.loads(video_clip_json_details)) - except Exception as err: - st.error(f"视频脚本格式错误,请检查脚本是否符合 JSON 格式;{err} \n\n{traceback.format_exc()}") - st.stop() - - # 存储为新的 JSON 文件 - with open(save_path, 'w', encoding='utf-8') as file: - json.dump(data, file, ensure_ascii=False, indent=4) - # 将data的值存储到 session_state 中,类似缓存 - st.session_state['video_clip_json'] = data - st.session_state['video_clip_json_path'] = save_path - # 刷新页面 - st.rerun() - - # 裁剪视频 - with button_columns[1]: - if st.button(tr("Crop Video"), key="auto_crop_video", use_container_width=True): - progress_bar = st.progress(0) - status_text = st.empty() - - def update_progress(progress): - progress_bar.progress(progress) - status_text.text(f"剪辑进度: {progress}%") - - try: - utils.cut_video(params, update_progress) - time.sleep(0.5) # 给进度条一点时间到达100% - progress_bar.progress(100) - status_text.text("剪辑完成!") - st.success("视频剪辑成功完成!") - except Exception as e: - st.error(f"剪辑过程中发生错误: {str(e)}") - finally: - time.sleep(2) # 给用户一些时间查看最终状态 - progress_bar.empty() - status_text.empty() - -# 新中间面板 -with middle_panel: - with st.container(border=True): - st.write(tr("Video Settings")) - - # 视频比例 - video_aspect_ratios = [ - (tr("Portrait"), VideoAspect.portrait.value), - (tr("Landscape"), VideoAspect.landscape.value), - ] - selected_index = st.selectbox( - tr("Video Ratio"), - options=range(len(video_aspect_ratios)), # 使用索引作为内部选项值 - format_func=lambda x: video_aspect_ratios[x][0], # 显示给用户的是标签 - ) - params.video_aspect = VideoAspect(video_aspect_ratios[selected_index][1]) - - # params.video_clip_duration = st.selectbox( - # tr("Clip Duration"), options=[2, 3, 4, 5, 6, 7, 8, 9, 10], index=1 - # ) - # params.video_count = st.selectbox( - # tr("Number of Videos Generated Simultaneously"), - # options=[1, 2, 3, 4, 5], - # index=0, - # ) - with st.container(border=True): - st.write(tr("Audio Settings")) - - # tts_providers = ['edge', 'azure'] - # tts_provider = st.selectbox(tr("TTS Provider"), tts_providers) - - voices = voice.get_all_azure_voices(filter_locals=support_locales) - friendly_names = { - v: v.replace("Female", tr("Female")) - .replace("Male", tr("Male")) - .replace("Neural", "") - for v in voices + # 合并所有参数 + all_params = { + **script_params, + **video_params, + **audio_params, + **subtitle_params } - saved_voice_name = config.ui.get("voice_name", "") - saved_voice_name_index = 0 - if saved_voice_name in friendly_names: - saved_voice_name_index = list(friendly_names.keys()).index(saved_voice_name) - else: - for i, v in enumerate(voices): - if ( - v.lower().startswith(st.session_state["ui_language"].lower()) - and "V2" not in v - ): - saved_voice_name_index = i - break - selected_friendly_name = st.selectbox( - tr("Speech Synthesis"), - options=list(friendly_names.values()), - index=saved_voice_name_index, + # 创建参数对象 + params = VideoClipParams(**all_params) + + result = tm.start_subclip( + task_id=task_id, + params=params, + subclip_path_videos=st.session_state['subclip_videos'] ) - voice_name = list(friendly_names.keys())[ - list(friendly_names.values()).index(selected_friendly_name) - ] - params.voice_name = voice_name - config.ui["voice_name"] = voice_name + video_files = result.get("videos", []) + st.success(tr("视频生成完成")) + + try: + if video_files: + player_cols = st.columns(len(video_files) * 2 + 1) + for i, url in enumerate(video_files): + player_cols[i * 2 + 1].video(url) + except Exception as e: + logger.error(f"播放视频失败: {e}") - if voice.is_azure_v2_voice(voice_name): - saved_azure_speech_region = config.azure.get("speech_region", "") - saved_azure_speech_key = config.azure.get("speech_key", "") - azure_speech_region = st.text_input( - tr("Speech Region"), value=saved_azure_speech_region - ) - azure_speech_key = st.text_input( - tr("Speech Key"), value=saved_azure_speech_key, type="password" - ) - config.azure["speech_region"] = azure_speech_region - config.azure["speech_key"] = azure_speech_key + file_utils.open_task_folder(config.root_dir, task_id) + logger.info(tr("视频生成完成")) - params.voice_volume = st.selectbox( - tr("Speech Volume"), - options=[0.6, 0.8, 1.0, 1.2, 1.5, 2.0, 3.0, 4.0, 5.0], - index=2, - ) +def main(): + """主函数""" + init_log() + init_global_state() + + st.title(f"NarratoAI :sunglasses:📽️") + st.write(tr("Get Help")) + + # 渲染基础设置面板 + basic_settings.render_basic_settings(tr) + + # 渲染主面板 + panel = st.columns(3) + with panel[0]: + script_settings.render_script_panel(tr) + with panel[1]: + video_settings.render_video_panel(tr) + audio_settings.render_audio_panel(tr) + with panel[2]: + subtitle_settings.render_subtitle_panel(tr) + + # 渲染视频审查面板 + review_settings.render_review_panel(tr) + + # 渲染生成按钮和处理逻辑 + render_generate_button() - params.voice_rate = st.selectbox( - tr("Speech Rate"), - options=[0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 1.8, 2.0], - index=2, - ) - - params.voice_pitch = st.selectbox( - tr("Speech Pitch"), - options=[0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 1.8, 2.0], - index=2, - ) - - # 试听语言合成 - if st.button(tr("Play Voice")): - play_content = "感谢关注 NarratoAI,有任何问题或建议,可以关注微信公众号,求助或讨论" - if not play_content: - play_content = params.video_script - if not play_content: - play_content = tr("Voice Example") - with st.spinner(tr("Synthesizing Voice")): - temp_dir = utils.storage_dir("temp", create=True) - audio_file = os.path.join(temp_dir, f"tmp-voice-{str(uuid4())}.mp3") - sub_maker = voice.tts( - text=play_content, - voice_name=voice_name, - voice_rate=params.voice_rate, - voice_pitch=params.voice_pitch, - voice_file=audio_file, - ) - # 如果语音文件生成失败,请使用默认内容重试。 - if not sub_maker: - play_content = "This is a example voice. if you hear this, the voice synthesis failed with the original content." - sub_maker = voice.tts( - text=play_content, - voice_name=voice_name, - voice_rate=params.voice_rate, - voice_pitch=params.voice_pitch, - voice_file=audio_file, - ) - - if sub_maker and os.path.exists(audio_file): - st.audio(audio_file, format="audio/mp3") - if os.path.exists(audio_file): - os.remove(audio_file) - - bgm_options = [ - (tr("No Background Music"), ""), - (tr("Random Background Music"), "random"), - (tr("Custom Background Music"), "custom"), - ] - selected_index = st.selectbox( - tr("Background Music"), - index=1, - options=range(len(bgm_options)), # 使用索引作为内部选项值 - format_func=lambda x: bgm_options[x][0], # 显示给用户的是标签 - ) - # 获取选择的背景音乐类型 - params.bgm_type = bgm_options[selected_index][1] - - # 根据选择显示或隐藏组件 - if params.bgm_type == "custom": - custom_bgm_file = st.text_input(tr("Custom Background Music File")) - if custom_bgm_file and os.path.exists(custom_bgm_file): - params.bgm_file = custom_bgm_file - # st.write(f":red[已选择自定义背景音乐]:**{custom_bgm_file}**") - params.bgm_volume = st.selectbox( - tr("Background Music Volume"), - options=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], - index=2, - ) - -# 新侧面板 -with right_panel: - with st.container(border=True): - st.write(tr("Subtitle Settings")) - params.subtitle_enabled = st.checkbox(tr("Enable Subtitles"), value=True) - font_names = get_all_fonts() - saved_font_name = config.ui.get("font_name", "") - saved_font_name_index = 0 - if saved_font_name in font_names: - saved_font_name_index = font_names.index(saved_font_name) - params.font_name = st.selectbox( - tr("Font"), font_names, index=saved_font_name_index - ) - config.ui["font_name"] = params.font_name - - subtitle_positions = [ - (tr("Top"), "top"), - (tr("Center"), "center"), - (tr("Bottom"), "bottom"), - (tr("Custom"), "custom"), - ] - selected_index = st.selectbox( - tr("Position"), - index=2, - options=range(len(subtitle_positions)), - format_func=lambda x: subtitle_positions[x][0], - ) - params.subtitle_position = subtitle_positions[selected_index][1] - - if params.subtitle_position == "custom": - custom_position = st.text_input( - tr("Custom Position (% from top)"), value="70.0" - ) - try: - params.custom_position = float(custom_position) - if params.custom_position < 0 or params.custom_position > 100: - st.error(tr("Please enter a value between 0 and 100")) - except ValueError: - logger.error(f"输入的值无效: {traceback.format_exc()}") - st.error(tr("Please enter a valid number")) - - font_cols = st.columns([0.3, 0.7]) - with font_cols[0]: - saved_text_fore_color = config.ui.get("text_fore_color", "#FFFFFF") - params.text_fore_color = st.color_picker( - tr("Font Color"), saved_text_fore_color - ) - config.ui["text_fore_color"] = params.text_fore_color - - with font_cols[1]: - saved_font_size = config.ui.get("font_size", 60) - params.font_size = st.slider(tr("Font Size"), 30, 100, saved_font_size) - config.ui["font_size"] = params.font_size - - stroke_cols = st.columns([0.3, 0.7]) - with stroke_cols[0]: - params.stroke_color = st.color_picker(tr("Stroke Color"), "#000000") - with stroke_cols[1]: - params.stroke_width = st.slider(tr("Stroke Width"), 0.0, 10.0, 1.5) - -# 视频编辑面板 -with st.expander(tr("Video Check"), expanded=False): - try: - video_list = st.session_state.video_clip_json - except KeyError as e: - video_list = [] - - # 计算列数和行数 - num_videos = len(video_list) - cols_per_row = 3 - rows = (num_videos + cols_per_row - 1) // cols_per_row # 向上取整计算行数 - - # 使用容器展示视频 - for row in range(rows): - cols = st.columns(cols_per_row) - for col in range(cols_per_row): - index = row * cols_per_row + col - if index < num_videos: - with cols[col]: - video_info = video_list[index] - video_path = video_info.get('path') - if video_path is not None: - initial_narration = video_info['narration'] - initial_picture = video_info['picture'] - initial_timestamp = video_info['timestamp'] - - with open(video_path, 'rb') as video_file: - video_bytes = video_file.read() - st.video(video_bytes) - - # 可编辑的输入框 - text_panels = st.columns(2) - with text_panels[0]: - text1 = st.text_area(tr("timestamp"), value=initial_timestamp, height=20, - key=f"timestamp_{index}") - with text_panels[1]: - text2 = st.text_area(tr("Picture description"), value=initial_picture, height=20, - key=f"picture_{index}") - text3 = st.text_area(tr("Narration"), value=initial_narration, height=100, - key=f"narration_{index}") - - # 重新生成按钮 - if st.button(tr("Rebuild"), key=f"rebuild_{index}"): - # 更新video_list中的对应项 - video_list[index]['timestamp'] = text1 - video_list[index]['picture'] = text2 - video_list[index]['narration'] = text3 - - for video in video_list: - if 'path' in video: - del video['path'] - # 更新session_state以确保更改被保存 - st.session_state['video_clip_json'] = utils.to_json(video_list) - # 替换原JSON 文件 - with open(params.video_clip_json_path, 'w', encoding='utf-8') as file: - json.dump(video_list, file, ensure_ascii=False, indent=4) - utils.cut_video(params, progress_callback=None) - st.rerun() - -# 开始按钮 -start_button = st.button(tr("Generate Video"), use_container_width=True, type="primary") -if start_button: - # 重置日志容器和记录 - log_container = st.empty() - log_records = [] - - config.save_config() - task_id = st.session_state.get('task_id') - if st.session_state.get('video_script_json_path') is not None: - params.video_clip_json = st.session_state.get('video_clip_json') - - logger.debug(f"当前的脚本文件为:{st.session_state.video_clip_json_path}") - logger.debug(f"当前的视频文件为:{st.session_state.video_origin_path}") - logger.debug(f"裁剪后是视频列表:{st.session_state.subclip_videos}") - - if not task_id: - st.error(tr("请先裁剪视频")) - scroll_to_bottom() - st.stop() - if not params.video_clip_json_path: - st.error(tr("脚本文件不能为空")) - scroll_to_bottom() - st.stop() - if not params.video_origin_path: - st.error(tr("视频文件不能为空")) - scroll_to_bottom() - st.stop() - - def log_received(msg): - with log_container: - log_records.append(msg) - st.code("\n".join(log_records)) - - logger.add(log_received) - - st.toast(tr("生成视频")) - logger.info(tr("开始生成视频")) - logger.info(utils.to_json(params)) - scroll_to_bottom() - - result = tm.start_subclip(task_id=task_id, params=params, subclip_path_videos=st.session_state.subclip_videos) - - video_files = result.get("videos", []) - st.success(tr("视频生成完成")) - try: - if video_files: - # 将视频播放器居中 - player_cols = st.columns(len(video_files) * 2 + 1) - for i, url in enumerate(video_files): - player_cols[i * 2 + 1].video(url) - except Exception as e: - pass - - open_task_folder(task_id) - logger.info(tr("视频生成完成")) - scroll_to_bottom() - -config.save_config() +if __name__ == "__main__": + main() diff --git a/webui/__init__.py b/webui/__init__.py new file mode 100644 index 0000000..3c0a334 --- /dev/null +++ b/webui/__init__.py @@ -0,0 +1,22 @@ +""" +NarratoAI WebUI Package +""" +from webui.config.settings import config +from webui.components import ( + basic_settings, + video_settings, + audio_settings, + subtitle_settings +) +from webui.utils import cache, file_utils, performance + +__all__ = [ + 'config', + 'basic_settings', + 'video_settings', + 'audio_settings', + 'subtitle_settings', + 'cache', + 'file_utils', + 'performance' +] \ No newline at end of file diff --git a/webui/components/__init__.py b/webui/components/__init__.py new file mode 100644 index 0000000..6aafcd7 --- /dev/null +++ b/webui/components/__init__.py @@ -0,0 +1,15 @@ +from .basic_settings import render_basic_settings +from .script_settings import render_script_panel +from .video_settings import render_video_panel +from .audio_settings import render_audio_panel +from .subtitle_settings import render_subtitle_panel +from .review_settings import render_review_panel + +__all__ = [ + 'render_basic_settings', + 'render_script_panel', + 'render_video_panel', + 'render_audio_panel', + 'render_subtitle_panel', + 'render_review_panel' +] \ No newline at end of file diff --git a/webui/components/audio_settings.py b/webui/components/audio_settings.py new file mode 100644 index 0000000..a189f65 --- /dev/null +++ b/webui/components/audio_settings.py @@ -0,0 +1,198 @@ +import streamlit as st +import os +from uuid import uuid4 +from app.config import config +from app.services import voice +from app.utils import utils +from webui.utils.cache import get_songs_cache + +def render_audio_panel(tr): + """渲染音频设置面板""" + with st.container(border=True): + st.write(tr("Audio Settings")) + + # 渲染TTS设置 + render_tts_settings(tr) + + # 渲染背景音乐设置 + render_bgm_settings(tr) + +def render_tts_settings(tr): + """渲染TTS(文本转语音)设置""" + # 获取支持的语音列表 + support_locales = ["zh-CN", "zh-HK", "zh-TW", "en-US"] + voices = voice.get_all_azure_voices(filter_locals=support_locales) + + # 创建友好的显示名称 + friendly_names = { + v: v.replace("Female", tr("Female")) + .replace("Male", tr("Male")) + .replace("Neural", "") + for v in voices + } + + # 获取保存的语音设置 + saved_voice_name = config.ui.get("voice_name", "") + saved_voice_name_index = 0 + + if saved_voice_name in friendly_names: + saved_voice_name_index = list(friendly_names.keys()).index(saved_voice_name) + else: + # 如果没有保存的设置,选择与UI语言匹配的第一个语音 + for i, v in enumerate(voices): + if (v.lower().startswith(st.session_state["ui_language"].lower()) + and "V2" not in v): + saved_voice_name_index = i + break + + # 语音选择下拉框 + selected_friendly_name = st.selectbox( + tr("Speech Synthesis"), + options=list(friendly_names.values()), + index=saved_voice_name_index, + ) + + # 获取实际的语音名称 + voice_name = list(friendly_names.keys())[ + list(friendly_names.values()).index(selected_friendly_name) + ] + + # 保存设置 + config.ui["voice_name"] = voice_name + + # Azure V2语音特殊处理 + if voice.is_azure_v2_voice(voice_name): + render_azure_v2_settings(tr) + + # 语音参数设置 + render_voice_parameters(tr) + + # 试听按钮 + render_voice_preview(tr, voice_name) + +def render_azure_v2_settings(tr): + """渲染Azure V2语音设置""" + saved_azure_speech_region = config.azure.get("speech_region", "") + saved_azure_speech_key = config.azure.get("speech_key", "") + + azure_speech_region = st.text_input( + tr("Speech Region"), + value=saved_azure_speech_region + ) + azure_speech_key = st.text_input( + tr("Speech Key"), + value=saved_azure_speech_key, + type="password" + ) + + config.azure["speech_region"] = azure_speech_region + config.azure["speech_key"] = azure_speech_key + +def render_voice_parameters(tr): + """渲染语音参数设置""" + # 音量 + voice_volume = st.selectbox( + tr("Speech Volume"), + options=[0.6, 0.8, 1.0, 1.2, 1.5, 2.0, 3.0, 4.0, 5.0], + index=2, + ) + st.session_state['voice_volume'] = voice_volume + + # 语速 + voice_rate = st.selectbox( + tr("Speech Rate"), + options=[0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 1.8, 2.0], + index=2, + ) + st.session_state['voice_rate'] = voice_rate + + # 音调 + voice_pitch = st.selectbox( + tr("Speech Pitch"), + options=[0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 1.8, 2.0], + index=2, + ) + st.session_state['voice_pitch'] = voice_pitch + +def render_voice_preview(tr, voice_name): + """渲染语音试听功能""" + if st.button(tr("Play Voice")): + play_content = "感谢关注 NarratoAI,有任何问题或建议,可以关注微信公众号,求助或讨论" + if not play_content: + play_content = st.session_state.get('video_script', '') + if not play_content: + play_content = tr("Voice Example") + + with st.spinner(tr("Synthesizing Voice")): + temp_dir = utils.storage_dir("temp", create=True) + audio_file = os.path.join(temp_dir, f"tmp-voice-{str(uuid4())}.mp3") + + sub_maker = voice.tts( + text=play_content, + voice_name=voice_name, + voice_rate=st.session_state.get('voice_rate', 1.0), + voice_pitch=st.session_state.get('voice_pitch', 1.0), + voice_file=audio_file, + ) + + # 如果语音文件生成失败,使用默认内容重试 + if not sub_maker: + play_content = "This is a example voice. if you hear this, the voice synthesis failed with the original content." + sub_maker = voice.tts( + text=play_content, + voice_name=voice_name, + voice_rate=st.session_state.get('voice_rate', 1.0), + voice_pitch=st.session_state.get('voice_pitch', 1.0), + voice_file=audio_file, + ) + + if sub_maker and os.path.exists(audio_file): + st.audio(audio_file, format="audio/mp3") + if os.path.exists(audio_file): + os.remove(audio_file) + +def render_bgm_settings(tr): + """渲染背景音乐设置""" + # 背景音乐选项 + bgm_options = [ + (tr("No Background Music"), ""), + (tr("Random Background Music"), "random"), + (tr("Custom Background Music"), "custom"), + ] + + selected_index = st.selectbox( + tr("Background Music"), + index=1, + options=range(len(bgm_options)), + format_func=lambda x: bgm_options[x][0], + ) + + # 获取选择的背景音乐类型 + bgm_type = bgm_options[selected_index][1] + st.session_state['bgm_type'] = bgm_type + + # 自定义背景音乐处理 + if bgm_type == "custom": + custom_bgm_file = st.text_input(tr("Custom Background Music File")) + if custom_bgm_file and os.path.exists(custom_bgm_file): + st.session_state['bgm_file'] = custom_bgm_file + + # 背景音乐音量 + bgm_volume = st.selectbox( + tr("Background Music Volume"), + options=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], + index=2, + ) + st.session_state['bgm_volume'] = bgm_volume + +def get_audio_params(): + """获取音频参数""" + return { + 'voice_name': config.ui.get("voice_name", ""), + 'voice_volume': st.session_state.get('voice_volume', 1.0), + 'voice_rate': st.session_state.get('voice_rate', 1.0), + 'voice_pitch': st.session_state.get('voice_pitch', 1.0), + 'bgm_type': st.session_state.get('bgm_type', 'random'), + 'bgm_file': st.session_state.get('bgm_file', ''), + 'bgm_volume': st.session_state.get('bgm_volume', 0.2), + } \ No newline at end of file diff --git a/webui/components/basic_settings.py b/webui/components/basic_settings.py new file mode 100644 index 0000000..f8dd65d --- /dev/null +++ b/webui/components/basic_settings.py @@ -0,0 +1,142 @@ +import streamlit as st +import os +from app.config import config +from app.utils import utils + +def render_basic_settings(tr): + """渲染基础设置面板""" + with st.expander(tr("Basic Settings"), expanded=False): + config_panels = st.columns(3) + left_config_panel = config_panels[0] + middle_config_panel = config_panels[1] + right_config_panel = config_panels[2] + + with left_config_panel: + render_language_settings(tr) + render_proxy_settings(tr) + + with middle_config_panel: + render_video_llm_settings(tr) + + with right_config_panel: + render_llm_settings(tr) + +def render_language_settings(tr): + """渲染语言设置""" + system_locale = utils.get_system_locale() + i18n_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "i18n") + locales = utils.load_locales(i18n_dir) + + display_languages = [] + selected_index = 0 + for i, code in enumerate(locales.keys()): + display_languages.append(f"{code} - {locales[code].get('Language')}") + if code == st.session_state.get('ui_language', system_locale): + selected_index = i + + selected_language = st.selectbox( + tr("Language"), + options=display_languages, + index=selected_index + ) + + if selected_language: + code = selected_language.split(" - ")[0].strip() + st.session_state['ui_language'] = code + config.ui['language'] = code + +def render_proxy_settings(tr): + """渲染代理设置""" + proxy_url_http = config.proxy.get("http", "") or os.getenv("VPN_PROXY_URL", "") + proxy_url_https = config.proxy.get("https", "") or os.getenv("VPN_PROXY_URL", "") + + HTTP_PROXY = st.text_input(tr("HTTP_PROXY"), value=proxy_url_http) + HTTPS_PROXY = st.text_input(tr("HTTPs_PROXY"), value=proxy_url_https) + + if HTTP_PROXY: + config.proxy["http"] = HTTP_PROXY + os.environ["HTTP_PROXY"] = HTTP_PROXY + if HTTPS_PROXY: + config.proxy["https"] = HTTPS_PROXY + os.environ["HTTPS_PROXY"] = HTTPS_PROXY + +def render_video_llm_settings(tr): + """渲染视频LLM设置""" + video_llm_providers = ['Gemini', 'NarratoAPI'] + saved_llm_provider = config.app.get("video_llm_provider", "OpenAI").lower() + saved_llm_provider_index = 0 + + for i, provider in enumerate(video_llm_providers): + if provider.lower() == saved_llm_provider: + saved_llm_provider_index = i + break + + video_llm_provider = st.selectbox( + tr("Video LLM Provider"), + options=video_llm_providers, + index=saved_llm_provider_index + ) + video_llm_provider = video_llm_provider.lower() + config.app["video_llm_provider"] = video_llm_provider + + # 获取已保存的配置 + video_llm_api_key = config.app.get(f"{video_llm_provider}_api_key", "") + video_llm_base_url = config.app.get(f"{video_llm_provider}_base_url", "") + video_llm_model_name = config.app.get(f"{video_llm_provider}_model_name", "") + + # 渲染输入框 + st_llm_api_key = st.text_input(tr("Video API Key"), value=video_llm_api_key, type="password") + st_llm_base_url = st.text_input(tr("Video Base Url"), value=video_llm_base_url) + st_llm_model_name = st.text_input(tr("Video Model Name"), value=video_llm_model_name) + + # 保存配置 + if st_llm_api_key: + config.app[f"{video_llm_provider}_api_key"] = st_llm_api_key + if st_llm_base_url: + config.app[f"{video_llm_provider}_base_url"] = st_llm_base_url + if st_llm_model_name: + config.app[f"{video_llm_provider}_model_name"] = st_llm_model_name + +def render_llm_settings(tr): + """渲染LLM设置""" + llm_providers = ['Gemini', 'OpenAI', 'Moonshot', 'Azure', 'Qwen', 'Ollama', 'G4f', 'OneAPI', "Cloudflare"] + saved_llm_provider = config.app.get("llm_provider", "OpenAI").lower() + saved_llm_provider_index = 0 + + for i, provider in enumerate(llm_providers): + if provider.lower() == saved_llm_provider: + saved_llm_provider_index = i + break + + llm_provider = st.selectbox( + tr("LLM Provider"), + options=llm_providers, + index=saved_llm_provider_index + ) + llm_provider = llm_provider.lower() + config.app["llm_provider"] = llm_provider + + # 获取已保存的配置 + llm_api_key = config.app.get(f"{llm_provider}_api_key", "") + llm_base_url = config.app.get(f"{llm_provider}_base_url", "") + llm_model_name = config.app.get(f"{llm_provider}_model_name", "") + llm_account_id = config.app.get(f"{llm_provider}_account_id", "") + + # 渲染输入框 + st_llm_api_key = st.text_input(tr("API Key"), value=llm_api_key, type="password") + st_llm_base_url = st.text_input(tr("Base Url"), value=llm_base_url) + st_llm_model_name = st.text_input(tr("Model Name"), value=llm_model_name) + + # 保存配置 + if st_llm_api_key: + config.app[f"{llm_provider}_api_key"] = st_llm_api_key + if st_llm_base_url: + config.app[f"{llm_provider}_base_url"] = st_llm_base_url + if st_llm_model_name: + config.app[f"{llm_provider}_model_name"] = st_llm_model_name + + # Cloudflare 特殊处理 + if llm_provider == 'cloudflare': + st_llm_account_id = st.text_input(tr("Account ID"), value=llm_account_id) + if st_llm_account_id: + config.app[f"{llm_provider}_account_id"] = st_llm_account_id \ No newline at end of file diff --git a/webui/components/review_settings.py b/webui/components/review_settings.py new file mode 100644 index 0000000..43f3844 --- /dev/null +++ b/webui/components/review_settings.py @@ -0,0 +1,65 @@ +import streamlit as st +import os +from loguru import logger + +def render_review_panel(tr): + """渲染视频审查面板""" + with st.expander(tr("Video Check"), expanded=False): + try: + video_list = st.session_state.get('video_clip_json', []) + except KeyError: + video_list = [] + + # 计算列数和行数 + num_videos = len(video_list) + cols_per_row = 3 + rows = (num_videos + cols_per_row - 1) // cols_per_row # 向上取整计算行数 + + # 使用容器展示视频 + for row in range(rows): + cols = st.columns(cols_per_row) + for col in range(cols_per_row): + index = row * cols_per_row + col + if index < num_videos: + with cols[col]: + render_video_item(tr, video_list, index) + +def render_video_item(tr, video_list, index): + """渲染单个视频项""" + video_info = video_list[index] + video_path = video_info.get('path') + if video_path is not None and os.path.exists(video_path): + initial_narration = video_info.get('narration', '') + initial_picture = video_info.get('picture', '') + initial_timestamp = video_info.get('timestamp', '') + + # 显示视频 + with open(video_path, 'rb') as video_file: + video_bytes = video_file.read() + st.video(video_bytes) + + # 显示信息(只读) + text_panels = st.columns(2) + with text_panels[0]: + st.text_area( + tr("timestamp"), + value=initial_timestamp, + height=20, + key=f"timestamp_{index}", + disabled=True + ) + with text_panels[1]: + st.text_area( + tr("Picture description"), + value=initial_picture, + height=20, + key=f"picture_{index}", + disabled=True + ) + st.text_area( + tr("Narration"), + value=initial_narration, + height=100, + key=f"narration_{index}", + disabled=True + ) \ No newline at end of file diff --git a/webui/components/script_settings.py b/webui/components/script_settings.py new file mode 100644 index 0000000..cc217d8 --- /dev/null +++ b/webui/components/script_settings.py @@ -0,0 +1,314 @@ +import streamlit as st +import os +import glob +import json +import time +from app.config import config +from app.models.schema import VideoClipParams +from app.services import llm +from app.utils import utils, check_script +from loguru import logger +from webui.utils import file_utils + +def render_script_panel(tr): + """渲染脚本配置面板""" + with st.container(border=True): + st.write(tr("Video Script Configuration")) + params = VideoClipParams() + + # 渲染脚本文件选择 + render_script_file(tr, params) + + # 渲染视频文件选择 + render_video_file(tr, params) + + # 渲染视频主题和提示词 + render_video_details(tr) + + # 渲染脚本操作按钮 + render_script_buttons(tr, params) + +def render_script_file(tr, params): + """渲染脚本文件选择""" + script_list = [(tr("None"), ""), (tr("Auto Generate"), "auto")] + + # 获取已有脚本文件 + suffix = "*.json" + script_dir = utils.script_dir() + files = glob.glob(os.path.join(script_dir, suffix)) + file_list = [] + + for file in files: + file_list.append({ + "name": os.path.basename(file), + "file": file, + "ctime": os.path.getctime(file) + }) + + file_list.sort(key=lambda x: x["ctime"], reverse=True) + for file in file_list: + display_name = file['file'].replace(config.root_dir, "") + script_list.append((display_name, file['file'])) + + # 找到保存的脚本文件在列表中的索引 + saved_script_path = st.session_state.get('video_clip_json_path', '') + selected_index = 0 + for i, (_, path) in enumerate(script_list): + if path == saved_script_path: + selected_index = i + break + + selected_script_index = st.selectbox( + tr("Script Files"), + index=selected_index, # 使用找到的索引 + options=range(len(script_list)), + format_func=lambda x: script_list[x][0] + ) + + script_path = script_list[selected_script_index][1] + st.session_state['video_clip_json_path'] = script_path + params.video_clip_json_path = script_path + +def render_video_file(tr, params): + """渲染视频文件选择""" + video_list = [(tr("None"), ""), (tr("Upload Local Files"), "local")] + + # 获取已有视频文件 + for suffix in ["*.mp4", "*.mov", "*.avi", "*.mkv"]: + video_files = glob.glob(os.path.join(utils.video_dir(), suffix)) + for file in video_files: + display_name = file.replace(config.root_dir, "") + video_list.append((display_name, file)) + + selected_video_index = st.selectbox( + tr("Video File"), + index=0, + options=range(len(video_list)), + format_func=lambda x: video_list[x][0] + ) + + video_path = video_list[selected_video_index][1] + st.session_state['video_origin_path'] = video_path + params.video_origin_path = video_path + + if video_path == "local": + uploaded_file = st.file_uploader( + tr("Upload Local Files"), + type=["mp4", "mov", "avi", "flv", "mkv"], + accept_multiple_files=False, + ) + + if uploaded_file is not None: + video_file_path = os.path.join(utils.video_dir(), uploaded_file.name) + file_name, file_extension = os.path.splitext(uploaded_file.name) + + if os.path.exists(video_file_path): + timestamp = time.strftime("%Y%m%d%H%M%S") + file_name_with_timestamp = f"{file_name}_{timestamp}" + video_file_path = os.path.join(utils.video_dir(), file_name_with_timestamp + file_extension) + + with open(video_file_path, "wb") as f: + f.write(uploaded_file.read()) + st.success(tr("File Uploaded Successfully")) + st.session_state['video_origin_path'] = video_file_path + params.video_origin_path = video_file_path + time.sleep(1) + st.rerun() + +def render_video_details(tr): + """渲染视频主题和提示词""" + video_theme = st.text_input(tr("Video Theme")) + prompt = st.text_area( + tr("Generation Prompt"), + value=st.session_state.get('video_plot', ''), + help=tr("Custom prompt for LLM, leave empty to use default prompt"), + height=180 + ) + st.session_state['video_name'] = video_theme + st.session_state['video_plot'] = prompt + return video_theme, prompt + +def render_script_buttons(tr, params): + """渲染脚本操作按钮""" + # 生成/加载按钮 + script_path = st.session_state.get('video_clip_json_path', '') + if script_path == "auto": + button_name = tr("Generate Video Script") + elif script_path: + button_name = tr("Load Video Script") + else: + button_name = tr("Please Select Script File") + + if st.button(button_name, key="script_action", disabled=not script_path): + if script_path == "auto": + generate_script(tr, params) + else: + load_script(tr, script_path) + + # 视频脚本编辑区 + video_clip_json_details = st.text_area( + tr("Video Script"), + value=json.dumps(st.session_state.get('video_clip_json', []), indent=2, ensure_ascii=False), + height=180 + ) + + # 操作按钮行 + button_cols = st.columns(3) + with button_cols[0]: + if st.button(tr("Check Format"), key="check_format", use_container_width=True): + check_script_format(tr, video_clip_json_details) + + with button_cols[1]: + if st.button(tr("Save Script"), key="save_script", use_container_width=True): + save_script(tr, video_clip_json_details) + + with button_cols[2]: + script_valid = st.session_state.get('script_format_valid', False) + if st.button(tr("Crop Video"), key="crop_video", disabled=not script_valid, use_container_width=True): + crop_video(tr, params) + +def check_script_format(tr, script_content): + """检查脚本格式""" + try: + result = check_script.check_format(script_content) + if result.get('success'): + st.success(tr("Script format check passed")) + st.session_state['script_format_valid'] = True + else: + st.error(f"{tr('Script format check failed')}: {result.get('message')}") + st.session_state['script_format_valid'] = False + except Exception as e: + st.error(f"{tr('Script format check error')}: {str(e)}") + st.session_state['script_format_valid'] = False + +def load_script(tr, script_path): + """加载脚本文件""" + try: + with open(script_path, 'r', encoding='utf-8') as f: + script = f.read() + script = utils.clean_model_output(script) + st.session_state['video_clip_json'] = json.loads(script) + st.success(tr("Script loaded successfully")) + st.rerun() + except Exception as e: + st.error(f"{tr('Failed to load script')}: {str(e)}") + +def generate_script(tr, params): + """生成视频脚本""" + progress_bar = st.progress(0) + status_text = st.empty() + + def update_progress(progress: float, message: str = ""): + progress_bar.progress(progress) + if message: + status_text.text(f"{progress}% - {message}") + else: + status_text.text(f"进度: {progress}%") + + try: + with st.spinner("正在生成脚本..."): + if not st.session_state.get('video_plot'): + st.warning("视频剧情为空; 会极大影响生成效果!") + + if params.video_clip_json_path == "" and params.video_origin_path != "": + update_progress(10, "压缩视频中...") + script = llm.generate_script( + video_path=params.video_origin_path, + video_plot=st.session_state.get('video_plot', ''), + video_name=st.session_state.get('video_name', ''), + language=params.video_language, + progress_callback=update_progress + ) + if script is None: + st.error("生成脚本失败,请检查日志") + st.stop() + else: + update_progress(90) + + script = utils.clean_model_output(script) + st.session_state['video_clip_json'] = json.loads(script) + else: + # 从本地加载 + with open(params.video_clip_json_path, 'r', encoding='utf-8') as f: + update_progress(50) + status_text.text("从本地加载中...") + script = f.read() + script = utils.clean_model_output(script) + st.session_state['video_clip_json'] = json.loads(script) + update_progress(100) + status_text.text("从本地加载成功") + + time.sleep(0.5) + progress_bar.progress(100) + status_text.text("脚本生成完成!") + st.success("视频脚本生成成功!") + except Exception as err: + st.error(f"生成过程中发生错误: {str(err)}") + finally: + time.sleep(2) + progress_bar.empty() + status_text.empty() + +def save_script(tr, video_clip_json_details): + """保存视频脚本""" + if not video_clip_json_details: + st.error(tr("请输入视频脚本")) + st.stop() + + with st.spinner(tr("Save Script")): + script_dir = utils.script_dir() + timestamp = time.strftime("%Y-%m%d-%H%M%S") + save_path = os.path.join(script_dir, f"{timestamp}.json") + + try: + data = json.loads(video_clip_json_details) + with open(save_path, 'w', encoding='utf-8') as file: + json.dump(data, file, ensure_ascii=False, indent=4) + st.session_state['video_clip_json'] = data + st.session_state['video_clip_json_path'] = save_path + + # 更新配置 + config.app["video_clip_json_path"] = save_path + + # 显示成功消息 + st.success(tr("Script saved successfully")) + + # 强制重新加载页面以更新选择框 + time.sleep(0.5) # 给一点时间让用户看到成功消息 + st.rerun() + + except Exception as err: + st.error(f"{tr('Failed to save script')}: {str(err)}") + st.stop() + +def crop_video(tr, params): + """裁剪视频""" + progress_bar = st.progress(0) + status_text = st.empty() + + def update_progress(progress): + progress_bar.progress(progress) + status_text.text(f"剪辑进度: {progress}%") + + try: + utils.cut_video(params, update_progress) + time.sleep(0.5) + progress_bar.progress(100) + status_text.text("剪辑完成!") + st.success("视频剪辑成功完成!") + except Exception as e: + st.error(f"剪辑过程中发生错误: {str(e)}") + finally: + time.sleep(2) + progress_bar.empty() + status_text.empty() + +def get_script_params(): + """获取脚本参数""" + return { + 'video_language': st.session_state.get('video_language', ''), + 'video_clip_json_path': st.session_state.get('video_clip_json_path', ''), + 'video_origin_path': st.session_state.get('video_origin_path', ''), + 'video_name': st.session_state.get('video_name', ''), + 'video_plot': st.session_state.get('video_plot', '') + } \ No newline at end of file diff --git a/webui/components/subtitle_settings.py b/webui/components/subtitle_settings.py new file mode 100644 index 0000000..9b94e3c --- /dev/null +++ b/webui/components/subtitle_settings.py @@ -0,0 +1,129 @@ +import streamlit as st +from app.config import config +from webui.utils.cache import get_fonts_cache +import os + +def render_subtitle_panel(tr): + """渲染字幕设置面板""" + with st.container(border=True): + st.write(tr("Subtitle Settings")) + + # 启用字幕选项 + enable_subtitles = st.checkbox(tr("Enable Subtitles"), value=True) + st.session_state['subtitle_enabled'] = enable_subtitles + + if enable_subtitles: + render_font_settings(tr) + render_position_settings(tr) + render_style_settings(tr) + +def render_font_settings(tr): + """渲染字体设置""" + # 获取字体列表 + font_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "resource", "fonts") + font_names = get_fonts_cache(font_dir) + + # 获取保存的字体设置 + saved_font_name = config.ui.get("font_name", "") + saved_font_name_index = 0 + if saved_font_name in font_names: + saved_font_name_index = font_names.index(saved_font_name) + + # 字体选择 + font_name = st.selectbox( + tr("Font"), + options=font_names, + index=saved_font_name_index + ) + config.ui["font_name"] = font_name + st.session_state['font_name'] = font_name + + # 字体大小 + font_cols = st.columns([0.3, 0.7]) + with font_cols[0]: + saved_text_fore_color = config.ui.get("text_fore_color", "#FFFFFF") + text_fore_color = st.color_picker( + tr("Font Color"), + saved_text_fore_color + ) + config.ui["text_fore_color"] = text_fore_color + st.session_state['text_fore_color'] = text_fore_color + + with font_cols[1]: + saved_font_size = config.ui.get("font_size", 60) + font_size = st.slider( + tr("Font Size"), + min_value=30, + max_value=100, + value=saved_font_size + ) + config.ui["font_size"] = font_size + st.session_state['font_size'] = font_size + +def render_position_settings(tr): + """渲染位置设置""" + subtitle_positions = [ + (tr("Top"), "top"), + (tr("Center"), "center"), + (tr("Bottom"), "bottom"), + (tr("Custom"), "custom"), + ] + + selected_index = st.selectbox( + tr("Position"), + index=2, + options=range(len(subtitle_positions)), + format_func=lambda x: subtitle_positions[x][0], + ) + + subtitle_position = subtitle_positions[selected_index][1] + st.session_state['subtitle_position'] = subtitle_position + + # 自定义位置处理 + if subtitle_position == "custom": + custom_position = st.text_input( + tr("Custom Position (% from top)"), + value="70.0" + ) + try: + custom_position_value = float(custom_position) + if custom_position_value < 0 or custom_position_value > 100: + st.error(tr("Please enter a value between 0 and 100")) + else: + st.session_state['custom_position'] = custom_position_value + except ValueError: + st.error(tr("Please enter a valid number")) + +def render_style_settings(tr): + """渲染样式设置""" + stroke_cols = st.columns([0.3, 0.7]) + + with stroke_cols[0]: + stroke_color = st.color_picker( + tr("Stroke Color"), + value="#000000" + ) + st.session_state['stroke_color'] = stroke_color + + with stroke_cols[1]: + stroke_width = st.slider( + tr("Stroke Width"), + min_value=0.0, + max_value=10.0, + value=1.5, + step=0.1 + ) + st.session_state['stroke_width'] = stroke_width + +def get_subtitle_params(): + """获取字幕参数""" + return { + 'enabled': st.session_state.get('subtitle_enabled', True), + 'font_name': st.session_state.get('font_name', ''), + 'font_size': st.session_state.get('font_size', 60), + 'text_fore_color': st.session_state.get('text_fore_color', '#FFFFFF'), + 'position': st.session_state.get('subtitle_position', 'bottom'), + 'custom_position': st.session_state.get('custom_position', 70.0), + 'stroke_color': st.session_state.get('stroke_color', '#000000'), + 'stroke_width': st.session_state.get('stroke_width', 1.5), + } \ No newline at end of file diff --git a/webui/components/video_settings.py b/webui/components/video_settings.py new file mode 100644 index 0000000..7942bee --- /dev/null +++ b/webui/components/video_settings.py @@ -0,0 +1,47 @@ +import streamlit as st +from app.models.schema import VideoClipParams, VideoAspect + +def render_video_panel(tr): + """渲染视频配置面板""" + with st.container(border=True): + st.write(tr("Video Settings")) + params = VideoClipParams() + render_video_config(tr, params) + +def render_video_config(tr, params): + """渲染视频配置""" + # 视频比例 + video_aspect_ratios = [ + (tr("Portrait"), VideoAspect.portrait.value), + (tr("Landscape"), VideoAspect.landscape.value), + ] + selected_index = st.selectbox( + tr("Video Ratio"), + options=range(len(video_aspect_ratios)), + format_func=lambda x: video_aspect_ratios[x][0], + ) + params.video_aspect = VideoAspect(video_aspect_ratios[selected_index][1]) + st.session_state['video_aspect'] = params.video_aspect.value + + # 视频画质 + video_qualities = [ + ("4K (2160p)", "2160p"), + ("2K (1440p)", "1440p"), + ("Full HD (1080p)", "1080p"), + ("HD (720p)", "720p"), + ("SD (480p)", "480p"), + ] + quality_index = st.selectbox( + tr("Video Quality"), + options=range(len(video_qualities)), + format_func=lambda x: video_qualities[x][0], + index=2 # 默认选择 1080p + ) + st.session_state['video_quality'] = video_qualities[quality_index][1] + +def get_video_params(): + """获取视频参数""" + return { + 'video_aspect': st.session_state.get('video_aspect', VideoAspect.portrait.value), + 'video_quality': st.session_state.get('video_quality', '1080p') + } \ No newline at end of file diff --git a/webui/config/settings.py b/webui/config/settings.py new file mode 100644 index 0000000..6ad4db3 --- /dev/null +++ b/webui/config/settings.py @@ -0,0 +1,155 @@ +import os +import tomli +from loguru import logger +from typing import Dict, Any, Optional +from dataclasses import dataclass + +@dataclass +class WebUIConfig: + """WebUI配置类""" + # UI配置 + ui: Dict[str, Any] = None + # 代理配置 + proxy: Dict[str, str] = None + # 应用配置 + app: Dict[str, Any] = None + # Azure配置 + azure: Dict[str, str] = None + # 项目版本 + project_version: str = "0.1.0" + # 项目根目录 + root_dir: str = None + + def __post_init__(self): + """初始化默认值""" + self.ui = self.ui or {} + self.proxy = self.proxy or {} + self.app = self.app or {} + self.azure = self.azure or {} + self.root_dir = self.root_dir or os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + +def load_config(config_path: Optional[str] = None) -> WebUIConfig: + """加载配置文件 + Args: + config_path: 配置文件路径,如果为None则使用默认路径 + Returns: + WebUIConfig: 配置对象 + """ + try: + if config_path is None: + config_path = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + ".streamlit", + "webui.toml" + ) + + # 如果配置文件不存在,使用示例配置 + if not os.path.exists(config_path): + example_config = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), + "config.example.toml" + ) + if os.path.exists(example_config): + config_path = example_config + else: + logger.warning(f"配置文件不存在: {config_path}") + return WebUIConfig() + + # 读取配置文件 + with open(config_path, "rb") as f: + config_dict = tomli.load(f) + + # 创建配置对象 + config = WebUIConfig( + ui=config_dict.get("ui", {}), + proxy=config_dict.get("proxy", {}), + app=config_dict.get("app", {}), + azure=config_dict.get("azure", {}), + project_version=config_dict.get("project_version", "0.1.0") + ) + + return config + + except Exception as e: + logger.error(f"加载配置文件失败: {e}") + return WebUIConfig() + +def save_config(config: WebUIConfig, config_path: Optional[str] = None) -> bool: + """保存配置到文件 + Args: + config: 配置对象 + config_path: 配置文件路径,如果为None则使用默认路径 + Returns: + bool: 是否保存成功 + """ + try: + if config_path is None: + config_path = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + ".streamlit", + "webui.toml" + ) + + # 确保目录存在 + os.makedirs(os.path.dirname(config_path), exist_ok=True) + + # 转换为字典 + config_dict = { + "ui": config.ui, + "proxy": config.proxy, + "app": config.app, + "azure": config.azure, + "project_version": config.project_version + } + + # 保存配置 + with open(config_path, "w", encoding="utf-8") as f: + import tomli_w + tomli_w.dump(config_dict, f) + + return True + + except Exception as e: + logger.error(f"保存配置文件失败: {e}") + return False + +def get_config() -> WebUIConfig: + """获取全局配置对象 + Returns: + WebUIConfig: 配置对象 + """ + if not hasattr(get_config, "_config"): + get_config._config = load_config() + return get_config._config + +def update_config(config_dict: Dict[str, Any]) -> bool: + """更新配置 + Args: + config_dict: 配置字典 + Returns: + bool: 是否更新成功 + """ + try: + config = get_config() + + # 更新配置 + if "ui" in config_dict: + config.ui.update(config_dict["ui"]) + if "proxy" in config_dict: + config.proxy.update(config_dict["proxy"]) + if "app" in config_dict: + config.app.update(config_dict["app"]) + if "azure" in config_dict: + config.azure.update(config_dict["azure"]) + if "project_version" in config_dict: + config.project_version = config_dict["project_version"] + + # 保存配置 + return save_config(config) + + except Exception as e: + logger.error(f"更新配置失败: {e}") + return False + +# 导出全局配置对象 +config = get_config() \ No newline at end of file diff --git a/webui/i18n/__init__.py b/webui/i18n/__init__.py new file mode 100644 index 0000000..0f05c76 --- /dev/null +++ b/webui/i18n/__init__.py @@ -0,0 +1 @@ +# 空文件,用于标记包 \ No newline at end of file diff --git a/webui/i18n/zh.json b/webui/i18n/zh.json index f1bc6b2..0481306 100644 --- a/webui/i18n/zh.json +++ b/webui/i18n/zh.json @@ -2,15 +2,15 @@ "Language": "简体中文", "Translation": { "Video Script Configuration": "**视频脚本配置**", - "Video Script Generate": "生成视频脚本", + "Generate Video Script": "生成视频脚本", "Video Subject": "视频主题(给定一个关键词,:red[AI自动生成]视频文案)", "Script Language": "生成视频脚本的语言(一般情况AI会自动根据你输入的主题语言输出)", "Script Files": "脚本文件", "Generate Video Script and Keywords": "点击使用AI根据**主题**生成 【视频文案】 和 【视频关键词】", "Auto Detect": "自动检测", "Auto Generate": "自动生成", - "Video Name": "视频名称", - "Video Script": "视频脚本(:blue[①使用AI生成 ②从本机加载])", + "Video Theme": "视频主题", + "Generation Prompt": "自定义提示词", "Save Script": "保存脚本", "Crop Video": "裁剪视频", "Video File": "视频文件(:blue[1️⃣支持上传视频文件(限制2G) 2️⃣大文件建议直接导入 ./resource/videos 目录])", @@ -91,7 +91,18 @@ "Picture description": "图片描述", "Narration": "视频文案", "Rebuild": "重新生成", - "Video Script Load": "加载视频脚本", - "Speech Pitch": "语调" + "Load Video Script": "加载视频脚本", + "Speech Pitch": "语调", + "Please Select Script File": "请选择脚本文件", + "Check Format": "脚本格式检查", + "Script Loaded Successfully": "脚本加载成功", + "Script format check passed": "脚本格式检查通过", + "Script format check failed": "脚本格式检查失败", + "Failed to Load Script": "加载脚本失败", + "Failed to Save Script": "保存脚本失败", + "Script saved successfully": "脚本保存成功", + "Video Script": "视频脚本", + "Video Quality": "视频质量", + "Custom prompt for LLM, leave empty to use default prompt": "自定义提示词,留空则使用默认提示词" } } \ No newline at end of file diff --git a/webui/utils/__init__.py b/webui/utils/__init__.py new file mode 100644 index 0000000..b6a1870 --- /dev/null +++ b/webui/utils/__init__.py @@ -0,0 +1,20 @@ +from .cache import get_fonts_cache, get_video_files_cache, get_songs_cache +from .file_utils import ( + open_task_folder, cleanup_temp_files, get_file_list, + save_uploaded_file, create_temp_file, get_file_size, ensure_directory +) +from .performance import monitor_performance + +__all__ = [ + 'get_fonts_cache', + 'get_video_files_cache', + 'get_songs_cache', + 'open_task_folder', + 'cleanup_temp_files', + 'get_file_list', + 'save_uploaded_file', + 'create_temp_file', + 'get_file_size', + 'ensure_directory', + 'monitor_performance' +] \ No newline at end of file diff --git a/webui/utils/cache.py b/webui/utils/cache.py new file mode 100644 index 0000000..6cc3b05 --- /dev/null +++ b/webui/utils/cache.py @@ -0,0 +1,33 @@ +import streamlit as st +import os +import glob +from app.utils import utils + +def get_fonts_cache(font_dir): + if 'fonts_cache' not in st.session_state: + fonts = [] + for root, dirs, files in os.walk(font_dir): + for file in files: + if file.endswith(".ttf") or file.endswith(".ttc"): + fonts.append(file) + fonts.sort() + st.session_state['fonts_cache'] = fonts + return st.session_state['fonts_cache'] + +def get_video_files_cache(): + if 'video_files_cache' not in st.session_state: + video_files = [] + for suffix in ["*.mp4", "*.mov", "*.avi", "*.mkv"]: + video_files.extend(glob.glob(os.path.join(utils.video_dir(), suffix))) + st.session_state['video_files_cache'] = video_files[::-1] + return st.session_state['video_files_cache'] + +def get_songs_cache(song_dir): + if 'songs_cache' not in st.session_state: + songs = [] + for root, dirs, files in os.walk(song_dir): + for file in files: + if file.endswith(".mp3"): + songs.append(file) + st.session_state['songs_cache'] = songs + return st.session_state['songs_cache'] \ No newline at end of file diff --git a/webui/utils/file_utils.py b/webui/utils/file_utils.py new file mode 100644 index 0000000..458efa6 --- /dev/null +++ b/webui/utils/file_utils.py @@ -0,0 +1,189 @@ +import os +import glob +import time +import platform +import shutil +from uuid import uuid4 +from loguru import logger +from app.utils import utils + +def open_task_folder(root_dir, task_id): + """打开任务文件夹 + Args: + root_dir: 项目根目录 + task_id: 任务ID + """ + try: + sys = platform.system() + path = os.path.join(root_dir, "storage", "tasks", task_id) + if os.path.exists(path): + if sys == 'Windows': + os.system(f"start {path}") + if sys == 'Darwin': + os.system(f"open {path}") + if sys == 'Linux': + os.system(f"xdg-open {path}") + except Exception as e: + logger.error(f"打开任务文件夹失败: {e}") + +def cleanup_temp_files(temp_dir, max_age=3600): + """清理临时文件 + Args: + temp_dir: 临时文件目录 + max_age: 文件最大保存时间(秒) + """ + if os.path.exists(temp_dir): + for file in os.listdir(temp_dir): + file_path = os.path.join(temp_dir, file) + try: + if os.path.getctime(file_path) < time.time() - max_age: + if os.path.isfile(file_path): + os.remove(file_path) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) + logger.debug(f"已清理临时文件: {file_path}") + except Exception as e: + logger.error(f"清理临时文件失败: {file_path}, 错误: {e}") + +def get_file_list(directory, file_types=None, sort_by='ctime', reverse=True): + """获取指定目录下的文件列表 + Args: + directory: 目录路径 + file_types: 文件类型列表,如 ['.mp4', '.mov'] + sort_by: 排序方式,支持 'ctime'(创建时间), 'mtime'(修改时间), 'size'(文件大小), 'name'(文件名) + reverse: 是否倒序排序 + Returns: + list: 文件信息列表 + """ + if not os.path.exists(directory): + return [] + + files = [] + if file_types: + for file_type in file_types: + files.extend(glob.glob(os.path.join(directory, f"*{file_type}"))) + else: + files = glob.glob(os.path.join(directory, "*")) + + file_list = [] + for file_path in files: + try: + file_stat = os.stat(file_path) + file_info = { + "name": os.path.basename(file_path), + "path": file_path, + "size": file_stat.st_size, + "ctime": file_stat.st_ctime, + "mtime": file_stat.st_mtime + } + file_list.append(file_info) + except Exception as e: + logger.error(f"获取文件信息失败: {file_path}, 错误: {e}") + + # 排序 + if sort_by in ['ctime', 'mtime', 'size', 'name']: + file_list.sort(key=lambda x: x.get(sort_by, ''), reverse=reverse) + + return file_list + +def save_uploaded_file(uploaded_file, save_dir, allowed_types=None): + """保存上传的文件 + Args: + uploaded_file: StreamlitUploadedFile对象 + save_dir: 保存目录 + allowed_types: 允许的文件类型列表,如 ['.mp4', '.mov'] + Returns: + str: 保存后的文件路径,失败返回None + """ + try: + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + file_name, file_extension = os.path.splitext(uploaded_file.name) + + # 检查文件类型 + if allowed_types and file_extension.lower() not in allowed_types: + logger.error(f"不支持的文件类型: {file_extension}") + return None + + # 如果文件已存在,添加时间戳 + save_path = os.path.join(save_dir, uploaded_file.name) + if os.path.exists(save_path): + timestamp = time.strftime("%Y%m%d%H%M%S") + new_file_name = f"{file_name}_{timestamp}{file_extension}" + save_path = os.path.join(save_dir, new_file_name) + + # 保存文件 + with open(save_path, "wb") as f: + f.write(uploaded_file.read()) + + logger.info(f"文件保存成功: {save_path}") + return save_path + + except Exception as e: + logger.error(f"保存上传文件失败: {e}") + return None + +def create_temp_file(prefix='tmp', suffix='', directory=None): + """创建临时文件 + Args: + prefix: 文件名前缀 + suffix: 文件扩展名 + directory: 临时文件目录,默认使用系统临时目录 + Returns: + str: 临时文件路径 + """ + try: + if directory is None: + directory = utils.storage_dir("temp", create=True) + + if not os.path.exists(directory): + os.makedirs(directory) + + temp_file = os.path.join(directory, f"{prefix}-{str(uuid4())}{suffix}") + return temp_file + + except Exception as e: + logger.error(f"创建临时文件失败: {e}") + return None + +def get_file_size(file_path, format='MB'): + """获取文件大小 + Args: + file_path: 文件路径 + format: 返回格式,支持 'B', 'KB', 'MB', 'GB' + Returns: + float: 文件大小 + """ + try: + size_bytes = os.path.getsize(file_path) + + if format.upper() == 'B': + return size_bytes + elif format.upper() == 'KB': + return size_bytes / 1024 + elif format.upper() == 'MB': + return size_bytes / (1024 * 1024) + elif format.upper() == 'GB': + return size_bytes / (1024 * 1024 * 1024) + else: + return size_bytes + + except Exception as e: + logger.error(f"获取文件大小失败: {file_path}, 错误: {e}") + return 0 + +def ensure_directory(directory): + """确保目录存在,如果不存在则创建 + Args: + directory: 目录路径 + Returns: + bool: 是否成功 + """ + try: + if not os.path.exists(directory): + os.makedirs(directory) + return True + except Exception as e: + logger.error(f"创建目录失败: {directory}, 错误: {e}") + return False \ No newline at end of file diff --git a/webui/utils/performance.py b/webui/utils/performance.py new file mode 100644 index 0000000..76b7dcb --- /dev/null +++ b/webui/utils/performance.py @@ -0,0 +1,24 @@ +import time +from loguru import logger + +try: + import psutil + ENABLE_PERFORMANCE_MONITORING = True +except ImportError: + ENABLE_PERFORMANCE_MONITORING = False + logger.warning("psutil not installed. Performance monitoring is disabled.") + +def monitor_performance(): + if not ENABLE_PERFORMANCE_MONITORING: + return {'execution_time': 0, 'memory_usage': 0} + + start_time = time.time() + try: + memory_usage = psutil.Process().memory_info().rss / 1024 / 1024 # MB + except: + memory_usage = 0 + + return { + 'execution_time': time.time() - start_time, + 'memory_usage': memory_usage + } \ No newline at end of file