Compare commits

...

7 Commits
v0.7.8 ... main

Author SHA1 Message Date
viccy
c0b72ec603 chore: 更新项目版本至0.7.9并优化README内容
- 将项目版本从0.7.8更新至0.7.9
- 优化README.md的排版和结构,提升可读性
- 更新功能列表和最新资讯,新增对0.7.9版本的说明
- 移除过时的推广内容,更新赞助商标识
2026-04-27 18:51:49 +08:00
viccy
99dd4193ae feat(字幕): 新增阿里百炼 Fun-ASR 音视频字幕转录功能
- 在 WebUI 中增加 Fun-ASR 转录界面,支持上传多种音视频格式并生成 SRT 字幕
- 新增 `app/services/fun_asr_subtitle.py` 服务模块,实现完整的 REST API 调用流程,包括获取上传凭证、文件上传、提交任务、轮询结果和 SRT 格式转换
- 在配置文件中增加 `[fun_asr]` 配置段,支持保存 API Key
- 添加完整的单元测试,覆盖核心转换逻辑和服务流程
- 为兼容 Python 3.11 以下版本,将 `tomllib` 导入改为尝试导入并回退到 `tomli`
- 在 `defaults.py` 中添加 `from __future__ import annotations` 以支持类型注解
2026-04-27 18:15:54 +08:00
viccy
8c129790c7
Merge pull request #237 from aw123456dew/feature/doubao-tts
add doubao tts
2026-04-08 15:14:10 +08:00
viccy
de33c6d0bd
Merge pull request #238 from aw123456dew/feature/export-jianying-draft
add export jianying draft feature
2026-04-08 15:13:02 +08:00
aw123456dew
852f5ae34c fix: jianying draft export failure due to floating-point precision in audio duration 2026-04-07 17:13:43 +08:00
aw123456dew
d45c1858c9 add export jianying draft feature 2026-04-07 11:33:12 +08:00
aw123456dew
71dfc99839 add doubao tts 2026-04-07 09:10:50 +08:00
16 changed files with 1811 additions and 36 deletions

View File

