NarratoAI/app/services/fun_asr_subtitle.py
2026-06-30 06:15:18 +08:00

979 lines
34 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Fun-ASR subtitle transcription helpers.
The Bailian path intentionally uses the REST API because the official Fun-ASR
recorded-file API supports temporary `oss://` resources only through REST.
"""
from __future__ import annotations
import os
import shutil
import time
from dataclasses import dataclass
from typing import Any, Optional
from urllib.parse import urljoin, urlparse, urlunparse
import requests
from loguru import logger
from app.utils import utils
DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com"
UPLOAD_POLICY_URL = f"{DASHSCOPE_BASE_URL}/api/v1/uploads"
TRANSCRIPTION_URL = f"{DASHSCOPE_BASE_URL}/api/v1/services/audio/asr/transcription"
TASK_URL_TEMPLATE = f"{DASHSCOPE_BASE_URL}/api/v1/tasks/{{task_id}}"
MODEL_NAME = "fun-asr"
LOCAL_FUN_ASR_OPENAI_MODEL = "sensevoice"
LOCAL_FUN_ASR_API_URL = "http://127.0.0.1:7860"
LOCAL_FIRERED_ASR_API_URL = "http://127.0.0.1:7867"
TERMINAL_FAILED_STATUSES = {"FAILED", "CANCELED", "UNKNOWN"}
PUNCTUATION_BREAKS = set(",。!?;,.!?;")
class FunAsrError(RuntimeError):
"""Raised for user-actionable Fun-ASR transcription failures."""
@dataclass
class UploadPolicy:
upload_host: str
upload_dir: str
policy: str
signature: str
oss_access_key_id: str
x_oss_object_acl: str = "private"
x_oss_forbid_overwrite: str = "true"
max_file_size_mb: Optional[float] = None
def _auth_headers(api_key: str, extra: Optional[dict[str, str]] = None) -> dict[str, str]:
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
if extra:
headers.update(extra)
return headers
def _raise_for_http(response: requests.Response, action: str) -> None:
try:
response.raise_for_status()
except Exception as exc: # requests may be mocked with generic exceptions
raise FunAsrError(f"{action}失败,请检查阿里百炼 API Key、网络或服务状态") from exc
def _json(response: requests.Response, action: str) -> dict[str, Any]:
_raise_for_http(response, action)
try:
data = response.json()
except Exception as exc:
raise FunAsrError(f"{action}返回了无效 JSON") from exc
if not isinstance(data, dict):
raise FunAsrError(f"{action}返回格式无效")
return data
def _require_api_key(api_key: str) -> str:
api_key = (api_key or "").strip()
if not api_key:
raise FunAsrError("请先输入阿里百炼 API Key")
return api_key
def _safe_upload_name(local_file: str) -> str:
name = os.path.basename(local_file).strip() or f"audio_{int(time.time())}.wav"
return name.replace("/", "_").replace("\\", "_")
def _session_get(session, url: str, **kwargs):
return session.get(url, **kwargs)
def _session_post(session, url: str, **kwargs):
return session.post(url, **kwargs)
def _require_local_file(local_file: str) -> None:
if not os.path.isfile(local_file):
raise FunAsrError(f"待转写文件不存在: {local_file}")
def _normalize_local_api_url(api_url: str = "") -> str:
api_url = (api_url or LOCAL_FUN_ASR_API_URL).strip().rstrip("/")
if not api_url:
raise FunAsrError("请先填写本地 FunASR-Pack API 地址")
if "://" not in api_url:
api_url = f"http://{api_url}"
return api_url
def _local_base_url(api_url: str = "") -> str:
api_url = _normalize_local_api_url(api_url)
parsed = urlparse(api_url)
path = parsed.path.rstrip("/")
for suffix in ("/v1/audio/transcriptions", "/v1", "/asr"):
if path.endswith(suffix):
path = path[: -len(suffix)].rstrip("/")
break
return urlunparse(parsed._replace(path=path, params="", query="", fragment="")).rstrip("/")
def _local_asr_url(api_url: str = "") -> str:
api_url = _normalize_local_api_url(api_url)
path = urlparse(api_url).path.rstrip("/")
if path.endswith("/asr"):
return api_url
if path.endswith("/v1") or path.endswith("/v1/audio/transcriptions"):
return f"{_local_base_url(api_url)}/asr"
return f"{api_url}/asr"
def _local_openai_transcriptions_url(api_url: str = "") -> str:
api_url = _normalize_local_api_url(api_url)
path = urlparse(api_url).path.rstrip("/")
if path.endswith("/v1/audio/transcriptions"):
return api_url
if path.endswith("/v1"):
return f"{api_url}/audio/transcriptions"
return f"{_local_base_url(api_url)}/v1/audio/transcriptions"
def _local_fun_asr_prefers_openai(api_url: str = "") -> bool:
path = urlparse(_normalize_local_api_url(api_url)).path.rstrip("/")
return path.endswith("/v1") or path.endswith("/v1/audio/transcriptions")
def _is_not_found_response(response: requests.Response) -> bool:
return getattr(response, "status_code", 200) == 404
def _absolute_local_download_url(api_url: str, download_url: str) -> str:
download_url = (download_url or "").strip()
if not download_url:
return ""
if urlparse(download_url).scheme:
return download_url
return urljoin(f"{_local_base_url(api_url)}/", download_url)
def _raise_for_local_http(
response: requests.Response,
action: str,
service_name: str = "本地 FunASR-Pack 服务",
) -> None:
status_code = getattr(response, "status_code", 200)
if status_code and status_code >= 400:
detail = ""
try:
data = response.json()
if isinstance(data, dict):
detail = str(data.get("detail") or "")
except Exception:
detail = ""
suffix = f": {detail}" if detail else ""
raise FunAsrError(f"{action}失败{suffix},请确认{service_name}可用")
try:
response.raise_for_status()
except Exception as exc:
raise FunAsrError(f"{action}失败,请确认{service_name}可用") from exc
def _local_json(
response: requests.Response,
action: str,
service_name: str = "本地 FunASR-Pack 服务",
) -> dict[str, Any]:
_raise_for_local_http(response, action, service_name=service_name)
try:
data = response.json()
except Exception as exc:
raise FunAsrError(f"{action}返回了无效 JSON") from exc
if not isinstance(data, dict):
raise FunAsrError(f"{action}返回格式无效")
return data
def _response_text(response: requests.Response) -> str:
text = getattr(response, "text", None)
if isinstance(text, str):
return text
content = getattr(response, "content", b"")
if isinstance(content, bytes):
return content.decode("utf-8")
return str(content)
def request_upload_policy(api_key: str, model: str = MODEL_NAME, session=requests) -> UploadPolicy:
"""Request Bailian temporary-storage upload policy for the target model."""
api_key = _require_api_key(api_key)
response = _session_get(
session,
UPLOAD_POLICY_URL,
params={"action": "getPolicy", "model": model},
headers=_auth_headers(api_key),
timeout=30,
)
data = _json(response, "获取临时存储上传凭证")
policy_data = data.get("data") or {}
required = ["upload_host", "upload_dir", "policy", "signature", "oss_access_key_id"]
missing = [field for field in required if not policy_data.get(field)]
if missing:
raise FunAsrError(f"临时存储上传凭证缺少字段: {', '.join(missing)}")
return UploadPolicy(
upload_host=str(policy_data["upload_host"]),
upload_dir=str(policy_data["upload_dir"]).rstrip("/"),
policy=str(policy_data["policy"]),
signature=str(policy_data["signature"]),
oss_access_key_id=str(policy_data["oss_access_key_id"]),
x_oss_object_acl=str(policy_data.get("x_oss_object_acl") or "private"),
x_oss_forbid_overwrite=str(policy_data.get("x_oss_forbid_overwrite") or "true"),
max_file_size_mb=policy_data.get("max_file_size_mb"),
)
def _validate_file_size(local_file: str, policy: UploadPolicy) -> None:
if policy.max_file_size_mb is None:
return
max_bytes = float(policy.max_file_size_mb) * 1024 * 1024
size = os.path.getsize(local_file)
if size > max_bytes:
raise FunAsrError(
f"文件大小超过阿里百炼临时存储限制: {size / 1024 / 1024:.2f}MB > {float(policy.max_file_size_mb):.2f}MB"
)
def upload_to_temporary_oss(local_file: str, policy: UploadPolicy, session=requests) -> str:
"""Upload local file to temporary OSS and return `oss://...` URL."""
if not os.path.isfile(local_file):
raise FunAsrError(f"待转写文件不存在: {local_file}")
_validate_file_size(local_file, policy)
key = f"{policy.upload_dir}/{_safe_upload_name(local_file)}"
data = {
"OSSAccessKeyId": policy.oss_access_key_id,
"policy": policy.policy,
"Signature": policy.signature,
"key": key,
"x-oss-object-acl": policy.x_oss_object_acl,
"x-oss-forbid-overwrite": policy.x_oss_forbid_overwrite,
"success_action_status": "200",
}
with open(local_file, "rb") as file_obj:
files = {"file": (_safe_upload_name(local_file), file_obj)}
response = _session_post(session, policy.upload_host, data=data, files=files, timeout=120)
_raise_for_http(response, "上传文件到阿里百炼临时存储")
return f"oss://{key}"
def submit_transcription_task(
api_key: str,
oss_url: str,
speaker_count: Optional[int] = None,
model: str = MODEL_NAME,
session=requests,
) -> str:
"""Submit async Fun-ASR task and return task_id."""
api_key = _require_api_key(api_key)
parameters: dict[str, Any] = {"diarization_enabled": True}
if speaker_count:
parameters["speaker_count"] = int(speaker_count)
payload = {
"model": model,
"input": {"file_urls": [oss_url]},
"parameters": parameters,
}
response = _session_post(
session,
TRANSCRIPTION_URL,
headers=_auth_headers(
api_key,
{
"X-DashScope-Async": "enable",
"X-DashScope-OssResourceResolve": "enable",
},
),
json=payload,
timeout=30,
)
data = _json(response, "提交 Fun-ASR 转写任务")
task_id = ((data.get("output") or {}).get("task_id") or "").strip()
if not task_id:
raise FunAsrError("提交 Fun-ASR 转写任务失败:未返回 task_id")
return task_id
def poll_transcription_task(
api_key: str,
task_id: str,
poll_interval: float = 2.0,
timeout: float = 600.0,
session=requests,
) -> dict[str, Any]:
"""Poll task until terminal status and return successful result item."""
api_key = _require_api_key(api_key)
deadline = time.time() + timeout
last_status = "PENDING"
while time.time() < deadline:
response = _session_post(
session,
TASK_URL_TEMPLATE.format(task_id=task_id),
headers=_auth_headers(api_key),
timeout=30,
)
data = _json(response, "查询 Fun-ASR 转写任务")
output = data.get("output") or {}
last_status = str(output.get("task_status") or "").upper()
if last_status == "SUCCEEDED":
results = output.get("results") or []
for result in results:
subtask_status = str(result.get("subtask_status") or "").upper()
if subtask_status and subtask_status != "SUCCEEDED":
raise FunAsrError(f"Fun-ASR 子任务失败: {subtask_status}")
if not results:
raise FunAsrError("Fun-ASR 转写成功但未返回结果")
return results[0]
if last_status in TERMINAL_FAILED_STATUSES:
raise FunAsrError(f"Fun-ASR 转写任务失败: {last_status}")
time.sleep(poll_interval)
raise FunAsrError(f"Fun-ASR 转写任务超时,最后状态: {last_status}")
def download_transcription_result(transcription_url: str, session=requests) -> dict[str, Any]:
if not transcription_url:
raise FunAsrError("Fun-ASR 结果缺少 transcription_url")
response = _session_get(session, transcription_url, timeout=60)
return _json(response, "下载 Fun-ASR 转写结果")
def _ms_to_srt_time(ms: float) -> str:
total_ms = max(0, int(round(float(ms))))
hours = total_ms // 3_600_000
total_ms %= 3_600_000
minutes = total_ms // 60_000
total_ms %= 60_000
seconds = total_ms // 1_000
milliseconds = total_ms % 1_000
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
def _srt_block(index: int, start_ms: float, end_ms: float, text: str) -> str:
if end_ms <= start_ms:
end_ms = start_ms + 500
return f"{index}\n{_ms_to_srt_time(start_ms)} --> {_ms_to_srt_time(end_ms)}\n{text.strip()}\n"
def _timestamp_ms(value: Any, field_name: str) -> float:
try:
return float(value)
except (TypeError, ValueError) as exc:
raise FunAsrError(f"Fun-ASR 转写结果时间戳无效: {field_name}={value!r}") from exc
def _speaker_prefix(speaker_id: Any) -> str:
if speaker_id is None or speaker_id == "":
return ""
try:
label = int(speaker_id) + 1
except (TypeError, ValueError):
label = str(speaker_id)
return f"说话人{label}: "
def _iter_sentences(result_json: dict[str, Any]):
transcripts = result_json.get("transcripts")
if transcripts is None and "sentences" in result_json:
transcripts = [{"sentences": result_json.get("sentences") or []}]
if not transcripts:
raise FunAsrError("Fun-ASR 转写结果为空:未找到 transcripts")
for transcript in transcripts:
for sentence in transcript.get("sentences") or []:
yield sentence
def _word_text(word: dict[str, Any]) -> str:
text = str(word.get("text") or word.get("word") or "")
punctuation = str(word.get("punctuation") or "")
if punctuation and not text.endswith(punctuation):
text += punctuation
return text
def _flush_block(blocks: list[dict[str, Any]], current: dict[str, Any]) -> None:
text = current.get("text", "").strip()
if text:
blocks.append(current.copy())
def _blocks_from_words(sentence: dict[str, Any], max_chars: int, max_duration: float) -> list[dict[str, Any]]:
words = sentence.get("words") or []
blocks: list[dict[str, Any]] = []
current: Optional[dict[str, Any]] = None
max_duration_ms = max_duration * 1000
sentence_speaker = sentence.get("speaker_id")
for word in words:
text = _word_text(word)
if not text:
continue
start = word.get("begin_time", word.get("start_time"))
end = word.get("end_time")
if start is None or end is None:
continue
speaker_id = word.get("speaker_id", sentence_speaker)
start_ms = _timestamp_ms(start, "word.begin_time")
end_ms = _timestamp_ms(end, "word.end_time")
if current is None:
current = {"start": start_ms, "end": end_ms, "text": text, "speaker_id": speaker_id}
else:
should_split_before = (
speaker_id != current.get("speaker_id")
or len(current["text"] + text) > max_chars
or (end_ms - current["start"]) > max_duration_ms
)
if should_split_before:
_flush_block(blocks, current)
current = {"start": start_ms, "end": end_ms, "text": text, "speaker_id": speaker_id}
else:
current["text"] += text
current["end"] = end_ms
if current and text[-1:] in PUNCTUATION_BREAKS:
_flush_block(blocks, current)
current = None
if current:
_flush_block(blocks, current)
return blocks
def _split_text(text: str, max_chars: int) -> list[str]:
chunks: list[str] = []
current = ""
for char in text:
current += char
if char in PUNCTUATION_BREAKS or len(current) >= max_chars:
chunks.append(current.strip())
current = ""
if current.strip():
chunks.append(current.strip())
return [chunk for chunk in chunks if chunk]
def _blocks_from_sentence(sentence: dict[str, Any], max_chars: int) -> list[dict[str, Any]]:
text = str(sentence.get("text") or "").strip()
if not text:
return []
start = sentence.get("begin_time", 0)
end = sentence.get("end_time")
start_ms = _timestamp_ms(start, "sentence.begin_time")
end_ms = _timestamp_ms(end, "sentence.end_time") if end is not None else start_ms + 500
chunks = _split_text(text, max_chars)
if not chunks:
return []
duration = max(500.0, end_ms - start_ms)
total_chars = max(1, sum(len(chunk) for chunk in chunks))
cursor = start_ms
blocks: list[dict[str, Any]] = []
for i, chunk in enumerate(chunks):
if i == len(chunks) - 1:
chunk_end = end_ms
else:
chunk_end = cursor + duration * (len(chunk) / total_chars)
blocks.append(
{
"start": cursor,
"end": max(cursor + 200, chunk_end),
"text": chunk,
"speaker_id": sentence.get("speaker_id"),
}
)
cursor = chunk_end
return blocks
def fun_asr_result_to_srt(result_json: dict[str, Any], max_chars: int = 20, max_duration: float = 3.5) -> str:
"""Convert downloaded Fun-ASR JSON into fine-grained SRT.
Official downloaded schema is `transcripts[*].sentences[*].words[*]`.
Fun-ASR timestamps are milliseconds.
"""
blocks: list[dict[str, Any]] = []
for sentence in _iter_sentences(result_json):
sentence_blocks = _blocks_from_words(sentence, max_chars, max_duration)
if not sentence_blocks:
sentence_blocks = _blocks_from_sentence(sentence, max_chars)
blocks.extend(sentence_blocks)
if not blocks:
raise FunAsrError("Fun-ASR 转写结果为空:未找到可用字幕内容")
lines = []
for index, block in enumerate(blocks, start=1):
text = f"{_speaker_prefix(block.get('speaker_id'))}{block['text']}"
lines.append(_srt_block(index, block["start"], block["end"], text))
return "\n".join(lines).rstrip() + "\n"
def write_srt_file(srt_content: str, subtitle_file: str = "") -> str:
if not subtitle_file:
subtitle_file = os.path.join(utils.subtitle_dir(), f"fun_asr_{int(time.time())}.srt")
parent = os.path.dirname(subtitle_file)
if parent:
os.makedirs(parent, exist_ok=True)
with open(subtitle_file, "w", encoding="utf-8") as f:
f.write(srt_content)
return subtitle_file
def copy_srt_file(source_file: str, subtitle_file: str = "") -> str:
"""Copy an existing SRT file into NarratoAI's subtitle directory."""
if not os.path.isfile(source_file):
raise FunAsrError(f"本地 FunASR-Pack 返回的字幕文件不存在: {source_file}")
if not subtitle_file:
subtitle_file = os.path.join(utils.subtitle_dir(), f"fun_asr_local_{int(time.time())}.srt")
parent = os.path.dirname(subtitle_file)
if parent:
os.makedirs(parent, exist_ok=True)
if os.path.abspath(source_file) != os.path.abspath(subtitle_file):
shutil.copyfile(source_file, subtitle_file)
return subtitle_file
def request_local_fun_asr_health(api_url: str = LOCAL_FUN_ASR_API_URL, session=requests) -> dict[str, Any]:
"""Fetch FunASR-Pack health metadata from the local service."""
response = _session_get(session, f"{_local_base_url(api_url)}/health", timeout=10)
return _local_json(response, "检查本地 FunASR-Pack 服务")
def request_local_firered_asr_health(
api_url: str = LOCAL_FIRERED_ASR_API_URL,
session=requests,
) -> dict[str, Any]:
"""Fetch FireRedASR2-AED-Pack health metadata from the local service."""
response = _session_get(session, f"{_local_base_url(api_url)}/health", timeout=10)
return _local_json(
response,
"检查本地 FireRedASR2-AED-Pack 服务",
service_name="本地 FireRedASR2-AED-Pack 服务",
)
def request_local_fun_asr(
local_file: str,
api_url: str = LOCAL_FUN_ASR_API_URL,
hotword: str = "",
enable_spk: Optional[bool] = None,
model: str = LOCAL_FUN_ASR_OPENAI_MODEL,
timeout: float = 600.0,
session=requests,
) -> dict[str, Any]:
"""Call the local FunASR-Pack API and return its JSON result."""
_require_local_file(local_file)
rest_data: dict[str, str] = {}
if hotword.strip():
rest_data["hotword"] = hotword.strip()
if enable_spk is not None:
rest_data["enable_spk"] = "true" if enable_spk else "false"
openai_data: dict[str, str] = {
"model": (model or LOCAL_FUN_ASR_OPENAI_MODEL).strip() or LOCAL_FUN_ASR_OPENAI_MODEL,
"response_format": "verbose_json",
}
if enable_spk is not None:
openai_data["spk"] = "true" if enable_spk else "false"
rest_url = _local_asr_url(api_url)
openai_url = _local_openai_transcriptions_url(api_url)
attempts = [
(openai_url, openai_data),
(rest_url, rest_data),
] if _local_fun_asr_prefers_openai(api_url) else [
(rest_url, rest_data),
(openai_url, openai_data),
]
last_response = None
for index, (url, data) in enumerate(attempts):
with open(local_file, "rb") as file_obj:
files = {"file": (_safe_upload_name(local_file), file_obj)}
response = _session_post(
session,
url,
data=data,
files=files,
timeout=timeout,
)
if index == 0 and _is_not_found_response(response):
last_response = response
continue
return _local_json(response, "调用本地 FunASR-Pack ASR API")
return _local_json(last_response, "调用本地 FunASR-Pack ASR API")
def request_local_firered_asr(
local_file: str,
api_url: str = LOCAL_FIRERED_ASR_API_URL,
enable_vad: Optional[bool] = True,
enable_lid: Optional[bool] = True,
enable_punc: Optional[bool] = True,
return_timestamp: Optional[bool] = True,
timeout: float = 600.0,
session=requests,
) -> dict[str, Any]:
"""Call the local FireRedASR2-AED-Pack `/asr` API and return its JSON result."""
_require_local_file(local_file)
data: dict[str, str] = {}
options = {
"enable_vad": enable_vad,
"enable_lid": enable_lid,
"enable_punc": enable_punc,
"return_timestamp": return_timestamp,
}
for key, value in options.items():
if value is not None:
data[key] = "true" if value else "false"
with open(local_file, "rb") as file_obj:
files = {"file": (_safe_upload_name(local_file), file_obj)}
response = _session_post(
session,
_local_asr_url(api_url),
data=data,
files=files,
timeout=timeout,
)
return _local_json(
response,
"调用本地 FireRedASR2-AED-Pack ASR API",
service_name="本地 FireRedASR2-AED-Pack 服务",
)
def download_local_srt(
download_url: str,
api_url: str = LOCAL_FUN_ASR_API_URL,
subtitle_file: str = "",
session=requests,
service_name: str = "本地 FunASR-Pack 服务",
) -> str:
"""Download an SRT exposed by FunASR-Pack and save it as a NarratoAI subtitle."""
absolute_url = _absolute_local_download_url(api_url, download_url)
if not absolute_url:
raise FunAsrError("本地 FunASR-Pack 结果缺少 SRT 下载地址")
response = _session_get(session, absolute_url, timeout=60)
_raise_for_local_http(response, "下载本地 SRT", service_name=service_name)
srt_content = _response_text(response)
if not srt_content.strip():
raise FunAsrError(f"{service_name}返回了空 SRT")
return write_srt_file(srt_content, subtitle_file)
def _local_result_items(result_json: dict[str, Any]):
raw = result_json.get("raw")
if isinstance(raw, dict):
yield raw
elif isinstance(raw, list):
for item in raw:
if isinstance(item, dict):
yield item
elif result_json.get("text"):
yield result_json
def _openai_segment_ms(value: Any, field_name: str) -> float:
return _timestamp_ms(value, field_name) * 1000
def _blocks_from_openai_segments(result_json: dict[str, Any], max_chars: int) -> list[dict[str, Any]]:
segments = result_json.get("segments") or []
if not isinstance(segments, list):
return []
blocks: list[dict[str, Any]] = []
for segment in segments:
if not isinstance(segment, dict):
continue
text = str(segment.get("text") or "").strip()
if not text:
continue
start = segment.get("start", segment.get("start_time", 0))
end = segment.get("end", segment.get("end_time"))
start_ms = _openai_segment_ms(start, "openai.segment.start")
end_ms = _openai_segment_ms(end, "openai.segment.end") if end is not None else start_ms + 500
blocks.extend(
_blocks_from_sentence(
{
"begin_time": start_ms,
"end_time": end_ms,
"text": text,
"speaker_id": segment.get("speaker"),
},
max_chars=max_chars,
)
)
return blocks
def _blocks_from_local_timestamp(item: dict[str, Any], max_chars: int, max_duration: float) -> list[dict[str, Any]]:
text = str(item.get("text") or "").strip()
timestamps = item.get("timestamp") or []
if not text or not isinstance(timestamps, list):
return []
non_space_chars = [char for char in text if char.strip()]
consume_punctuation = len(timestamps) >= len(non_space_chars)
blocks: list[dict[str, Any]] = []
current: Optional[dict[str, Any]] = None
timestamp_index = 0
last_end = 0.0
max_duration_ms = max_duration * 1000
for char in text:
if not char.strip():
continue
is_punctuation = char in PUNCTUATION_BREAKS
consume_timestamp = consume_punctuation or not is_punctuation
if consume_timestamp and timestamp_index < len(timestamps):
pair = timestamps[timestamp_index]
timestamp_index += 1
if not isinstance(pair, (list, tuple)) or len(pair) < 2:
continue
start_ms = _timestamp_ms(pair[0], "local.timestamp.start")
end_ms = _timestamp_ms(pair[1], "local.timestamp.end")
last_end = end_ms
else:
start_ms = last_end
end_ms = last_end if is_punctuation else last_end + 200
last_end = end_ms
if current is None:
current = {"start": start_ms, "end": end_ms, "text": char}
else:
should_split_before = (
len(current["text"] + char) > max_chars
or (end_ms - current["start"]) > max_duration_ms
)
if should_split_before:
_flush_block(blocks, current)
current = {"start": start_ms, "end": end_ms, "text": char}
else:
current["text"] += char
current["end"] = end_ms
if current and is_punctuation:
_flush_block(blocks, current)
current = None
if current:
_flush_block(blocks, current)
return blocks
def local_fun_asr_result_to_srt(
result_json: dict[str, Any],
max_chars: int = 20,
max_duration: float = 3.5,
) -> str:
"""Convert a FunASR-Pack JSON response into SRT when the API SRT is unavailable."""
blocks = _blocks_from_openai_segments(result_json, max_chars=max_chars)
for item in _local_result_items(result_json):
item_blocks = _blocks_from_local_timestamp(item, max_chars, max_duration)
if not item_blocks:
text = str(item.get("text") or "").strip()
if text:
item_blocks = _blocks_from_sentence(
{
"begin_time": 0,
"end_time": max(1500, len(text) * 180),
"text": text,
},
max_chars=max_chars,
)
blocks.extend(item_blocks)
if not blocks:
raise FunAsrError("本地 FunASR-Pack 转写结果为空:未找到可用字幕内容")
lines = []
for index, block in enumerate(blocks, start=1):
lines.append(_srt_block(index, block["start"], block["end"], block["text"]))
return "\n".join(lines).rstrip() + "\n"
def firered_asr_result_to_srt(result_json: dict[str, Any]) -> str:
"""Convert a FireRedASR2-AED-Pack JSON response into SRT when no SRT URL is returned."""
blocks: list[dict[str, Any]] = []
sentences = result_json.get("sentences")
if isinstance(sentences, list):
for sentence in sentences:
if not isinstance(sentence, dict):
continue
text = str(sentence.get("text") or "").strip()
if not text:
continue
start = sentence.get("start_ms", sentence.get("begin_time", sentence.get("start_time", 0)))
end = sentence.get("end_ms", sentence.get("end_time"))
start_ms = _timestamp_ms(start, "firered.sentence.start_ms")
end_ms = _timestamp_ms(end, "firered.sentence.end_ms") if end is not None else start_ms + 500
blocks.append({"start": start_ms, "end": end_ms, "text": text})
if not blocks:
return local_fun_asr_result_to_srt(result_json)
lines = []
for index, block in enumerate(blocks, start=1):
lines.append(_srt_block(index, block["start"], block["end"], block["text"]))
return "\n".join(lines).rstrip() + "\n"
def _get_local_srt_download_url(result_json: dict[str, Any]) -> str:
downloads = result_json.get("downloads") or {}
if isinstance(downloads, dict):
download_url = downloads.get("srt")
if download_url:
return str(download_url)
for key in ("srt_url", "srt_download_url", "download_url"):
download_url = result_json.get(key)
if download_url:
return str(download_url)
return ""
def create_with_local_fun_asr(
local_file: str,
subtitle_file: str = "",
api_url: str = LOCAL_FUN_ASR_API_URL,
hotword: str = "",
enable_spk: Optional[bool] = None,
timeout: float = 600.0,
session=requests,
) -> Optional[str]:
"""Create an SRT file through a locally running FunASR-Pack API."""
try:
result_json = request_local_fun_asr(
local_file=local_file,
api_url=api_url,
hotword=hotword,
enable_spk=enable_spk,
timeout=timeout,
session=session,
)
srt_file = result_json.get("srt_file")
if isinstance(srt_file, str) and srt_file and os.path.isfile(srt_file):
output_file = copy_srt_file(srt_file, subtitle_file)
else:
download_url = _get_local_srt_download_url(result_json)
if download_url:
output_file = download_local_srt(
download_url,
api_url=api_url,
subtitle_file=subtitle_file,
session=session,
)
else:
srt_content = local_fun_asr_result_to_srt(result_json)
output_file = write_srt_file(srt_content, subtitle_file)
logger.info(f"本地 FunASR-Pack 字幕文件已生成: {output_file}")
return output_file
except FunAsrError:
raise
except Exception as exc:
raise FunAsrError("本地 FunASR-Pack 字幕转写失败,请检查服务地址、文件或模型状态") from exc
def create_with_local_firered_asr(
local_file: str,
subtitle_file: str = "",
api_url: str = LOCAL_FIRERED_ASR_API_URL,
enable_vad: Optional[bool] = True,
enable_lid: Optional[bool] = True,
enable_punc: Optional[bool] = True,
return_timestamp: Optional[bool] = True,
timeout: float = 600.0,
session=requests,
) -> Optional[str]:
"""Create an SRT file through a locally running FireRedASR2-AED-Pack API."""
service_name = "本地 FireRedASR2-AED-Pack 服务"
try:
result_json = request_local_firered_asr(
local_file=local_file,
api_url=api_url,
enable_vad=enable_vad,
enable_lid=enable_lid,
enable_punc=enable_punc,
return_timestamp=return_timestamp,
timeout=timeout,
session=session,
)
srt_file = result_json.get("srt_file")
if isinstance(srt_file, str) and srt_file and os.path.isfile(srt_file):
output_file = copy_srt_file(srt_file, subtitle_file)
else:
download_url = _get_local_srt_download_url(result_json)
if download_url:
output_file = download_local_srt(
download_url,
api_url=api_url,
subtitle_file=subtitle_file,
session=session,
service_name=service_name,
)
else:
srt_content = firered_asr_result_to_srt(result_json)
output_file = write_srt_file(srt_content, subtitle_file)
logger.info(f"本地 FireRedASR2-AED-Pack 字幕文件已生成: {output_file}")
return output_file
except FunAsrError:
raise
except Exception as exc:
raise FunAsrError("本地ASR字幕转写失败请检查 FireRedASR2-AED-Pack 服务地址、文件或模型状态") from exc
def create_with_fun_asr(
local_file: str,
subtitle_file: str = "",
api_key: str = "",
speaker_count: Optional[int] = None,
poll_interval: float = 2.0,
timeout: float = 600.0,
session=requests,
) -> Optional[str]:
"""Upload local media to Bailian temporary storage and create a Fun-ASR SRT file."""
api_key = _require_api_key(api_key)
try:
policy = request_upload_policy(api_key, session=session)
oss_url = upload_to_temporary_oss(local_file, policy, session=session)
task_id = submit_transcription_task(api_key, oss_url, speaker_count=speaker_count, session=session)
task_result = poll_transcription_task(
api_key,
task_id,
poll_interval=poll_interval,
timeout=timeout,
session=session,
)
transcription_url = task_result.get("transcription_url")
result_json = download_transcription_result(transcription_url, session=session)
srt_content = fun_asr_result_to_srt(result_json)
output_file = write_srt_file(srt_content, subtitle_file)
logger.info(f"Fun-ASR 字幕文件已生成: {output_file}")
return output_file
except FunAsrError:
raise
except Exception as exc:
raise FunAsrError("Fun-ASR 字幕转写失败,请检查文件、网络或阿里百炼服务状态") from exc