NarratoAI/app/services/test_fun_asr_subtitle_unittest.py
viccy 99dd4193ae feat(字幕): 新增阿里百炼 Fun-ASR 音视频字幕转录功能
- 在 WebUI 中增加 Fun-ASR 转录界面,支持上传多种音视频格式并生成 SRT 字幕
- 新增 `app/services/fun_asr_subtitle.py` 服务模块,实现完整的 REST API 调用流程,包括获取上传凭证、文件上传、提交任务、轮询结果和 SRT 格式转换
- 在配置文件中增加 `[fun_asr]` 配置段,支持保存 API Key
- 添加完整的单元测试,覆盖核心转换逻辑和服务流程
- 为兼容 Python 3.11 以下版本,将 `tomllib` 导入改为尝试导入并回退到 `tomli`
- 在 `defaults.py` 中添加 `from __future__ import annotations` 以支持类型注解
2026-04-27 18:15:54 +08:00

404 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import tempfile
import unittest
from pathlib import Path
try:
import tomllib
except ModuleNotFoundError: # Python < 3.11
import tomli as tomllib
from app.config import config as cfg
from app.services import fun_asr_subtitle as fasr
class FakeResponse:
def __init__(self, payload=None, status_code=200):
self.payload = payload or {}
self.status_code = status_code
def json(self):
return self.payload
def raise_for_status(self):
if self.status_code >= 400:
raise RuntimeError(f"HTTP {self.status_code}")
class InvalidJsonResponse(FakeResponse):
def json(self):
raise ValueError("invalid json")
class FakeSession:
def __init__(self, local_result):
self.calls = []
self.local_result = local_result
def get(self, url, **kwargs):
self.calls.append(("GET", url, kwargs))
if url == fasr.UPLOAD_POLICY_URL:
return FakeResponse(
{
"data": {
"policy": "policy-token",
"signature": "signature-token",
"upload_dir": "dashscope-instant/test-dir",
"upload_host": "https://dashscope-file-test.oss-cn-beijing.aliyuncs.com",
"oss_access_key_id": "oss-ak",
"x_oss_object_acl": "private",
"x_oss_forbid_overwrite": "true",
"max_file_size_mb": 1,
}
}
)
if url == "https://result.example/transcription.json":
return FakeResponse(self.local_result)
return FakeResponse({}, 404)
def post(self, url, **kwargs):
self.calls.append(("POST", url, kwargs))
if url == "https://dashscope-file-test.oss-cn-beijing.aliyuncs.com":
return FakeResponse({})
if url == fasr.TRANSCRIPTION_URL:
return FakeResponse({"output": {"task_status": "PENDING", "task_id": "task-123"}})
if url == fasr.TASK_URL_TEMPLATE.format(task_id="task-123"):
return FakeResponse(
{
"output": {
"task_status": "SUCCEEDED",
"results": [
{
"file_url": "oss://dashscope-instant/test-dir/audio.wav",
"transcription_url": "https://result.example/transcription.json",
"subtask_status": "SUCCEEDED",
}
],
}
}
)
return FakeResponse({}, 404)
OFFICIAL_SHAPE_RESULT = {
"transcripts": [
{
"sentences": [
{
"begin_time": 0,
"end_time": 3600,
"text": "你好欢迎观看今天的内容",
"speaker_id": 0,
"words": [
{"begin_time": 0, "end_time": 400, "text": "你好", "punctuation": ""},
{"begin_time": 400, "end_time": 900, "text": "欢迎", "punctuation": ""},
{"begin_time": 900, "end_time": 1300, "text": "观看", "punctuation": ""},
{"begin_time": 1300, "end_time": 1800, "text": "今天", "punctuation": ""},
{"begin_time": 1800, "end_time": 2400, "text": "的内容", "punctuation": ""},
],
}
]
}
]
}
class FunAsrSrtConversionTests(unittest.TestCase):
def test_official_shape_words_convert_ms_and_speaker_label(self):
srt = fasr.fun_asr_result_to_srt(OFFICIAL_SHAPE_RESULT, max_chars=20, max_duration=3.5)
self.assertIn("1\n00:00:00,000 --> 00:00:00,400\n说话人1: 你好,", srt)
self.assertIn("2\n00:00:00,400 --> 00:00:02,400\n说话人1: 欢迎观看今天的内容。", srt)
self.assertNotIn("00:06:40,000", srt, "milliseconds must not be treated as seconds")
def test_long_word_sequence_splits_into_fine_blocks(self):
result = {
"transcripts": [
{
"sentences": [
{
"begin_time": 0,
"end_time": 6000,
"speaker_id": 1,
"words": [
{"begin_time": i * 500, "end_time": (i + 1) * 500, "text": token, "punctuation": ""}
for i, token in enumerate(["这是", "一个", "很长", "字幕", "需要", "拆分"])
],
}
]
}
]
}
srt = fasr.fun_asr_result_to_srt(result, max_chars=4, max_duration=10)
self.assertGreaterEqual(srt.count("\n说话人2:"), 3)
self.assertIn("1\n00:00:00,000", srt)
def test_sentence_fallback_uses_ms_without_zero_duration(self):
result = {
"transcripts": [
{
"sentences": [
{
"begin_time": 1000,
"end_time": 3000,
"text": "没有词级时间戳也可以拆分。",
"speaker_id": 0,
"words": [],
}
]
}
]
}
srt = fasr.fun_asr_result_to_srt(result, max_chars=5)
self.assertIn("00:00:01,000", srt)
self.assertIn("说话人1:", srt)
self.assertNotIn("--> 00:00:01,000\n", srt)
def test_empty_result_raises_clear_error(self):
with self.assertRaises(fasr.FunAsrError):
fasr.fun_asr_result_to_srt({"transcripts": []})
def test_malformed_word_timestamp_raises_fun_asr_error(self):
result = {
"transcripts": [
{
"sentences": [
{
"begin_time": 0,
"end_time": 1000,
"speaker_id": 0,
"words": [
{"begin_time": "bad", "end_time": 500, "text": "坏时间", "punctuation": ""}
],
}
]
}
]
}
with self.assertRaises(fasr.FunAsrError):
fasr.fun_asr_result_to_srt(result)
def test_malformed_sentence_timestamp_raises_fun_asr_error(self):
result = {
"transcripts": [
{
"sentences": [
{
"begin_time": "bad",
"end_time": 1000,
"text": "坏时间",
"speaker_id": 0,
"words": [],
}
]
}
]
}
with self.assertRaises(fasr.FunAsrError):
fasr.fun_asr_result_to_srt(result)
class FunAsrServiceTests(unittest.TestCase):
def test_create_with_fun_asr_uses_expected_rest_flow(self):
with tempfile.TemporaryDirectory() as tmp_dir:
local_file = Path(tmp_dir) / "audio.wav"
local_file.write_bytes(b"audio")
subtitle_file = Path(tmp_dir) / "out.srt"
session = FakeSession(OFFICIAL_SHAPE_RESULT)
result_path = fasr.create_with_fun_asr(
str(local_file),
subtitle_file=str(subtitle_file),
api_key="sk-test",
speaker_count=2,
poll_interval=0,
session=session,
)
self.assertEqual(str(subtitle_file), result_path)
self.assertTrue(subtitle_file.exists())
self.assertIn("说话人1:", subtitle_file.read_text(encoding="utf-8"))
policy_call = session.calls[0]
self.assertEqual("GET", policy_call[0])
self.assertEqual(fasr.UPLOAD_POLICY_URL, policy_call[1])
self.assertEqual({"action": "getPolicy", "model": "fun-asr"}, policy_call[2]["params"])
self.assertEqual("Bearer sk-test", policy_call[2]["headers"]["Authorization"])
upload_call = session.calls[1]
self.assertEqual("POST", upload_call[0])
self.assertEqual("https://dashscope-file-test.oss-cn-beijing.aliyuncs.com", upload_call[1])
upload_data = upload_call[2]["data"]
self.assertEqual("oss-ak", upload_data["OSSAccessKeyId"])
self.assertEqual("policy-token", upload_data["policy"])
self.assertEqual("signature-token", upload_data["Signature"])
self.assertEqual("dashscope-instant/test-dir/audio.wav", upload_data["key"])
self.assertEqual("200", upload_data["success_action_status"])
submit_call = session.calls[2]
self.assertEqual(fasr.TRANSCRIPTION_URL, submit_call[1])
headers = submit_call[2]["headers"]
self.assertEqual("enable", headers["X-DashScope-Async"])
self.assertEqual("enable", headers["X-DashScope-OssResourceResolve"])
payload = submit_call[2]["json"]
self.assertEqual("fun-asr", payload["model"])
self.assertEqual(["oss://dashscope-instant/test-dir/audio.wav"], payload["input"]["file_urls"])
self.assertTrue(payload["parameters"]["diarization_enabled"])
self.assertEqual(2, payload["parameters"]["speaker_count"])
poll_call = session.calls[3]
self.assertEqual("POST", poll_call[0])
self.assertTrue(poll_call[1].endswith("/api/v1/tasks/task-123"))
download_call = session.calls[4]
self.assertEqual(("GET", "https://result.example/transcription.json"), download_call[:2])
def test_upload_policy_size_validation_fails_before_upload(self):
policy = fasr.UploadPolicy(
upload_host="https://upload.example",
upload_dir="dashscope-instant/test",
policy="p",
signature="s",
oss_access_key_id="ak",
max_file_size_mb=0.000001,
)
with tempfile.NamedTemporaryFile() as f:
f.write(b"too-large")
f.flush()
with self.assertRaises(fasr.FunAsrError):
fasr.upload_to_temporary_oss(f.name, policy, session=FakeSession({}))
def test_failed_subtask_raises(self):
class FailedSession(FakeSession):
def post(self, url, **kwargs):
if url == fasr.TASK_URL_TEMPLATE.format(task_id="task-123"):
return FakeResponse(
{
"output": {
"task_status": "SUCCEEDED",
"results": [{"subtask_status": "FAILED"}],
}
}
)
return super().post(url, **kwargs)
with self.assertRaises(fasr.FunAsrError):
fasr.poll_transcription_task("sk-test", "task-123", poll_interval=0, session=FailedSession({}))
def test_missing_api_key_raises_before_request(self):
session = FakeSession(OFFICIAL_SHAPE_RESULT)
with self.assertRaises(fasr.FunAsrError):
fasr.request_upload_policy("", session=session)
self.assertEqual([], session.calls)
def test_upload_policy_http_error_raises(self):
class PolicyHttpErrorSession(FakeSession):
def get(self, url, **kwargs):
self.calls.append(("GET", url, kwargs))
return FakeResponse({}, status_code=403)
with self.assertRaises(fasr.FunAsrError):
fasr.request_upload_policy("sk-test", session=PolicyHttpErrorSession({}))
def test_malformed_upload_policy_raises(self):
class MalformedPolicySession(FakeSession):
def get(self, url, **kwargs):
self.calls.append(("GET", url, kwargs))
return FakeResponse({"data": {"policy": "missing-required-fields"}})
with self.assertRaises(fasr.FunAsrError):
fasr.request_upload_policy("sk-test", session=MalformedPolicySession({}))
def test_upload_http_failure_raises(self):
class UploadFailureSession(FakeSession):
def post(self, url, **kwargs):
self.calls.append(("POST", url, kwargs))
return FakeResponse({}, status_code=500)
policy = fasr.UploadPolicy(
upload_host="https://upload.example",
upload_dir="dashscope-instant/test",
policy="p",
signature="s",
oss_access_key_id="ak",
max_file_size_mb=1,
)
with tempfile.NamedTemporaryFile() as f:
f.write(b"audio")
f.flush()
with self.assertRaises(fasr.FunAsrError):
fasr.upload_to_temporary_oss(f.name, policy, session=UploadFailureSession({}))
def test_submit_failure_raises(self):
class SubmitFailureSession(FakeSession):
def post(self, url, **kwargs):
self.calls.append(("POST", url, kwargs))
return FakeResponse({}, status_code=500)
with self.assertRaises(fasr.FunAsrError):
fasr.submit_transcription_task("sk-test", "oss://file", session=SubmitFailureSession({}))
def test_poll_timeout_raises(self):
class PendingSession(FakeSession):
def post(self, url, **kwargs):
self.calls.append(("POST", url, kwargs))
return FakeResponse({"output": {"task_status": "RUNNING"}})
with self.assertRaises(fasr.FunAsrError):
fasr.poll_transcription_task("sk-test", "task-123", poll_interval=0, timeout=-1, session=PendingSession({}))
def test_task_failed_status_raises(self):
class FailedTaskSession(FakeSession):
def post(self, url, **kwargs):
self.calls.append(("POST", url, kwargs))
return FakeResponse({"output": {"task_status": "FAILED"}})
with self.assertRaises(fasr.FunAsrError):
fasr.poll_transcription_task("sk-test", "task-123", poll_interval=0, session=FailedTaskSession({}))
def test_missing_transcription_url_raises(self):
with self.assertRaises(fasr.FunAsrError):
fasr.download_transcription_result("", session=FakeSession({}))
def test_malformed_downloaded_json_raises(self):
class MalformedDownloadSession(FakeSession):
def get(self, url, **kwargs):
self.calls.append(("GET", url, kwargs))
return InvalidJsonResponse()
with self.assertRaises(fasr.FunAsrError):
fasr.download_transcription_result("https://result.example/bad.json", session=MalformedDownloadSession({}))
class FunAsrConfigTests(unittest.TestCase):
def test_save_config_persists_fun_asr_section(self):
original_config_file = cfg.config_file
original_fun_asr = cfg.fun_asr
try:
with tempfile.TemporaryDirectory() as tmp_dir:
config_path = Path(tmp_dir) / "config.toml"
cfg.config_file = str(config_path)
cfg.fun_asr = {"api_key": "sk-local", "model": "fun-asr"}
cfg.save_config()
saved = tomllib.loads(config_path.read_text(encoding="utf-8"))
finally:
cfg.config_file = original_config_file
cfg.fun_asr = original_fun_asr
self.assertEqual("sk-local", saved["fun_asr"]["api_key"])
self.assertEqual("fun-asr", saved["fun_asr"]["model"])
def test_config_example_fun_asr_section_parses(self):
config_data = tomllib.loads(Path("config.example.toml").read_text(encoding="utf-8"))
self.assertEqual("fun-asr", config_data["fun_asr"]["model"])
self.assertIn("api_key", config_data["fun_asr"])
if __name__ == "__main__":
unittest.main()