mirror of
https://github.com/linyqh/NarratoAI.git
synced 2025-12-10 09:52:49 +00:00
feat: 新增 IndexTTS2 零样本语音克隆引擎支持
添加 IndexTTS2 TTS 引擎配置和实现,支持零样本语音克隆功能。包括配置保存加载、API 调用、参考音频上传、高级参数设置(温度、top_p、top_k、束搜索、重复惩罚等),并在 WebUI 中提供完整的配置界面和使用说明。
This commit is contained in:
parent
d75c2e000f
commit
cda5760e37
@ -52,6 +52,7 @@ def save_config():
|
||||
_cfg["soulvoice"] = soulvoice
|
||||
_cfg["ui"] = ui
|
||||
_cfg["tts_qwen"] = tts_qwen
|
||||
_cfg["indextts2"] = indextts2
|
||||
f.write(toml.dumps(_cfg))
|
||||
|
||||
|
||||
@ -65,6 +66,7 @@ soulvoice = _cfg.get("soulvoice", {})
|
||||
ui = _cfg.get("ui", {})
|
||||
frames = _cfg.get("frames", {})
|
||||
tts_qwen = _cfg.get("tts_qwen", {})
|
||||
indextts2 = _cfg.get("indextts2", {})
|
||||
|
||||
hostname = socket.gethostname()
|
||||
|
||||
|
||||
@ -1107,6 +1107,10 @@ def tts(
|
||||
if tts_engine == "edge_tts":
|
||||
logger.info("分发到 Edge TTS")
|
||||
return azure_tts_v1(text, voice_name, voice_rate, voice_pitch, voice_file)
|
||||
|
||||
if tts_engine == "indextts2":
|
||||
logger.info("分发到 IndexTTS2")
|
||||
return indextts2_tts(text, voice_name, voice_file, speed=voice_rate)
|
||||
|
||||
# Fallback for unknown engine - default to azure v1
|
||||
logger.warning(f"未知的 TTS 引擎: '{tts_engine}', 将默认使用 Edge TTS (Azure V1)。")
|
||||
@ -1541,8 +1545,8 @@ def tts_multiple(task_id: str, list_script: list, voice_name: str, voice_rate: f
|
||||
f"或者使用其他 tts 引擎")
|
||||
continue
|
||||
else:
|
||||
# SoulVoice 引擎不生成字幕文件
|
||||
if is_soulvoice_voice(voice_name) or is_qwen_engine(tts_engine):
|
||||
# SoulVoice、Qwen3、IndexTTS2 引擎不生成字幕文件
|
||||
if is_soulvoice_voice(voice_name) or is_qwen_engine(tts_engine) or tts_engine == "indextts2":
|
||||
# 获取实际音频文件的时长
|
||||
duration = get_audio_duration_from_file(audio_file)
|
||||
if duration <= 0:
|
||||
@ -1943,4 +1947,127 @@ def parse_soulvoice_voice(voice_name: str) -> str:
|
||||
return voice_name
|
||||
|
||||
|
||||
def parse_indextts2_voice(voice_name: str) -> str:
|
||||
"""
|
||||
解析 IndexTTS2 语音名称
|
||||
支持格式:indextts2:reference_audio_path
|
||||
返回参考音频文件路径
|
||||
"""
|
||||
if voice_name.startswith("indextts2:"):
|
||||
return voice_name[10:] # 移除 "indextts2:" 前缀
|
||||
return voice_name
|
||||
|
||||
|
||||
def indextts2_tts(text: str, voice_name: str, voice_file: str, speed: float = 1.0) -> Union[SubMaker, None]:
|
||||
"""
|
||||
使用 IndexTTS2 API 进行零样本语音克隆
|
||||
|
||||
Args:
|
||||
text: 要转换的文本
|
||||
voice_name: 参考音频路径(格式:indextts2:path/to/audio.wav)
|
||||
voice_file: 输出音频文件路径
|
||||
speed: 语音速度(此引擎暂不支持速度调节)
|
||||
|
||||
Returns:
|
||||
SubMaker: 包含时间戳信息的字幕制作器,失败时返回 None
|
||||
"""
|
||||
# 获取配置
|
||||
api_url = config.indextts2.get("api_url", "http://192.168.3.6:8081/tts")
|
||||
infer_mode = config.indextts2.get("infer_mode", "普通推理")
|
||||
temperature = config.indextts2.get("temperature", 1.0)
|
||||
top_p = config.indextts2.get("top_p", 0.8)
|
||||
top_k = config.indextts2.get("top_k", 30)
|
||||
do_sample = config.indextts2.get("do_sample", True)
|
||||
num_beams = config.indextts2.get("num_beams", 3)
|
||||
repetition_penalty = config.indextts2.get("repetition_penalty", 10.0)
|
||||
|
||||
# 解析参考音频路径
|
||||
reference_audio_path = parse_indextts2_voice(voice_name)
|
||||
|
||||
if not reference_audio_path or not os.path.exists(reference_audio_path):
|
||||
logger.error(f"IndexTTS2 参考音频文件不存在: {reference_audio_path}")
|
||||
return None
|
||||
|
||||
# 准备请求数据
|
||||
files = {
|
||||
'prompt_audio': open(reference_audio_path, 'rb')
|
||||
}
|
||||
|
||||
data = {
|
||||
'text': text.strip(),
|
||||
'infer_mode': infer_mode,
|
||||
'temperature': temperature,
|
||||
'top_p': top_p,
|
||||
'top_k': top_k,
|
||||
'do_sample': do_sample,
|
||||
'num_beams': num_beams,
|
||||
'repetition_penalty': repetition_penalty,
|
||||
}
|
||||
|
||||
# 重试机制
|
||||
for attempt in range(3):
|
||||
try:
|
||||
logger.info(f"第 {attempt + 1} 次调用 IndexTTS2 API")
|
||||
|
||||
# 设置代理
|
||||
proxies = {}
|
||||
if config.proxy.get("http"):
|
||||
proxies = {
|
||||
'http': config.proxy.get("http"),
|
||||
'https': config.proxy.get("https", config.proxy.get("http"))
|
||||
}
|
||||
|
||||
# 调用 API
|
||||
response = requests.post(
|
||||
api_url,
|
||||
files=files,
|
||||
data=data,
|
||||
proxies=proxies,
|
||||
timeout=120 # IndexTTS2 推理可能需要较长时间
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
# 保存音频文件
|
||||
with open(voice_file, 'wb') as f:
|
||||
f.write(response.content)
|
||||
|
||||
logger.info(f"IndexTTS2 成功生成音频: {voice_file}, 大小: {len(response.content)} 字节")
|
||||
|
||||
# IndexTTS2 不支持精确字幕生成,返回简单的 SubMaker 对象
|
||||
sub_maker = SubMaker()
|
||||
# 估算音频时长(基于文本长度)
|
||||
estimated_duration_ms = max(1000, int(len(text) * 200))
|
||||
sub_maker.create_sub((0, estimated_duration_ms * 10000), text)
|
||||
|
||||
return sub_maker
|
||||
|
||||
else:
|
||||
logger.error(f"IndexTTS2 API 调用失败: {response.status_code} - {response.text}")
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
logger.error(f"IndexTTS2 API 调用超时 (尝试 {attempt + 1}/3)")
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"IndexTTS2 API 网络错误: {str(e)} (尝试 {attempt + 1}/3)")
|
||||
except Exception as e:
|
||||
logger.error(f"IndexTTS2 TTS 处理错误: {str(e)} (尝试 {attempt + 1}/3)")
|
||||
finally:
|
||||
# 确保关闭文件
|
||||
try:
|
||||
files['prompt_audio'].close()
|
||||
except:
|
||||
pass
|
||||
|
||||
if attempt < 2: # 不是最后一次尝试
|
||||
time.sleep(2) # 等待2秒后重试
|
||||
# 重新打开文件用于下次重试
|
||||
if attempt < 2:
|
||||
try:
|
||||
files['prompt_audio'] = open(reference_audio_path, 'rb')
|
||||
except:
|
||||
pass
|
||||
|
||||
logger.error("IndexTTS2 TTS 生成失败,已达到最大重试次数")
|
||||
return None
|
||||
|
||||
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from tkinter import N
|
||||
from venv import logger
|
||||
import streamlit as st
|
||||
import os
|
||||
@ -26,7 +27,8 @@ def get_tts_engine_options():
|
||||
"edge_tts": "Edge TTS",
|
||||
"azure_speech": "Azure Speech Services",
|
||||
"tencent_tts": "腾讯云 TTS",
|
||||
"qwen3_tts": "通义千问 Qwen3 TTS"
|
||||
"qwen3_tts": "通义千问 Qwen3 TTS",
|
||||
"indextts2": "IndexTTS2 语音克隆"
|
||||
}
|
||||
|
||||
|
||||
@ -56,6 +58,12 @@ def get_tts_engine_descriptions():
|
||||
"features": "阿里云通义千问语音合成,音质优秀,支持多种音色",
|
||||
"use_case": "需要高质量中文语音合成的用户",
|
||||
"registration": "https://dashscope.aliyuncs.com/"
|
||||
},
|
||||
"indextts2": {
|
||||
"title": "IndexTTS2 语音克隆",
|
||||
"features": "零样本语音克隆,上传参考音频即可合成相同音色的语音,需要本地或私有部署",
|
||||
"use_case": "下载地址:https://pan.quark.cn/s/0767c9bcefd5",
|
||||
"registration": None
|
||||
}
|
||||
}
|
||||
|
||||
@ -139,6 +147,8 @@ def render_tts_settings(tr):
|
||||
render_tencent_tts_settings(tr)
|
||||
elif selected_engine == "qwen3_tts":
|
||||
render_qwen3_tts_settings(tr)
|
||||
elif selected_engine == "indextts2":
|
||||
render_indextts2_tts_settings(tr)
|
||||
|
||||
# 4. 试听功能
|
||||
render_voice_preview_new(tr, selected_engine)
|
||||
@ -562,6 +572,139 @@ def render_qwen3_tts_settings(tr):
|
||||
config.ui["qwen3_rate"] = voice_rate
|
||||
config.ui["voice_name"] = voice_type #兼容性
|
||||
|
||||
|
||||
def render_indextts2_tts_settings(tr):
|
||||
"""渲染 IndexTTS2 TTS 设置"""
|
||||
import os
|
||||
|
||||
# API 地址配置
|
||||
api_url = st.text_input(
|
||||
"API 地址",
|
||||
value=config.indextts2.get("api_url", "http://127.0.0.1:8081/tts"),
|
||||
help="IndexTTS2 API 服务地址"
|
||||
)
|
||||
|
||||
# 参考音频文件路径
|
||||
reference_audio = st.text_input(
|
||||
"参考音频路径",
|
||||
value=config.indextts2.get("reference_audio", ""),
|
||||
help="用于语音克隆的参考音频文件路径(WAV 格式,建议 3-10 秒)"
|
||||
)
|
||||
|
||||
# 文件上传功能
|
||||
uploaded_file = st.file_uploader(
|
||||
"或上传参考音频文件",
|
||||
type=["wav", "mp3"],
|
||||
help="上传一段清晰的音频用于语音克隆"
|
||||
)
|
||||
|
||||
if uploaded_file is not None:
|
||||
# 保存上传的文件
|
||||
import tempfile
|
||||
temp_dir = tempfile.gettempdir()
|
||||
audio_path = os.path.join(temp_dir, f"indextts2_ref_{uploaded_file.name}")
|
||||
with open(audio_path, "wb") as f:
|
||||
f.write(uploaded_file.getbuffer())
|
||||
reference_audio = audio_path
|
||||
st.success(f"✅ 音频已上传: {audio_path}")
|
||||
|
||||
# 推理模式
|
||||
infer_mode = st.selectbox(
|
||||
"推理模式",
|
||||
options=["普通推理", "快速推理"],
|
||||
index=0 if config.indextts2.get("infer_mode", "普通推理") == "普通推理" else 1,
|
||||
help="普通推理质量更高但速度较慢,快速推理速度更快但质量略低"
|
||||
)
|
||||
|
||||
# 高级参数折叠面板
|
||||
with st.expander("🔧 高级参数", expanded=False):
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
temperature = st.slider(
|
||||
"采样温度 (Temperature)",
|
||||
min_value=0.1,
|
||||
max_value=2.0,
|
||||
value=float(config.indextts2.get("temperature", 1.0)),
|
||||
step=0.1,
|
||||
help="控制随机性,值越高输出越随机,值越低越确定"
|
||||
)
|
||||
|
||||
top_p = st.slider(
|
||||
"Top P",
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
value=float(config.indextts2.get("top_p", 0.8)),
|
||||
step=0.05,
|
||||
help="nucleus 采样的概率阈值,值越小结果越确定"
|
||||
)
|
||||
|
||||
top_k = st.slider(
|
||||
"Top K",
|
||||
min_value=0,
|
||||
max_value=100,
|
||||
value=int(config.indextts2.get("top_k", 30)),
|
||||
step=5,
|
||||
help="top-k 采样的 k 值,0 表示不使用 top-k"
|
||||
)
|
||||
|
||||
with col2:
|
||||
num_beams = st.slider(
|
||||
"束搜索 (Num Beams)",
|
||||
min_value=1,
|
||||
max_value=10,
|
||||
value=int(config.indextts2.get("num_beams", 3)),
|
||||
step=1,
|
||||
help="束搜索的 beam 数量,值越大质量可能越好但速度越慢"
|
||||
)
|
||||
|
||||
repetition_penalty = st.slider(
|
||||
"重复惩罚 (Repetition Penalty)",
|
||||
min_value=1.0,
|
||||
max_value=20.0,
|
||||
value=float(config.indextts2.get("repetition_penalty", 10.0)),
|
||||
step=0.5,
|
||||
help="值越大越能避免重复,但过大可能导致不自然"
|
||||
)
|
||||
|
||||
do_sample = st.checkbox(
|
||||
"启用采样",
|
||||
value=config.indextts2.get("do_sample", True),
|
||||
help="启用采样可以获得更自然的语音"
|
||||
)
|
||||
|
||||
# 显示使用说明
|
||||
with st.expander("💡 IndexTTS2 使用说明", expanded=False):
|
||||
st.markdown("""
|
||||
**零样本语音克隆**
|
||||
|
||||
1. **准备参考音频**:上传或指定一段清晰的音频文件(建议 3-10 秒)
|
||||
2. **设置 API 地址**:确保 IndexTTS2 服务正常运行
|
||||
3. **开始合成**:系统会自动使用参考音频的音色合成新语音
|
||||
|
||||
**注意事项**:
|
||||
- 参考音频质量直接影响合成效果
|
||||
- 建议使用无背景噪音的清晰音频
|
||||
- 文本长度建议控制在合理范围内
|
||||
- 首次合成可能需要较长时间
|
||||
""")
|
||||
|
||||
# 保存配置
|
||||
config.indextts2["api_url"] = api_url
|
||||
config.indextts2["reference_audio"] = reference_audio
|
||||
config.indextts2["infer_mode"] = infer_mode
|
||||
config.indextts2["temperature"] = temperature
|
||||
config.indextts2["top_p"] = top_p
|
||||
config.indextts2["top_k"] = top_k
|
||||
config.indextts2["num_beams"] = num_beams
|
||||
config.indextts2["repetition_penalty"] = repetition_penalty
|
||||
config.indextts2["do_sample"] = do_sample
|
||||
|
||||
# 保存 voice_name 用于兼容性
|
||||
if reference_audio:
|
||||
config.ui["voice_name"] = f"indextts2:{reference_audio}"
|
||||
|
||||
|
||||
def render_voice_preview_new(tr, selected_engine):
|
||||
"""渲染新的语音试听功能"""
|
||||
if st.button("🎵 试听语音合成", use_container_width=True):
|
||||
@ -599,6 +742,12 @@ def render_voice_preview_new(tr, selected_engine):
|
||||
voice_name = f"qwen3:{vt}"
|
||||
voice_rate = config.ui.get("qwen3_rate", 1.0)
|
||||
voice_pitch = 1.0 # Qwen3 TTS 不支持音调调节
|
||||
elif selected_engine == "indextts2":
|
||||
reference_audio = config.indextts2.get("reference_audio", "")
|
||||
if reference_audio:
|
||||
voice_name = f"indextts2:{reference_audio}"
|
||||
voice_rate = 1.0 # IndexTTS2 不支持速度调节
|
||||
voice_pitch = 1.0 # IndexTTS2 不支持音调调节
|
||||
|
||||
if not voice_name:
|
||||
st.error("请先配置语音设置")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user