mirror of
https://github.com/linyqh/NarratoAI.git
synced 2026-05-01 06:08:16 +00:00
feat(字幕): 新增阿里百炼 Fun-ASR 音视频字幕转录功能
- 在 WebUI 中增加 Fun-ASR 转录界面,支持上传多种音视频格式并生成 SRT 字幕 - 新增 `app/services/fun_asr_subtitle.py` 服务模块,实现完整的 REST API 调用流程,包括获取上传凭证、文件上传、提交任务、轮询结果和 SRT 格式转换 - 在配置文件中增加 `[fun_asr]` 配置段,支持保存 API Key - 添加完整的单元测试,覆盖核心转换逻辑和服务流程 - 为兼容 Python 3.11 以下版本,将 `tomllib` 导入改为尝试导入并回退到 `tomli` - 在 `defaults.py` 中添加 `from __future__ import annotations` 以支持类型注解
This commit is contained in:
parent
8c129790c7
commit
99dd4193ae
@ -81,6 +81,7 @@ def save_config():
|
||||
_cfg["soulvoice"] = soulvoice
|
||||
_cfg["ui"] = ui
|
||||
_cfg["tts_qwen"] = tts_qwen
|
||||
_cfg["fun_asr"] = fun_asr
|
||||
_cfg["indextts2"] = indextts2
|
||||
_cfg["doubaotts"] = doubaotts
|
||||
f.write(toml.dumps(_cfg))
|
||||
@ -96,6 +97,7 @@ soulvoice = _cfg.get("soulvoice", {})
|
||||
ui = _cfg.get("ui", {})
|
||||
frames = _cfg.get("frames", {})
|
||||
tts_qwen = _cfg.get("tts_qwen", {})
|
||||
fun_asr = _cfg.get("fun_asr", {})
|
||||
indextts2 = _cfg.get("indextts2", {})
|
||||
doubaotts = _cfg.get("doubaotts", {})
|
||||
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
"""Shared config defaults used by both bootstrap and WebUI fallbacks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
DEFAULT_OPENAI_COMPATIBLE_BASE_URL = "https://api.siliconflow.cn/v1"
|
||||
DEFAULT_OPENAI_COMPATIBLE_PROVIDER = "openai"
|
||||
|
||||
|
||||
@ -2,7 +2,10 @@ import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import tomllib
|
||||
try:
|
||||
import tomllib
|
||||
except ModuleNotFoundError: # Python < 3.11
|
||||
import tomli as tomllib
|
||||
|
||||
from app.config import config as cfg
|
||||
from app.config.defaults import (
|
||||
|
||||
452
app/services/fun_asr_subtitle.py
Normal file
452
app/services/fun_asr_subtitle.py
Normal file
@ -0,0 +1,452 @@
|
||||
"""Aliyun Bailian Fun-ASR subtitle transcription helpers.
|
||||
|
||||
This module intentionally uses the REST API because the official Fun-ASR
|
||||
recorded-file API supports temporary `oss://` resources only through REST.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
from loguru import logger
|
||||
|
||||
from app.utils import utils
|
||||
|
||||
DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com"
|
||||
UPLOAD_POLICY_URL = f"{DASHSCOPE_BASE_URL}/api/v1/uploads"
|
||||
TRANSCRIPTION_URL = f"{DASHSCOPE_BASE_URL}/api/v1/services/audio/asr/transcription"
|
||||
TASK_URL_TEMPLATE = f"{DASHSCOPE_BASE_URL}/api/v1/tasks/{{task_id}}"
|
||||
MODEL_NAME = "fun-asr"
|
||||
TERMINAL_FAILED_STATUSES = {"FAILED", "CANCELED", "UNKNOWN"}
|
||||
PUNCTUATION_BREAKS = set(",。!?;,.!?;")
|
||||
|
||||
|
||||
class FunAsrError(RuntimeError):
|
||||
"""Raised for user-actionable Fun-ASR transcription failures."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class UploadPolicy:
|
||||
upload_host: str
|
||||
upload_dir: str
|
||||
policy: str
|
||||
signature: str
|
||||
oss_access_key_id: str
|
||||
x_oss_object_acl: str = "private"
|
||||
x_oss_forbid_overwrite: str = "true"
|
||||
max_file_size_mb: Optional[float] = None
|
||||
|
||||
|
||||
def _auth_headers(api_key: str, extra: Optional[dict[str, str]] = None) -> dict[str, str]:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if extra:
|
||||
headers.update(extra)
|
||||
return headers
|
||||
|
||||
|
||||
def _raise_for_http(response: requests.Response, action: str) -> None:
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except Exception as exc: # requests may be mocked with generic exceptions
|
||||
raise FunAsrError(f"{action}失败,请检查阿里百炼 API Key、网络或服务状态") from exc
|
||||
|
||||
|
||||
def _json(response: requests.Response, action: str) -> dict[str, Any]:
|
||||
_raise_for_http(response, action)
|
||||
try:
|
||||
data = response.json()
|
||||
except Exception as exc:
|
||||
raise FunAsrError(f"{action}返回了无效 JSON") from exc
|
||||
if not isinstance(data, dict):
|
||||
raise FunAsrError(f"{action}返回格式无效")
|
||||
return data
|
||||
|
||||
|
||||
def _require_api_key(api_key: str) -> str:
|
||||
api_key = (api_key or "").strip()
|
||||
if not api_key:
|
||||
raise FunAsrError("请先输入阿里百炼 API Key")
|
||||
return api_key
|
||||
|
||||
|
||||
def _safe_upload_name(local_file: str) -> str:
|
||||
name = os.path.basename(local_file).strip() or f"audio_{int(time.time())}.wav"
|
||||
return name.replace("/", "_").replace("\\", "_")
|
||||
|
||||
|
||||
def _session_get(session, url: str, **kwargs):
|
||||
return session.get(url, **kwargs)
|
||||
|
||||
|
||||
def _session_post(session, url: str, **kwargs):
|
||||
return session.post(url, **kwargs)
|
||||
|
||||
|
||||
def request_upload_policy(api_key: str, model: str = MODEL_NAME, session=requests) -> UploadPolicy:
|
||||
"""Request Bailian temporary-storage upload policy for the target model."""
|
||||
api_key = _require_api_key(api_key)
|
||||
response = _session_get(
|
||||
session,
|
||||
UPLOAD_POLICY_URL,
|
||||
params={"action": "getPolicy", "model": model},
|
||||
headers=_auth_headers(api_key),
|
||||
timeout=30,
|
||||
)
|
||||
data = _json(response, "获取临时存储上传凭证")
|
||||
policy_data = data.get("data") or {}
|
||||
required = ["upload_host", "upload_dir", "policy", "signature", "oss_access_key_id"]
|
||||
missing = [field for field in required if not policy_data.get(field)]
|
||||
if missing:
|
||||
raise FunAsrError(f"临时存储上传凭证缺少字段: {', '.join(missing)}")
|
||||
|
||||
return UploadPolicy(
|
||||
upload_host=str(policy_data["upload_host"]),
|
||||
upload_dir=str(policy_data["upload_dir"]).rstrip("/"),
|
||||
policy=str(policy_data["policy"]),
|
||||
signature=str(policy_data["signature"]),
|
||||
oss_access_key_id=str(policy_data["oss_access_key_id"]),
|
||||
x_oss_object_acl=str(policy_data.get("x_oss_object_acl") or "private"),
|
||||
x_oss_forbid_overwrite=str(policy_data.get("x_oss_forbid_overwrite") or "true"),
|
||||
max_file_size_mb=policy_data.get("max_file_size_mb"),
|
||||
)
|
||||
|
||||
|
||||
def _validate_file_size(local_file: str, policy: UploadPolicy) -> None:
|
||||
if policy.max_file_size_mb is None:
|
||||
return
|
||||
max_bytes = float(policy.max_file_size_mb) * 1024 * 1024
|
||||
size = os.path.getsize(local_file)
|
||||
if size > max_bytes:
|
||||
raise FunAsrError(
|
||||
f"文件大小超过阿里百炼临时存储限制: {size / 1024 / 1024:.2f}MB > {float(policy.max_file_size_mb):.2f}MB"
|
||||
)
|
||||
|
||||
|
||||
def upload_to_temporary_oss(local_file: str, policy: UploadPolicy, session=requests) -> str:
|
||||
"""Upload local file to temporary OSS and return `oss://...` URL."""
|
||||
if not os.path.isfile(local_file):
|
||||
raise FunAsrError(f"待转写文件不存在: {local_file}")
|
||||
_validate_file_size(local_file, policy)
|
||||
|
||||
key = f"{policy.upload_dir}/{_safe_upload_name(local_file)}"
|
||||
data = {
|
||||
"OSSAccessKeyId": policy.oss_access_key_id,
|
||||
"policy": policy.policy,
|
||||
"Signature": policy.signature,
|
||||
"key": key,
|
||||
"x-oss-object-acl": policy.x_oss_object_acl,
|
||||
"x-oss-forbid-overwrite": policy.x_oss_forbid_overwrite,
|
||||
"success_action_status": "200",
|
||||
}
|
||||
with open(local_file, "rb") as file_obj:
|
||||
files = {"file": (_safe_upload_name(local_file), file_obj)}
|
||||
response = _session_post(session, policy.upload_host, data=data, files=files, timeout=120)
|
||||
_raise_for_http(response, "上传文件到阿里百炼临时存储")
|
||||
return f"oss://{key}"
|
||||
|
||||
|
||||
def submit_transcription_task(
|
||||
api_key: str,
|
||||
oss_url: str,
|
||||
speaker_count: Optional[int] = None,
|
||||
model: str = MODEL_NAME,
|
||||
session=requests,
|
||||
) -> str:
|
||||
"""Submit async Fun-ASR task and return task_id."""
|
||||
api_key = _require_api_key(api_key)
|
||||
parameters: dict[str, Any] = {"diarization_enabled": True}
|
||||
if speaker_count:
|
||||
parameters["speaker_count"] = int(speaker_count)
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"input": {"file_urls": [oss_url]},
|
||||
"parameters": parameters,
|
||||
}
|
||||
response = _session_post(
|
||||
session,
|
||||
TRANSCRIPTION_URL,
|
||||
headers=_auth_headers(
|
||||
api_key,
|
||||
{
|
||||
"X-DashScope-Async": "enable",
|
||||
"X-DashScope-OssResourceResolve": "enable",
|
||||
},
|
||||
),
|
||||
json=payload,
|
||||
timeout=30,
|
||||
)
|
||||
data = _json(response, "提交 Fun-ASR 转写任务")
|
||||
task_id = ((data.get("output") or {}).get("task_id") or "").strip()
|
||||
if not task_id:
|
||||
raise FunAsrError("提交 Fun-ASR 转写任务失败:未返回 task_id")
|
||||
return task_id
|
||||
|
||||
|
||||
def poll_transcription_task(
|
||||
api_key: str,
|
||||
task_id: str,
|
||||
poll_interval: float = 2.0,
|
||||
timeout: float = 600.0,
|
||||
session=requests,
|
||||
) -> dict[str, Any]:
|
||||
"""Poll task until terminal status and return successful result item."""
|
||||
api_key = _require_api_key(api_key)
|
||||
deadline = time.time() + timeout
|
||||
last_status = "PENDING"
|
||||
while time.time() < deadline:
|
||||
response = _session_post(
|
||||
session,
|
||||
TASK_URL_TEMPLATE.format(task_id=task_id),
|
||||
headers=_auth_headers(api_key),
|
||||
timeout=30,
|
||||
)
|
||||
data = _json(response, "查询 Fun-ASR 转写任务")
|
||||
output = data.get("output") or {}
|
||||
last_status = str(output.get("task_status") or "").upper()
|
||||
|
||||
if last_status == "SUCCEEDED":
|
||||
results = output.get("results") or []
|
||||
for result in results:
|
||||
subtask_status = str(result.get("subtask_status") or "").upper()
|
||||
if subtask_status and subtask_status != "SUCCEEDED":
|
||||
raise FunAsrError(f"Fun-ASR 子任务失败: {subtask_status}")
|
||||
if not results:
|
||||
raise FunAsrError("Fun-ASR 转写成功但未返回结果")
|
||||
return results[0]
|
||||
|
||||
if last_status in TERMINAL_FAILED_STATUSES:
|
||||
raise FunAsrError(f"Fun-ASR 转写任务失败: {last_status}")
|
||||
|
||||
time.sleep(poll_interval)
|
||||
|
||||
raise FunAsrError(f"Fun-ASR 转写任务超时,最后状态: {last_status}")
|
||||
|
||||
|
||||
def download_transcription_result(transcription_url: str, session=requests) -> dict[str, Any]:
|
||||
if not transcription_url:
|
||||
raise FunAsrError("Fun-ASR 结果缺少 transcription_url")
|
||||
response = _session_get(session, transcription_url, timeout=60)
|
||||
return _json(response, "下载 Fun-ASR 转写结果")
|
||||
|
||||
|
||||
def _ms_to_srt_time(ms: float) -> str:
|
||||
total_ms = max(0, int(round(float(ms))))
|
||||
hours = total_ms // 3_600_000
|
||||
total_ms %= 3_600_000
|
||||
minutes = total_ms // 60_000
|
||||
total_ms %= 60_000
|
||||
seconds = total_ms // 1_000
|
||||
milliseconds = total_ms % 1_000
|
||||
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
|
||||
|
||||
|
||||
def _srt_block(index: int, start_ms: float, end_ms: float, text: str) -> str:
|
||||
if end_ms <= start_ms:
|
||||
end_ms = start_ms + 500
|
||||
return f"{index}\n{_ms_to_srt_time(start_ms)} --> {_ms_to_srt_time(end_ms)}\n{text.strip()}\n"
|
||||
|
||||
|
||||
def _timestamp_ms(value: Any, field_name: str) -> float:
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise FunAsrError(f"Fun-ASR 转写结果时间戳无效: {field_name}={value!r}") from exc
|
||||
|
||||
|
||||
def _speaker_prefix(speaker_id: Any) -> str:
|
||||
if speaker_id is None or speaker_id == "":
|
||||
return ""
|
||||
try:
|
||||
label = int(speaker_id) + 1
|
||||
except (TypeError, ValueError):
|
||||
label = str(speaker_id)
|
||||
return f"说话人{label}: "
|
||||
|
||||
|
||||
def _iter_sentences(result_json: dict[str, Any]):
|
||||
transcripts = result_json.get("transcripts")
|
||||
if transcripts is None and "sentences" in result_json:
|
||||
transcripts = [{"sentences": result_json.get("sentences") or []}]
|
||||
if not transcripts:
|
||||
raise FunAsrError("Fun-ASR 转写结果为空:未找到 transcripts")
|
||||
for transcript in transcripts:
|
||||
for sentence in transcript.get("sentences") or []:
|
||||
yield sentence
|
||||
|
||||
|
||||
def _word_text(word: dict[str, Any]) -> str:
|
||||
text = str(word.get("text") or word.get("word") or "")
|
||||
punctuation = str(word.get("punctuation") or "")
|
||||
if punctuation and not text.endswith(punctuation):
|
||||
text += punctuation
|
||||
return text
|
||||
|
||||
|
||||
def _flush_block(blocks: list[dict[str, Any]], current: dict[str, Any]) -> None:
|
||||
text = current.get("text", "").strip()
|
||||
if text:
|
||||
blocks.append(current.copy())
|
||||
|
||||
|
||||
def _blocks_from_words(sentence: dict[str, Any], max_chars: int, max_duration: float) -> list[dict[str, Any]]:
|
||||
words = sentence.get("words") or []
|
||||
blocks: list[dict[str, Any]] = []
|
||||
current: Optional[dict[str, Any]] = None
|
||||
max_duration_ms = max_duration * 1000
|
||||
sentence_speaker = sentence.get("speaker_id")
|
||||
|
||||
for word in words:
|
||||
text = _word_text(word)
|
||||
if not text:
|
||||
continue
|
||||
start = word.get("begin_time", word.get("start_time"))
|
||||
end = word.get("end_time")
|
||||
if start is None or end is None:
|
||||
continue
|
||||
speaker_id = word.get("speaker_id", sentence_speaker)
|
||||
start_ms = _timestamp_ms(start, "word.begin_time")
|
||||
end_ms = _timestamp_ms(end, "word.end_time")
|
||||
|
||||
if current is None:
|
||||
current = {"start": start_ms, "end": end_ms, "text": text, "speaker_id": speaker_id}
|
||||
else:
|
||||
should_split_before = (
|
||||
speaker_id != current.get("speaker_id")
|
||||
or len(current["text"] + text) > max_chars
|
||||
or (end_ms - current["start"]) > max_duration_ms
|
||||
)
|
||||
if should_split_before:
|
||||
_flush_block(blocks, current)
|
||||
current = {"start": start_ms, "end": end_ms, "text": text, "speaker_id": speaker_id}
|
||||
else:
|
||||
current["text"] += text
|
||||
current["end"] = end_ms
|
||||
|
||||
if current and text[-1:] in PUNCTUATION_BREAKS:
|
||||
_flush_block(blocks, current)
|
||||
current = None
|
||||
|
||||
if current:
|
||||
_flush_block(blocks, current)
|
||||
return blocks
|
||||
|
||||
|
||||
def _split_text(text: str, max_chars: int) -> list[str]:
|
||||
chunks: list[str] = []
|
||||
current = ""
|
||||
for char in text:
|
||||
current += char
|
||||
if char in PUNCTUATION_BREAKS or len(current) >= max_chars:
|
||||
chunks.append(current.strip())
|
||||
current = ""
|
||||
if current.strip():
|
||||
chunks.append(current.strip())
|
||||
return [chunk for chunk in chunks if chunk]
|
||||
|
||||
|
||||
def _blocks_from_sentence(sentence: dict[str, Any], max_chars: int) -> list[dict[str, Any]]:
|
||||
text = str(sentence.get("text") or "").strip()
|
||||
if not text:
|
||||
return []
|
||||
start = sentence.get("begin_time", 0)
|
||||
end = sentence.get("end_time")
|
||||
start_ms = _timestamp_ms(start, "sentence.begin_time")
|
||||
end_ms = _timestamp_ms(end, "sentence.end_time") if end is not None else start_ms + 500
|
||||
chunks = _split_text(text, max_chars)
|
||||
if not chunks:
|
||||
return []
|
||||
duration = max(500.0, end_ms - start_ms)
|
||||
total_chars = max(1, sum(len(chunk) for chunk in chunks))
|
||||
cursor = start_ms
|
||||
blocks: list[dict[str, Any]] = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
if i == len(chunks) - 1:
|
||||
chunk_end = end_ms
|
||||
else:
|
||||
chunk_end = cursor + duration * (len(chunk) / total_chars)
|
||||
blocks.append(
|
||||
{
|
||||
"start": cursor,
|
||||
"end": max(cursor + 200, chunk_end),
|
||||
"text": chunk,
|
||||
"speaker_id": sentence.get("speaker_id"),
|
||||
}
|
||||
)
|
||||
cursor = chunk_end
|
||||
return blocks
|
||||
|
||||
|
||||
def fun_asr_result_to_srt(result_json: dict[str, Any], max_chars: int = 20, max_duration: float = 3.5) -> str:
|
||||
"""Convert downloaded Fun-ASR JSON into fine-grained SRT.
|
||||
|
||||
Official downloaded schema is `transcripts[*].sentences[*].words[*]`.
|
||||
Fun-ASR timestamps are milliseconds.
|
||||
"""
|
||||
blocks: list[dict[str, Any]] = []
|
||||
for sentence in _iter_sentences(result_json):
|
||||
sentence_blocks = _blocks_from_words(sentence, max_chars, max_duration)
|
||||
if not sentence_blocks:
|
||||
sentence_blocks = _blocks_from_sentence(sentence, max_chars)
|
||||
blocks.extend(sentence_blocks)
|
||||
|
||||
if not blocks:
|
||||
raise FunAsrError("Fun-ASR 转写结果为空:未找到可用字幕内容")
|
||||
|
||||
lines = []
|
||||
for index, block in enumerate(blocks, start=1):
|
||||
text = f"{_speaker_prefix(block.get('speaker_id'))}{block['text']}"
|
||||
lines.append(_srt_block(index, block["start"], block["end"], text))
|
||||
return "\n".join(lines).rstrip() + "\n"
|
||||
|
||||
|
||||
def write_srt_file(srt_content: str, subtitle_file: str = "") -> str:
|
||||
if not subtitle_file:
|
||||
subtitle_file = os.path.join(utils.subtitle_dir(), f"fun_asr_{int(time.time())}.srt")
|
||||
parent = os.path.dirname(subtitle_file)
|
||||
if parent:
|
||||
os.makedirs(parent, exist_ok=True)
|
||||
with open(subtitle_file, "w", encoding="utf-8") as f:
|
||||
f.write(srt_content)
|
||||
return subtitle_file
|
||||
|
||||
|
||||
def create_with_fun_asr(
|
||||
local_file: str,
|
||||
subtitle_file: str = "",
|
||||
api_key: str = "",
|
||||
speaker_count: Optional[int] = None,
|
||||
poll_interval: float = 2.0,
|
||||
timeout: float = 600.0,
|
||||
session=requests,
|
||||
) -> Optional[str]:
|
||||
"""Upload local media to Bailian temporary storage and create a Fun-ASR SRT file."""
|
||||
api_key = _require_api_key(api_key)
|
||||
try:
|
||||
policy = request_upload_policy(api_key, session=session)
|
||||
oss_url = upload_to_temporary_oss(local_file, policy, session=session)
|
||||
task_id = submit_transcription_task(api_key, oss_url, speaker_count=speaker_count, session=session)
|
||||
task_result = poll_transcription_task(
|
||||
api_key,
|
||||
task_id,
|
||||
poll_interval=poll_interval,
|
||||
timeout=timeout,
|
||||
session=session,
|
||||
)
|
||||
transcription_url = task_result.get("transcription_url")
|
||||
result_json = download_transcription_result(transcription_url, session=session)
|
||||
srt_content = fun_asr_result_to_srt(result_json)
|
||||
output_file = write_srt_file(srt_content, subtitle_file)
|
||||
logger.info(f"Fun-ASR 字幕文件已生成: {output_file}")
|
||||
return output_file
|
||||
except FunAsrError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise FunAsrError("Fun-ASR 字幕转写失败,请检查文件、网络或阿里百炼服务状态") from exc
|
||||
403
app/services/test_fun_asr_subtitle_unittest.py
Normal file
403
app/services/test_fun_asr_subtitle_unittest.py
Normal file
@ -0,0 +1,403 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
import tomllib
|
||||
except ModuleNotFoundError: # Python < 3.11
|
||||
import tomli as tomllib
|
||||
|
||||
from app.config import config as cfg
|
||||
from app.services import fun_asr_subtitle as fasr
|
||||
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, payload=None, status_code=200):
|
||||
self.payload = payload or {}
|
||||
self.status_code = status_code
|
||||
|
||||
def json(self):
|
||||
return self.payload
|
||||
|
||||
def raise_for_status(self):
|
||||
if self.status_code >= 400:
|
||||
raise RuntimeError(f"HTTP {self.status_code}")
|
||||
|
||||
|
||||
class InvalidJsonResponse(FakeResponse):
|
||||
def json(self):
|
||||
raise ValueError("invalid json")
|
||||
|
||||
|
||||
class FakeSession:
|
||||
def __init__(self, local_result):
|
||||
self.calls = []
|
||||
self.local_result = local_result
|
||||
|
||||
def get(self, url, **kwargs):
|
||||
self.calls.append(("GET", url, kwargs))
|
||||
if url == fasr.UPLOAD_POLICY_URL:
|
||||
return FakeResponse(
|
||||
{
|
||||
"data": {
|
||||
"policy": "policy-token",
|
||||
"signature": "signature-token",
|
||||
"upload_dir": "dashscope-instant/test-dir",
|
||||
"upload_host": "https://dashscope-file-test.oss-cn-beijing.aliyuncs.com",
|
||||
"oss_access_key_id": "oss-ak",
|
||||
"x_oss_object_acl": "private",
|
||||
"x_oss_forbid_overwrite": "true",
|
||||
"max_file_size_mb": 1,
|
||||
}
|
||||
}
|
||||
)
|
||||
if url == "https://result.example/transcription.json":
|
||||
return FakeResponse(self.local_result)
|
||||
return FakeResponse({}, 404)
|
||||
|
||||
def post(self, url, **kwargs):
|
||||
self.calls.append(("POST", url, kwargs))
|
||||
if url == "https://dashscope-file-test.oss-cn-beijing.aliyuncs.com":
|
||||
return FakeResponse({})
|
||||
if url == fasr.TRANSCRIPTION_URL:
|
||||
return FakeResponse({"output": {"task_status": "PENDING", "task_id": "task-123"}})
|
||||
if url == fasr.TASK_URL_TEMPLATE.format(task_id="task-123"):
|
||||
return FakeResponse(
|
||||
{
|
||||
"output": {
|
||||
"task_status": "SUCCEEDED",
|
||||
"results": [
|
||||
{
|
||||
"file_url": "oss://dashscope-instant/test-dir/audio.wav",
|
||||
"transcription_url": "https://result.example/transcription.json",
|
||||
"subtask_status": "SUCCEEDED",
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
)
|
||||
return FakeResponse({}, 404)
|
||||
|
||||
|
||||
OFFICIAL_SHAPE_RESULT = {
|
||||
"transcripts": [
|
||||
{
|
||||
"sentences": [
|
||||
{
|
||||
"begin_time": 0,
|
||||
"end_time": 3600,
|
||||
"text": "你好欢迎观看今天的内容",
|
||||
"speaker_id": 0,
|
||||
"words": [
|
||||
{"begin_time": 0, "end_time": 400, "text": "你好", "punctuation": ","},
|
||||
{"begin_time": 400, "end_time": 900, "text": "欢迎", "punctuation": ""},
|
||||
{"begin_time": 900, "end_time": 1300, "text": "观看", "punctuation": ""},
|
||||
{"begin_time": 1300, "end_time": 1800, "text": "今天", "punctuation": ""},
|
||||
{"begin_time": 1800, "end_time": 2400, "text": "的内容", "punctuation": "。"},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class FunAsrSrtConversionTests(unittest.TestCase):
|
||||
def test_official_shape_words_convert_ms_and_speaker_label(self):
|
||||
srt = fasr.fun_asr_result_to_srt(OFFICIAL_SHAPE_RESULT, max_chars=20, max_duration=3.5)
|
||||
|
||||
self.assertIn("1\n00:00:00,000 --> 00:00:00,400\n说话人1: 你好,", srt)
|
||||
self.assertIn("2\n00:00:00,400 --> 00:00:02,400\n说话人1: 欢迎观看今天的内容。", srt)
|
||||
self.assertNotIn("00:06:40,000", srt, "milliseconds must not be treated as seconds")
|
||||
|
||||
def test_long_word_sequence_splits_into_fine_blocks(self):
|
||||
result = {
|
||||
"transcripts": [
|
||||
{
|
||||
"sentences": [
|
||||
{
|
||||
"begin_time": 0,
|
||||
"end_time": 6000,
|
||||
"speaker_id": 1,
|
||||
"words": [
|
||||
{"begin_time": i * 500, "end_time": (i + 1) * 500, "text": token, "punctuation": ""}
|
||||
for i, token in enumerate(["这是", "一个", "很长", "字幕", "需要", "拆分"])
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
srt = fasr.fun_asr_result_to_srt(result, max_chars=4, max_duration=10)
|
||||
|
||||
self.assertGreaterEqual(srt.count("\n说话人2:"), 3)
|
||||
self.assertIn("1\n00:00:00,000", srt)
|
||||
|
||||
def test_sentence_fallback_uses_ms_without_zero_duration(self):
|
||||
result = {
|
||||
"transcripts": [
|
||||
{
|
||||
"sentences": [
|
||||
{
|
||||
"begin_time": 1000,
|
||||
"end_time": 3000,
|
||||
"text": "没有词级时间戳也可以拆分。",
|
||||
"speaker_id": 0,
|
||||
"words": [],
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
srt = fasr.fun_asr_result_to_srt(result, max_chars=5)
|
||||
|
||||
self.assertIn("00:00:01,000", srt)
|
||||
self.assertIn("说话人1:", srt)
|
||||
self.assertNotIn("--> 00:00:01,000\n", srt)
|
||||
|
||||
def test_empty_result_raises_clear_error(self):
|
||||
with self.assertRaises(fasr.FunAsrError):
|
||||
fasr.fun_asr_result_to_srt({"transcripts": []})
|
||||
|
||||
def test_malformed_word_timestamp_raises_fun_asr_error(self):
|
||||
result = {
|
||||
"transcripts": [
|
||||
{
|
||||
"sentences": [
|
||||
{
|
||||
"begin_time": 0,
|
||||
"end_time": 1000,
|
||||
"speaker_id": 0,
|
||||
"words": [
|
||||
{"begin_time": "bad", "end_time": 500, "text": "坏时间", "punctuation": ""}
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with self.assertRaises(fasr.FunAsrError):
|
||||
fasr.fun_asr_result_to_srt(result)
|
||||
|
||||
def test_malformed_sentence_timestamp_raises_fun_asr_error(self):
|
||||
result = {
|
||||
"transcripts": [
|
||||
{
|
||||
"sentences": [
|
||||
{
|
||||
"begin_time": "bad",
|
||||
"end_time": 1000,
|
||||
"text": "坏时间",
|
||||
"speaker_id": 0,
|
||||
"words": [],
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with self.assertRaises(fasr.FunAsrError):
|
||||
fasr.fun_asr_result_to_srt(result)
|
||||
|
||||
|
||||
class FunAsrServiceTests(unittest.TestCase):
|
||||
def test_create_with_fun_asr_uses_expected_rest_flow(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
local_file = Path(tmp_dir) / "audio.wav"
|
||||
local_file.write_bytes(b"audio")
|
||||
subtitle_file = Path(tmp_dir) / "out.srt"
|
||||
session = FakeSession(OFFICIAL_SHAPE_RESULT)
|
||||
|
||||
result_path = fasr.create_with_fun_asr(
|
||||
str(local_file),
|
||||
subtitle_file=str(subtitle_file),
|
||||
api_key="sk-test",
|
||||
speaker_count=2,
|
||||
poll_interval=0,
|
||||
session=session,
|
||||
)
|
||||
|
||||
self.assertEqual(str(subtitle_file), result_path)
|
||||
self.assertTrue(subtitle_file.exists())
|
||||
self.assertIn("说话人1:", subtitle_file.read_text(encoding="utf-8"))
|
||||
|
||||
policy_call = session.calls[0]
|
||||
self.assertEqual("GET", policy_call[0])
|
||||
self.assertEqual(fasr.UPLOAD_POLICY_URL, policy_call[1])
|
||||
self.assertEqual({"action": "getPolicy", "model": "fun-asr"}, policy_call[2]["params"])
|
||||
self.assertEqual("Bearer sk-test", policy_call[2]["headers"]["Authorization"])
|
||||
|
||||
upload_call = session.calls[1]
|
||||
self.assertEqual("POST", upload_call[0])
|
||||
self.assertEqual("https://dashscope-file-test.oss-cn-beijing.aliyuncs.com", upload_call[1])
|
||||
upload_data = upload_call[2]["data"]
|
||||
self.assertEqual("oss-ak", upload_data["OSSAccessKeyId"])
|
||||
self.assertEqual("policy-token", upload_data["policy"])
|
||||
self.assertEqual("signature-token", upload_data["Signature"])
|
||||
self.assertEqual("dashscope-instant/test-dir/audio.wav", upload_data["key"])
|
||||
self.assertEqual("200", upload_data["success_action_status"])
|
||||
|
||||
submit_call = session.calls[2]
|
||||
self.assertEqual(fasr.TRANSCRIPTION_URL, submit_call[1])
|
||||
headers = submit_call[2]["headers"]
|
||||
self.assertEqual("enable", headers["X-DashScope-Async"])
|
||||
self.assertEqual("enable", headers["X-DashScope-OssResourceResolve"])
|
||||
payload = submit_call[2]["json"]
|
||||
self.assertEqual("fun-asr", payload["model"])
|
||||
self.assertEqual(["oss://dashscope-instant/test-dir/audio.wav"], payload["input"]["file_urls"])
|
||||
self.assertTrue(payload["parameters"]["diarization_enabled"])
|
||||
self.assertEqual(2, payload["parameters"]["speaker_count"])
|
||||
|
||||
poll_call = session.calls[3]
|
||||
self.assertEqual("POST", poll_call[0])
|
||||
self.assertTrue(poll_call[1].endswith("/api/v1/tasks/task-123"))
|
||||
|
||||
download_call = session.calls[4]
|
||||
self.assertEqual(("GET", "https://result.example/transcription.json"), download_call[:2])
|
||||
|
||||
def test_upload_policy_size_validation_fails_before_upload(self):
|
||||
policy = fasr.UploadPolicy(
|
||||
upload_host="https://upload.example",
|
||||
upload_dir="dashscope-instant/test",
|
||||
policy="p",
|
||||
signature="s",
|
||||
oss_access_key_id="ak",
|
||||
max_file_size_mb=0.000001,
|
||||
)
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
f.write(b"too-large")
|
||||
f.flush()
|
||||
with self.assertRaises(fasr.FunAsrError):
|
||||
fasr.upload_to_temporary_oss(f.name, policy, session=FakeSession({}))
|
||||
|
||||
def test_failed_subtask_raises(self):
|
||||
class FailedSession(FakeSession):
|
||||
def post(self, url, **kwargs):
|
||||
if url == fasr.TASK_URL_TEMPLATE.format(task_id="task-123"):
|
||||
return FakeResponse(
|
||||
{
|
||||
"output": {
|
||||
"task_status": "SUCCEEDED",
|
||||
"results": [{"subtask_status": "FAILED"}],
|
||||
}
|
||||
}
|
||||
)
|
||||
return super().post(url, **kwargs)
|
||||
|
||||
with self.assertRaises(fasr.FunAsrError):
|
||||
fasr.poll_transcription_task("sk-test", "task-123", poll_interval=0, session=FailedSession({}))
|
||||
|
||||
def test_missing_api_key_raises_before_request(self):
|
||||
session = FakeSession(OFFICIAL_SHAPE_RESULT)
|
||||
|
||||
with self.assertRaises(fasr.FunAsrError):
|
||||
fasr.request_upload_policy("", session=session)
|
||||
|
||||
self.assertEqual([], session.calls)
|
||||
|
||||
def test_upload_policy_http_error_raises(self):
|
||||
class PolicyHttpErrorSession(FakeSession):
|
||||
def get(self, url, **kwargs):
|
||||
self.calls.append(("GET", url, kwargs))
|
||||
return FakeResponse({}, status_code=403)
|
||||
|
||||
with self.assertRaises(fasr.FunAsrError):
|
||||
fasr.request_upload_policy("sk-test", session=PolicyHttpErrorSession({}))
|
||||
|
||||
def test_malformed_upload_policy_raises(self):
|
||||
class MalformedPolicySession(FakeSession):
|
||||
def get(self, url, **kwargs):
|
||||
self.calls.append(("GET", url, kwargs))
|
||||
return FakeResponse({"data": {"policy": "missing-required-fields"}})
|
||||
|
||||
with self.assertRaises(fasr.FunAsrError):
|
||||
fasr.request_upload_policy("sk-test", session=MalformedPolicySession({}))
|
||||
|
||||
def test_upload_http_failure_raises(self):
|
||||
class UploadFailureSession(FakeSession):
|
||||
def post(self, url, **kwargs):
|
||||
self.calls.append(("POST", url, kwargs))
|
||||
return FakeResponse({}, status_code=500)
|
||||
|
||||
policy = fasr.UploadPolicy(
|
||||
upload_host="https://upload.example",
|
||||
upload_dir="dashscope-instant/test",
|
||||
policy="p",
|
||||
signature="s",
|
||||
oss_access_key_id="ak",
|
||||
max_file_size_mb=1,
|
||||
)
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
f.write(b"audio")
|
||||
f.flush()
|
||||
with self.assertRaises(fasr.FunAsrError):
|
||||
fasr.upload_to_temporary_oss(f.name, policy, session=UploadFailureSession({}))
|
||||
|
||||
def test_submit_failure_raises(self):
|
||||
class SubmitFailureSession(FakeSession):
|
||||
def post(self, url, **kwargs):
|
||||
self.calls.append(("POST", url, kwargs))
|
||||
return FakeResponse({}, status_code=500)
|
||||
|
||||
with self.assertRaises(fasr.FunAsrError):
|
||||
fasr.submit_transcription_task("sk-test", "oss://file", session=SubmitFailureSession({}))
|
||||
|
||||
def test_poll_timeout_raises(self):
|
||||
class PendingSession(FakeSession):
|
||||
def post(self, url, **kwargs):
|
||||
self.calls.append(("POST", url, kwargs))
|
||||
return FakeResponse({"output": {"task_status": "RUNNING"}})
|
||||
|
||||
with self.assertRaises(fasr.FunAsrError):
|
||||
fasr.poll_transcription_task("sk-test", "task-123", poll_interval=0, timeout=-1, session=PendingSession({}))
|
||||
|
||||
def test_task_failed_status_raises(self):
|
||||
class FailedTaskSession(FakeSession):
|
||||
def post(self, url, **kwargs):
|
||||
self.calls.append(("POST", url, kwargs))
|
||||
return FakeResponse({"output": {"task_status": "FAILED"}})
|
||||
|
||||
with self.assertRaises(fasr.FunAsrError):
|
||||
fasr.poll_transcription_task("sk-test", "task-123", poll_interval=0, session=FailedTaskSession({}))
|
||||
|
||||
def test_missing_transcription_url_raises(self):
|
||||
with self.assertRaises(fasr.FunAsrError):
|
||||
fasr.download_transcription_result("", session=FakeSession({}))
|
||||
|
||||
def test_malformed_downloaded_json_raises(self):
|
||||
class MalformedDownloadSession(FakeSession):
|
||||
def get(self, url, **kwargs):
|
||||
self.calls.append(("GET", url, kwargs))
|
||||
return InvalidJsonResponse()
|
||||
|
||||
with self.assertRaises(fasr.FunAsrError):
|
||||
fasr.download_transcription_result("https://result.example/bad.json", session=MalformedDownloadSession({}))
|
||||
|
||||
|
||||
class FunAsrConfigTests(unittest.TestCase):
|
||||
def test_save_config_persists_fun_asr_section(self):
|
||||
original_config_file = cfg.config_file
|
||||
original_fun_asr = cfg.fun_asr
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config_path = Path(tmp_dir) / "config.toml"
|
||||
cfg.config_file = str(config_path)
|
||||
cfg.fun_asr = {"api_key": "sk-local", "model": "fun-asr"}
|
||||
cfg.save_config()
|
||||
saved = tomllib.loads(config_path.read_text(encoding="utf-8"))
|
||||
finally:
|
||||
cfg.config_file = original_config_file
|
||||
cfg.fun_asr = original_fun_asr
|
||||
|
||||
self.assertEqual("sk-local", saved["fun_asr"]["api_key"])
|
||||
self.assertEqual("fun-asr", saved["fun_asr"]["model"])
|
||||
|
||||
def test_config_example_fun_asr_section_parses(self):
|
||||
config_data = tomllib.loads(Path("config.example.toml").read_text(encoding="utf-8"))
|
||||
self.assertEqual("fun-asr", config_data["fun_asr"]["model"])
|
||||
self.assertIn("api_key", config_data["fun_asr"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -93,6 +93,12 @@
|
||||
# 访问 https://bailian.console.aliyun.com/?tab=model#/api-key 获取你的 API 密钥
|
||||
api_key = ""
|
||||
model_name = "qwen3-tts-flash"
|
||||
|
||||
[fun_asr]
|
||||
# 阿里百炼 Fun-ASR 字幕转录配置
|
||||
# 访问 https://bailian.console.aliyun.com/?tab=model#/api-key 获取你的 API 密钥
|
||||
api_key = ""
|
||||
model = "fun-asr"
|
||||
|
||||
[indextts2]
|
||||
# IndexTTS2 语音克隆配置
|
||||
|
||||
@ -327,6 +327,8 @@ def short_drama_summary(tr):
|
||||
# 检查是否已经处理过字幕文件
|
||||
if 'subtitle_file_processed' not in st.session_state:
|
||||
st.session_state['subtitle_file_processed'] = False
|
||||
|
||||
render_fun_asr_transcription(tr)
|
||||
|
||||
subtitle_file = st.file_uploader(
|
||||
tr("上传字幕文件"),
|
||||
@ -401,6 +403,95 @@ def short_drama_summary(tr):
|
||||
return video_theme
|
||||
|
||||
|
||||
def render_fun_asr_transcription(tr):
|
||||
"""使用阿里百炼 Fun-ASR 从本地音视频转写生成字幕。"""
|
||||
def clear_fun_asr_subtitle_state():
|
||||
st.session_state['subtitle_path'] = None
|
||||
st.session_state['subtitle_content'] = None
|
||||
st.session_state['subtitle_file_processed'] = False
|
||||
|
||||
with st.expander("阿里百炼 Fun-ASR 字幕转录", expanded=False):
|
||||
st.caption("上传本地音频/视频后,将自动上传到阿里百炼临时存储并通过 fun-asr 生成 SRT 字幕。")
|
||||
st.markdown(
|
||||
"API Key 获取地址:"
|
||||
"[https://bailian.console.aliyun.com/?tab=model#/api-key]"
|
||||
"(https://bailian.console.aliyun.com/?tab=model#/api-key)"
|
||||
)
|
||||
|
||||
api_key = st.text_input(
|
||||
"阿里百炼 API Key",
|
||||
value=config.fun_asr.get("api_key", ""),
|
||||
type="password",
|
||||
help="请输入你自己的阿里百炼 API Key;保存配置后会写入本地 config.toml",
|
||||
key="fun_asr_api_key",
|
||||
)
|
||||
uploaded_media = st.file_uploader(
|
||||
"上传需要转录的音频/视频",
|
||||
type=[
|
||||
"aac", "amr", "avi", "flac", "flv", "m4a", "mkv", "mov",
|
||||
"mp3", "mp4", "mpeg", "ogg", "opus", "wav", "webm", "wma", "wmv",
|
||||
],
|
||||
accept_multiple_files=False,
|
||||
key="fun_asr_media_uploader",
|
||||
)
|
||||
|
||||
if st.button("转写生成字幕", key="fun_asr_transcribe"):
|
||||
if not api_key.strip():
|
||||
clear_fun_asr_subtitle_state()
|
||||
st.error("请先输入阿里百炼 API Key")
|
||||
return
|
||||
if uploaded_media is None:
|
||||
clear_fun_asr_subtitle_state()
|
||||
st.error("请先上传需要转录的音频或视频文件")
|
||||
return
|
||||
|
||||
try:
|
||||
clear_fun_asr_subtitle_state()
|
||||
from app.services import fun_asr_subtitle
|
||||
|
||||
config.fun_asr["api_key"] = api_key.strip()
|
||||
config.fun_asr["model"] = "fun-asr"
|
||||
config.save_config()
|
||||
|
||||
temp_dir = utils.temp_dir("fun_asr")
|
||||
safe_filename = os.path.basename(uploaded_media.name)
|
||||
media_path = os.path.join(temp_dir, safe_filename)
|
||||
file_name, file_extension = os.path.splitext(safe_filename)
|
||||
if os.path.exists(media_path):
|
||||
timestamp = time.strftime("%Y%m%d%H%M%S")
|
||||
media_path = os.path.join(temp_dir, f"{file_name}_{timestamp}{file_extension}")
|
||||
|
||||
with open(media_path, "wb") as f:
|
||||
f.write(uploaded_media.getbuffer())
|
||||
|
||||
subtitle_name = f"{os.path.splitext(os.path.basename(media_path))[0]}_fun_asr.srt"
|
||||
subtitle_path = os.path.join(utils.subtitle_dir(), subtitle_name)
|
||||
|
||||
with st.spinner("正在使用阿里百炼 Fun-ASR 转写字幕,请稍候..."):
|
||||
generated_path = fun_asr_subtitle.create_with_fun_asr(
|
||||
local_file=media_path,
|
||||
subtitle_file=subtitle_path,
|
||||
api_key=api_key.strip(),
|
||||
)
|
||||
|
||||
if not generated_path or not os.path.exists(generated_path):
|
||||
clear_fun_asr_subtitle_state()
|
||||
st.error("Fun-ASR 转写失败:未生成字幕文件")
|
||||
return
|
||||
|
||||
with open(generated_path, "r", encoding="utf-8") as f:
|
||||
subtitle_content = f.read()
|
||||
|
||||
st.session_state['subtitle_path'] = generated_path
|
||||
st.session_state['subtitle_content'] = subtitle_content
|
||||
st.session_state['subtitle_file_processed'] = True
|
||||
st.success(f"字幕转写成功: {os.path.basename(generated_path)}")
|
||||
except Exception as e:
|
||||
clear_fun_asr_subtitle_state()
|
||||
logger.error(f"Fun-ASR 字幕转写失败: {traceback.format_exc()}")
|
||||
st.error(f"Fun-ASR 字幕转写失败: {str(e)}")
|
||||
|
||||
|
||||
def render_script_buttons(tr, params):
|
||||
"""渲染脚本操作按钮"""
|
||||
# 获取当前选择的脚本类型
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user