NarratoAI/app/services/fun_asr_subtitle.py
viccy 99dd4193ae feat(字幕): 新增阿里百炼 Fun-ASR 音视频字幕转录功能
- 在 WebUI 中增加 Fun-ASR 转录界面,支持上传多种音视频格式并生成 SRT 字幕
- 新增 `app/services/fun_asr_subtitle.py` 服务模块,实现完整的 REST API 调用流程,包括获取上传凭证、文件上传、提交任务、轮询结果和 SRT 格式转换
- 在配置文件中增加 `[fun_asr]` 配置段,支持保存 API Key
- 添加完整的单元测试,覆盖核心转换逻辑和服务流程
- 为兼容 Python 3.11 以下版本,将 `tomllib` 导入改为尝试导入并回退到 `tomli`
- 在 `defaults.py` 中添加 `from __future__ import annotations` 以支持类型注解
2026-04-27 18:15:54 +08:00

453 lines
16 KiB
Python

"""Aliyun Bailian Fun-ASR subtitle transcription helpers.
This module intentionally uses the REST API because the official Fun-ASR
recorded-file API supports temporary `oss://` resources only through REST.
"""
from __future__ import annotations
import os
import time
from dataclasses import dataclass
from typing import Any, Optional
import requests
from loguru import logger
from app.utils import utils
DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com"
UPLOAD_POLICY_URL = f"{DASHSCOPE_BASE_URL}/api/v1/uploads"
TRANSCRIPTION_URL = f"{DASHSCOPE_BASE_URL}/api/v1/services/audio/asr/transcription"
TASK_URL_TEMPLATE = f"{DASHSCOPE_BASE_URL}/api/v1/tasks/{{task_id}}"
MODEL_NAME = "fun-asr"
TERMINAL_FAILED_STATUSES = {"FAILED", "CANCELED", "UNKNOWN"}
PUNCTUATION_BREAKS = set(",。!?;,.!?;")
class FunAsrError(RuntimeError):
"""Raised for user-actionable Fun-ASR transcription failures."""
@dataclass
class UploadPolicy:
upload_host: str
upload_dir: str
policy: str
signature: str
oss_access_key_id: str
x_oss_object_acl: str = "private"
x_oss_forbid_overwrite: str = "true"
max_file_size_mb: Optional[float] = None
def _auth_headers(api_key: str, extra: Optional[dict[str, str]] = None) -> dict[str, str]:
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
if extra:
headers.update(extra)
return headers
def _raise_for_http(response: requests.Response, action: str) -> None:
try:
response.raise_for_status()
except Exception as exc: # requests may be mocked with generic exceptions
raise FunAsrError(f"{action}失败,请检查阿里百炼 API Key、网络或服务状态") from exc
def _json(response: requests.Response, action: str) -> dict[str, Any]:
_raise_for_http(response, action)
try:
data = response.json()
except Exception as exc:
raise FunAsrError(f"{action}返回了无效 JSON") from exc
if not isinstance(data, dict):
raise FunAsrError(f"{action}返回格式无效")
return data
def _require_api_key(api_key: str) -> str:
api_key = (api_key or "").strip()
if not api_key:
raise FunAsrError("请先输入阿里百炼 API Key")
return api_key
def _safe_upload_name(local_file: str) -> str:
name = os.path.basename(local_file).strip() or f"audio_{int(time.time())}.wav"
return name.replace("/", "_").replace("\\", "_")
def _session_get(session, url: str, **kwargs):
return session.get(url, **kwargs)
def _session_post(session, url: str, **kwargs):
return session.post(url, **kwargs)
def request_upload_policy(api_key: str, model: str = MODEL_NAME, session=requests) -> UploadPolicy:
"""Request Bailian temporary-storage upload policy for the target model."""
api_key = _require_api_key(api_key)
response = _session_get(
session,
UPLOAD_POLICY_URL,
params={"action": "getPolicy", "model": model},
headers=_auth_headers(api_key),
timeout=30,
)
data = _json(response, "获取临时存储上传凭证")
policy_data = data.get("data") or {}
required = ["upload_host", "upload_dir", "policy", "signature", "oss_access_key_id"]
missing = [field for field in required if not policy_data.get(field)]
if missing:
raise FunAsrError(f"临时存储上传凭证缺少字段: {', '.join(missing)}")
return UploadPolicy(
upload_host=str(policy_data["upload_host"]),
upload_dir=str(policy_data["upload_dir"]).rstrip("/"),
policy=str(policy_data["policy"]),
signature=str(policy_data["signature"]),
oss_access_key_id=str(policy_data["oss_access_key_id"]),
x_oss_object_acl=str(policy_data.get("x_oss_object_acl") or "private"),
x_oss_forbid_overwrite=str(policy_data.get("x_oss_forbid_overwrite") or "true"),
max_file_size_mb=policy_data.get("max_file_size_mb"),
)
def _validate_file_size(local_file: str, policy: UploadPolicy) -> None:
if policy.max_file_size_mb is None:
return
max_bytes = float(policy.max_file_size_mb) * 1024 * 1024
size = os.path.getsize(local_file)
if size > max_bytes:
raise FunAsrError(
f"文件大小超过阿里百炼临时存储限制: {size / 1024 / 1024:.2f}MB > {float(policy.max_file_size_mb):.2f}MB"
)
def upload_to_temporary_oss(local_file: str, policy: UploadPolicy, session=requests) -> str:
"""Upload local file to temporary OSS and return `oss://...` URL."""
if not os.path.isfile(local_file):
raise FunAsrError(f"待转写文件不存在: {local_file}")
_validate_file_size(local_file, policy)
key = f"{policy.upload_dir}/{_safe_upload_name(local_file)}"
data = {
"OSSAccessKeyId": policy.oss_access_key_id,
"policy": policy.policy,
"Signature": policy.signature,
"key": key,
"x-oss-object-acl": policy.x_oss_object_acl,
"x-oss-forbid-overwrite": policy.x_oss_forbid_overwrite,
"success_action_status": "200",
}
with open(local_file, "rb") as file_obj:
files = {"file": (_safe_upload_name(local_file), file_obj)}
response = _session_post(session, policy.upload_host, data=data, files=files, timeout=120)
_raise_for_http(response, "上传文件到阿里百炼临时存储")
return f"oss://{key}"
def submit_transcription_task(
api_key: str,
oss_url: str,
speaker_count: Optional[int] = None,
model: str = MODEL_NAME,
session=requests,
) -> str:
"""Submit async Fun-ASR task and return task_id."""
api_key = _require_api_key(api_key)
parameters: dict[str, Any] = {"diarization_enabled": True}
if speaker_count:
parameters["speaker_count"] = int(speaker_count)
payload = {
"model": model,
"input": {"file_urls": [oss_url]},
"parameters": parameters,
}
response = _session_post(
session,
TRANSCRIPTION_URL,
headers=_auth_headers(
api_key,
{
"X-DashScope-Async": "enable",
"X-DashScope-OssResourceResolve": "enable",
},
),
json=payload,
timeout=30,
)
data = _json(response, "提交 Fun-ASR 转写任务")
task_id = ((data.get("output") or {}).get("task_id") or "").strip()
if not task_id:
raise FunAsrError("提交 Fun-ASR 转写任务失败:未返回 task_id")
return task_id
def poll_transcription_task(
api_key: str,
task_id: str,
poll_interval: float = 2.0,
timeout: float = 600.0,
session=requests,
) -> dict[str, Any]:
"""Poll task until terminal status and return successful result item."""
api_key = _require_api_key(api_key)
deadline = time.time() + timeout
last_status = "PENDING"
while time.time() < deadline:
response = _session_post(
session,
TASK_URL_TEMPLATE.format(task_id=task_id),
headers=_auth_headers(api_key),
timeout=30,
)
data = _json(response, "查询 Fun-ASR 转写任务")
output = data.get("output") or {}
last_status = str(output.get("task_status") or "").upper()
if last_status == "SUCCEEDED":
results = output.get("results") or []
for result in results:
subtask_status = str(result.get("subtask_status") or "").upper()
if subtask_status and subtask_status != "SUCCEEDED":
raise FunAsrError(f"Fun-ASR 子任务失败: {subtask_status}")
if not results:
raise FunAsrError("Fun-ASR 转写成功但未返回结果")
return results[0]
if last_status in TERMINAL_FAILED_STATUSES:
raise FunAsrError(f"Fun-ASR 转写任务失败: {last_status}")
time.sleep(poll_interval)
raise FunAsrError(f"Fun-ASR 转写任务超时,最后状态: {last_status}")
def download_transcription_result(transcription_url: str, session=requests) -> dict[str, Any]:
if not transcription_url:
raise FunAsrError("Fun-ASR 结果缺少 transcription_url")
response = _session_get(session, transcription_url, timeout=60)
return _json(response, "下载 Fun-ASR 转写结果")
def _ms_to_srt_time(ms: float) -> str:
total_ms = max(0, int(round(float(ms))))
hours = total_ms // 3_600_000
total_ms %= 3_600_000
minutes = total_ms // 60_000
total_ms %= 60_000
seconds = total_ms // 1_000
milliseconds = total_ms % 1_000
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
def _srt_block(index: int, start_ms: float, end_ms: float, text: str) -> str:
if end_ms <= start_ms:
end_ms = start_ms + 500
return f"{index}\n{_ms_to_srt_time(start_ms)} --> {_ms_to_srt_time(end_ms)}\n{text.strip()}\n"
def _timestamp_ms(value: Any, field_name: str) -> float:
try:
return float(value)
except (TypeError, ValueError) as exc:
raise FunAsrError(f"Fun-ASR 转写结果时间戳无效: {field_name}={value!r}") from exc
def _speaker_prefix(speaker_id: Any) -> str:
if speaker_id is None or speaker_id == "":
return ""
try:
label = int(speaker_id) + 1
except (TypeError, ValueError):
label = str(speaker_id)
return f"说话人{label}: "
def _iter_sentences(result_json: dict[str, Any]):
transcripts = result_json.get("transcripts")
if transcripts is None and "sentences" in result_json:
transcripts = [{"sentences": result_json.get("sentences") or []}]
if not transcripts:
raise FunAsrError("Fun-ASR 转写结果为空:未找到 transcripts")
for transcript in transcripts:
for sentence in transcript.get("sentences") or []:
yield sentence
def _word_text(word: dict[str, Any]) -> str:
text = str(word.get("text") or word.get("word") or "")
punctuation = str(word.get("punctuation") or "")
if punctuation and not text.endswith(punctuation):
text += punctuation
return text
def _flush_block(blocks: list[dict[str, Any]], current: dict[str, Any]) -> None:
text = current.get("text", "").strip()
if text:
blocks.append(current.copy())
def _blocks_from_words(sentence: dict[str, Any], max_chars: int, max_duration: float) -> list[dict[str, Any]]:
words = sentence.get("words") or []
blocks: list[dict[str, Any]] = []
current: Optional[dict[str, Any]] = None
max_duration_ms = max_duration * 1000
sentence_speaker = sentence.get("speaker_id")
for word in words:
text = _word_text(word)
if not text:
continue
start = word.get("begin_time", word.get("start_time"))
end = word.get("end_time")
if start is None or end is None:
continue
speaker_id = word.get("speaker_id", sentence_speaker)
start_ms = _timestamp_ms(start, "word.begin_time")
end_ms = _timestamp_ms(end, "word.end_time")
if current is None:
current = {"start": start_ms, "end": end_ms, "text": text, "speaker_id": speaker_id}
else:
should_split_before = (
speaker_id != current.get("speaker_id")
or len(current["text"] + text) > max_chars
or (end_ms - current["start"]) > max_duration_ms
)
if should_split_before:
_flush_block(blocks, current)
current = {"start": start_ms, "end": end_ms, "text": text, "speaker_id": speaker_id}
else:
current["text"] += text
current["end"] = end_ms
if current and text[-1:] in PUNCTUATION_BREAKS:
_flush_block(blocks, current)
current = None
if current:
_flush_block(blocks, current)
return blocks
def _split_text(text: str, max_chars: int) -> list[str]:
chunks: list[str] = []
current = ""
for char in text:
current += char
if char in PUNCTUATION_BREAKS or len(current) >= max_chars:
chunks.append(current.strip())
current = ""
if current.strip():
chunks.append(current.strip())
return [chunk for chunk in chunks if chunk]
def _blocks_from_sentence(sentence: dict[str, Any], max_chars: int) -> list[dict[str, Any]]:
text = str(sentence.get("text") or "").strip()
if not text:
return []
start = sentence.get("begin_time", 0)
end = sentence.get("end_time")
start_ms = _timestamp_ms(start, "sentence.begin_time")
end_ms = _timestamp_ms(end, "sentence.end_time") if end is not None else start_ms + 500
chunks = _split_text(text, max_chars)
if not chunks:
return []
duration = max(500.0, end_ms - start_ms)
total_chars = max(1, sum(len(chunk) for chunk in chunks))
cursor = start_ms
blocks: list[dict[str, Any]] = []
for i, chunk in enumerate(chunks):
if i == len(chunks) - 1:
chunk_end = end_ms
else:
chunk_end = cursor + duration * (len(chunk) / total_chars)
blocks.append(
{
"start": cursor,
"end": max(cursor + 200, chunk_end),
"text": chunk,
"speaker_id": sentence.get("speaker_id"),
}
)
cursor = chunk_end
return blocks
def fun_asr_result_to_srt(result_json: dict[str, Any], max_chars: int = 20, max_duration: float = 3.5) -> str:
"""Convert downloaded Fun-ASR JSON into fine-grained SRT.
Official downloaded schema is `transcripts[*].sentences[*].words[*]`.
Fun-ASR timestamps are milliseconds.
"""
blocks: list[dict[str, Any]] = []
for sentence in _iter_sentences(result_json):
sentence_blocks = _blocks_from_words(sentence, max_chars, max_duration)
if not sentence_blocks:
sentence_blocks = _blocks_from_sentence(sentence, max_chars)
blocks.extend(sentence_blocks)
if not blocks:
raise FunAsrError("Fun-ASR 转写结果为空:未找到可用字幕内容")
lines = []
for index, block in enumerate(blocks, start=1):
text = f"{_speaker_prefix(block.get('speaker_id'))}{block['text']}"
lines.append(_srt_block(index, block["start"], block["end"], text))
return "\n".join(lines).rstrip() + "\n"
def write_srt_file(srt_content: str, subtitle_file: str = "") -> str:
if not subtitle_file:
subtitle_file = os.path.join(utils.subtitle_dir(), f"fun_asr_{int(time.time())}.srt")
parent = os.path.dirname(subtitle_file)
if parent:
os.makedirs(parent, exist_ok=True)
with open(subtitle_file, "w", encoding="utf-8") as f:
f.write(srt_content)
return subtitle_file
def create_with_fun_asr(
local_file: str,
subtitle_file: str = "",
api_key: str = "",
speaker_count: Optional[int] = None,
poll_interval: float = 2.0,
timeout: float = 600.0,
session=requests,
) -> Optional[str]:
"""Upload local media to Bailian temporary storage and create a Fun-ASR SRT file."""
api_key = _require_api_key(api_key)
try:
policy = request_upload_policy(api_key, session=session)
oss_url = upload_to_temporary_oss(local_file, policy, session=session)
task_id = submit_transcription_task(api_key, oss_url, speaker_count=speaker_count, session=session)
task_result = poll_transcription_task(
api_key,
task_id,
poll_interval=poll_interval,
timeout=timeout,
session=session,
)
transcription_url = task_result.get("transcription_url")
result_json = download_transcription_result(transcription_url, session=session)
srt_content = fun_asr_result_to_srt(result_json)
output_file = write_srt_file(srt_content, subtitle_file)
logger.info(f"Fun-ASR 字幕文件已生成: {output_file}")
return output_file
except FunAsrError:
raise
except Exception as exc:
raise FunAsrError("Fun-ASR 字幕转写失败,请检查文件、网络或阿里百炼服务状态") from exc