@ -1,38 +1,47 @@
<div align="center">
<h1 align="center" style="font-size: 2cm;"> NarratoAI 😎📽️ </h1>
<h1 align="center"> NarratoAI 😎📽️ </h1>
<h3 align="center">一站式 AI 影视解说+自动化剪辑工具🎬🎞️ </h3>
<p align="center">
📖 <a href="README-en.md">English</a> | 简体中文 | <a href="https://www.narratoai.cn">☁️ <b>云端版入口 (NarratoAI.cn)</b></a>
</p>
<h3>📖 <a href="README-en.md">English</a> | 简体中文 </h3>
<div align="center">
<br>
> **🔥 隆重推荐VibeCut 的新范式 —— [speclip.com](https://speclip.com)**
>
> **一个真正意义上的视频剪辑 Agent像聊天(vibecoding)一样剪辑视频。**
> **[👉 点击立即免费下载 Speclip](https://speclip.com)**
[//]: # ( <a href="https://trendshift.io/repositories/8731" target="_blank"><img src="https://trendshift.io/api/badge/repositories/8731" alt="harry0703%2FNarratoAI | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>)
</div>
<br>
NarratoAI 是一个自动化影视解说工具基于LLM实现文案撰写、自动化视频剪辑、配音和字幕生成的一站式流程助力高效内容创作。
<br>
> **🔥 隆重推荐VibeCut 的新范式 —— [Speclip](https://speclip.com) ,一个真正意义上的剪辑 Agent[👉 点击免费下载](https://speclip.com)**
NarratoAI 是一款自动化影视解说工具,基于 LLM 实现文案撰写、自动化视频剪辑、配音和字幕生成的一站式流程,助力高效内容创作。支持本地部署开源版及 [云端托管版](https://www.narratoai.cn)。
[![madewithlove](https://img.shields.io/badge/made_with-%E2%9D%A4-red?style=for-the-badge&labelColor=orange)](https://github.com/linyqh/NarratoAI)
[![GitHub license](https://img.shields.io/github/license/linyqh/NarratoAI?style=for-the-badge)](https://github.com/linyqh/NarratoAI/blob/main/LICENSE)
[![GitHub issues](https://img.shields.io/github/issues/linyqh/NarratoAI?style=for-the-badge)](https://github.com/linyqh/NarratoAI/issues)
[![GitHub stars](https://img.shields.io/github/stars/linyqh/NarratoAI?style=for-the-badge)](https://github.com/linyqh/NarratoAI/stargazers)
<br>
<a href="https://discord.com/invite/V2pbAqqQNb" target="_blank">💬 加入 discord 开源社区,获取项目动态和最新资讯。</a>
[![GitHub stars](https://img.shields.io/github/stars/linyqh/NarratoAI?style=for-the-badge)](https://github.com/linyqh/NarratoAI/stargazers) [![GitHub issues](https://img.shields.io/github/issues/linyqh/NarratoAI?style=for-the-badge)](https://github.com/linyqh/NarratoAI/issues) [![madewithlove](https://img.shields.io/badge/made_with-%E2%9D%A4-red?style=for-the-badge&labelColor=orange)](https://github.com/linyqh/NarratoAI)
<br>
<a href="https://github.com/linyqh/NarratoAI/wiki" target="_blank">💬 加入开源社群,获取项目动态和最新资讯</a>
<br>
<h2><a href="https://p9mf6rjv3c.feishu.cn/wiki/SP8swLLZki5WRWkhuFvc2CyInDg?from=from_copylink" target="_blank">🎉🎉🎉 官方文档 🎉🎉🎉</a> </h2>
<h3>首页</h3>
### 界面预览
![](docs/index-zh.png)
</div>
## 许可证
本项目仅供学习和研究使用,不得商用。如需商业授权,请联系作者。
## 最新资讯
- 2026.04.27 发布新版本 0.7.9,新增 **Fun-ASR一键转录字幕**
- 2026.04.03 发布新版本 0.7.8,重构纪录片逐帧分析链路,统一共享服务并优化抽帧、缓存、视觉并发与文案生成流程
- 2026.03.27 发布新版本 0.7.7,出于安全考虑,已移除 LiteLLM 依赖,统一使用 OpenAI 兼容请求链路
- 2025.11.20 发布新版本 0.7.5,新增 [IndexTTS2](https://github.com/index-tts/index-tts) 语音克隆支持
@ -48,17 +57,7 @@ NarratoAI 是一个自动化影视解说工具基于LLM实现文案撰写、
- 2024.11.10 发布新版本 v0.3.5;优化视频剪辑流程,
## 重磅福利 🎉
> 1
> **开发者专属福利一站式AI平台注册即送体验金**
>
> 还在为接入各种AI模型烦恼吗向您推荐 302.AI一个企业级的AI资源中心。一次接入即可调用上百种AI模型涵盖语言、图像、音视频等按量付费极大降低开发成本。
>
> 通过下方我的专属链接注册,**立获1美元免费体验金**助您轻松开启AI开发之旅。
>
> **立即注册领取:** [https://share.302.ai/I9P6mP](https://share.302.ai/I9P6mP)
---
> 2
> 即日起全面支持硅基流动注册即享2000万免费Token价值16元平台配额剪辑10分钟视频仅需0.1元!
>
> 🔥 快速领福利:
@ -97,7 +96,7 @@ _**1. NarratoAI 是一款完全免费的软件,近期在社交媒体(抖音,B
- [x] 一键合并素材
- [x] 一键转录
- [x] 一键清理缓存
- [ ] 支持导出剪映草稿
- [x] 支持导出剪映草稿
- [X] 支持短剧解说
- [ ] 主角人脸匹配
- [ ] 支持根据口播,文案,视频素材自动匹配
@ -169,7 +168,9 @@ streamlit run webui.py --server.maxUploadSize=2048
</div>
## 赞助
[![Powered by DartNode](https://dartnode.com/branding/DN-Open-Source-sm.png)](https://dartnode.com "Powered by DartNode - Free VPS for Open Source")
<a href="https://dartnode.com">
<img src="https://dartnode.com/_branding/white_color_full.png" alt="Powered by DartNode" style="background-color: #333; padding: 10px; border-radius: 4px;">
</a>
## 许可证 📝

View File

@ -81,7 +81,9 @@ def save_config():
_cfg["soulvoice"] = soulvoice
_cfg["ui"] = ui
_cfg["tts_qwen"] = tts_qwen
_cfg["fun_asr"] = fun_asr
_cfg["indextts2"] = indextts2
_cfg["doubaotts"] = doubaotts
f.write(toml.dumps(_cfg))
@ -95,7 +97,9 @@ soulvoice = _cfg.get("soulvoice", {})
ui = _cfg.get("ui", {})
frames = _cfg.get("frames", {})
tts_qwen = _cfg.get("tts_qwen", {})
fun_asr = _cfg.get("fun_asr", {})
indextts2 = _cfg.get("indextts2", {})
doubaotts = _cfg.get("doubaotts", {})
hostname = socket.gethostname()

View File

@ -1,5 +1,7 @@
"""Shared config defaults used by both bootstrap and WebUI fallbacks."""
from __future__ import annotations
DEFAULT_OPENAI_COMPATIBLE_BASE_URL = "https://api.siliconflow.cn/v1"
DEFAULT_OPENAI_COMPATIBLE_PROVIDER = "openai"

View File

@ -2,7 +2,10 @@ import tempfile
import unittest
from pathlib import Path
import tomllib
try:
import tomllib
except ModuleNotFoundError: # Python < 3.11
import tomli as tomllib
from app.config import config as cfg
from app.config.defaults import (

View File

@ -196,6 +196,7 @@ class VideoClipParams(BaseModel):
tts_volume: Optional[float] = Field(default=AudioVolumeDefaults.TTS_VOLUME, description="解说语音音量(后处理)")
original_volume: Optional[float] = Field(default=AudioVolumeDefaults.ORIGINAL_VOLUME, description="视频原声音量")
bgm_volume: Optional[float] = Field(default=AudioVolumeDefaults.BGM_VOLUME, description="背景音乐音量")
draft_name: Optional[str] = Field(default="", description="剪映草稿名称")

View File

@ -0,0 +1,452 @@
"""Aliyun Bailian Fun-ASR subtitle transcription helpers.
This module intentionally uses the REST API because the official Fun-ASR
recorded-file API supports temporary `oss://` resources only through REST.
"""
from __future__ import annotations
import os
import time
from dataclasses import dataclass
from typing import Any, Optional
import requests
from loguru import logger
from app.utils import utils
DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com"
UPLOAD_POLICY_URL = f"{DASHSCOPE_BASE_URL}/api/v1/uploads"
TRANSCRIPTION_URL = f"{DASHSCOPE_BASE_URL}/api/v1/services/audio/asr/transcription"
TASK_URL_TEMPLATE = f"{DASHSCOPE_BASE_URL}/api/v1/tasks/{{task_id}}"
MODEL_NAME = "fun-asr"
TERMINAL_FAILED_STATUSES = {"FAILED", "CANCELED", "UNKNOWN"}
PUNCTUATION_BREAKS = set(",。!?;,.!?;")
class FunAsrError(RuntimeError):
"""Raised for user-actionable Fun-ASR transcription failures."""
@dataclass
class UploadPolicy:
upload_host: str
upload_dir: str
policy: str
signature: str
oss_access_key_id: str
x_oss_object_acl: str = "private"
x_oss_forbid_overwrite: str = "true"
max_file_size_mb: Optional[float] = None
def _auth_headers(api_key: str, extra: Optional[dict[str, str]] = None) -> dict[str, str]:
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
if extra:
headers.update(extra)
return headers
def _raise_for_http(response: requests.Response, action: str) -> None:
try:
response.raise_for_status()
except Exception as exc: # requests may be mocked with generic exceptions
raise FunAsrError(f"{action}失败,请检查阿里百炼 API Key、网络或服务状态") from exc
def _json(response: requests.Response, action: str) -> dict[str, Any]:
_raise_for_http(response, action)
try:
data = response.json()
except Exception as exc:
raise FunAsrError(f"{action}返回了无效 JSON") from exc
if not isinstance(data, dict):
raise FunAsrError(f"{action}返回格式无效")
return data
def _require_api_key(api_key: str) -> str:
api_key = (api_key or "").strip()
if not api_key:
raise FunAsrError("请先输入阿里百炼 API Key")
return api_key
def _safe_upload_name(local_file: str) -> str:
name = os.path.basename(local_file).strip() or f"audio_{int(time.time())}.wav"
return name.replace("/", "_").replace("\\", "_")
def _session_get(session, url: str, **kwargs):
return session.get(url, **kwargs)
def _session_post(session, url: str, **kwargs):
return session.post(url, **kwargs)
def request_upload_policy(api_key: str, model: str = MODEL_NAME, session=requests) -> UploadPolicy:
"""Request Bailian temporary-storage upload policy for the target model."""
api_key = _require_api_key(api_key)
response = _session_get(
session,
UPLOAD_POLICY_URL,
params={"action": "getPolicy", "model": model},
headers=_auth_headers(api_key),
timeout=30,
)
data = _json(response, "获取临时存储上传凭证")
policy_data = data.get("data") or {}
required = ["upload_host", "upload_dir", "policy", "signature", "oss_access_key_id"]
missing = [field for field in required if not policy_data.get(field)]
if missing:
raise FunAsrError(f"临时存储上传凭证缺少字段: {', '.join(missing)}")
return UploadPolicy(
upload_host=str(policy_data["upload_host"]),
upload_dir=str(policy_data["upload_dir"]).rstrip("/"),
policy=str(policy_data["policy"]),
signature=str(policy_data["signature"]),
oss_access_key_id=str(policy_data["oss_access_key_id"]),
x_oss_object_acl=str(policy_data.get("x_oss_object_acl") or "private"),
x_oss_forbid_overwrite=str(policy_data.get("x_oss_forbid_overwrite") or "true"),
max_file_size_mb=policy_data.get("max_file_size_mb"),
)
def _validate_file_size(local_file: str, policy: UploadPolicy) -> None:
if policy.max_file_size_mb is None:
return
max_bytes = float(policy.max_file_size_mb) * 1024 * 1024
size = os.path.getsize(local_file)
if size > max_bytes:
raise FunAsrError(
f"文件大小超过阿里百炼临时存储限制: {size / 1024 / 1024:.2f}MB > {float(policy.max_file_size_mb):.2f}MB"
)
def upload_to_temporary_oss(local_file: str, policy: UploadPolicy, session=requests) -> str:
"""Upload local file to temporary OSS and return `oss://...` URL."""
if not os.path.isfile(local_file):
raise FunAsrError(f"待转写文件不存在: {local_file}")
_validate_file_size(local_file, policy)
key = f"{policy.upload_dir}/{_safe_upload_name(local_file)}"
data = {
"OSSAccessKeyId": policy.oss_access_key_id,
"policy": policy.policy,
"Signature": policy.signature,
"key": key,
"x-oss-object-acl": policy.x_oss_object_acl,
"x-oss-forbid-overwrite": policy.x_oss_forbid_overwrite,
"success_action_status": "200",
}
with open(local_file, "rb") as file_obj:
files = {"file": (_safe_upload_name(local_file), file_obj)}
response = _session_post(session, policy.upload_host, data=data, files=files, timeout=120)
_raise_for_http(response, "上传文件到阿里百炼临时存储")
return f"oss://{key}"
def submit_transcription_task(
api_key: str,
oss_url: str,
speaker_count: Optional[int] = None,
model: str = MODEL_NAME,
session=requests,
) -> str:
"""Submit async Fun-ASR task and return task_id."""
api_key = _require_api_key(api_key)
parameters: dict[str, Any] = {"diarization_enabled": True}
if speaker_count:
parameters["speaker_count"] = int(speaker_count)
payload = {
"model": model,
"input": {"file_urls": [oss_url]},
"parameters": parameters,
}
response = _session_post(
session,
TRANSCRIPTION_URL,
headers=_auth_headers(
api_key,
{
"X-DashScope-Async": "enable",
"X-DashScope-OssResourceResolve": "enable",
},
),
json=payload,
timeout=30,
)
data = _json(response, "提交 Fun-ASR 转写任务")
task_id = ((data.get("output") or {}).get("task_id") or "").strip()
if not task_id:
raise FunAsrError("提交 Fun-ASR 转写任务失败:未返回 task_id")
return task_id
def poll_transcription_task(
api_key: str,
task_id: str,
poll_interval: float = 2.0,
timeout: float = 600.0,
session=requests,
) -> dict[str, Any]:
"""Poll task until terminal status and return successful result item."""
api_key = _require_api_key(api_key)
deadline = time.time() + timeout
last_status = "PENDING"
while time.time() < deadline:
response = _session_post(
session,
TASK_URL_TEMPLATE.format(task_id=task_id),
headers=_auth_headers(api_key),
timeout=30,
)
data = _json(response, "查询 Fun-ASR 转写任务")
output = data.get("output") or {}
last_status = str(output.get("task_status") or "").upper()
if last_status == "SUCCEEDED":
results = output.get("results") or []
for result in results:
subtask_status = str(result.get("subtask_status") or "").upper()
if subtask_status and subtask_status != "SUCCEEDED":
raise FunAsrError(f"Fun-ASR 子任务失败: {subtask_status}")
if not results:
raise FunAsrError("Fun-ASR 转写成功但未返回结果")
return results[0]
if last_status in TERMINAL_FAILED_STATUSES:
raise FunAsrError(f"Fun-ASR 转写任务失败: {last_status}")
time.sleep(poll_interval)
raise FunAsrError(f"Fun-ASR 转写任务超时,最后状态: {last_status}")
def download_transcription_result(transcription_url: str, session=requests) -> dict[str, Any]:
if not transcription_url:
raise FunAsrError("Fun-ASR 结果缺少 transcription_url")
response = _session_get(session, transcription_url, timeout=60)
return _json(response, "下载 Fun-ASR 转写结果")
def _ms_to_srt_time(ms: float) -> str:
total_ms = max(0, int(round(float(ms))))
hours = total_ms // 3_600_000
total_ms %= 3_600_000
minutes = total_ms // 60_000
total_ms %= 60_000
seconds = total_ms // 1_000
milliseconds = total_ms % 1_000
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
def _srt_block(index: int, start_ms: float, end_ms: float, text: str) -> str:
if end_ms <= start_ms:
end_ms = start_ms + 500
return f"{index}\n{_ms_to_srt_time(start_ms)} --> {_ms_to_srt_time(end_ms)}\n{text.strip()}\n"
def _timestamp_ms(value: Any, field_name: str) -> float:
try:
return float(value)
except (TypeError, ValueError) as exc:
raise FunAsrError(f"Fun-ASR 转写结果时间戳无效: {field_name}={value!r}") from exc
def _speaker_prefix(speaker_id: Any) -> str:
if speaker_id is None or speaker_id == "":
return ""
try:
label = int(speaker_id) + 1
except (TypeError, ValueError):
label = str(speaker_id)
return f"说话人{label}: "
def _iter_sentences(result_json: dict[str, Any]):
transcripts = result_json.get("transcripts")
if transcripts is None and "sentences" in result_json:
transcripts = [{"sentences": result_json.get("sentences") or []}]
if not transcripts:
raise FunAsrError("Fun-ASR 转写结果为空:未找到 transcripts")
for transcript in transcripts:
for sentence in transcript.get("sentences") or []:
yield sentence
def _word_text(word: dict[str, Any]) -> str:
text = str(word.get("text") or word.get("word") or "")
punctuation = str(word.get("punctuation") or "")
if punctuation and not text.endswith(punctuation):
text += punctuation
return text
def _flush_block(blocks: list[dict[str, Any]], current: dict[str, Any]) -> None:
text = current.get("text", "").strip()
if text:
blocks.append(current.copy())
def _blocks_from_words(sentence: dict[str, Any], max_chars: int, max_duration: float) -> list[dict[str, Any]]:
words = sentence.get("words") or []
blocks: list[dict[str, Any]] = []
current: Optional[dict[str, Any]] = None
max_duration_ms = max_duration * 1000
sentence_speaker = sentence.get("speaker_id")
for word in words:
text = _word_text(word)
if not text:
continue
start = word.get("begin_time", word.get("start_time"))
end = word.get("end_time")
if start is None or end is None:
continue
speaker_id = word.get("speaker_id", sentence_speaker)
start_ms = _timestamp_ms(start, "word.begin_time")
end_ms = _timestamp_ms(end, "word.end_time")
if current is None:
current = {"start": start_ms, "end": end_ms, "text": text, "speaker_id": speaker_id}
else:
should_split_before = (
speaker_id != current.get("speaker_id")
or len(current["text"] + text) > max_chars
or (end_ms - current["start"]) > max_duration_ms
)
if should_split_before:
_flush_block(blocks, current)
current = {"start": start_ms, "end": end_ms, "text": text, "speaker_id": speaker_id}
else:
current["text"] += text
current["end"] = end_ms
if current and text[-1:] in PUNCTUATION_BREAKS:
_flush_block(blocks, current)
current = None
if current:
_flush_block(blocks, current)
return blocks
def _split_text(text: str, max_chars: int) -> list[str]:
chunks: list[str] = []
current = ""
for char in text:
current += char
if char in PUNCTUATION_BREAKS or len(current) >= max_chars:
chunks.append(current.strip())
current = ""
if current.strip():
chunks.append(current.strip())
return [chunk for chunk in chunks if chunk]
def _blocks_from_sentence(sentence: dict[str, Any], max_chars: int) -> list[dict[str, Any]]:
text = str(sentence.get("text") or "").strip()
if not text:
return []
start = sentence.get("begin_time", 0)
end = sentence.get("end_time")
start_ms = _timestamp_ms(start, "sentence.begin_time")
end_ms = _timestamp_ms(end, "sentence.end_time") if end is not None else start_ms + 500
chunks = _split_text(text, max_chars)
if not chunks:
return []
duration = max(500.0, end_ms - start_ms)
total_chars = max(1, sum(len(chunk) for chunk in chunks))
cursor = start_ms
blocks: list[dict[str, Any]] = []
for i, chunk in enumerate(chunks):
if i == len(chunks) - 1:
chunk_end = end_ms
else:
chunk_end = cursor + duration * (len(chunk) / total_chars)
blocks.append(
{
"start": cursor,
"end": max(cursor + 200, chunk_end),
"text": chunk,
"speaker_id": sentence.get("speaker_id"),
}
)
cursor = chunk_end
return blocks
def fun_asr_result_to_srt(result_json: dict[str, Any], max_chars: int = 20, max_duration: float = 3.5) -> str:
"""Convert downloaded Fun-ASR JSON into fine-grained SRT.
Official downloaded schema is `transcripts[*].sentences[*].words[*]`.
Fun-ASR timestamps are milliseconds.
"""
blocks: list[dict[str, Any]] = []
for sentence in _iter_sentences(result_json):
sentence_blocks = _blocks_from_words(sentence, max_chars, max_duration)
if not sentence_blocks:
sentence_blocks = _blocks_from_sentence(sentence, max_chars)
blocks.extend(sentence_blocks)
if not blocks:
raise FunAsrError("Fun-ASR 转写结果为空:未找到可用字幕内容")
lines = []
for index, block in enumerate(blocks, start=1):
text = f"{_speaker_prefix(block.get('speaker_id'))}{block['text']}"
lines.append(_srt_block(index, block["start"], block["end"], text))
return "\n".join(lines).rstrip() + "\n"
def write_srt_file(srt_content: str, subtitle_file: str = "") -> str:
if not subtitle_file:
subtitle_file = os.path.join(utils.subtitle_dir(), f"fun_asr_{int(time.time())}.srt")
parent = os.path.dirname(subtitle_file)
if parent:
os.makedirs(parent, exist_ok=True)
with open(subtitle_file, "w", encoding="utf-8") as f:
f.write(srt_content)
return subtitle_file
def create_with_fun_asr(
local_file: str,
subtitle_file: str = "",
api_key: str = "",
speaker_count: Optional[int] = None,
poll_interval: float = 2.0,
timeout: float = 600.0,
session=requests,
) -> Optional[str]:
"""Upload local media to Bailian temporary storage and create a Fun-ASR SRT file."""
api_key = _require_api_key(api_key)
try:
policy = request_upload_policy(api_key, session=session)
oss_url = upload_to_temporary_oss(local_file, policy, session=session)
task_id = submit_transcription_task(api_key, oss_url, speaker_count=speaker_count, session=session)
task_result = poll_transcription_task(
api_key,
task_id,
poll_interval=poll_interval,
timeout=timeout,
session=session,
)
transcription_url = task_result.get("transcription_url")
result_json = download_transcription_result(transcription_url, session=session)
srt_content = fun_asr_result_to_srt(result_json)
output_file = write_srt_file(srt_content, subtitle_file)
logger.info(f"Fun-ASR 字幕文件已生成: {output_file}")
return output_file
except FunAsrError:
raise
except Exception as exc:
raise FunAsrError("Fun-ASR 字幕转写失败,请检查文件、网络或阿里百炼服务状态") from exc

View File

@ -0,0 +1,241 @@
import json
import os
import subprocess
import time
from os import path
from loguru import logger
from app.config import config
from app.models import const
from app.models.schema import VideoClipParams
from app.services import voice, clip_video, update_script
from app.services import state as sm
from app.utils import utils
def get_audio_duration_ffprobe(audio_file: str) -> float:
"""
使用ffprobe获取音频文件的精确时长
Args:
audio_file: 音频文件路径
Returns:
float: 音频时长精确到微秒
"""
try:
cmd = [
'ffprobe',
'-v', 'error',
'-show_entries', 'format=duration',
'-of', 'csv=p=0',
audio_file
]
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
duration = float(result.stdout.strip())
logger.debug(f"使用ffprobe获取音频时长: {duration:.6f}")
return duration
except subprocess.CalledProcessError as e:
logger.error(f"ffprobe执行失败: {e.stderr}")
raise
except Exception as e:
logger.error(f"获取音频时长失败: {str(e)}")
raise
def start_export_jianying_draft(task_id: str, params: VideoClipParams):
"""
导出到剪映草稿的后台任务
Args:
task_id: 任务ID
params: 视频参数
"""
logger.info(f"\n\n## 开始导出到剪映草稿任务: {task_id}")
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=0)
"""
1. 加载剪辑脚本
"""
logger.info("\n\n## 1. 加载视频脚本")
video_script_path = path.join(params.video_clip_json_path)
if path.exists(video_script_path):
try:
with open(video_script_path, "r", encoding="utf-8") as f:
list_script = json.load(f)
video_list = [i['narration'] for i in list_script]
video_ost = [i['OST'] for i in list_script]
time_list = [i['timestamp'] for i in list_script]
video_script = " ".join(video_list)
logger.debug(f"解说完整脚本: \n{video_script}")
logger.debug(f"解说 OST 列表: \n{video_ost}")
logger.debug(f"解说时间戳列表: \n{time_list}")
except Exception as e:
logger.error(f"无法读取视频json脚本请检查脚本格式是否正确")
raise ValueError("无法读取视频json脚本请检查脚本格式是否正确")
else:
logger.error(f"解说脚本文件不存在: {video_script_path},请先点击【保存脚本】按钮保存脚本后再生成视频")
raise ValueError("解说脚本文件不存在!请先点击【保存脚本】按钮保存脚本后再生成视频。")
"""
2. 使用 TTS 生成音频素材
"""
logger.info("\n\n## 2. 根据OST设置生成音频列表")
tts_segments = [
segment for segment in list_script
if segment['OST'] in [0, 2]
]
logger.debug(f"需要生成TTS的片段数: {len(tts_segments)}")
tts_results = voice.tts_multiple(
task_id=task_id,
list_script=tts_segments, # 只传入需要TTS的片段
tts_engine=params.tts_engine,
voice_name=params.voice_name,
voice_rate=params.voice_rate,
voice_pitch=params.voice_pitch,
)
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20)
"""
3. 统一视频裁剪 - 基于OST类型的差异化裁剪策略
"""
logger.info("\n\n## 3. 统一视频裁剪基于OST类型")
video_clip_result = clip_video.clip_video_unified(
video_origin_path=params.video_origin_path,
script_list=list_script,
tts_results=tts_results
)
tts_clip_result = {tts_result['_id']: tts_result['audio_file'] for tts_result in tts_results}
subclip_clip_result = {
tts_result['_id']: tts_result['subtitle_file'] for tts_result in tts_results
}
new_script_list = update_script.update_script_timestamps(list_script, video_clip_result, tts_clip_result, subclip_clip_result)
logger.info(f"统一裁剪完成,处理了 {len(video_clip_result)} 个视频片段")
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=60)
"""
4. 导出到剪映草稿
"""
logger.info("\n\n## 4. 导出到剪映草稿")
try:
import pyJianYingDraft
from pyJianYingDraft import DraftFolder, VideoSegment, AudioSegment, trange, TrackType
jianying_draft_path = config.ui.get("jianying_draft_path", "")
if not jianying_draft_path:
raise ValueError("剪映草稿路径未配置")
# 创建DraftFolder实例
draft_folder = DraftFolder(jianying_draft_path)
# 使用从参数中获取的草稿名称,如果为空则使用默认名称
draft_name = getattr(params, 'draft_name', "")
logger.debug(f"从params获取的草稿名称: '{draft_name}' (类型: {type(draft_name)})")
if not draft_name:
draft_name = f"NarratoAI_{int(time.time())}"
logger.debug(f"使用默认草稿名称: '{draft_name}'")
# 创建新草稿
script = draft_folder.create_draft(draft_name, 1920, 1080)
# 添加视频轨道和音频轨道
script.add_track(TrackType.video, '视频轨道')
script.add_track(TrackType.audio, '音频轨道')
# 处理脚本数据
current_time = 0
output_dir = utils.task_dir(task_id)
for item in new_script_list:
# 获取时间信息
start_time = float(item.get('start_time', 0.0))
duration = float(item.get('duration', 0.0))
timestamp = item.get('timestamp', '')
logger.info(f"处理片段: OST={item['OST']}, start_time={start_time}, duration={duration}, timestamp={timestamp}")
# 生成音频文件路径
audio_file = ""
if timestamp:
timestamp_formatted = timestamp.replace(':', '_')
audio_file = os.path.join(
output_dir,
f"audio_{timestamp_formatted}.mp3"
)
# 检查是否有裁剪后的视频文件
video_file = item.get('video', '')
if video_file and not os.path.exists(video_file):
video_file = ""
# 添加视频片段
if video_file:
# 使用裁剪后的视频文件
# 对于裁剪后的视频target_timerange的第二个参数是持续时间
video_segment = VideoSegment(
video_file,
trange(f"{current_time}s", f"{duration}s")
)
else:
# 使用原始视频文件
# source_timerange是从原始视频中截取的部分
# target_timerange是片段在时间轴上的位置
video_segment = VideoSegment(
params.video_origin_path,
trange(f"{current_time}s", f"{duration}s"),
source_timerange=trange(f"{start_time}s", f"{duration}s")
)
script.add_segment(video_segment, '视频轨道')
# 处理音频
if item['OST'] in [0, 2]: # 需要TTS的片段
if os.path.exists(audio_file):
# 使用ffprobe获取精确的音频时长避免因TTS引擎差异导致时长不匹配
actual_audio_duration = get_audio_duration_ffprobe(audio_file)
logger.info(f"音频文件实际时长: {actual_audio_duration:.6f}秒, 脚本时长(视频): {duration:.3f}")
# 使用音频实际时长和视频时长中的较小值,确保不超过素材时长
# 当TTS语速调整时音频可能比视频长或短取较小值可以避免超出素材
safe_duration = min(actual_audio_duration, duration)
logger.info(f"使用时长: {safe_duration:.6f}秒 (取音频和视频时长的较小值)")
audio_segment = AudioSegment(
audio_file,
trange(f"{current_time}s", f"{safe_duration}s")
)
script.add_segment(audio_segment, '音频轨道')
else:
logger.warning(f"音频文件不存在: {audio_file}")
# OST=1的片段保留原声不需要添加额外音频
# 更新当前时间
current_time += duration
# 保存草稿
script.save()
draft_path = os.path.join(jianying_draft_path, draft_name)
logger.success(f"成功导出到剪映草稿: {draft_name}")
logger.info(f"草稿已保存到: {draft_path}")
# 更新任务状态
sm.state.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, draft_path=draft_path, draft_name=draft_name)
return {"draft_path": draft_path, "draft_name": draft_name}
except ImportError as e:
logger.error(f"导入pyJianYingDraft失败: {e}")
raise ImportError(f"pyJianYingDraft库导入失败: {e}\n请确保已正确安装该库")
except Exception as e:
logger.error(f"导出到剪映草稿失败: {e}")
import traceback
logger.error(f"错误详情: {traceback.format_exc()}")
raise Exception(f"导出到剪映草稿失败: {e}")

View File

@ -0,0 +1,403 @@
import tempfile
import unittest
from pathlib import Path
try:
import tomllib
except ModuleNotFoundError: # Python < 3.11
import tomli as tomllib
from app.config import config as cfg
from app.services import fun_asr_subtitle as fasr
class FakeResponse:
def __init__(self, payload=None, status_code=200):
self.payload = payload or {}
self.status_code = status_code
def json(self):
return self.payload
def raise_for_status(self):
if self.status_code >= 400:
raise RuntimeError(f"HTTP {self.status_code}")
class InvalidJsonResponse(FakeResponse):
def json(self):
raise ValueError("invalid json")
class FakeSession:
def __init__(self, local_result):
self.calls = []
self.local_result = local_result
def get(self, url, **kwargs):
self.calls.append(("GET", url, kwargs))
if url == fasr.UPLOAD_POLICY_URL:
return FakeResponse(
{
"data": {
"policy": "policy-token",
"signature": "signature-token",
"upload_dir": "dashscope-instant/test-dir",
"upload_host": "https://dashscope-file-test.oss-cn-beijing.aliyuncs.com",
"oss_access_key_id": "oss-ak",
"x_oss_object_acl": "private",
"x_oss_forbid_overwrite": "true",
"max_file_size_mb": 1,
}
}
)
if url == "https://result.example/transcription.json":
return FakeResponse(self.local_result)
return FakeResponse({}, 404)
def post(self, url, **kwargs):
self.calls.append(("POST", url, kwargs))
if url == "https://dashscope-file-test.oss-cn-beijing.aliyuncs.com":
return FakeResponse({})
if url == fasr.TRANSCRIPTION_URL:
return FakeResponse({"output": {"task_status": "PENDING", "task_id": "task-123"}})
if url == fasr.TASK_URL_TEMPLATE.format(task_id="task-123"):
return FakeResponse(
{
"output": {
"task_status": "SUCCEEDED",
"results": [
{
"file_url": "oss://dashscope-instant/test-dir/audio.wav",
"transcription_url": "https://result.example/transcription.json",
"subtask_status": "SUCCEEDED",
}
],
}
}
)
return FakeResponse({}, 404)
OFFICIAL_SHAPE_RESULT = {
"transcripts": [
{
"sentences": [
{
"begin_time": 0,
"end_time": 3600,
"text": "你好欢迎观看今天的内容",
"speaker_id": 0,
"words": [
{"begin_time": 0, "end_time": 400, "text": "你好", "punctuation": ""},
{"begin_time": 400, "end_time": 900, "text": "欢迎", "punctuation": ""},
{"begin_time": 900, "end_time": 1300, "text": "观看", "punctuation": ""},
{"begin_time": 1300, "end_time": 1800, "text": "今天", "punctuation": ""},
{"begin_time": 1800, "end_time": 2400, "text": "的内容", "punctuation": ""},
],
}
]
}
]
}
class FunAsrSrtConversionTests(unittest.TestCase):
def test_official_shape_words_convert_ms_and_speaker_label(self):
srt = fasr.fun_asr_result_to_srt(OFFICIAL_SHAPE_RESULT, max_chars=20, max_duration=3.5)
self.assertIn("1\n00:00:00,000 --> 00:00:00,400\n说话人1: 你好,", srt)
self.assertIn("2\n00:00:00,400 --> 00:00:02,400\n说话人1: 欢迎观看今天的内容。", srt)
self.assertNotIn("00:06:40,000", srt, "milliseconds must not be treated as seconds")
def test_long_word_sequence_splits_into_fine_blocks(self):
result = {
"transcripts": [
{
"sentences": [
{
"begin_time": 0,
"end_time": 6000,
"speaker_id": 1,
"words": [
{"begin_time": i * 500, "end_time": (i + 1) * 500, "text": token, "punctuation": ""}
for i, token in enumerate(["这是", "一个", "很长", "字幕", "需要", "拆分"])
],
}
]
}
]
}
srt = fasr.fun_asr_result_to_srt(result, max_chars=4, max_duration=10)
self.assertGreaterEqual(srt.count("\n说话人2:"), 3)
self.assertIn("1\n00:00:00,000", srt)
def test_sentence_fallback_uses_ms_without_zero_duration(self):
result = {
"transcripts": [
{
"sentences": [
{
"begin_time": 1000,
"end_time": 3000,
"text": "没有词级时间戳也可以拆分。",
"speaker_id": 0,
"words": [],
}
]
}
]
}
srt = fasr.fun_asr_result_to_srt(result, max_chars=5)
self.assertIn("00:00:01,000", srt)
self.assertIn("说话人1:", srt)
self.assertNotIn("--> 00:00:01,000\n", srt)
def test_empty_result_raises_clear_error(self):
with self.assertRaises(fasr.FunAsrError):
fasr.fun_asr_result_to_srt({"transcripts": []})
def test_malformed_word_timestamp_raises_fun_asr_error(self):
result = {
"transcripts": [
{
"sentences": [
{
"begin_time": 0,
"end_time": 1000,
"speaker_id": 0,
"words": [
{"begin_time": "bad", "end_time": 500, "text": "坏时间", "punctuation": ""}
],
}
]
}
]
}
with self.assertRaises(fasr.FunAsrError):
fasr.fun_asr_result_to_srt(result)
def test_malformed_sentence_timestamp_raises_fun_asr_error(self):
result = {
"transcripts": [
{
"sentences": [
{
"begin_time": "bad",
"end_time": 1000,
"text": "坏时间",
"speaker_id": 0,
"words": [],
}
]
}
]
}
with self.assertRaises(fasr.FunAsrError):
fasr.fun_asr_result_to_srt(result)
class FunAsrServiceTests(unittest.TestCase):
def test_create_with_fun_asr_uses_expected_rest_flow(self):
with tempfile.TemporaryDirectory() as tmp_dir:
local_file = Path(tmp_dir) / "audio.wav"
local_file.write_bytes(b"audio")
subtitle_file = Path(tmp_dir) / "out.srt"
session = FakeSession(OFFICIAL_SHAPE_RESULT)
result_path = fasr.create_with_fun_asr(
str(local_file),
subtitle_file=str(subtitle_file),
api_key="sk-test",
speaker_count=2,
poll_interval=0,
session=session,
)
self.assertEqual(str(subtitle_file), result_path)
self.assertTrue(subtitle_file.exists())
self.assertIn("说话人1:", subtitle_file.read_text(encoding="utf-8"))
policy_call = session.calls[0]
self.assertEqual("GET", policy_call[0])
self.assertEqual(fasr.UPLOAD_POLICY_URL, policy_call[1])
self.assertEqual({"action": "getPolicy", "model": "fun-asr"}, policy_call[2]["params"])
self.assertEqual("Bearer sk-test", policy_call[2]["headers"]["Authorization"])
upload_call = session.calls[1]
self.assertEqual("POST", upload_call[0])
self.assertEqual("https://dashscope-file-test.oss-cn-beijing.aliyuncs.com", upload_call[1])
upload_data = upload_call[2]["data"]
self.assertEqual("oss-ak", upload_data["OSSAccessKeyId"])
self.assertEqual("policy-token", upload_data["policy"])
self.assertEqual("signature-token", upload_data["Signature"])
self.assertEqual("dashscope-instant/test-dir/audio.wav", upload_data["key"])
self.assertEqual("200", upload_data["success_action_status"])
submit_call = session.calls[2]
self.assertEqual(fasr.TRANSCRIPTION_URL, submit_call[1])
headers = submit_call[2]["headers"]
self.assertEqual("enable", headers["X-DashScope-Async"])
self.assertEqual("enable", headers["X-DashScope-OssResourceResolve"])
payload = submit_call[2]["json"]
self.assertEqual("fun-asr", payload["model"])
self.assertEqual(["oss://dashscope-instant/test-dir/audio.wav"], payload["input"]["file_urls"])
self.assertTrue(payload["parameters"]["diarization_enabled"])
self.assertEqual(2, payload["parameters"]["speaker_count"])
poll_call = session.calls[3]
self.assertEqual("POST", poll_call[0])
self.assertTrue(poll_call[1].endswith("/api/v1/tasks/task-123"))
download_call = session.calls[4]
self.assertEqual(("GET", "https://result.example/transcription.json"), download_call[:2])
def test_upload_policy_size_validation_fails_before_upload(self):
policy = fasr.UploadPolicy(
upload_host="https://upload.example",
upload_dir="dashscope-instant/test",
policy="p",
signature="s",
oss_access_key_id="ak",
max_file_size_mb=0.000001,
)
with tempfile.NamedTemporaryFile() as f:
f.write(b"too-large")
f.flush()
with self.assertRaises(fasr.FunAsrError):
fasr.upload_to_temporary_oss(f.name, policy, session=FakeSession({}))
def test_failed_subtask_raises(self):
class FailedSession(FakeSession):
def post(self, url, **kwargs):
if url == fasr.TASK_URL_TEMPLATE.format(task_id="task-123"):
return FakeResponse(
{
"output": {
"task_status": "SUCCEEDED",
"results": [{"subtask_status": "FAILED"}],
}
}
)
return super().post(url, **kwargs)
with self.assertRaises(fasr.FunAsrError):
fasr.poll_transcription_task("sk-test", "task-123", poll_interval=0, session=FailedSession({}))
def test_missing_api_key_raises_before_request(self):
session = FakeSession(OFFICIAL_SHAPE_RESULT)
with self.assertRaises(fasr.FunAsrError):
fasr.request_upload_policy("", session=session)
self.assertEqual([], session.calls)
def test_upload_policy_http_error_raises(self):
class PolicyHttpErrorSession(FakeSession):
def get(self, url, **kwargs):
self.calls.append(("GET", url, kwargs))
return FakeResponse({}, status_code=403)
with self.assertRaises(fasr.FunAsrError):
fasr.request_upload_policy("sk-test", session=PolicyHttpErrorSession({}))
def test_malformed_upload_policy_raises(self):
class MalformedPolicySession(FakeSession):
def get(self, url, **kwargs):
self.calls.append(("GET", url, kwargs))
return FakeResponse({"data": {"policy": "missing-required-fields"}})
with self.assertRaises(fasr.FunAsrError):
fasr.request_upload_policy("sk-test", session=MalformedPolicySession({}))
def test_upload_http_failure_raises(self):
class UploadFailureSession(FakeSession):
def post(self, url, **kwargs):
self.calls.append(("POST", url, kwargs))
return FakeResponse({}, status_code=500)
policy = fasr.UploadPolicy(
upload_host="https://upload.example",
upload_dir="dashscope-instant/test",
policy="p",
signature="s",
oss_access_key_id="ak",
max_file_size_mb=1,
)
with tempfile.NamedTemporaryFile() as f:
f.write(b"audio")
f.flush()
with self.assertRaises(fasr.FunAsrError):
fasr.upload_to_temporary_oss(f.name, policy, session=UploadFailureSession({}))
def test_submit_failure_raises(self):
class SubmitFailureSession(FakeSession):
def post(self, url, **kwargs):
self.calls.append(("POST", url, kwargs))
return FakeResponse({}, status_code=500)
with self.assertRaises(fasr.FunAsrError):
fasr.submit_transcription_task("sk-test", "oss://file", session=SubmitFailureSession({}))
def test_poll_timeout_raises(self):
class PendingSession(FakeSession):
def post(self, url, **kwargs):
self.calls.append(("POST", url, kwargs))
return FakeResponse({"output": {"task_status": "RUNNING"}})
with self.assertRaises(fasr.FunAsrError):
fasr.poll_transcription_task("sk-test", "task-123", poll_interval=0, timeout=-1, session=PendingSession({}))
def test_task_failed_status_raises(self):
class FailedTaskSession(FakeSession):
def post(self, url, **kwargs):
self.calls.append(("POST", url, kwargs))
return FakeResponse({"output": {"task_status": "FAILED"}})
with self.assertRaises(fasr.FunAsrError):
fasr.poll_transcription_task("sk-test", "task-123", poll_interval=0, session=FailedTaskSession({}))
def test_missing_transcription_url_raises(self):
with self.assertRaises(fasr.FunAsrError):
fasr.download_transcription_result("", session=FakeSession({}))
def test_malformed_downloaded_json_raises(self):
class MalformedDownloadSession(FakeSession):
def get(self, url, **kwargs):
self.calls.append(("GET", url, kwargs))
return InvalidJsonResponse()
with self.assertRaises(fasr.FunAsrError):
fasr.download_transcription_result("https://result.example/bad.json", session=MalformedDownloadSession({}))
class FunAsrConfigTests(unittest.TestCase):
def test_save_config_persists_fun_asr_section(self):
original_config_file = cfg.config_file
original_fun_asr = cfg.fun_asr
try:
with tempfile.TemporaryDirectory() as tmp_dir:
config_path = Path(tmp_dir) / "config.toml"
cfg.config_file = str(config_path)
cfg.fun_asr = {"api_key": "sk-local", "model": "fun-asr"}
cfg.save_config()
saved = tomllib.loads(config_path.read_text(encoding="utf-8"))
finally:
cfg.config_file = original_config_file
cfg.fun_asr = original_fun_asr
self.assertEqual("sk-local", saved["fun_asr"]["api_key"])
self.assertEqual("fun-asr", saved["fun_asr"]["model"])
def test_config_example_fun_asr_section_parses(self):
config_data = tomllib.loads(Path("config.example.toml").read_text(encoding="utf-8"))
self.assertEqual("fun-asr", config_data["fun_asr"]["model"])
self.assertIn("api_key", config_data["fun_asr"])
if __name__ == "__main__":
unittest.main()

View File

@ -1116,6 +1116,125 @@ def should_use_azure_speech_services(voice_name: str) -> bool:
return False
def doubaotts_tts(text: str, voice_name: str, voice_file: str, speed: float = 1.0) -> Union[SubMaker, None]:
"""
使用豆包语音 TTS 生成语音
"""
# 读取配置
doubaotts_cfg = getattr(config, "doubaotts", {}) or {}
appid = doubaotts_cfg.get("appid", "")
token = doubaotts_cfg.get("token", "")
ak = doubaotts_cfg.get("ak", "")
sk = doubaotts_cfg.get("sk", "")
cluster = doubaotts_cfg.get("cluster", "volcano_tts")
if not appid or not token:
logger.error("豆包语音 TTS 配置未完成")
return None
# 准备参数
voice_type = voice_name
safe_speed = float(max(0.2, min(3.0, speed)))
text = text.strip()
# 构建请求参数
import uuid
reqid = str(uuid.uuid4())
# 获取高级参数
volume = doubaotts_cfg.get("volume", 1.0)
pitch = doubaotts_cfg.get("pitch", 1.0)
silence_duration = doubaotts_cfg.get("silence_duration", 0.125)
payload = {
"app": {
"appid": appid,
"token": token,
"cluster": cluster
},
"user": {
"uid": "NarratoAI"
},
"audio": {
"voice_type": voice_type,
"encoding": "mp3",
"rate": 24000,
"speed_ratio": safe_speed,
"volume_ratio": float(volume),
"pitch_ratio": float(pitch)
},
"request": {
"reqid": reqid,
"text": text,
"text_type": "plain",
"operation": "query"
}
}
# 如果设置了句尾静音时长,添加到请求参数中
if silence_duration > 0:
payload["audio"]["silence_duration"] = float(silence_duration)
# API 地址
url = "https://openspeech.bytedance.com/api/v1/tts"
# 构建请求头使用Bearer Token认证
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer;{token}"
}
for i in range(3):
try:
logger.info(f"=== 豆包语音 TTS 请求参数 (第 {i+1} 次调用) ===")
# 发送请求
import requests
# 处理代理设置
proxies = None
proxy_enabled = config.proxy.get("enabled", False)
if proxy_enabled:
proxy_url = config.proxy.get("https", config.proxy.get("http", ""))
if proxy_url:
proxies = {"https": proxy_url, "http": proxy_url}
response = requests.post(url, json=payload, headers=headers, proxies=proxies, timeout=60)
if response.status_code == 200:
result = response.json()
if result.get("code") == 3000:
# 成功
audio_data = result.get("data", "")
if audio_data:
# 解码 base64 音频数据
import base64
audio_bytes = base64.b64decode(audio_data)
# 写入文件
with open(voice_file, "wb") as f:
f.write(audio_bytes)
logger.success(f"豆包语音 TTS 合成成功: {voice_file}")
# 创建 SubMaker 对象(简化版,不包含时间戳)
sub_maker = new_sub_maker()
return sub_maker
else:
logger.error("豆包语音 TTS 响应中无音频数据")
else:
logger.error(f"豆包语音 TTS 失败: {result.get('message', '未知错误')}")
else:
logger.error(f"豆包语音 TTS API 请求失败: {response.status_code}, {response.text}")
if i < 2:
time.sleep(1)
except Exception as e:
logger.error(f"豆包语音 TTS 错误: {str(e)}")
if i < 2:
time.sleep(3)
return None
def tts(
text: str, voice_name: str, voice_rate: float, voice_pitch: float, voice_file: str, tts_engine: str
) -> Union[SubMaker, None]:
@ -1147,6 +1266,10 @@ def tts(
if tts_engine == "indextts2":
logger.info("分发到 IndexTTS2")
return indextts2_tts(text, voice_name, voice_file, speed=voice_rate)
if tts_engine == "doubaotts":
logger.info("分发到豆包语音 TTS")
return doubaotts_tts(text, voice_name, voice_file, speed=voice_rate)
# Fallback for unknown engine - default to azure v1
logger.warning(f"未知的 TTS 引擎: '{tts_engine}', 将默认使用 Edge TTS (Azure V1)。")
@ -1606,8 +1729,8 @@ def tts_multiple(task_id: str, list_script: list, voice_name: str, voice_rate: f
f"或者使用其他 tts 引擎")
continue
else:
# SoulVoice、Qwen3、IndexTTS2 引擎不生成字幕文件
if is_soulvoice_voice(voice_name) or is_qwen_engine(tts_engine) or tts_engine == "indextts2":
# SoulVoice、Qwen3、IndexTTS2、豆包语音 引擎不生成字幕文件
if is_soulvoice_voice(voice_name) or is_qwen_engine(tts_engine) or tts_engine == "indextts2" or tts_engine == "doubaotts":
# 获取实际音频文件的时长
duration = get_audio_duration_from_file(audio_file)
if duration <= 0:
@ -1615,8 +1738,27 @@ def tts_multiple(task_id: str, list_script: list, voice_name: str, voice_rate: f
duration = get_audio_duration(sub_maker)
if duration <= 0:
# 最后的 fallback基于文本长度估算
duration = max(1.0, len(text) / 3.0)
logger.warning(f"无法获取音频时长,使用文本估算: {duration:.2f}")
# 对于英文文本,使用更准确的估算方法
# 英文平均语速约为每分钟150-180个单词即每秒2.5-3个单词
# 对于中文文本约为每秒3-4字
import re
# 计算英文单词数
english_words = len(re.findall(r'\b\w+\b', text))
# 计算中文字符数
chinese_chars = len(re.findall(r'[\u4e00-\u9fa5]', text))
if english_words > chinese_chars:
# 主要是英文文本
# 假设平均每个单词需要0.35秒
estimated_duration = max(1.0, english_words * 0.35)
else:
# 主要是中文文本
# 假设平均每个汉字需要0.3秒
estimated_duration = max(1.0, chinese_chars * 0.3)
# 确保估算时长合理
duration = max(1.0, estimated_duration)
logger.warning(f"无法获取音频时长,使用文本估算: {duration:.2f}秒 (英文单词: {english_words}, 中文字符: {chinese_chars})")
# 不创建字幕文件
subtitle_file = ""
else:
@ -1658,8 +1800,6 @@ def get_audio_duration_from_file(audio_file: str) -> float:
# 但实际文件还包含头部信息,所以调整系数
estimated_duration = max(1.0, file_size / 20000) # 调整为更保守的估算
# 对于中文语音,根据文本长度进行二次校正
# 一般中文语音速度约为 3-4 字/秒
logger.warning(f"使用文件大小估算音频时长: {estimated_duration:.2f}")
return estimated_duration
except Exception as e:

View File

@ -93,6 +93,12 @@
# 访问 https://bailian.console.aliyun.com/?tab=model#/api-key 获取你的 API 密钥
api_key = ""
model_name = "qwen3-tts-flash"
[fun_asr]
# 阿里百炼 Fun-ASR 字幕转录配置
# 访问 https://bailian.console.aliyun.com/?tab=model#/api-key 获取你的 API 密钥
api_key = ""
model = "fun-asr"
[indextts2]
# IndexTTS2 语音克隆配置
@ -114,9 +120,25 @@
do_sample = true
num_beams = 3
repetition_penalty = 10.0
[doubaotts]
# 豆包语音 TTS 配置
# 申请流程:
# 1. 打开 https://console.volcengine.com/iam/keymanage 新建 Access Key 和 Secret Key
# 2. 打开 https://www.volcengine.com/product/voice-tech 点击立即使用
# 3. 在 API 服务中心找到音频生成下面的语音合成,获取 APPID 和 Token
ak = ""
sk = ""
appid = ""
token = ""
cluster = "volcano_tts"
# 高级参数
volume = 1.0
pitch = 1.0
silence_duration = 0.125
[ui]
# TTS引擎选择 (edge_tts, azure_speech, soulvoice, tencent_tts, tts_qwen)
# TTS引擎选择 (edge_tts, azure_speech, soulvoice, tencent_tts, tts_qwen, doubaotts)
tts_engine = "edge_tts"
# Edge TTS 配置
@ -130,6 +152,10 @@
azure_volume = 80
azure_rate = 1.0
azure_pitch = 0
# 豆包语音 TTS 配置
doubaotts_voice_type = "BV700_V2_streaming"
doubaotts_rate = 1.0
##########################################
# 代理和网络配置

View File

@ -1 +1 @@
0.7.8
0.7.9

View File

@ -35,3 +35,6 @@ tenacity>=9.0.0
# torch>=2.0.0
# torchvision>=0.15.0
# torchaudio>=2.0.0
# 剪映草稿导出依赖
pyJianYingDraft>=0.1.0

141
webui.py
View File

@ -1,6 +1,7 @@
import streamlit as st
import os
import sys
import time
from loguru import logger
from app.config import config
from webui.components import basic_settings, video_settings, audio_settings, subtitle_settings, script_settings, \
@ -221,6 +222,145 @@ def render_generate_button():
time.sleep(0.5)
def get_voice_name_for_tts_engine(tts_engine: str) -> str:
"""根据TTS引擎获取用户选择的音色"""
if tts_engine == 'doubaotts':
return st.session_state.get('voice_name', config.ui.get('doubaotts_voice_type', 'BV700_streaming'))
elif tts_engine == 'azure_speech':
return st.session_state.get('voice_name', config.ui.get('azure_voice_name', 'zh-CN-XiaoxiaoMultilingualNeural'))
else:
return st.session_state.get('voice_name', config.ui.get('edge_voice_name', 'zh-CN-XiaoxiaoNeural-Female'))
def get_jianying_export_params() -> VideoClipParams:
"""获取导出到剪映草稿的参数"""
tts_engine = st.session_state.get('tts_engine', 'azure')
voice_name = get_voice_name_for_tts_engine(tts_engine)
voice_rate = st.session_state.get('voice_rate', 1.0)
voice_pitch = st.session_state.get('voice_pitch', 1.0)
return VideoClipParams(
video_clip_json_path=st.session_state['video_clip_json_path'],
video_origin_path=st.session_state['video_origin_path'],
tts_engine=tts_engine,
voice_name=voice_name,
voice_rate=voice_rate,
voice_pitch=voice_pitch,
n_threads=config.app.get('n_threads', 4),
video_aspect=VideoAspect.landscape,
subtitle_enabled=st.session_state.get('subtitle_enabled', False),
font_name=st.session_state.get('font_name', 'Microsoft YaHei'),
font_size=st.session_state.get('font_size', 24),
text_fore_color=st.session_state.get('text_fore_color', '#FFFFFF'),
subtitle_position=st.session_state.get('subtitle_position', 'bottom'),
custom_position=st.session_state.get('custom_position', 70.0),
tts_volume=st.session_state.get('tts_volume', 1.0),
original_volume=st.session_state.get('original_volume', 0.7),
bgm_volume=st.session_state.get('bgm_volume', 0.3),
draft_name=st.session_state.get('draft_name_input', f"NarratoAI_{int(time.time())}")
)
def render_export_jianying_button():
"""渲染导出到剪映草稿按钮和处理逻辑"""
import os
import time
import uuid
from loguru import logger
# 初始化session state
if 'show_jianying_export_form' not in st.session_state:
st.session_state['show_jianying_export_form'] = False
if 'jianying_export_result' not in st.session_state:
st.session_state['jianying_export_result'] = None
if 'jianying_export_error' not in st.session_state:
st.session_state['jianying_export_error'] = None
if st.button("📤 导出到剪映草稿", use_container_width=True, type="secondary"):
config.save_config()
if not st.session_state.get('video_clip_json_path'):
st.error("脚本文件不能为空")
return
if not st.session_state.get('video_origin_path'):
st.error("视频文件不能为空")
return
jianying_draft_path = config.ui.get("jianying_draft_path", "")
if not jianying_draft_path:
st.error("请在基础设置中配置剪映草稿地址")
return
if not os.path.exists(jianying_draft_path):
st.error(f"剪映草稿文件夹不存在: {jianying_draft_path}")
return
# 显示导出表单
st.session_state['show_jianying_export_form'] = True
st.session_state['jianying_export_result'] = None
st.session_state['jianying_export_error'] = None
# 显示导出表单
if st.session_state['show_jianying_export_form']:
st.markdown("---")
st.subheader("导出到剪映草稿")
draft_name = st.text_input(
"请输入剪映草稿名称",
value=f"NarratoAI_{int(time.time())}",
key="draft_name_input"
)
if st.button("确认导出", key="confirm_export"):
if not draft_name:
st.error("请输入草稿名称")
return
# 创建任务ID
task_id = str(uuid.uuid4())
st.session_state['task_id'] = task_id
# 构建参数
try:
params = get_jianying_export_params()
except Exception as e:
logger.error(f"构建参数失败: {e}")
st.error(f"参数构建失败: {e}")
return
with st.spinner("正在导出到剪映草稿,请稍候..."):
try:
from app.services import jianying_task
# 调用导出到剪映草稿的任务
result = jianying_task.start_export_jianying_draft(task_id, params)
# 记录日志
logger.info(f"成功导出到剪映草稿: {result['draft_name']}")
logger.info(f"草稿已保存到: {result['draft_path']}")
# 保存结果到session state
st.session_state['jianying_export_result'] = result
st.session_state['jianying_export_error'] = None
st.session_state['show_jianying_export_form'] = False
st.success(f"✅ 成功导出到剪映草稿: {result['draft_name']}")
st.info(f"📁 草稿已保存到: {result['draft_path']}")
except Exception as e:
logger.error(f"导出到剪映草稿失败: {e}")
import traceback
logger.error(f"错误详情: {traceback.format_exc()}")
st.session_state['jianying_export_error'] = str(e)
st.session_state['jianying_export_result'] = None
st.error(f"❌ 导出到剪映草稿失败: {e}")
if st.button("取消", key="cancel_export"):
st.session_state['show_jianying_export_form'] = False
st.session_state['jianying_export_result'] = None
st.session_state['jianying_export_error'] = None
st.rerun()
def main():
"""主函数"""
@ -285,6 +425,7 @@ def main():
# 放到最后渲染生成按钮和处理逻辑
render_generate_button()
render_export_jianying_button()
if __name__ == "__main__":

View File

@ -26,7 +26,8 @@ def get_tts_engine_options():
"azure_speech": "Azure Speech Services",
"tencent_tts": "腾讯云 TTS",
"qwen3_tts": "通义千问 Qwen3 TTS",
"indextts2": "IndexTTS2 语音克隆"
"indextts2": "IndexTTS2 语音克隆",
"doubaotts": "豆包语音 TTS"
}
@ -62,6 +63,12 @@ def get_tts_engine_descriptions():
"features": "零样本语音克隆,上传参考音频即可合成相同音色的语音,需要本地或私有部署",
"use_case": "下载地址https://pan.quark.cn/s/0767c9bcefd5",
"registration": None
},
"doubaotts": {
"title": "豆包语音 TTS",
"features": "火山引擎豆包语音合成,支持多种音色和情感,国内访问速度快",
"use_case": "需要高质量中文语音合成的用户",
"registration": "https://www.volcengine.com/product/voice-tech"
}
}
@ -147,6 +154,8 @@ def render_tts_settings(tr):
render_qwen3_tts_settings(tr)
elif selected_engine == "indextts2":
render_indextts2_tts_settings(tr)
elif selected_engine == "doubaotts":
render_doubaotts_settings(tr)
# 4. 试听功能
render_voice_preview_new(tr, selected_engine)
@ -703,6 +712,250 @@ def render_indextts2_tts_settings(tr):
config.ui["voice_name"] = f"indextts2:{reference_audio}"
def render_doubaotts_settings(tr):
"""渲染豆包语音 TTS 设置"""
# AK 输入
ak = st.text_input(
"Access Key",
value=config.doubaotts.get("ak", ""),
help="火山引擎 Access Key"
)
# SK 输入
sk = st.text_input(
"Secret Key",
value=config.doubaotts.get("sk", ""),
type="password",
help="火山引擎 Secret Key"
)
# AppID 输入
appid = st.text_input(
"AppID",
value=config.doubaotts.get("appid", ""),
help="豆包语音应用 AppID"
)
# Token 输入
token = st.text_input(
"Token",
value=config.doubaotts.get("token", ""),
type="password",
help="豆包语音应用 Token"
)
# 集群配置
cluster = st.text_input(
"集群",
value=config.doubaotts.get("cluster", "volcano_tts"),
help="业务集群,标准音色使用 volcano_tts"
)
# 音色选择
# 在线音色列表(从文档中提取)
voice_options = {
"BV700_V2_streaming": "灿灿 2.0",
"BV705_streaming": "炀炀",
"BV701_V2_streaming": "擎苍 2.0",
"BV001_V2_streaming": "通用女声 2.0",
"BV700_streaming": "灿灿",
"BV406_V2_streaming": "超自然音色-梓梓2.0",
"BV406_streaming": "超自然音色-梓梓",
"BV407_V2_streaming": "超自然音色-燃燃2.0",
"BV407_streaming": "超自然音色-燃燃",
"BV001_streaming": "通用女声",
"BV002_streaming": "通用男声",
"BV701_streaming": "擎苍",
"BV123_streaming": "阳光青年",
"BV120_streaming": "反卷青年",
"BV119_streaming": "通用赘婿",
"BV115_streaming": "古风少御",
"BV107_streaming": "霸气青叔",
"BV100_streaming": "质朴青年",
"BV104_streaming": "温柔淑女",
"BV004_streaming": "开朗青年",
"BV113_streaming": "甜宠少御",
"BV102_streaming": "儒雅青年",
"BV405_streaming": "甜美小源",
"BV007_streaming": "亲切女声",
"BV009_streaming": "知性女声",
"BV419_streaming": "诚诚",
"BV415_streaming": "童童",
"BV008_streaming": "亲切男声",
"BV408_streaming": "译制片男声",
"BV426_streaming": "懒小羊",
"BV428_streaming": "清新文艺女声",
"BV403_streaming": "鸡汤女声",
"BV158_streaming": "智慧老者",
"BV157_streaming": "慈爱姥姥",
"BR001_streaming": "说唱小哥",
"BV410_streaming": "活力解说男",
"BV411_streaming": "影视解说小帅",
"BV437_streaming": "解说小帅-多情感",
"BV412_streaming": "影视解说小美",
"BV159_streaming": "纨绔青年",
"BV418_streaming": "直播一姐",
"BV142_streaming": "沉稳解说男",
"BV143_streaming": "潇洒青年",
"BV056_streaming": "阳光男声",
"BV005_streaming": "活泼女声",
"BV064_streaming": "小萝莉",
"BV051_streaming": "奶气萌娃",
"BV063_streaming": "动漫海绵",
"BV417_streaming": "动漫海星",
"BV050_streaming": "动漫小新",
"BV061_streaming": "天才童声",
"BV401_streaming": "促销男声",
"BV402_streaming": "促销女声",
"BV006_streaming": "磁性男声",
"BV011_streaming": "新闻女声",
"BV012_streaming": "新闻男声",
"BV034_streaming": "知性姐姐-双语",
"BV033_streaming": "温柔小哥",
"BV511_streaming": "慵懒女声-Ava",
"BV505_streaming": "议论女声-Alicia",
"BV138_streaming": "情感女声-Lawrence",
"BV027_streaming": "美式女声-Amelia",
"BV502_streaming": "讲述女声-Amanda",
"BV503_streaming": "活力女声-Ariana",
"BV504_streaming": "活力男声-Jackson",
"BV421_streaming": "天才少女",
"BV702_streaming": "Stefan",
"BV506_streaming": "天真萌娃-Lily",
"BV040_streaming": "亲切女声-Anna",
"BV516_streaming": "澳洲男声-Henry",
"BV520_streaming": "元气少女",
"BV521_streaming": "萌系少女",
"BV522_streaming": "气质女声",
"BV524_streaming": "日语男声",
"BV531_streaming": "活力男声Carlos巴西地区",
"BV530_streaming": "活力女声(巴西地区)",
"BV065_streaming": "气质御姐(墨西哥地区)",
"BV021_streaming": "东北老铁",
"BV020_streaming": "东北丫头",
"BV704_streaming": "方言灿灿",
"BV210_streaming": "西安佟掌柜",
"BV217_streaming": "沪上阿姐",
"BV213_streaming": "广西表哥",
"BV025_streaming": "甜美台妹",
"BV227_streaming": "台普男声",
"BV026_streaming": "港剧男神",
"BV424_streaming": "广东女仔",
"BV212_streaming": "相声演员",
"BV019_streaming": "重庆小伙",
"BV221_streaming": "四川甜妹儿",
"BV423_streaming": "重庆幺妹儿",
"BV214_streaming": "乡村企业家",
"BV226_streaming": "湖南妹坨",
"BV216_streaming": "长沙靓女"
}
saved_voice_type = config.ui.get("doubaotts_voice_type", "BV700_streaming")
if saved_voice_type not in voice_options:
voice_options[saved_voice_type] = f"自定义音色 ({saved_voice_type})"
selected_voice_display = st.selectbox(
"音色选择",
options=list(voice_options.values()),
index=list(voice_options.keys()).index(saved_voice_type) if saved_voice_type in voice_options else 0,
help="选择豆包语音 TTS 音色"
)
# 获取实际的音色ID
voice_type = list(voice_options.keys())[
list(voice_options.values()).index(selected_voice_display)
]
# 高级参数折叠面板
with st.expander("🔧 高级参数", expanded=False):
col1, col2 = st.columns(2)
with col1:
# 语速调节
voice_rate = st.slider(
"语速调节",
min_value=0.2,
max_value=3.0,
value=config.ui.get("doubaotts_rate", 1.0),
step=0.1,
help="调节语音速度 (0.2-3.0)"
)
# 音量调节
voice_volume = st.slider(
"音量调节",
min_value=0.1,
max_value=2.0,
value=config.doubaotts.get("volume", 1.0),
step=0.1,
help="调节语音音量 (0.1-2.0)"
)
with col2:
# 音高调节
voice_pitch = st.slider(
"音高调节",
min_value=0.5,
max_value=1.5,
value=config.doubaotts.get("pitch", 1.0),
step=0.1,
help="调节语音音高 (0.5-1.5)"
)
# 句尾静音时长
silence_duration = st.slider(
"句尾静音时长 (秒)",
min_value=0.0,
max_value=2.0,
value=config.doubaotts.get("silence_duration", 0.125),
step=0.05,
help="调节句尾静音时长 (0.0-2.0秒)"
)
# 显示API Key申请流程
with st.expander("💡 豆包语音 TTS API Key申请流程", expanded=False):
st.write("**申请步骤:**")
st.write("1. 打开 [https://console.volcengine.com/iam/keymanage](https://console.volcengine.com/iam/keymanage)")
st.write("2. 新建 Access Key 和 Secret Key")
st.write("3. 打开 [https://www.volcengine.com/product/voice-tech](https://www.volcengine.com/product/voice-tech)")
st.write("4. 点击立即使用")
st.write("5. 在最左边的API服务中心找到音频生成下面的语音合成注意是语音合成不是语音合成大模型")
st.write("6. 翻到最下面获取 APPID 和 Access Token")
st.write("")
st.info("💡 请将获取到的 Access Key、Secret Key、AppID 和 Token 填写到上方的配置中")
# 保存配置
config.doubaotts["ak"] = ak
config.doubaotts["sk"] = sk
config.doubaotts["appid"] = appid
config.doubaotts["token"] = token
config.doubaotts["cluster"] = cluster
config.doubaotts["volume"] = voice_volume
config.doubaotts["pitch"] = voice_pitch
config.doubaotts["silence_duration"] = silence_duration
config.ui["doubaotts_voice_type"] = voice_type
config.ui["doubaotts_rate"] = voice_rate
config.ui["voice_name"] = voice_type # 兼容性
st.session_state['voice_rate'] = voice_rate # 确保语速参数被保存到session state
# 显示配置状态
if ak and sk and appid and token:
st.success("✅ 豆包语音 TTS 配置已设置")
else:
missing = []
if not ak:
missing.append("Access Key")
if not sk:
missing.append("Secret Key")
if not appid:
missing.append("AppID")
if not token:
missing.append("Token")
if missing:
st.warning(f"⚠️ 请配置: {', '.join(missing)}")
def render_voice_preview_new(tr, selected_engine):
"""渲染新的语音试听功能"""
if st.button("🎵 试听语音合成", use_container_width=True):
@ -746,6 +999,11 @@ def render_voice_preview_new(tr, selected_engine):
voice_name = f"indextts2:{reference_audio}"
voice_rate = 1.0 # IndexTTS2 不支持速度调节
voice_pitch = 1.0 # IndexTTS2 不支持音调调节
elif selected_engine == "doubaotts":
voice_type = config.ui.get("doubaotts_voice_type", "BV700_streaming")
voice_name = voice_type
voice_rate = config.ui.get("doubaotts_rate", 1.0)
voice_pitch = 1.0 # 豆包语音 TTS 不支持音调调节
if not voice_name:
st.error("请先配置语音设置")

View File

@ -217,6 +217,15 @@ def render_proxy_settings(tr):
config.proxy["http"] = ""
config.proxy["https"] = ""
# 剪映草稿地址设置
st.subheader("剪映草稿设置")
jianying_draft_path = st.text_input(
"剪映草稿文件夹路径",
value=config.ui.get("jianying_draft_path", ""),
help="剪映草稿文件夹路径例如C:\\Users\\用户名\\Documents\\JianyingPro Drafts"
)
config.ui["jianying_draft_path"] = jianying_draft_path
def test_vision_model_connection(api_key, base_url, model_name, provider, tr):
"""测试视觉模型连接

View File

@ -327,6 +327,8 @@ def short_drama_summary(tr):
# 检查是否已经处理过字幕文件
if 'subtitle_file_processed' not in st.session_state:
st.session_state['subtitle_file_processed'] = False
render_fun_asr_transcription(tr)
subtitle_file = st.file_uploader(
tr("上传字幕文件"),
@ -401,6 +403,95 @@ def short_drama_summary(tr):
return video_theme
def render_fun_asr_transcription(tr):
"""使用阿里百炼 Fun-ASR 从本地音视频转写生成字幕。"""
def clear_fun_asr_subtitle_state():
st.session_state['subtitle_path'] = None
st.session_state['subtitle_content'] = None
st.session_state['subtitle_file_processed'] = False
with st.expander("阿里百炼 Fun-ASR 字幕转录", expanded=False):
st.caption("上传本地音频/视频后,将自动上传到阿里百炼临时存储并通过 fun-asr 生成 SRT 字幕。")
st.markdown(
"API Key 获取地址:"
"[https://bailian.console.aliyun.com/?tab=model#/api-key]"
"(https://bailian.console.aliyun.com/?tab=model#/api-key)"
)
api_key = st.text_input(
"阿里百炼 API Key",
value=config.fun_asr.get("api_key", ""),
type="password",
help="请输入你自己的阿里百炼 API Key保存配置后会写入本地 config.toml",
key="fun_asr_api_key",
)
uploaded_media = st.file_uploader(
"上传需要转录的音频/视频",
type=[
"aac", "amr", "avi", "flac", "flv", "m4a", "mkv", "mov",
"mp3", "mp4", "mpeg", "ogg", "opus", "wav", "webm", "wma", "wmv",
],
accept_multiple_files=False,
key="fun_asr_media_uploader",
)
if st.button("转写生成字幕", key="fun_asr_transcribe"):
if not api_key.strip():
clear_fun_asr_subtitle_state()
st.error("请先输入阿里百炼 API Key")
return
if uploaded_media is None:
clear_fun_asr_subtitle_state()
st.error("请先上传需要转录的音频或视频文件")
return
try:
clear_fun_asr_subtitle_state()
from app.services import fun_asr_subtitle
config.fun_asr["api_key"] = api_key.strip()
config.fun_asr["model"] = "fun-asr"
config.save_config()
temp_dir = utils.temp_dir("fun_asr")
safe_filename = os.path.basename(uploaded_media.name)
media_path = os.path.join(temp_dir, safe_filename)
file_name, file_extension = os.path.splitext(safe_filename)
if os.path.exists(media_path):
timestamp = time.strftime("%Y%m%d%H%M%S")
media_path = os.path.join(temp_dir, f"{file_name}_{timestamp}{file_extension}")
with open(media_path, "wb") as f:
f.write(uploaded_media.getbuffer())
subtitle_name = f"{os.path.splitext(os.path.basename(media_path))[0]}_fun_asr.srt"
subtitle_path = os.path.join(utils.subtitle_dir(), subtitle_name)
with st.spinner("正在使用阿里百炼 Fun-ASR 转写字幕,请稍候..."):
generated_path = fun_asr_subtitle.create_with_fun_asr(
local_file=media_path,
subtitle_file=subtitle_path,
api_key=api_key.strip(),
)
if not generated_path or not os.path.exists(generated_path):
clear_fun_asr_subtitle_state()
st.error("Fun-ASR 转写失败:未生成字幕文件")
return
with open(generated_path, "r", encoding="utf-8") as f:
subtitle_content = f.read()
st.session_state['subtitle_path'] = generated_path
st.session_state['subtitle_content'] = subtitle_content
st.session_state['subtitle_file_processed'] = True
st.success(f"字幕转写成功: {os.path.basename(generated_path)}")
except Exception as e:
clear_fun_asr_subtitle_state()
logger.error(f"Fun-ASR 字幕转写失败: {traceback.format_exc()}")
st.error(f"Fun-ASR 字幕转写失败: {str(e)}")
def render_script_buttons(tr, params):
"""渲染脚本操作按钮"""
# 获取当前选择的脚本类型