优化 webui 代码逻辑

This commit is contained in:
linyq 2024-09-24 18:25:02 +08:00
parent 93188e1328
commit 6669b28361
9 changed files with 84 additions and 73 deletions

View File

@ -166,7 +166,7 @@ sudo yum install ImageMagick
``` ```
3. 启动 webui 3. 启动 webui
```shell ```shell
streamlit run ./webui/Main.py --browser.serverAddress=127.0.0.1 --server.enableCORS=True --browser.gatherUsageStats=False streamlit run ./webui/webui.py --browser.serverAddress=127.0.0.1 --server.enableCORS=True --browser.gatherUsageStats=False
``` ```
4. 访问 http://127.0.0.1:8501 4. 访问 http://127.0.0.1:8501

View File

@ -167,7 +167,7 @@ sudo yum install ImageMagick
3. initiate webui 3. initiate webui
```shell ```shell
streamlit run ./webui/Main.py --browser.serverAddress=127.0.0.1 --server.enableCORS=True --browser.gatherUsageStats=False streamlit run ./webui/webui.py --browser.serverAddress=127.0.0.1 --server.enableCORS=True --browser.gatherUsageStats=False
``` ```
4. Access http://127.0.0.1:8501 4. Access http://127.0.0.1:8501

View File

@ -339,7 +339,7 @@ class VideoClipParams(BaseModel):
video_count: Optional[int] = 1 # 视频片段数量 video_count: Optional[int] = 1 # 视频片段数量
video_source: Optional[str] = "local" video_source: Optional[str] = "local"
video_language: Optional[str] = "" # 自动检测 video_language: Optional[str] = "" # 自动检测
video_concat_mode: Optional[VideoConcatMode] = VideoConcatMode.random.value # video_concat_mode: Optional[VideoConcatMode] = VideoConcatMode.random.value
# # 女性 # # 女性
# "zh-CN-XiaoxiaoNeural", # "zh-CN-XiaoxiaoNeural",
@ -366,5 +366,6 @@ class VideoClipParams(BaseModel):
font_size: int = 60 # 文字大小 font_size: int = 60 # 文字大小
stroke_color: Optional[str] = "#000000" # 文字描边颜色 stroke_color: Optional[str] = "#000000" # 文字描边颜色
stroke_width: float = 1.5 # 文字描边宽度 stroke_width: float = 1.5 # 文字描边宽度
custom_position: float = 70.0 # 自定义位置
n_threads: Optional[int] = 2 # 线程数 n_threads: Optional[int] = 2 # 线程数
paragraph_number: Optional[int] = 1 # 段落数量 paragraph_number: Optional[int] = 1 # 段落数量

View File

@ -20,7 +20,9 @@ services:
dockerfile: Dockerfile dockerfile: Dockerfile
container_name: "api" container_name: "api"
ports: ports:
- "8502:8080" - "8502:22"
command: [ "python3", "main.py" ] command: [ "sleep", "48h" ]
volumes: *common-volumes volumes: *common-volumes
environment:
- "VPN_PROXY_URL=http://host.docker.internal:7890"
restart: always restart: always

View File

@ -40,4 +40,4 @@ pause
rem set HF_ENDPOINT=https://hf-mirror.com rem set HF_ENDPOINT=https://hf-mirror.com
streamlit run .\webui\Main.py --browser.gatherUsageStats=False --server.enableCORS=True streamlit run webui.py --browser.gatherUsageStats=False --server.enableCORS=True

View File

