优化 webui task 逻辑30%;新增检查/修复脚本方法

This commit is contained in:
linyq 2024-09-25 18:32:38 +08:00
parent d6663fde21
commit 990994e9cd
6 changed files with 269 additions and 124 deletions

View File

@ -3,7 +3,7 @@ from enum import Enum
from typing import Any, List, Optional
import pydantic
from pydantic import BaseModel
from pydantic import BaseModel, Field
# 忽略 Pydantic 的特定警告
warnings.filterwarnings(
@ -330,42 +330,39 @@ class BgmUploadResponse(BaseResponse):
class VideoClipParams(BaseModel):
video_subject: Optional[str] = "春天的花海让人心旷神怡"
"""
NarratoAI 数据模型
"""
video_clip_json: Optional[list] = Field(default=[], description="LLM 生成的视频剪辑脚本内容")
video_clip_json_path: Optional[str] = Field(default="", description="LLM 生成的视频剪辑脚本路径")
video_origin_path: Optional[str] = Field(default="", description="原视频路径")
video_aspect: Optional[VideoAspect] = Field(default=VideoAspect.portrait.value, description="视频比例")
video_language: Optional[str] = Field(default="zh-CN", description="视频语言")
video_clip_json: Optional[str] = "" # 视频剪辑脚本
video_origin_path: Optional[str] = "" # 原视频路径
video_aspect: Optional[VideoAspect] = VideoAspect.portrait.value # 视频比例
video_clip_duration: Optional[int] = 5 # 视频片段时长
video_count: Optional[int] = 1 # 视频片段数量
video_source: Optional[str] = "local"
video_language: Optional[str] = "" # 自动检测
# video_clip_duration: Optional[int] = 5 # 视频片段时长
# video_count: Optional[int] = 1 # 视频片段数量
# video_source: Optional[str] = "local"
# video_concat_mode: Optional[VideoConcatMode] = VideoConcatMode.random.value
# # 女性
# "zh-CN-XiaoxiaoNeural",
# "zh-CN-XiaoyiNeural",
# # 男性
# "zh-CN-YunjianNeural" 男声
# "zh-CN-YunyangNeural",
# "zh-CN-YunxiNeural",
voice_name: Optional[str] = "zh-CN-YunjianNeural" # 语音名称 指定选择:
voice_volume: Optional[float] = 1.0 # 语音音量
voice_rate: Optional[float] = 1.0 # 语速
voice_name: Optional[str] = Field(default="zh-CN-YunjianNeural", description="语音名称")
voice_volume: Optional[float] = Field(default=1.0, description="语音音量")
voice_rate: Optional[float] = Field(default=1.0, description="语速")
bgm_name: Optional[str] = "random" # 背景音乐名称
bgm_type: Optional[str] = "random" # 背景音乐类型
bgm_file: Optional[str] = "" # 背景音乐文件
bgm_volume: Optional[float] = 0.2
bgm_name: Optional[str] = Field(default="random", description="背景音乐名称")
bgm_type: Optional[str] = Field(default="random", description="背景音乐类型")
bgm_file: Optional[str] = Field(default="", description="背景音乐文件")
bgm_volume: Optional[float] = Field(default=0.2, description="背景音乐音量")
subtitle_enabled: Optional[bool] = True # 是否启用字幕
subtitle_position: Optional[str] = "bottom" # top, bottom, center
font_name: Optional[str] = "STHeitiMedium.ttc" # 字体名称
text_fore_color: Optional[str] = "#FFFFFF" # 文字前景色
text_background_color: Optional[str] = "transparent" # 文字背景色
subtitle_enabled: Optional[bool] = Field(default=True, description="是否启用字幕")
subtitle_position: Optional[str] = Field(default="bottom", description="字幕位置") # top, bottom, center
font_name: Optional[str] = Field(default="STHeitiMedium.ttc", description="字体名称")
text_fore_color: Optional[str] = Field(default="#FFFFFF", description="文字前景色")
text_background_color: Optional[str] = Field(default="transparent", description="文字背景色")
font_size: int = 60 # 文字大小
stroke_color: Optional[str] = "#000000" # 文字描边颜色
stroke_width: float = 1.5 # 文字描边宽度
custom_position: float = 70.0 # 自定义位置
n_threads: Optional[int] = 2 # 线程数
paragraph_number: Optional[int] = 1 # 段落数量
font_size: int = Field(default=60, description="文字大小")
stroke_color: Optional[str] = Field(default="#000000", description="文字描边颜色")
stroke_width: float = Field(default=1.5, description="文字描边宽度")
custom_position: float = Field(default=70.0, description="自定义位置")
# n_threads: Optional[int] = 2 # 线程数
# paragraph_number: Optional[int] = 1 # 段落数量

View File

@ -352,7 +352,8 @@ def compress_video(input_path: str, output_path: str):
input_path: 输入视频文件路径
output_path: 输出压缩后的视频文件路径
"""
ffmpeg_path = "E:\\projects\\NarratoAI_v0.1.2\\lib\\ffmpeg\\ffmpeg-7.0-essentials_build\\ffmpeg.exe" # 指定 ffmpeg 的完整路径
# 指定 ffmpeg 的完整路径
ffmpeg_path = os.getenv("FFMPEG_PATH") or config.app.get("ffmpeg_path")
# 如果压缩后的视频文件已经存在,则直接使用
if os.path.exists(output_path):

View File

@ -326,17 +326,20 @@ def start(task_id, params: VideoParams, stop_at: str = "video"):
def start_subclip(task_id, params: VideoClipParams, subclip_path_videos):
"""
后台任务自动剪辑视频进行剪辑
task_id: 任务ID
params: 剪辑参数
subclip_path_videos: 视频文件路径
"""
logger.info(f"\n\n## 开始任务: {task_id}")
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5)
# tts 角色名称
voice_name = voice.parse_voice_name(params.voice_name)
paragraph_number = params.paragraph_number
n_threads = params.n_threads
max_clip_duration = params.video_clip_duration
logger.info("\n\n## 1. 读取视频json脚本")
video_script_path = path.join(params.video_clip_json)
video_script_path = path.join(params.video_clip_json_path)
# 判断json文件是否存在
if path.exists(video_script_path):
try:
@ -430,7 +433,7 @@ def start_subclip(task_id, params: VideoClipParams, subclip_path_videos):
video_ost_list=video_ost,
list_script=list_script,
video_aspect=params.video_aspect,
threads=n_threads
threads=1 # 暂时只支持单线程
)
_progress += 50 / params.video_count / 2

198
app/utils/check_script.py Normal file
View File

@ -0,0 +1,198 @@
import json
from loguru import logger
import os
from datetime import datetime, timedelta
import re
def time_to_seconds(time_str):
time_obj = datetime.strptime(time_str, "%M:%S")
return timedelta(minutes=time_obj.minute, seconds=time_obj.second).total_seconds()
def seconds_to_time_str(seconds):
minutes, seconds = divmod(int(seconds), 60)
return f"{minutes:02d}:{seconds:02d}"
def check_script(file_path, total_duration):
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
errors = []
ost_narrations = set()
last_end_time = 0
logger.info(f"开始检查文件: {file_path}")
logger.info(f"视频总时长: {total_duration:.2f}")
logger.info("=" * 50)
for i, item in enumerate(data, 1):
logger.info(f"\n检查第 {i} 项:")
# 检查所有必需字段是否存在
required_fields = ['picture', 'timestamp', 'narration', 'OST', 'new_timestamp']
for field in required_fields:
if field not in item:
errors.append(f"{i} 项缺少 {field} 字段")
logger.info(f" - 错误: 缺少 {field} 字段")
else:
logger.info(f" - {field}: {item[field]}")
# 检查 OST 为 false 的情况
if item.get('OST') == False:
if not item.get('narration'):
errors.append(f"{i} 项 OST 为 false但 narration 为空")
logger.info(" - 错误: OST 为 false但 narration 为空")
elif len(item['narration']) > 30:
errors.append(f"{i} 项 OST 为 false但 narration 超过 30 字")
logger.info(f" - 错误: OST 为 false但 narration 超过 30 字 (当前: {len(item['narration'])} 字)")
else:
logger.info(" - OST 为 falsenarration 检查通过")
# 检查 OST 为 true 的情况
if item.get('OST') == True:
if not item.get('narration').startswith('原声播放_'):
errors.append(f"{i} 项 OST 为 true但 narration 不是 '原声播放_xxx' 格式")
logger.info(" - 错误: OST 为 true但 narration 不是 '原声播放_xxx' 格式")
elif item['narration'] in ost_narrations:
errors.append(f"{i} 项 OST 为 true但 narration '{item['narration']}' 不是唯一值")
logger.info(f" - 错误: OST 为 true但 narration '{item['narration']}' 不是唯一值")
else:
logger.info(" - OST 为 truenarration 检查通过")
ost_narrations.add(item['narration'])
# 检查 timestamp 是否重叠
if 'timestamp' in item:
start, end = map(time_to_seconds, item['timestamp'].split('-'))
if start < last_end_time:
errors.append(f"{i} 项 timestamp '{item['timestamp']}' 与前一项重叠")
logger.info(f" - 错误: timestamp '{item['timestamp']}' 与前一项重叠")
else:
logger.info(f" - timestamp '{item['timestamp']}' 检查通过")
last_end_time = end
# 检查 timestamp 是否超过总时长
if end > total_duration:
errors.append(f"{i} 项 timestamp '{item['timestamp']}' 超过总时长 {total_duration:.2f}")
logger.info(f" - 错误: timestamp '{item['timestamp']}' 超过总时长 {total_duration:.2f}")
else:
logger.info(f" - timestamp 在总时长范围内")
# 检查 new_timestamp 是否连续
logger.info("\n检查 new_timestamp 连续性:")
last_end_time = 0
for i, item in enumerate(data, 1):
if 'new_timestamp' in item:
start, end = map(time_to_seconds, item['new_timestamp'].split('-'))
if start != last_end_time:
errors.append(f"{i} 项 new_timestamp '{item['new_timestamp']}' 与前一项不连续")
logger.info(f" - 错误: 第 {i} 项 new_timestamp '{item['new_timestamp']}' 与前一项不连续")
else:
logger.info(f" - 第 {i} 项 new_timestamp '{item['new_timestamp']}' 连续性检查通过")
last_end_time = end
if errors:
logger.info("检查结果:不通过")
logger.info("发现以下错误:")
for error in errors:
logger.info(f"- {error}")
fix_script(file_path, data, errors)
else:
logger.info("检查结果:通过")
logger.info("所有项目均符合规则要求。")
def fix_script(file_path, data, errors):
logger.info("\n开始修复脚本...")
fixed_data = []
for i, item in enumerate(data, 1):
if item['OST'] == False and (not item['narration'] or len(item['narration']) > 30):
if not item['narration']:
logger.info(f"{i} 项 narration 为空,需要人工参与修复。")
fixed_data.append(item)
else:
logger.info(f"修复第 {i} 项 narration 超过 30 字的问题...")
fixed_items = split_narration(item)
fixed_data.extend(fixed_items)
else:
fixed_data.append(item)
for error in errors:
if not error.startswith("") or "OST 为 false" not in error:
logger.info(f"需要人工参与修复: {error}")
# 生成新的文件名
file_name, file_ext = os.path.splitext(file_path)
new_file_path = f"{file_name}_revise{file_ext}"
# 保存修复后的数据到新文件
with open(new_file_path, 'w', encoding='utf-8') as f:
json.dump(fixed_data, f, ensure_ascii=False, indent=4)
logger.info(f"\n脚本修复完成,已保存到新文件: {new_file_path}")
def split_narration(item):
narration = item['narration']
chunks = smart_split(narration)
start_time, end_time = map(time_to_seconds, item['timestamp'].split('-'))
new_start_time, new_end_time = map(time_to_seconds, item['new_timestamp'].split('-'))
total_duration = end_time - start_time
new_total_duration = new_end_time - new_start_time
chunk_duration = total_duration / len(chunks)
new_chunk_duration = new_total_duration / len(chunks)
fixed_items = []
for i, chunk in enumerate(chunks):
new_item = item.copy()
new_item['narration'] = chunk
chunk_start = start_time + i * chunk_duration
chunk_end = chunk_start + chunk_duration
new_item['timestamp'] = f"{seconds_to_time_str(chunk_start)}-{seconds_to_time_str(chunk_end)}"
new_chunk_start = new_start_time + i * new_chunk_duration
new_chunk_end = new_chunk_start + new_chunk_duration
new_item['new_timestamp'] = f"{seconds_to_time_str(new_chunk_start)}-{seconds_to_time_str(new_chunk_end)}"
fixed_items.append(new_item)
return fixed_items
def smart_split(text, target_length=30):
# 使用正则表达式分割文本,保留标点符号
segments = re.findall(r'[^,。!?,!?]+[,。!?,!?]?', text)
result = []
current_chunk = ""
for segment in segments:
if len(current_chunk) + len(segment) <= target_length:
current_chunk += segment
else:
if current_chunk:
result.append(current_chunk.strip())
current_chunk = segment
if current_chunk:
result.append(current_chunk.strip())
# 如果有任何chunk超过了目标长度进行进一步的分割
final_result = []
for chunk in result:
if len(chunk) > target_length:
sub_chunks = [chunk[i:i + target_length] for i in range(0, len(chunk), target_length)]
final_result.extend(sub_chunks)
else:
final_result.append(chunk)
return final_result
if __name__ == "__main__":
file_path = "/Users/apple/Desktop/home/NarratoAI/resource/scripts/2024-0923-085036.json"
total_duration = 280
check_script(file_path, total_duration)

View File

@ -24,3 +24,4 @@ opencv-python~=4.9.0.80
# https://techcommunity.microsoft.com/t5/ai-azure-ai-services-blog/9-more-realistic-ai-voices-for-conversations-now-generally/ba-p/4099471
azure-cognitiveservices-speech~=1.37.0
git-changelog~=2.5.2
watchdog==5.0.2

113
webui.py
View File

@ -1,29 +1,5 @@
import sys
import os
import glob
import json
import time
import datetime
import traceback
import streamlit as st
from uuid import uuid4
import platform
import streamlit.components.v1 as components
from loguru import logger
from app.config import config
from app.models.const import FILE_TYPE_VIDEOS
from app.models.schema import VideoClipParams, VideoAspect, VideoConcatMode
from app.services import task as tm, llm, voice, material
from app.utils import utils
# # 将项目的根目录添加到系统路径中,以允许从项目导入模块
root_dir = os.path.dirname(os.path.realpath(__file__))
if root_dir not in sys.path:
sys.path.append(root_dir)
print("******** sys.path ********")
print(sys.path)
print("*" * 20)
st.set_page_config(
page_title="NarratoAI",
@ -37,6 +13,31 @@ st.set_page_config(
},
)
import sys
import os
import glob
import json
import time
import datetime
import traceback
from uuid import uuid4
import platform
import streamlit.components.v1 as components
from loguru import logger
from app.models.const import FILE_TYPE_VIDEOS
from app.models.schema import VideoClipParams, VideoAspect, VideoConcatMode
from app.services import task as tm, llm, voice, material
from app.utils import utils
# # 将项目的根目录添加到系统路径中,以允许从项目导入模块
root_dir = os.path.dirname(os.path.realpath(__file__))
if root_dir not in sys.path:
sys.path.append(root_dir)
print("******** sys.path ********")
print(sys.path)
print("*" * 20)
proxy_url_http = config.proxy.get("http", "") or os.getenv("VPN_PROXY_URL", "")
proxy_url_https = config.proxy.get("https", "") or os.getenv("VPN_PROXY_URL", "")
os.environ["HTTP_PROXY"] = proxy_url_http
@ -59,8 +60,6 @@ i18n_dir = os.path.join(root_dir, "webui", "i18n")
config_file = os.path.join(root_dir, "webui", ".streamlit", "webui.toml")
system_locale = utils.get_system_locale()
if 'video_subject' not in st.session_state:
st.session_state['video_subject'] = ''
if 'video_clip_json' not in st.session_state:
st.session_state['video_clip_json'] = ''
if 'video_plot' not in st.session_state:
@ -189,16 +188,7 @@ with st.expander(tr("Basic Settings"), expanded=False):
if HTTPS_PROXY:
config.proxy["https"] = HTTPS_PROXY
with middle_config_panel:
# openai
# moonshot (月之暗面)
# oneapi
# g4f
# azure
# qwen (通义千问)
# gemini
# ollama
llm_providers = ['Gemini']
saved_llm_provider = config.app.get("llm_provider", "OpenAI").lower()
saved_llm_provider_index = 0
@ -470,6 +460,7 @@ with left_panel:
else:
st.error(tr("请先生成视频脚本"))
# 裁剪视频
with button_columns[1]:
if st.button(tr("Crop Video"), key="auto_crop_video", use_container_width=True):
@ -479,50 +470,6 @@ with left_panel:
with middle_panel:
with st.container(border=True):
st.write(tr("Video Settings"))
# video_concat_modes = [
# (tr("Sequential"), "sequential"),
# (tr("Random"), "random"),
# ]
# video_sources = [
# (tr("Pexels"), "pexels"),
# (tr("Pixabay"), "pixabay"),
# (tr("Local file"), "local"),
# (tr("TikTok"), "douyin"),
# (tr("Bilibili"), "bilibili"),
# (tr("Xiaohongshu"), "xiaohongshu"),
# ]
#
# saved_video_source_name = config.app.get("video_source", "pexels")
# saved_video_source_index = [v[1] for v in video_sources].index(
# saved_video_source_name
# )
#
# selected_index = st.selectbox(
# tr("Video Source"),
# options=range(len(video_sources)),
# format_func=lambda x: video_sources[x][0],
# index=saved_video_source_index,
# )
# params.video_source = video_sources[selected_index][1]
# config.app["video_source"] = params.video_source
#
# if params.video_source == "local":
# _supported_types = FILE_TYPE_VIDEOS + FILE_TYPE_IMAGES
# uploaded_files = st.file_uploader(
# "Upload Local Files",
# type=["mp4", "mov", "avi", "flv", "mkv", "jpg", "jpeg", "png"],
# accept_multiple_files=True,
# )
# selected_index = st.selectbox(
# tr("Video Concat Mode"),
# index=1,
# options=range(len(video_concat_modes)), # 使用索引作为内部选项值
# format_func=lambda x: video_concat_modes[x][0], # 显示给用户的是标签
# )
# params.video_concat_mode = VideoConcatMode(
# video_concat_modes[selected_index][1]
# )
# 视频比例
video_aspect_ratios = [
@ -582,8 +529,9 @@ with middle_panel:
params.voice_name = voice_name
config.ui["voice_name"] = voice_name
# 试听语言合成
if st.button(tr("Play Voice")):
play_content = params.video_subject
play_content = "这是一段试听语言"
if not play_content:
play_content = params.video_script
if not play_content:
@ -779,6 +727,7 @@ with st.expander(tr("Video Check"), expanded=False):
caijian()
st.rerun()
# 开始按钮
start_button = st.button(tr("Generate Video"), use_container_width=True, type="primary")
if start_button:
config.save_config()
@ -800,10 +749,6 @@ if start_button:
st.error(tr("视频文件不能为空"))
scroll_to_bottom()
st.stop()
if llm_provider != 'g4f' and not config.app.get(f"{llm_provider}_api_key", ""):
st.error(tr("请输入 LLM API 密钥"))
scroll_to_bottom()
st.stop()
log_container = st.empty()
log_records = []