diff --git a/app/config/config.py b/app/config/config.py index ddc091c..e157f94 100644 --- a/app/config/config.py +++ b/app/config/config.py @@ -81,6 +81,7 @@ def save_config(): _cfg["soulvoice"] = soulvoice _cfg["ui"] = ui _cfg["tts_qwen"] = tts_qwen + _cfg["fun_asr"] = fun_asr _cfg["indextts2"] = indextts2 _cfg["doubaotts"] = doubaotts f.write(toml.dumps(_cfg)) @@ -96,6 +97,7 @@ soulvoice = _cfg.get("soulvoice", {}) ui = _cfg.get("ui", {}) frames = _cfg.get("frames", {}) tts_qwen = _cfg.get("tts_qwen", {}) +fun_asr = _cfg.get("fun_asr", {}) indextts2 = _cfg.get("indextts2", {}) doubaotts = _cfg.get("doubaotts", {}) diff --git a/app/config/defaults.py b/app/config/defaults.py index 859e121..9a686f2 100644 --- a/app/config/defaults.py +++ b/app/config/defaults.py @@ -1,5 +1,7 @@ """Shared config defaults used by both bootstrap and WebUI fallbacks.""" +from __future__ import annotations + DEFAULT_OPENAI_COMPATIBLE_BASE_URL = "https://api.siliconflow.cn/v1" DEFAULT_OPENAI_COMPATIBLE_PROVIDER = "openai" diff --git a/app/config/test_config_bootstrap_unittest.py b/app/config/test_config_bootstrap_unittest.py index c6844ed..720a934 100644 --- a/app/config/test_config_bootstrap_unittest.py +++ b/app/config/test_config_bootstrap_unittest.py @@ -2,7 +2,10 @@ import tempfile import unittest from pathlib import Path -import tomllib +try: + import tomllib +except ModuleNotFoundError: # Python < 3.11 + import tomli as tomllib from app.config import config as cfg from app.config.defaults import ( diff --git a/app/services/fun_asr_subtitle.py b/app/services/fun_asr_subtitle.py new file mode 100644 index 0000000..7af2637 --- /dev/null +++ b/app/services/fun_asr_subtitle.py @@ -0,0 +1,452 @@ +"""Aliyun Bailian Fun-ASR subtitle transcription helpers. + +This module intentionally uses the REST API because the official Fun-ASR +recorded-file API supports temporary `oss://` resources only through REST. +""" + +from __future__ import annotations + +import os +import time +from dataclasses import dataclass +from typing import Any, Optional + +import requests +from loguru import logger + +from app.utils import utils + +DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com" +UPLOAD_POLICY_URL = f"{DASHSCOPE_BASE_URL}/api/v1/uploads" +TRANSCRIPTION_URL = f"{DASHSCOPE_BASE_URL}/api/v1/services/audio/asr/transcription" +TASK_URL_TEMPLATE = f"{DASHSCOPE_BASE_URL}/api/v1/tasks/{{task_id}}" +MODEL_NAME = "fun-asr" +TERMINAL_FAILED_STATUSES = {"FAILED", "CANCELED", "UNKNOWN"} +PUNCTUATION_BREAKS = set(",。!?;,.!?;") + + +class FunAsrError(RuntimeError): + """Raised for user-actionable Fun-ASR transcription failures.""" + + +@dataclass +class UploadPolicy: + upload_host: str + upload_dir: str + policy: str + signature: str + oss_access_key_id: str + x_oss_object_acl: str = "private" + x_oss_forbid_overwrite: str = "true" + max_file_size_mb: Optional[float] = None + + +def _auth_headers(api_key: str, extra: Optional[dict[str, str]] = None) -> dict[str, str]: + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + if extra: + headers.update(extra) + return headers + + +def _raise_for_http(response: requests.Response, action: str) -> None: + try: + response.raise_for_status() + except Exception as exc: # requests may be mocked with generic exceptions + raise FunAsrError(f"{action}失败,请检查阿里百炼 API Key、网络或服务状态") from exc + + +def _json(response: requests.Response, action: str) -> dict[str, Any]: + _raise_for_http(response, action) + try: + data = response.json() + except Exception as exc: + raise FunAsrError(f"{action}返回了无效 JSON") from exc + if not isinstance(data, dict): + raise FunAsrError(f"{action}返回格式无效") + return data + + +def _require_api_key(api_key: str) -> str: + api_key = (api_key or "").strip() + if not api_key: + raise FunAsrError("请先输入阿里百炼 API Key") + return api_key + + +def _safe_upload_name(local_file: str) -> str: + name = os.path.basename(local_file).strip() or f"audio_{int(time.time())}.wav" + return name.replace("/", "_").replace("\\", "_") + + +def _session_get(session, url: str, **kwargs): + return session.get(url, **kwargs) + + +def _session_post(session, url: str, **kwargs): + return session.post(url, **kwargs) + + +def request_upload_policy(api_key: str, model: str = MODEL_NAME, session=requests) -> UploadPolicy: + """Request Bailian temporary-storage upload policy for the target model.""" + api_key = _require_api_key(api_key) + response = _session_get( + session, + UPLOAD_POLICY_URL, + params={"action": "getPolicy", "model": model}, + headers=_auth_headers(api_key), + timeout=30, + ) + data = _json(response, "获取临时存储上传凭证") + policy_data = data.get("data") or {} + required = ["upload_host", "upload_dir", "policy", "signature", "oss_access_key_id"] + missing = [field for field in required if not policy_data.get(field)] + if missing: + raise FunAsrError(f"临时存储上传凭证缺少字段: {', '.join(missing)}") + + return UploadPolicy( + upload_host=str(policy_data["upload_host"]), + upload_dir=str(policy_data["upload_dir"]).rstrip("/"), + policy=str(policy_data["policy"]), + signature=str(policy_data["signature"]), + oss_access_key_id=str(policy_data["oss_access_key_id"]), + x_oss_object_acl=str(policy_data.get("x_oss_object_acl") or "private"), + x_oss_forbid_overwrite=str(policy_data.get("x_oss_forbid_overwrite") or "true"), + max_file_size_mb=policy_data.get("max_file_size_mb"), + ) + + +def _validate_file_size(local_file: str, policy: UploadPolicy) -> None: + if policy.max_file_size_mb is None: + return + max_bytes = float(policy.max_file_size_mb) * 1024 * 1024 + size = os.path.getsize(local_file) + if size > max_bytes: + raise FunAsrError( + f"文件大小超过阿里百炼临时存储限制: {size / 1024 / 1024:.2f}MB > {float(policy.max_file_size_mb):.2f}MB" + ) + + +def upload_to_temporary_oss(local_file: str, policy: UploadPolicy, session=requests) -> str: + """Upload local file to temporary OSS and return `oss://...` URL.""" + if not os.path.isfile(local_file): + raise FunAsrError(f"待转写文件不存在: {local_file}") + _validate_file_size(local_file, policy) + + key = f"{policy.upload_dir}/{_safe_upload_name(local_file)}" + data = { + "OSSAccessKeyId": policy.oss_access_key_id, + "policy": policy.policy, + "Signature": policy.signature, + "key": key, + "x-oss-object-acl": policy.x_oss_object_acl, + "x-oss-forbid-overwrite": policy.x_oss_forbid_overwrite, + "success_action_status": "200", + } + with open(local_file, "rb") as file_obj: + files = {"file": (_safe_upload_name(local_file), file_obj)} + response = _session_post(session, policy.upload_host, data=data, files=files, timeout=120) + _raise_for_http(response, "上传文件到阿里百炼临时存储") + return f"oss://{key}" + + +def submit_transcription_task( + api_key: str, + oss_url: str, + speaker_count: Optional[int] = None, + model: str = MODEL_NAME, + session=requests, +) -> str: + """Submit async Fun-ASR task and return task_id.""" + api_key = _require_api_key(api_key) + parameters: dict[str, Any] = {"diarization_enabled": True} + if speaker_count: + parameters["speaker_count"] = int(speaker_count) + + payload = { + "model": model, + "input": {"file_urls": [oss_url]}, + "parameters": parameters, + } + response = _session_post( + session, + TRANSCRIPTION_URL, + headers=_auth_headers( + api_key, + { + "X-DashScope-Async": "enable", + "X-DashScope-OssResourceResolve": "enable", + }, + ), + json=payload, + timeout=30, + ) + data = _json(response, "提交 Fun-ASR 转写任务") + task_id = ((data.get("output") or {}).get("task_id") or "").strip() + if not task_id: + raise FunAsrError("提交 Fun-ASR 转写任务失败:未返回 task_id") + return task_id + + +def poll_transcription_task( + api_key: str, + task_id: str, + poll_interval: float = 2.0, + timeout: float = 600.0, + session=requests, +) -> dict[str, Any]: + """Poll task until terminal status and return successful result item.""" + api_key = _require_api_key(api_key) + deadline = time.time() + timeout + last_status = "PENDING" + while time.time() < deadline: + response = _session_post( + session, + TASK_URL_TEMPLATE.format(task_id=task_id), + headers=_auth_headers(api_key), + timeout=30, + ) + data = _json(response, "查询 Fun-ASR 转写任务") + output = data.get("output") or {} + last_status = str(output.get("task_status") or "").upper() + + if last_status == "SUCCEEDED": + results = output.get("results") or [] + for result in results: + subtask_status = str(result.get("subtask_status") or "").upper() + if subtask_status and subtask_status != "SUCCEEDED": + raise FunAsrError(f"Fun-ASR 子任务失败: {subtask_status}") + if not results: + raise FunAsrError("Fun-ASR 转写成功但未返回结果") + return results[0] + + if last_status in TERMINAL_FAILED_STATUSES: + raise FunAsrError(f"Fun-ASR 转写任务失败: {last_status}") + + time.sleep(poll_interval) + + raise FunAsrError(f"Fun-ASR 转写任务超时,最后状态: {last_status}") + + +def download_transcription_result(transcription_url: str, session=requests) -> dict[str, Any]: + if not transcription_url: + raise FunAsrError("Fun-ASR 结果缺少 transcription_url") + response = _session_get(session, transcription_url, timeout=60) + return _json(response, "下载 Fun-ASR 转写结果") + + +def _ms_to_srt_time(ms: float) -> str: + total_ms = max(0, int(round(float(ms)))) + hours = total_ms // 3_600_000 + total_ms %= 3_600_000 + minutes = total_ms // 60_000 + total_ms %= 60_000 + seconds = total_ms // 1_000 + milliseconds = total_ms % 1_000 + return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}" + + +def _srt_block(index: int, start_ms: float, end_ms: float, text: str) -> str: + if end_ms <= start_ms: + end_ms = start_ms + 500 + return f"{index}\n{_ms_to_srt_time(start_ms)} --> {_ms_to_srt_time(end_ms)}\n{text.strip()}\n" + + +def _timestamp_ms(value: Any, field_name: str) -> float: + try: + return float(value) + except (TypeError, ValueError) as exc: + raise FunAsrError(f"Fun-ASR 转写结果时间戳无效: {field_name}={value!r}") from exc + + +def _speaker_prefix(speaker_id: Any) -> str: + if speaker_id is None or speaker_id == "": + return "" + try: + label = int(speaker_id) + 1 + except (TypeError, ValueError): + label = str(speaker_id) + return f"说话人{label}: " + + +def _iter_sentences(result_json: dict[str, Any]): + transcripts = result_json.get("transcripts") + if transcripts is None and "sentences" in result_json: + transcripts = [{"sentences": result_json.get("sentences") or []}] + if not transcripts: + raise FunAsrError("Fun-ASR 转写结果为空:未找到 transcripts") + for transcript in transcripts: + for sentence in transcript.get("sentences") or []: + yield sentence + + +def _word_text(word: dict[str, Any]) -> str: + text = str(word.get("text") or word.get("word") or "") + punctuation = str(word.get("punctuation") or "") + if punctuation and not text.endswith(punctuation): + text += punctuation + return text + + +def _flush_block(blocks: list[dict[str, Any]], current: dict[str, Any]) -> None: + text = current.get("text", "").strip() + if text: + blocks.append(current.copy()) + + +def _blocks_from_words(sentence: dict[str, Any], max_chars: int, max_duration: float) -> list[dict[str, Any]]: + words = sentence.get("words") or [] + blocks: list[dict[str, Any]] = [] + current: Optional[dict[str, Any]] = None + max_duration_ms = max_duration * 1000 + sentence_speaker = sentence.get("speaker_id") + + for word in words: + text = _word_text(word) + if not text: + continue + start = word.get("begin_time", word.get("start_time")) + end = word.get("end_time") + if start is None or end is None: + continue + speaker_id = word.get("speaker_id", sentence_speaker) + start_ms = _timestamp_ms(start, "word.begin_time") + end_ms = _timestamp_ms(end, "word.end_time") + + if current is None: + current = {"start": start_ms, "end": end_ms, "text": text, "speaker_id": speaker_id} + else: + should_split_before = ( + speaker_id != current.get("speaker_id") + or len(current["text"] + text) > max_chars + or (end_ms - current["start"]) > max_duration_ms + ) + if should_split_before: + _flush_block(blocks, current) + current = {"start": start_ms, "end": end_ms, "text": text, "speaker_id": speaker_id} + else: + current["text"] += text + current["end"] = end_ms + + if current and text[-1:] in PUNCTUATION_BREAKS: + _flush_block(blocks, current) + current = None + + if current: + _flush_block(blocks, current) + return blocks + + +def _split_text(text: str, max_chars: int) -> list[str]: + chunks: list[str] = [] + current = "" + for char in text: + current += char + if char in PUNCTUATION_BREAKS or len(current) >= max_chars: + chunks.append(current.strip()) + current = "" + if current.strip(): + chunks.append(current.strip()) + return [chunk for chunk in chunks if chunk] + + +def _blocks_from_sentence(sentence: dict[str, Any], max_chars: int) -> list[dict[str, Any]]: + text = str(sentence.get("text") or "").strip() + if not text: + return [] + start = sentence.get("begin_time", 0) + end = sentence.get("end_time") + start_ms = _timestamp_ms(start, "sentence.begin_time") + end_ms = _timestamp_ms(end, "sentence.end_time") if end is not None else start_ms + 500 + chunks = _split_text(text, max_chars) + if not chunks: + return [] + duration = max(500.0, end_ms - start_ms) + total_chars = max(1, sum(len(chunk) for chunk in chunks)) + cursor = start_ms + blocks: list[dict[str, Any]] = [] + for i, chunk in enumerate(chunks): + if i == len(chunks) - 1: + chunk_end = end_ms + else: + chunk_end = cursor + duration * (len(chunk) / total_chars) + blocks.append( + { + "start": cursor, + "end": max(cursor + 200, chunk_end), + "text": chunk, + "speaker_id": sentence.get("speaker_id"), + } + ) + cursor = chunk_end + return blocks + + +def fun_asr_result_to_srt(result_json: dict[str, Any], max_chars: int = 20, max_duration: float = 3.5) -> str: + """Convert downloaded Fun-ASR JSON into fine-grained SRT. + + Official downloaded schema is `transcripts[*].sentences[*].words[*]`. + Fun-ASR timestamps are milliseconds. + """ + blocks: list[dict[str, Any]] = [] + for sentence in _iter_sentences(result_json): + sentence_blocks = _blocks_from_words(sentence, max_chars, max_duration) + if not sentence_blocks: + sentence_blocks = _blocks_from_sentence(sentence, max_chars) + blocks.extend(sentence_blocks) + + if not blocks: + raise FunAsrError("Fun-ASR 转写结果为空:未找到可用字幕内容") + + lines = [] + for index, block in enumerate(blocks, start=1): + text = f"{_speaker_prefix(block.get('speaker_id'))}{block['text']}" + lines.append(_srt_block(index, block["start"], block["end"], text)) + return "\n".join(lines).rstrip() + "\n" + + +def write_srt_file(srt_content: str, subtitle_file: str = "") -> str: + if not subtitle_file: + subtitle_file = os.path.join(utils.subtitle_dir(), f"fun_asr_{int(time.time())}.srt") + parent = os.path.dirname(subtitle_file) + if parent: + os.makedirs(parent, exist_ok=True) + with open(subtitle_file, "w", encoding="utf-8") as f: + f.write(srt_content) + return subtitle_file + + +def create_with_fun_asr( + local_file: str, + subtitle_file: str = "", + api_key: str = "", + speaker_count: Optional[int] = None, + poll_interval: float = 2.0, + timeout: float = 600.0, + session=requests, +) -> Optional[str]: + """Upload local media to Bailian temporary storage and create a Fun-ASR SRT file.""" + api_key = _require_api_key(api_key) + try: + policy = request_upload_policy(api_key, session=session) + oss_url = upload_to_temporary_oss(local_file, policy, session=session) + task_id = submit_transcription_task(api_key, oss_url, speaker_count=speaker_count, session=session) + task_result = poll_transcription_task( + api_key, + task_id, + poll_interval=poll_interval, + timeout=timeout, + session=session, + ) + transcription_url = task_result.get("transcription_url") + result_json = download_transcription_result(transcription_url, session=session) + srt_content = fun_asr_result_to_srt(result_json) + output_file = write_srt_file(srt_content, subtitle_file) + logger.info(f"Fun-ASR 字幕文件已生成: {output_file}") + return output_file + except FunAsrError: + raise + except Exception as exc: + raise FunAsrError("Fun-ASR 字幕转写失败,请检查文件、网络或阿里百炼服务状态") from exc diff --git a/app/services/test_fun_asr_subtitle_unittest.py b/app/services/test_fun_asr_subtitle_unittest.py new file mode 100644 index 0000000..83062bd --- /dev/null +++ b/app/services/test_fun_asr_subtitle_unittest.py @@ -0,0 +1,403 @@ +import tempfile +import unittest +from pathlib import Path + +try: + import tomllib +except ModuleNotFoundError: # Python < 3.11 + import tomli as tomllib + +from app.config import config as cfg +from app.services import fun_asr_subtitle as fasr + + +class FakeResponse: + def __init__(self, payload=None, status_code=200): + self.payload = payload or {} + self.status_code = status_code + + def json(self): + return self.payload + + def raise_for_status(self): + if self.status_code >= 400: + raise RuntimeError(f"HTTP {self.status_code}") + + +class InvalidJsonResponse(FakeResponse): + def json(self): + raise ValueError("invalid json") + + +class FakeSession: + def __init__(self, local_result): + self.calls = [] + self.local_result = local_result + + def get(self, url, **kwargs): + self.calls.append(("GET", url, kwargs)) + if url == fasr.UPLOAD_POLICY_URL: + return FakeResponse( + { + "data": { + "policy": "policy-token", + "signature": "signature-token", + "upload_dir": "dashscope-instant/test-dir", + "upload_host": "https://dashscope-file-test.oss-cn-beijing.aliyuncs.com", + "oss_access_key_id": "oss-ak", + "x_oss_object_acl": "private", + "x_oss_forbid_overwrite": "true", + "max_file_size_mb": 1, + } + } + ) + if url == "https://result.example/transcription.json": + return FakeResponse(self.local_result) + return FakeResponse({}, 404) + + def post(self, url, **kwargs): + self.calls.append(("POST", url, kwargs)) + if url == "https://dashscope-file-test.oss-cn-beijing.aliyuncs.com": + return FakeResponse({}) + if url == fasr.TRANSCRIPTION_URL: + return FakeResponse({"output": {"task_status": "PENDING", "task_id": "task-123"}}) + if url == fasr.TASK_URL_TEMPLATE.format(task_id="task-123"): + return FakeResponse( + { + "output": { + "task_status": "SUCCEEDED", + "results": [ + { + "file_url": "oss://dashscope-instant/test-dir/audio.wav", + "transcription_url": "https://result.example/transcription.json", + "subtask_status": "SUCCEEDED", + } + ], + } + } + ) + return FakeResponse({}, 404) + + +OFFICIAL_SHAPE_RESULT = { + "transcripts": [ + { + "sentences": [ + { + "begin_time": 0, + "end_time": 3600, + "text": "你好欢迎观看今天的内容", + "speaker_id": 0, + "words": [ + {"begin_time": 0, "end_time": 400, "text": "你好", "punctuation": ","}, + {"begin_time": 400, "end_time": 900, "text": "欢迎", "punctuation": ""}, + {"begin_time": 900, "end_time": 1300, "text": "观看", "punctuation": ""}, + {"begin_time": 1300, "end_time": 1800, "text": "今天", "punctuation": ""}, + {"begin_time": 1800, "end_time": 2400, "text": "的内容", "punctuation": "。"}, + ], + } + ] + } + ] +} + + +class FunAsrSrtConversionTests(unittest.TestCase): + def test_official_shape_words_convert_ms_and_speaker_label(self): + srt = fasr.fun_asr_result_to_srt(OFFICIAL_SHAPE_RESULT, max_chars=20, max_duration=3.5) + + self.assertIn("1\n00:00:00,000 --> 00:00:00,400\n说话人1: 你好,", srt) + self.assertIn("2\n00:00:00,400 --> 00:00:02,400\n说话人1: 欢迎观看今天的内容。", srt) + self.assertNotIn("00:06:40,000", srt, "milliseconds must not be treated as seconds") + + def test_long_word_sequence_splits_into_fine_blocks(self): + result = { + "transcripts": [ + { + "sentences": [ + { + "begin_time": 0, + "end_time": 6000, + "speaker_id": 1, + "words": [ + {"begin_time": i * 500, "end_time": (i + 1) * 500, "text": token, "punctuation": ""} + for i, token in enumerate(["这是", "一个", "很长", "字幕", "需要", "拆分"]) + ], + } + ] + } + ] + } + srt = fasr.fun_asr_result_to_srt(result, max_chars=4, max_duration=10) + + self.assertGreaterEqual(srt.count("\n说话人2:"), 3) + self.assertIn("1\n00:00:00,000", srt) + + def test_sentence_fallback_uses_ms_without_zero_duration(self): + result = { + "transcripts": [ + { + "sentences": [ + { + "begin_time": 1000, + "end_time": 3000, + "text": "没有词级时间戳也可以拆分。", + "speaker_id": 0, + "words": [], + } + ] + } + ] + } + srt = fasr.fun_asr_result_to_srt(result, max_chars=5) + + self.assertIn("00:00:01,000", srt) + self.assertIn("说话人1:", srt) + self.assertNotIn("--> 00:00:01,000\n", srt) + + def test_empty_result_raises_clear_error(self): + with self.assertRaises(fasr.FunAsrError): + fasr.fun_asr_result_to_srt({"transcripts": []}) + + def test_malformed_word_timestamp_raises_fun_asr_error(self): + result = { + "transcripts": [ + { + "sentences": [ + { + "begin_time": 0, + "end_time": 1000, + "speaker_id": 0, + "words": [ + {"begin_time": "bad", "end_time": 500, "text": "坏时间", "punctuation": ""} + ], + } + ] + } + ] + } + + with self.assertRaises(fasr.FunAsrError): + fasr.fun_asr_result_to_srt(result) + + def test_malformed_sentence_timestamp_raises_fun_asr_error(self): + result = { + "transcripts": [ + { + "sentences": [ + { + "begin_time": "bad", + "end_time": 1000, + "text": "坏时间", + "speaker_id": 0, + "words": [], + } + ] + } + ] + } + + with self.assertRaises(fasr.FunAsrError): + fasr.fun_asr_result_to_srt(result) + + +class FunAsrServiceTests(unittest.TestCase): + def test_create_with_fun_asr_uses_expected_rest_flow(self): + with tempfile.TemporaryDirectory() as tmp_dir: + local_file = Path(tmp_dir) / "audio.wav" + local_file.write_bytes(b"audio") + subtitle_file = Path(tmp_dir) / "out.srt" + session = FakeSession(OFFICIAL_SHAPE_RESULT) + + result_path = fasr.create_with_fun_asr( + str(local_file), + subtitle_file=str(subtitle_file), + api_key="sk-test", + speaker_count=2, + poll_interval=0, + session=session, + ) + + self.assertEqual(str(subtitle_file), result_path) + self.assertTrue(subtitle_file.exists()) + self.assertIn("说话人1:", subtitle_file.read_text(encoding="utf-8")) + + policy_call = session.calls[0] + self.assertEqual("GET", policy_call[0]) + self.assertEqual(fasr.UPLOAD_POLICY_URL, policy_call[1]) + self.assertEqual({"action": "getPolicy", "model": "fun-asr"}, policy_call[2]["params"]) + self.assertEqual("Bearer sk-test", policy_call[2]["headers"]["Authorization"]) + + upload_call = session.calls[1] + self.assertEqual("POST", upload_call[0]) + self.assertEqual("https://dashscope-file-test.oss-cn-beijing.aliyuncs.com", upload_call[1]) + upload_data = upload_call[2]["data"] + self.assertEqual("oss-ak", upload_data["OSSAccessKeyId"]) + self.assertEqual("policy-token", upload_data["policy"]) + self.assertEqual("signature-token", upload_data["Signature"]) + self.assertEqual("dashscope-instant/test-dir/audio.wav", upload_data["key"]) + self.assertEqual("200", upload_data["success_action_status"]) + + submit_call = session.calls[2] + self.assertEqual(fasr.TRANSCRIPTION_URL, submit_call[1]) + headers = submit_call[2]["headers"] + self.assertEqual("enable", headers["X-DashScope-Async"]) + self.assertEqual("enable", headers["X-DashScope-OssResourceResolve"]) + payload = submit_call[2]["json"] + self.assertEqual("fun-asr", payload["model"]) + self.assertEqual(["oss://dashscope-instant/test-dir/audio.wav"], payload["input"]["file_urls"]) + self.assertTrue(payload["parameters"]["diarization_enabled"]) + self.assertEqual(2, payload["parameters"]["speaker_count"]) + + poll_call = session.calls[3] + self.assertEqual("POST", poll_call[0]) + self.assertTrue(poll_call[1].endswith("/api/v1/tasks/task-123")) + + download_call = session.calls[4] + self.assertEqual(("GET", "https://result.example/transcription.json"), download_call[:2]) + + def test_upload_policy_size_validation_fails_before_upload(self): + policy = fasr.UploadPolicy( + upload_host="https://upload.example", + upload_dir="dashscope-instant/test", + policy="p", + signature="s", + oss_access_key_id="ak", + max_file_size_mb=0.000001, + ) + with tempfile.NamedTemporaryFile() as f: + f.write(b"too-large") + f.flush() + with self.assertRaises(fasr.FunAsrError): + fasr.upload_to_temporary_oss(f.name, policy, session=FakeSession({})) + + def test_failed_subtask_raises(self): + class FailedSession(FakeSession): + def post(self, url, **kwargs): + if url == fasr.TASK_URL_TEMPLATE.format(task_id="task-123"): + return FakeResponse( + { + "output": { + "task_status": "SUCCEEDED", + "results": [{"subtask_status": "FAILED"}], + } + } + ) + return super().post(url, **kwargs) + + with self.assertRaises(fasr.FunAsrError): + fasr.poll_transcription_task("sk-test", "task-123", poll_interval=0, session=FailedSession({})) + + def test_missing_api_key_raises_before_request(self): + session = FakeSession(OFFICIAL_SHAPE_RESULT) + + with self.assertRaises(fasr.FunAsrError): + fasr.request_upload_policy("", session=session) + + self.assertEqual([], session.calls) + + def test_upload_policy_http_error_raises(self): + class PolicyHttpErrorSession(FakeSession): + def get(self, url, **kwargs): + self.calls.append(("GET", url, kwargs)) + return FakeResponse({}, status_code=403) + + with self.assertRaises(fasr.FunAsrError): + fasr.request_upload_policy("sk-test", session=PolicyHttpErrorSession({})) + + def test_malformed_upload_policy_raises(self): + class MalformedPolicySession(FakeSession): + def get(self, url, **kwargs): + self.calls.append(("GET", url, kwargs)) + return FakeResponse({"data": {"policy": "missing-required-fields"}}) + + with self.assertRaises(fasr.FunAsrError): + fasr.request_upload_policy("sk-test", session=MalformedPolicySession({})) + + def test_upload_http_failure_raises(self): + class UploadFailureSession(FakeSession): + def post(self, url, **kwargs): + self.calls.append(("POST", url, kwargs)) + return FakeResponse({}, status_code=500) + + policy = fasr.UploadPolicy( + upload_host="https://upload.example", + upload_dir="dashscope-instant/test", + policy="p", + signature="s", + oss_access_key_id="ak", + max_file_size_mb=1, + ) + with tempfile.NamedTemporaryFile() as f: + f.write(b"audio") + f.flush() + with self.assertRaises(fasr.FunAsrError): + fasr.upload_to_temporary_oss(f.name, policy, session=UploadFailureSession({})) + + def test_submit_failure_raises(self): + class SubmitFailureSession(FakeSession): + def post(self, url, **kwargs): + self.calls.append(("POST", url, kwargs)) + return FakeResponse({}, status_code=500) + + with self.assertRaises(fasr.FunAsrError): + fasr.submit_transcription_task("sk-test", "oss://file", session=SubmitFailureSession({})) + + def test_poll_timeout_raises(self): + class PendingSession(FakeSession): + def post(self, url, **kwargs): + self.calls.append(("POST", url, kwargs)) + return FakeResponse({"output": {"task_status": "RUNNING"}}) + + with self.assertRaises(fasr.FunAsrError): + fasr.poll_transcription_task("sk-test", "task-123", poll_interval=0, timeout=-1, session=PendingSession({})) + + def test_task_failed_status_raises(self): + class FailedTaskSession(FakeSession): + def post(self, url, **kwargs): + self.calls.append(("POST", url, kwargs)) + return FakeResponse({"output": {"task_status": "FAILED"}}) + + with self.assertRaises(fasr.FunAsrError): + fasr.poll_transcription_task("sk-test", "task-123", poll_interval=0, session=FailedTaskSession({})) + + def test_missing_transcription_url_raises(self): + with self.assertRaises(fasr.FunAsrError): + fasr.download_transcription_result("", session=FakeSession({})) + + def test_malformed_downloaded_json_raises(self): + class MalformedDownloadSession(FakeSession): + def get(self, url, **kwargs): + self.calls.append(("GET", url, kwargs)) + return InvalidJsonResponse() + + with self.assertRaises(fasr.FunAsrError): + fasr.download_transcription_result("https://result.example/bad.json", session=MalformedDownloadSession({})) + + +class FunAsrConfigTests(unittest.TestCase): + def test_save_config_persists_fun_asr_section(self): + original_config_file = cfg.config_file + original_fun_asr = cfg.fun_asr + try: + with tempfile.TemporaryDirectory() as tmp_dir: + config_path = Path(tmp_dir) / "config.toml" + cfg.config_file = str(config_path) + cfg.fun_asr = {"api_key": "sk-local", "model": "fun-asr"} + cfg.save_config() + saved = tomllib.loads(config_path.read_text(encoding="utf-8")) + finally: + cfg.config_file = original_config_file + cfg.fun_asr = original_fun_asr + + self.assertEqual("sk-local", saved["fun_asr"]["api_key"]) + self.assertEqual("fun-asr", saved["fun_asr"]["model"]) + + def test_config_example_fun_asr_section_parses(self): + config_data = tomllib.loads(Path("config.example.toml").read_text(encoding="utf-8")) + self.assertEqual("fun-asr", config_data["fun_asr"]["model"]) + self.assertIn("api_key", config_data["fun_asr"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/config.example.toml b/config.example.toml index 5674e39..c503129 100644 --- a/config.example.toml +++ b/config.example.toml @@ -93,6 +93,12 @@ # 访问 https://bailian.console.aliyun.com/?tab=model#/api-key 获取你的 API 密钥 api_key = "" model_name = "qwen3-tts-flash" + +[fun_asr] + # 阿里百炼 Fun-ASR 字幕转录配置 + # 访问 https://bailian.console.aliyun.com/?tab=model#/api-key 获取你的 API 密钥 + api_key = "" + model = "fun-asr" [indextts2] # IndexTTS2 语音克隆配置 diff --git a/webui/components/script_settings.py b/webui/components/script_settings.py index f6ee6f5..7c7a3f2 100644 --- a/webui/components/script_settings.py +++ b/webui/components/script_settings.py @@ -327,6 +327,8 @@ def short_drama_summary(tr): # 检查是否已经处理过字幕文件 if 'subtitle_file_processed' not in st.session_state: st.session_state['subtitle_file_processed'] = False + + render_fun_asr_transcription(tr) subtitle_file = st.file_uploader( tr("上传字幕文件"), @@ -401,6 +403,95 @@ def short_drama_summary(tr): return video_theme +def render_fun_asr_transcription(tr): + """使用阿里百炼 Fun-ASR 从本地音视频转写生成字幕。""" + def clear_fun_asr_subtitle_state(): + st.session_state['subtitle_path'] = None + st.session_state['subtitle_content'] = None + st.session_state['subtitle_file_processed'] = False + + with st.expander("阿里百炼 Fun-ASR 字幕转录", expanded=False): + st.caption("上传本地音频/视频后,将自动上传到阿里百炼临时存储并通过 fun-asr 生成 SRT 字幕。") + st.markdown( + "API Key 获取地址:" + "[https://bailian.console.aliyun.com/?tab=model#/api-key]" + "(https://bailian.console.aliyun.com/?tab=model#/api-key)" + ) + + api_key = st.text_input( + "阿里百炼 API Key", + value=config.fun_asr.get("api_key", ""), + type="password", + help="请输入你自己的阿里百炼 API Key;保存配置后会写入本地 config.toml", + key="fun_asr_api_key", + ) + uploaded_media = st.file_uploader( + "上传需要转录的音频/视频", + type=[ + "aac", "amr", "avi", "flac", "flv", "m4a", "mkv", "mov", + "mp3", "mp4", "mpeg", "ogg", "opus", "wav", "webm", "wma", "wmv", + ], + accept_multiple_files=False, + key="fun_asr_media_uploader", + ) + + if st.button("转写生成字幕", key="fun_asr_transcribe"): + if not api_key.strip(): + clear_fun_asr_subtitle_state() + st.error("请先输入阿里百炼 API Key") + return + if uploaded_media is None: + clear_fun_asr_subtitle_state() + st.error("请先上传需要转录的音频或视频文件") + return + + try: + clear_fun_asr_subtitle_state() + from app.services import fun_asr_subtitle + + config.fun_asr["api_key"] = api_key.strip() + config.fun_asr["model"] = "fun-asr" + config.save_config() + + temp_dir = utils.temp_dir("fun_asr") + safe_filename = os.path.basename(uploaded_media.name) + media_path = os.path.join(temp_dir, safe_filename) + file_name, file_extension = os.path.splitext(safe_filename) + if os.path.exists(media_path): + timestamp = time.strftime("%Y%m%d%H%M%S") + media_path = os.path.join(temp_dir, f"{file_name}_{timestamp}{file_extension}") + + with open(media_path, "wb") as f: + f.write(uploaded_media.getbuffer()) + + subtitle_name = f"{os.path.splitext(os.path.basename(media_path))[0]}_fun_asr.srt" + subtitle_path = os.path.join(utils.subtitle_dir(), subtitle_name) + + with st.spinner("正在使用阿里百炼 Fun-ASR 转写字幕,请稍候..."): + generated_path = fun_asr_subtitle.create_with_fun_asr( + local_file=media_path, + subtitle_file=subtitle_path, + api_key=api_key.strip(), + ) + + if not generated_path or not os.path.exists(generated_path): + clear_fun_asr_subtitle_state() + st.error("Fun-ASR 转写失败:未生成字幕文件") + return + + with open(generated_path, "r", encoding="utf-8") as f: + subtitle_content = f.read() + + st.session_state['subtitle_path'] = generated_path + st.session_state['subtitle_content'] = subtitle_content + st.session_state['subtitle_file_processed'] = True + st.success(f"字幕转写成功: {os.path.basename(generated_path)}") + except Exception as e: + clear_fun_asr_subtitle_state() + logger.error(f"Fun-ASR 字幕转写失败: {traceback.format_exc()}") + st.error(f"Fun-ASR 字幕转写失败: {str(e)}") + + def render_script_buttons(tr, params): """渲染脚本操作按钮""" # 获取当前选择的脚本类型