@ -5,24 +5,26 @@ import json
import time import time
import datetime import datetime
import traceback import traceback
import streamlit as st
from uuid import uuid4
import platform
import streamlit.components.v1 as components
from loguru import logger
# 将项目的根目录添加到系统路径中,以允许从项目导入模块 from app.config import config
root_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) from app.models.const import FILE_TYPE_VIDEOS
from app.models.schema import VideoClipParams, VideoAspect, VideoConcatMode
from app.services import task as tm, llm, voice, material
from app.utils import utils
# # 将项目的根目录添加到系统路径中,以允许从项目导入模块
root_dir = os.path.dirname(os.path.realpath(__file__))
if root_dir not in sys.path: if root_dir not in sys.path:
sys.path.append(root_dir) sys.path.append(root_dir)
print("******** sys.path ********") print("******** sys.path ********")
print(sys.path) print(sys.path)
print("") print("")
import streamlit as st
import os
from uuid import uuid4
import platform
import streamlit.components.v1 as components
from loguru import logger
from app.config import config
st.set_page_config( st.set_page_config(
page_title="NarratoAI", page_title="NarratoAI",
page_icon="📽️", page_icon="📽️",
@ -35,11 +37,6 @@ st.set_page_config(
}, },
) )
from app.models.const import FILE_TYPE_IMAGES, FILE_TYPE_VIDEOS
from app.models.schema import VideoClipParams, VideoAspect, VideoConcatMode
from app.services import task as tm, llm, voice, material
from app.utils import utils
proxy_url_http = config.proxy.get("http", "") or os.getenv("VPN_PROXY_URL", "") proxy_url_http = config.proxy.get("http", "") or os.getenv("VPN_PROXY_URL", "")
proxy_url_https = config.proxy.get("https", "") or os.getenv("VPN_PROXY_URL", "") proxy_url_https = config.proxy.get("https", "") or os.getenv("VPN_PROXY_URL", "")
os.environ["HTTP_PROXY"] = proxy_url_http os.environ["HTTP_PROXY"] = proxy_url_http
@ -278,18 +275,23 @@ with left_panel:
"name": os.path.basename(file), "name": os.path.basename(file),
"size": os.path.getsize(file), "size": os.path.getsize(file),
"file": file, "file": file,
"ctime": os.path.getctime(file) # 获取文件创建时间
}) })
script_path = [(tr("Auto Generate"), ""), ] # 按创建时间降序排序
for code in [file['file'] for file in script_list]: script_list.sort(key=lambda x: x["ctime"], reverse=True)
script_path.append((code, code))
selected_json2 = st.selectbox(tr("Script Files"), # 脚本文件 下拉框
script_path = [(tr("Auto Generate"), ""), ]
for file in script_list:
display_name = file['file'].replace(root_dir, "")
script_path.append((display_name, file['file']))
selected_script_index = st.selectbox(tr("Script Files"),
index=0, index=0,
options=range(len(script_path)), # 使用索引作为内部选项值 options=range(len(script_path)), # 使用索引作为内部选项值
format_func=lambda x: script_path[x][0] # 显示给用户的是标签 format_func=lambda x: script_path[x][0] # 显示给用户的是标签
) )
params.video_clip_json = script_path[selected_json2][1] params.video_clip_json = script_path[selected_script_index][1]
video_json_file = params.video_clip_json video_json_file = params.video_clip_json
# 视频文件处理 # 视频文件处理
@ -310,12 +312,12 @@ with left_panel:
for code in [file['file'] for file in video_list]: for code in [file['file'] for file in video_list]:
video_path.append((code, code)) video_path.append((code, code))
selected_index2 = st.selectbox(tr("Video File"), selected_video_index = st.selectbox(tr("Video File"),
index=0, index=0,
options=range(len(video_path)), # 使用索引作为内部选项值 options=range(len(video_path)), # 使用索引作为内部选项值
format_func=lambda x: video_path[x][0] # 显示给用户的是标签 format_func=lambda x: video_path[x][0] # 显示给用户的是标签
) )
params.video_origin_path = video_path[selected_index2][1] params.video_origin_path = video_path[selected_video_index][1]
config.app["video_origin_path"] = params.video_origin_path config.app["video_origin_path"] = params.video_origin_path
# 从本地上传 mp4 文件 # 从本地上传 mp4 文件
@ -341,8 +343,6 @@ with left_panel:
st.success(tr("File Uploaded Successfully")) st.success(tr("File Uploaded Successfully"))
time.sleep(1) time.sleep(1)
st.rerun() st.rerun()
# params.video_origin_path = video_path[selected_index2][1]
# config.app["video_origin_path"] = params.video_origin_path
# 剧情内容 # 剧情内容
video_plot = st.text_area( video_plot = st.text_area(
@ -351,12 +351,13 @@ with left_panel:
height=180 height=180
) )
# 生成视频脚本
if st.button(tr("Video Script Generate"), key="auto_generate_script"): if st.button(tr("Video Script Generate"), key="auto_generate_script"):
with st.spinner(tr("Video Script Generate")): with st.spinner(tr("Video Script Generate")):
if video_json_file == "" and params.video_origin_path != "": if video_json_file == "" and params.video_origin_path != "":
# 使用大模型生成视频脚本 # 使用大模型生成视频脚本
script = llm.gemini_video2json( script = llm.gemini_video2json(
video_origin_name=params.video_origin_path.split("\\")[-1], video_origin_name=os.path.basename(params.video_origin_path),
video_origin_path=params.video_origin_path, video_origin_path=params.video_origin_path,
video_plot=video_plot, video_plot=video_plot,
language=params.video_language, language=params.video_language,
@ -371,12 +372,14 @@ with left_panel:
cleaned_string = script.strip("```json").strip("```") cleaned_string = script.strip("```json").strip("```")
st.session_state['video_script_list'] = json.loads(cleaned_string) st.session_state['video_script_list'] = json.loads(cleaned_string)
# 视频脚本
video_clip_json_details = st.text_area( video_clip_json_details = st.text_area(
tr("Video Script"), tr("Video Script"),
value=st.session_state['video_clip_json'], value=st.session_state['video_clip_json'],
height=180 height=180
) )
# 保存脚本
button_columns = st.columns(2) button_columns = st.columns(2)
with button_columns[0]: with button_columns[0]:
if st.button(tr("Save Script"), key="auto_generate_terms", use_container_width=True): if st.button(tr("Save Script"), key="auto_generate_terms", use_container_width=True):
@ -397,20 +400,23 @@ with left_panel:
try: try:
data = utils.add_new_timestamps(json.loads(input_json)) data = utils.add_new_timestamps(json.loads(input_json))
except Exception as err: except Exception as err:
raise ValueError( st.error(f"视频脚本格式错误,请检查脚本是否符合 JSON 格式;{err} \n\n{traceback.format_exc()}")
f"视频脚本格式错误,请检查脚本是否符合 JSON 格式;{err} \n\n{traceback.format_exc()}") st.stop()
# 检查是否是一个列表 # 检查是否是一个列表
if not isinstance(data, list): if not isinstance(data, list):
raise ValueError("JSON is not a list") st.error("JSON is not a list")
st.stop()
# 检查列表中的每个元素是否包含所需的键 # 检查列表中的每个元素是否包含所需的键
required_keys = {"picture", "timestamp", "narration"} required_keys = {"picture", "timestamp", "narration"}
for item in data: for item in data:
if not isinstance(item, dict): if not isinstance(item, dict):
raise ValueError("List 元素不是字典") st.error("List 元素不是字典")
st.stop()
if not required_keys.issubset(item.keys()): if not required_keys.issubset(item.keys()):
raise ValueError("Dict 元素不包含必需的键") st.error("Dict 元素不包含必需的键")
st.stop()
# 存储为新的 JSON 文件 # 存储为新的 JSON 文件
with open(save_path, 'w', encoding='utf-8') as file: with open(save_path, 'w', encoding='utf-8') as file:
@ -441,13 +447,13 @@ with left_panel:
for video_script in video_script_list: for video_script in video_script_list:
try: try:
video_script['path'] = subclip_videos[video_script['timestamp']] video_script['path'] = subclip_videos[video_script['timestamp']]
except KeyError as e: except KeyError as err:
st.error(f"裁剪视频失败") st.error(f"裁剪视频失败 {err}")
# logger.debug(f"当前的脚本为:{st.session_state.video_script_list}") # logger.debug(f"当前的脚本为:{st.session_state.video_script_list}")
else: else:
st.error(tr("请先生成视频脚本")) st.error(tr("请先生成视频脚本"))
# 裁剪视频
with button_columns[1]: with button_columns[1]:
if st.button(tr("Crop Video"), key="auto_crop_video", use_container_width=True): if st.button(tr("Crop Video"), key="auto_crop_video", use_container_width=True):
caijian() caijian()
@ -456,10 +462,10 @@ with left_panel:
with middle_panel: with middle_panel:
with st.container(border=True): with st.container(border=True):
st.write(tr("Video Settings")) st.write(tr("Video Settings"))
video_concat_modes = [ # video_concat_modes = [
(tr("Sequential"), "sequential"), # (tr("Sequential"), "sequential"),
(tr("Random"), "random"), # (tr("Random"), "random"),
] # ]
# video_sources = [ # video_sources = [
# (tr("Pexels"), "pexels"), # (tr("Pexels"), "pexels"),
# (tr("Pixabay"), "pixabay"), # (tr("Pixabay"), "pixabay"),
@ -491,16 +497,17 @@ with middle_panel:
# accept_multiple_files=True, # accept_multiple_files=True,
# ) # )
selected_index = st.selectbox( # selected_index = st.selectbox(
tr("Video Concat Mode"), # tr("Video Concat Mode"),
index=1, # index=1,
options=range(len(video_concat_modes)), # 使用索引作为内部选项值 # options=range(len(video_concat_modes)), # 使用索引作为内部选项值
format_func=lambda x: video_concat_modes[x][0], # 显示给用户的是标签 # format_func=lambda x: video_concat_modes[x][0], # 显示给用户的是标签
) # )
params.video_concat_mode = VideoConcatMode( # params.video_concat_mode = VideoConcatMode(
video_concat_modes[selected_index][1] # video_concat_modes[selected_index][1]
) # )
# 视频比例
video_aspect_ratios = [ video_aspect_ratios = [
(tr("Portrait"), VideoAspect.portrait.value), (tr("Portrait"), VideoAspect.portrait.value),
(tr("Landscape"), VideoAspect.landscape.value), (tr("Landscape"), VideoAspect.landscape.value),
@ -512,14 +519,14 @@ with middle_panel:
) )
params.video_aspect = VideoAspect(video_aspect_ratios[selected_index][1]) params.video_aspect = VideoAspect(video_aspect_ratios[selected_index][1])
params.video_clip_duration = st.selectbox( # params.video_clip_duration = st.selectbox(
tr("Clip Duration"), options=[2, 3, 4, 5, 6, 7, 8, 9, 10], index=1 # tr("Clip Duration"), options=[2, 3, 4, 5, 6, 7, 8, 9, 10], index=1
) # )
params.video_count = st.selectbox( # params.video_count = st.selectbox(
tr("Number of Videos Generated Simultaneously"), # tr("Number of Videos Generated Simultaneously"),
options=[1, 2, 3, 4, 5], # options=[1, 2, 3, 4, 5],
index=0, # index=0,
) # )
with st.container(border=True): with st.container(border=True):
st.write(tr("Audio Settings")) st.write(tr("Audio Settings"))
@ -638,7 +645,7 @@ with middle_panel:
index=2, index=2,
) )
# 新侧面板 # 新侧面板
with right_panel: with right_panel:
with st.container(border=True): with st.container(border=True):
st.write(tr("Subtitle Settings")) st.write(tr("Subtitle Settings"))
@ -676,6 +683,7 @@ with right_panel:
if params.custom_position < 0 or params.custom_position > 100: if params.custom_position < 0 or params.custom_position > 100:
st.error(tr("Please enter a value between 0 and 100")) st.error(tr("Please enter a value between 0 and 100"))
except ValueError: except ValueError:
logger.error(f"输入的值无效: {traceback.format_exc()}")
st.error(tr("Please enter a valid number")) st.error(tr("Please enter a valid number"))
font_cols = st.columns([0.3, 0.7]) font_cols = st.columns([0.3, 0.7])

