NarratoAI/app/services/test_fun_asr_subtitle_unittest.py
viccy 34d5532119 feat(subtitle): 新增 FireRedASR2 本地 ASR 后端支持
添加 FireRedASR2 本地 ASR 转写后端的完整支持:
1. 新增配置参数与数据模型字段
2. 更新示例配置文件,添加默认本地服务地址
3. 完善任务服务中的转写逻辑,支持 FireRedASR 后端
4. 更新 WebUI 界面,新增对应配置选项
5. 补充中英文多语言翻译
6. 新增本地 FireRedASR 服务的单元测试
2026-06-07 17:58:02 +08:00

598 lines
23 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import tempfile
import unittest
from pathlib import Path
try:
import tomllib
except ModuleNotFoundError: # Python < 3.11
import tomli as tomllib
from app.config import config as cfg
from app.services import fun_asr_subtitle as fasr
class FakeResponse:
def __init__(self, payload=None, status_code=200, text=None):
self.payload = payload or {}
self.status_code = status_code
self.text = text
self.content = text.encode("utf-8") if isinstance(text, str) else b""
def json(self):
return self.payload
def raise_for_status(self):
if self.status_code >= 400:
raise RuntimeError(f"HTTP {self.status_code}")
class InvalidJsonResponse(FakeResponse):
def json(self):
raise ValueError("invalid json")
class FakeSession:
def __init__(self, local_result):
self.calls = []
self.local_result = local_result
def get(self, url, **kwargs):
self.calls.append(("GET", url, kwargs))
if url == fasr.UPLOAD_POLICY_URL:
return FakeResponse(
{
"data": {
"policy": "policy-token",
"signature": "signature-token",
"upload_dir": "dashscope-instant/test-dir",
"upload_host": "https://dashscope-file-test.oss-cn-beijing.aliyuncs.com",
"oss_access_key_id": "oss-ak",
"x_oss_object_acl": "private",
"x_oss_forbid_overwrite": "true",
"max_file_size_mb": 1,
}
}
)
if url == "https://result.example/transcription.json":
return FakeResponse(self.local_result)
return FakeResponse({}, 404)
def post(self, url, **kwargs):
self.calls.append(("POST", url, kwargs))
if url == "https://dashscope-file-test.oss-cn-beijing.aliyuncs.com":
return FakeResponse({})
if url == fasr.TRANSCRIPTION_URL:
return FakeResponse({"output": {"task_status": "PENDING", "task_id": "task-123"}})
if url == fasr.TASK_URL_TEMPLATE.format(task_id="task-123"):
return FakeResponse(
{
"output": {
"task_status": "SUCCEEDED",
"results": [
{
"file_url": "oss://dashscope-instant/test-dir/audio.wav",
"transcription_url": "https://result.example/transcription.json",
"subtask_status": "SUCCEEDED",
}
],
}
}
)
return FakeResponse({}, 404)
OFFICIAL_SHAPE_RESULT = {
"transcripts": [
{
"sentences": [
{
"begin_time": 0,
"end_time": 3600,
"text": "你好欢迎观看今天的内容",
"speaker_id": 0,
"words": [
{"begin_time": 0, "end_time": 400, "text": "你好", "punctuation": ""},
{"begin_time": 400, "end_time": 900, "text": "欢迎", "punctuation": ""},
{"begin_time": 900, "end_time": 1300, "text": "观看", "punctuation": ""},
{"begin_time": 1300, "end_time": 1800, "text": "今天", "punctuation": ""},
{"begin_time": 1800, "end_time": 2400, "text": "的内容", "punctuation": ""},
],
}
]
}
]
}
class FunAsrSrtConversionTests(unittest.TestCase):
def test_official_shape_words_convert_ms_and_speaker_label(self):
srt = fasr.fun_asr_result_to_srt(OFFICIAL_SHAPE_RESULT, max_chars=20, max_duration=3.5)
self.assertIn("1\n00:00:00,000 --> 00:00:00,400\n说话人1: 你好,", srt)
self.assertIn("2\n00:00:00,400 --> 00:00:02,400\n说话人1: 欢迎观看今天的内容。", srt)
self.assertNotIn("00:06:40,000", srt, "milliseconds must not be treated as seconds")
def test_long_word_sequence_splits_into_fine_blocks(self):
result = {
"transcripts": [
{
"sentences": [
{
"begin_time": 0,
"end_time": 6000,
"speaker_id": 1,
"words": [
{"begin_time": i * 500, "end_time": (i + 1) * 500, "text": token, "punctuation": ""}
for i, token in enumerate(["这是", "一个", "很长", "字幕", "需要", "拆分"])
],
}
]
}
]
}
srt = fasr.fun_asr_result_to_srt(result, max_chars=4, max_duration=10)
self.assertGreaterEqual(srt.count("\n说话人2:"), 3)
self.assertIn("1\n00:00:00,000", srt)
def test_sentence_fallback_uses_ms_without_zero_duration(self):
result = {
"transcripts": [
{
"sentences": [
{
"begin_time": 1000,
"end_time": 3000,
"text": "没有词级时间戳也可以拆分。",
"speaker_id": 0,
"words": [],
}
]
}
]
}
srt = fasr.fun_asr_result_to_srt(result, max_chars=5)
self.assertIn("00:00:01,000", srt)
self.assertIn("说话人1:", srt)
self.assertNotIn("--> 00:00:01,000\n", srt)
def test_empty_result_raises_clear_error(self):
with self.assertRaises(fasr.FunAsrError):
fasr.fun_asr_result_to_srt({"transcripts": []})
def test_malformed_word_timestamp_raises_fun_asr_error(self):
result = {
"transcripts": [
{
"sentences": [
{
"begin_time": 0,
"end_time": 1000,
"speaker_id": 0,
"words": [
{"begin_time": "bad", "end_time": 500, "text": "坏时间", "punctuation": ""}
],
}
]
}
]
}
with self.assertRaises(fasr.FunAsrError):
fasr.fun_asr_result_to_srt(result)
def test_malformed_sentence_timestamp_raises_fun_asr_error(self):
result = {
"transcripts": [
{
"sentences": [
{
"begin_time": "bad",
"end_time": 1000,
"text": "坏时间",
"speaker_id": 0,
"words": [],
}
]
}
]
}
with self.assertRaises(fasr.FunAsrError):
fasr.fun_asr_result_to_srt(result)
class FunAsrServiceTests(unittest.TestCase):
def test_create_with_fun_asr_uses_expected_rest_flow(self):
with tempfile.TemporaryDirectory() as tmp_dir:
local_file = Path(tmp_dir) / "audio.wav"
local_file.write_bytes(b"audio")
subtitle_file = Path(tmp_dir) / "out.srt"
session = FakeSession(OFFICIAL_SHAPE_RESULT)
result_path = fasr.create_with_fun_asr(
str(local_file),
subtitle_file=str(subtitle_file),
api_key="sk-test",
speaker_count=2,
poll_interval=0,
session=session,
)
self.assertEqual(str(subtitle_file), result_path)
self.assertTrue(subtitle_file.exists())
self.assertIn("说话人1:", subtitle_file.read_text(encoding="utf-8"))
policy_call = session.calls[0]
self.assertEqual("GET", policy_call[0])
self.assertEqual(fasr.UPLOAD_POLICY_URL, policy_call[1])
self.assertEqual({"action": "getPolicy", "model": "fun-asr"}, policy_call[2]["params"])
self.assertEqual("Bearer sk-test", policy_call[2]["headers"]["Authorization"])
upload_call = session.calls[1]
self.assertEqual("POST", upload_call[0])
self.assertEqual("https://dashscope-file-test.oss-cn-beijing.aliyuncs.com", upload_call[1])
upload_data = upload_call[2]["data"]
self.assertEqual("oss-ak", upload_data["OSSAccessKeyId"])
self.assertEqual("policy-token", upload_data["policy"])
self.assertEqual("signature-token", upload_data["Signature"])
self.assertEqual("dashscope-instant/test-dir/audio.wav", upload_data["key"])
self.assertEqual("200", upload_data["success_action_status"])
submit_call = session.calls[2]
self.assertEqual(fasr.TRANSCRIPTION_URL, submit_call[1])
headers = submit_call[2]["headers"]
self.assertEqual("enable", headers["X-DashScope-Async"])
self.assertEqual("enable", headers["X-DashScope-OssResourceResolve"])
payload = submit_call[2]["json"]
self.assertEqual("fun-asr", payload["model"])
self.assertEqual(["oss://dashscope-instant/test-dir/audio.wav"], payload["input"]["file_urls"])
self.assertTrue(payload["parameters"]["diarization_enabled"])
self.assertEqual(2, payload["parameters"]["speaker_count"])
poll_call = session.calls[3]
self.assertEqual("POST", poll_call[0])
self.assertTrue(poll_call[1].endswith("/api/v1/tasks/task-123"))
download_call = session.calls[4]
self.assertEqual(("GET", "https://result.example/transcription.json"), download_call[:2])
def test_upload_policy_size_validation_fails_before_upload(self):
policy = fasr.UploadPolicy(
upload_host="https://upload.example",
upload_dir="dashscope-instant/test",
policy="p",
signature="s",
oss_access_key_id="ak",
max_file_size_mb=0.000001,
)
with tempfile.NamedTemporaryFile() as f:
f.write(b"too-large")
f.flush()
with self.assertRaises(fasr.FunAsrError):
fasr.upload_to_temporary_oss(f.name, policy, session=FakeSession({}))
def test_failed_subtask_raises(self):
class FailedSession(FakeSession):
def post(self, url, **kwargs):
if url == fasr.TASK_URL_TEMPLATE.format(task_id="task-123"):
return FakeResponse(
{
"output": {
"task_status": "SUCCEEDED",
"results": [{"subtask_status": "FAILED"}],
}
}
)
return super().post(url, **kwargs)
with self.assertRaises(fasr.FunAsrError):
fasr.poll_transcription_task("sk-test", "task-123", poll_interval=0, session=FailedSession({}))
def test_missing_api_key_raises_before_request(self):
session = FakeSession(OFFICIAL_SHAPE_RESULT)
with self.assertRaises(fasr.FunAsrError):
fasr.request_upload_policy("", session=session)
self.assertEqual([], session.calls)
def test_upload_policy_http_error_raises(self):
class PolicyHttpErrorSession(FakeSession):
def get(self, url, **kwargs):
self.calls.append(("GET", url, kwargs))
return FakeResponse({}, status_code=403)
with self.assertRaises(fasr.FunAsrError):
fasr.request_upload_policy("sk-test", session=PolicyHttpErrorSession({}))
def test_malformed_upload_policy_raises(self):
class MalformedPolicySession(FakeSession):
def get(self, url, **kwargs):
self.calls.append(("GET", url, kwargs))
return FakeResponse({"data": {"policy": "missing-required-fields"}})
with self.assertRaises(fasr.FunAsrError):
fasr.request_upload_policy("sk-test", session=MalformedPolicySession({}))
def test_upload_http_failure_raises(self):
class UploadFailureSession(FakeSession):
def post(self, url, **kwargs):
self.calls.append(("POST", url, kwargs))
return FakeResponse({}, status_code=500)
policy = fasr.UploadPolicy(
upload_host="https://upload.example",
upload_dir="dashscope-instant/test",
policy="p",
signature="s",
oss_access_key_id="ak",
max_file_size_mb=1,
)
with tempfile.NamedTemporaryFile() as f:
f.write(b"audio")
f.flush()
with self.assertRaises(fasr.FunAsrError):
fasr.upload_to_temporary_oss(f.name, policy, session=UploadFailureSession({}))
def test_submit_failure_raises(self):
class SubmitFailureSession(FakeSession):
def post(self, url, **kwargs):
self.calls.append(("POST", url, kwargs))
return FakeResponse({}, status_code=500)
with self.assertRaises(fasr.FunAsrError):
fasr.submit_transcription_task("sk-test", "oss://file", session=SubmitFailureSession({}))
def test_poll_timeout_raises(self):
class PendingSession(FakeSession):
def post(self, url, **kwargs):
self.calls.append(("POST", url, kwargs))
return FakeResponse({"output": {"task_status": "RUNNING"}})
with self.assertRaises(fasr.FunAsrError):
fasr.poll_transcription_task("sk-test", "task-123", poll_interval=0, timeout=-1, session=PendingSession({}))
def test_task_failed_status_raises(self):
class FailedTaskSession(FakeSession):
def post(self, url, **kwargs):
self.calls.append(("POST", url, kwargs))
return FakeResponse({"output": {"task_status": "FAILED"}})
with self.assertRaises(fasr.FunAsrError):
fasr.poll_transcription_task("sk-test", "task-123", poll_interval=0, session=FailedTaskSession({}))
def test_missing_transcription_url_raises(self):
with self.assertRaises(fasr.FunAsrError):
fasr.download_transcription_result("", session=FakeSession({}))
def test_malformed_downloaded_json_raises(self):
class MalformedDownloadSession(FakeSession):
def get(self, url, **kwargs):
self.calls.append(("GET", url, kwargs))
return InvalidJsonResponse()
with self.assertRaises(fasr.FunAsrError):
fasr.download_transcription_result("https://result.example/bad.json", session=MalformedDownloadSession({}))
class LocalFunAsrServiceTests(unittest.TestCase):
def test_request_local_fun_asr_posts_file_and_options(self):
class LocalSession:
def __init__(self):
self.calls = []
def post(self, url, **kwargs):
self.calls.append(("POST", url, kwargs))
return FakeResponse({"text": "你好", "srt_file": "/tmp/out.srt"})
with tempfile.TemporaryDirectory() as tmp_dir:
local_file = Path(tmp_dir) / "audio.wav"
local_file.write_bytes(b"audio")
session = LocalSession()
result = fasr.request_local_fun_asr(
str(local_file),
api_url="127.0.0.1:7860",
hotword="NarratoAI",
enable_spk=True,
timeout=123,
session=session,
)
self.assertEqual("你好", result["text"])
self.assertEqual("POST", session.calls[0][0])
self.assertEqual("http://127.0.0.1:7860/asr", session.calls[0][1])
self.assertEqual({"hotword": "NarratoAI", "enable_spk": "true"}, session.calls[0][2]["data"])
self.assertEqual(123, session.calls[0][2]["timeout"])
self.assertIn("file", session.calls[0][2]["files"])
def test_create_with_local_fun_asr_copies_pack_srt_file(self):
class LocalSession:
def __init__(self, srt_file):
self.srt_file = srt_file
self.calls = []
def post(self, url, **kwargs):
self.calls.append(("POST", url, kwargs))
return FakeResponse({"text": "你好", "srt_file": str(self.srt_file)})
with tempfile.TemporaryDirectory() as tmp_dir:
local_file = Path(tmp_dir) / "audio.wav"
local_file.write_bytes(b"audio")
pack_srt = Path(tmp_dir) / "pack.srt"
pack_srt.write_text("1\n00:00:00,000 --> 00:00:01,000\n你好\n", encoding="utf-8")
subtitle_file = Path(tmp_dir) / "out.srt"
result_path = fasr.create_with_local_fun_asr(
str(local_file),
subtitle_file=str(subtitle_file),
api_url="http://127.0.0.1:7860",
session=LocalSession(pack_srt),
)
self.assertEqual(str(subtitle_file), result_path)
self.assertEqual(pack_srt.read_text(encoding="utf-8"), subtitle_file.read_text(encoding="utf-8"))
def test_create_with_local_fun_asr_downloads_relative_srt(self):
class LocalSession:
def __init__(self):
self.calls = []
def post(self, url, **kwargs):
self.calls.append(("POST", url, kwargs))
return FakeResponse({"text": "你好", "downloads": {"srt": "/download/result.srt"}})
def get(self, url, **kwargs):
self.calls.append(("GET", url, kwargs))
return FakeResponse(text="1\n00:00:00,000 --> 00:00:01,000\n你好\n")
with tempfile.TemporaryDirectory() as tmp_dir:
local_file = Path(tmp_dir) / "audio.wav"
local_file.write_bytes(b"audio")
subtitle_file = Path(tmp_dir) / "out.srt"
session = LocalSession()
result_path = fasr.create_with_local_fun_asr(
str(local_file),
subtitle_file=str(subtitle_file),
api_url="http://127.0.0.1:7860/asr",
session=session,
)
self.assertEqual(str(subtitle_file), result_path)
self.assertEqual("http://127.0.0.1:7860/download/result.srt", session.calls[1][1])
self.assertIn("你好", subtitle_file.read_text(encoding="utf-8"))
def test_local_fun_asr_result_to_srt_uses_raw_timestamps(self):
result = {
"raw": [
{
"text": "你好,世界。",
"timestamp": [[0, 300], [300, 600], [600, 900], [900, 1200]],
}
]
}
srt = fasr.local_fun_asr_result_to_srt(result, max_chars=20)
self.assertIn("00:00:00,000 --> 00:00:00,600\n你好,", srt)
self.assertIn("世界。", srt)
class LocalFireRedAsrServiceTests(unittest.TestCase):
def test_request_local_firered_asr_posts_file_and_options(self):
class LocalSession:
def __init__(self):
self.calls = []
def post(self, url, **kwargs):
self.calls.append(("POST", url, kwargs))
return FakeResponse({"text": "你好", "srt_url": "/outputs/out.srt"})
with tempfile.TemporaryDirectory() as tmp_dir:
local_file = Path(tmp_dir) / "audio.wav"
local_file.write_bytes(b"audio")
session = LocalSession()
result = fasr.request_local_firered_asr(
str(local_file),
api_url="127.0.0.1:7867",
enable_vad=True,
enable_lid=False,
enable_punc=True,
return_timestamp=True,
timeout=456,
session=session,
)
self.assertEqual("你好", result["text"])
self.assertEqual("POST", session.calls[0][0])
self.assertEqual("http://127.0.0.1:7867/asr", session.calls[0][1])
self.assertEqual(
{
"enable_vad": "true",
"enable_lid": "false",
"enable_punc": "true",
"return_timestamp": "true",
},
session.calls[0][2]["data"],
)
self.assertEqual(456, session.calls[0][2]["timeout"])
self.assertIn("file", session.calls[0][2]["files"])
def test_create_with_local_firered_asr_downloads_srt_url(self):
class LocalSession:
def __init__(self):
self.calls = []
def post(self, url, **kwargs):
self.calls.append(("POST", url, kwargs))
return FakeResponse({"text": "你好", "srt_url": "/outputs/result.srt"})
def get(self, url, **kwargs):
self.calls.append(("GET", url, kwargs))
return FakeResponse(text="1\n00:00:00,000 --> 00:00:01,000\n你好\n")
with tempfile.TemporaryDirectory() as tmp_dir:
local_file = Path(tmp_dir) / "audio.wav"
local_file.write_bytes(b"audio")
subtitle_file = Path(tmp_dir) / "out.srt"
session = LocalSession()
result_path = fasr.create_with_local_firered_asr(
str(local_file),
subtitle_file=str(subtitle_file),
api_url="http://127.0.0.1:7867",
session=session,
)
self.assertEqual(str(subtitle_file), result_path)
self.assertEqual("http://127.0.0.1:7867/outputs/result.srt", session.calls[1][1])
self.assertIn("你好", subtitle_file.read_text(encoding="utf-8"))
def test_firered_asr_result_to_srt_uses_sentence_timestamps(self):
result = {
"sentences": [
{"text": "你好。", "start_ms": 40, "end_ms": 900},
{"text": "欢迎观看。", "start_ms": 900, "end_ms": 2100},
]
}
srt = fasr.firered_asr_result_to_srt(result)
self.assertIn("1\n00:00:00,040 --> 00:00:00,900\n你好。", srt)
self.assertIn("2\n00:00:00,900 --> 00:00:02,100\n欢迎观看。", srt)
class FunAsrConfigTests(unittest.TestCase):
def test_save_config_persists_fun_asr_section(self):
original_config_file = cfg.config_file
original_fun_asr = cfg.fun_asr
try:
with tempfile.TemporaryDirectory() as tmp_dir:
config_path = Path(tmp_dir) / "config.toml"
cfg.config_file = str(config_path)
cfg.fun_asr = {"api_key": "sk-local", "model": "fun-asr"}
cfg.save_config()
saved = tomllib.loads(config_path.read_text(encoding="utf-8"))
finally:
cfg.config_file = original_config_file
cfg.fun_asr = original_fun_asr
self.assertEqual("sk-local", saved["fun_asr"]["api_key"])
self.assertEqual("fun-asr", saved["fun_asr"]["model"])
def test_config_example_fun_asr_section_parses(self):
config_data = tomllib.loads(Path("config.example.toml").read_text(encoding="utf-8"))
self.assertEqual("local", config_data["fun_asr"]["backend"])
self.assertEqual("http://127.0.0.1:7860", config_data["fun_asr"]["api_url"])
self.assertEqual("http://127.0.0.1:7867", config_data["fun_asr"]["firered_api_url"])
self.assertEqual("fun-asr", config_data["fun_asr"]["model"])
self.assertIn("api_key", config_data["fun_asr"])
if __name__ == "__main__":
unittest.main()