View File

@ -47,4 +47,4 @@ done
# 等待所有后台任务完成 # 等待所有后台任务完成
wait wait
echo "所有文件已成功下载到指定目录" echo "所有文件已成功下载到指定目录"
streamlit run ./webui/Main.py --browser.serverAddress="0.0.0.0" --server.enableCORS=True --server.maxUploadSize=2048 --browser.gatherUsageStats=False streamlit run webui.py --browser.serverAddress="0.0.0.0" --server.enableCORS=True --server.maxUploadSize=2048 --browser.gatherUsageStats=False

View File

@ -73,7 +73,7 @@
"Please Enter the LLM API Key": "Please enter the **LLM API Key**", "Please Enter the LLM API Key": "Please enter the **LLM API Key**",
"Please Enter the Pexels API Key": "Please enter the **Pexels API Key**", "Please Enter the Pexels API Key": "Please enter the **Pexels API Key**",
"Please Enter the Pixabay API Key": "Please enter the **Pixabay API Key**", "Please Enter the Pixabay API Key": "Please enter the **Pixabay API Key**",
"Get Help": "One-stop AI video commentary + automated editing tool\uD83C\uDF89\uD83C\uDF89\uD83C\uDF89\n\nFor any questions or suggestions, you can join the **community channel** for help or discussion: https://discord.gg/WBKChhmZ", "Get Help": "One-stop AI video commentary + automated editing tool\uD83C\uDF89\uD83C\uDF89\uD83C\uDF89\n\nFor any questions or suggestions, you can join the **community channel** for help or discussion: https://github.com/linyqh/NarratoAI/wiki",
"Video Source": "Video Source", "Video Source": "Video Source",
"TikTok": "TikTok (Support is coming soon)", "TikTok": "TikTok (Support is coming soon)",
"Bilibili": "Bilibili (Support is coming soon)", "Bilibili": "Bilibili (Support is coming soon)",

View File

@ -73,7 +73,7 @@
"Please Enter the LLM API Key": "请先填写大模型 **API Key**", "Please Enter the LLM API Key": "请先填写大模型 **API Key**",
"Please Enter the Pexels API Key": "请先填写 **Pexels API Key**", "Please Enter the Pexels API Key": "请先填写 **Pexels API Key**",
"Please Enter the Pixabay API Key": "请先填写 **Pixabay API Key**", "Please Enter the Pixabay API Key": "请先填写 **Pixabay API Key**",
"Get Help": "一站式 AI 影视解说+自动化剪辑工具\uD83C\uDF89\uD83C\uDF89\uD83C\uDF89\n\n有任何问题或建议可以加入 **社区频道** 求助或讨论https://discord.gg/WBKChhmZ", "Get Help": "一站式 AI 影视解说+自动化剪辑工具\uD83C\uDF89\uD83C\uDF89\uD83C\uDF89\n\n有任何问题或建议可以加入 **社区频道** 求助或讨论https://github.com/linyqh/NarratoAI/wiki",
"Video Source": "视频来源", "Video Source": "视频来源",
"TikTok": "抖音 (TikTok 支持中,敬请期待)", "TikTok": "抖音 (TikTok 支持中,敬请期待)",
"Bilibili": "哔哩哔哩 (Bilibili 支持中,敬请期待)", "Bilibili": "哔哩哔哩 (Bilibili 支持中,敬请期待)",