mirror of
https://github.com/linyqh/NarratoAI.git
synced 2026-05-02 06:38:37 +00:00
- 在 WebUI 中增加 Fun-ASR 转录界面,支持上传多种音视频格式并生成 SRT 字幕 - 新增 `app/services/fun_asr_subtitle.py` 服务模块,实现完整的 REST API 调用流程,包括获取上传凭证、文件上传、提交任务、轮询结果和 SRT 格式转换 - 在配置文件中增加 `[fun_asr]` 配置段,支持保存 API Key - 添加完整的单元测试,覆盖核心转换逻辑和服务流程 - 为兼容 Python 3.11 以下版本,将 `tomllib` 导入改为尝试导入并回退到 `tomli` - 在 `defaults.py` 中添加 `from __future__ import annotations` 以支持类型注解
404 lines
16 KiB
Python
404 lines
16 KiB
Python
import tempfile
|
||
import unittest
|
||
from pathlib import Path
|
||
|
||
try:
|
||
import tomllib
|
||
except ModuleNotFoundError: # Python < 3.11
|
||
import tomli as tomllib
|
||
|
||
from app.config import config as cfg
|
||
from app.services import fun_asr_subtitle as fasr
|
||
|
||
|
||
class FakeResponse:
|
||
def __init__(self, payload=None, status_code=200):
|
||
self.payload = payload or {}
|
||
self.status_code = status_code
|
||
|
||
def json(self):
|
||
return self.payload
|
||
|
||
def raise_for_status(self):
|
||
if self.status_code >= 400:
|
||
raise RuntimeError(f"HTTP {self.status_code}")
|
||
|
||
|
||
class InvalidJsonResponse(FakeResponse):
|
||
def json(self):
|
||
raise ValueError("invalid json")
|
||
|
||
|
||
class FakeSession:
|
||
def __init__(self, local_result):
|
||
self.calls = []
|
||
self.local_result = local_result
|
||
|
||
def get(self, url, **kwargs):
|
||
self.calls.append(("GET", url, kwargs))
|
||
if url == fasr.UPLOAD_POLICY_URL:
|
||
return FakeResponse(
|
||
{
|
||
"data": {
|
||
"policy": "policy-token",
|
||
"signature": "signature-token",
|
||
"upload_dir": "dashscope-instant/test-dir",
|
||
"upload_host": "https://dashscope-file-test.oss-cn-beijing.aliyuncs.com",
|
||
"oss_access_key_id": "oss-ak",
|
||
"x_oss_object_acl": "private",
|
||
"x_oss_forbid_overwrite": "true",
|
||
"max_file_size_mb": 1,
|
||
}
|
||
}
|
||
)
|
||
if url == "https://result.example/transcription.json":
|
||
return FakeResponse(self.local_result)
|
||
return FakeResponse({}, 404)
|
||
|
||
def post(self, url, **kwargs):
|
||
self.calls.append(("POST", url, kwargs))
|
||
if url == "https://dashscope-file-test.oss-cn-beijing.aliyuncs.com":
|
||
return FakeResponse({})
|
||
if url == fasr.TRANSCRIPTION_URL:
|
||
return FakeResponse({"output": {"task_status": "PENDING", "task_id": "task-123"}})
|
||
if url == fasr.TASK_URL_TEMPLATE.format(task_id="task-123"):
|
||
return FakeResponse(
|
||
{
|
||
"output": {
|
||
"task_status": "SUCCEEDED",
|
||
"results": [
|
||
{
|
||
"file_url": "oss://dashscope-instant/test-dir/audio.wav",
|
||
"transcription_url": "https://result.example/transcription.json",
|
||
"subtask_status": "SUCCEEDED",
|
||
}
|
||
],
|
||
}
|
||
}
|
||
)
|
||
return FakeResponse({}, 404)
|
||
|
||
|
||
OFFICIAL_SHAPE_RESULT = {
|
||
"transcripts": [
|
||
{
|
||
"sentences": [
|
||
{
|
||
"begin_time": 0,
|
||
"end_time": 3600,
|
||
"text": "你好欢迎观看今天的内容",
|
||
"speaker_id": 0,
|
||
"words": [
|
||
{"begin_time": 0, "end_time": 400, "text": "你好", "punctuation": ","},
|
||
{"begin_time": 400, "end_time": 900, "text": "欢迎", "punctuation": ""},
|
||
{"begin_time": 900, "end_time": 1300, "text": "观看", "punctuation": ""},
|
||
{"begin_time": 1300, "end_time": 1800, "text": "今天", "punctuation": ""},
|
||
{"begin_time": 1800, "end_time": 2400, "text": "的内容", "punctuation": "。"},
|
||
],
|
||
}
|
||
]
|
||
}
|
||
]
|
||
}
|
||
|
||
|
||
class FunAsrSrtConversionTests(unittest.TestCase):
|
||
def test_official_shape_words_convert_ms_and_speaker_label(self):
|
||
srt = fasr.fun_asr_result_to_srt(OFFICIAL_SHAPE_RESULT, max_chars=20, max_duration=3.5)
|
||
|
||
self.assertIn("1\n00:00:00,000 --> 00:00:00,400\n说话人1: 你好,", srt)
|
||
self.assertIn("2\n00:00:00,400 --> 00:00:02,400\n说话人1: 欢迎观看今天的内容。", srt)
|
||
self.assertNotIn("00:06:40,000", srt, "milliseconds must not be treated as seconds")
|
||
|
||
def test_long_word_sequence_splits_into_fine_blocks(self):
|
||
result = {
|
||
"transcripts": [
|
||
{
|
||
"sentences": [
|
||
{
|
||
"begin_time": 0,
|
||
"end_time": 6000,
|
||
"speaker_id": 1,
|
||
"words": [
|
||
{"begin_time": i * 500, "end_time": (i + 1) * 500, "text": token, "punctuation": ""}
|
||
for i, token in enumerate(["这是", "一个", "很长", "字幕", "需要", "拆分"])
|
||
],
|
||
}
|
||
]
|
||
}
|
||
]
|
||
}
|
||
srt = fasr.fun_asr_result_to_srt(result, max_chars=4, max_duration=10)
|
||
|
||
self.assertGreaterEqual(srt.count("\n说话人2:"), 3)
|
||
self.assertIn("1\n00:00:00,000", srt)
|
||
|
||
def test_sentence_fallback_uses_ms_without_zero_duration(self):
|
||
result = {
|
||
"transcripts": [
|
||
{
|
||
"sentences": [
|
||
{
|
||
"begin_time": 1000,
|
||
"end_time": 3000,
|
||
"text": "没有词级时间戳也可以拆分。",
|
||
"speaker_id": 0,
|
||
"words": [],
|
||
}
|
||
]
|
||
}
|
||
]
|
||
}
|
||
srt = fasr.fun_asr_result_to_srt(result, max_chars=5)
|
||
|
||
self.assertIn("00:00:01,000", srt)
|
||
self.assertIn("说话人1:", srt)
|
||
self.assertNotIn("--> 00:00:01,000\n", srt)
|
||
|
||
def test_empty_result_raises_clear_error(self):
|
||
with self.assertRaises(fasr.FunAsrError):
|
||
fasr.fun_asr_result_to_srt({"transcripts": []})
|
||
|
||
def test_malformed_word_timestamp_raises_fun_asr_error(self):
|
||
result = {
|
||
"transcripts": [
|
||
{
|
||
"sentences": [
|
||
{
|
||
"begin_time": 0,
|
||
"end_time": 1000,
|
||
"speaker_id": 0,
|
||
"words": [
|
||
{"begin_time": "bad", "end_time": 500, "text": "坏时间", "punctuation": ""}
|
||
],
|
||
}
|
||
]
|
||
}
|
||
]
|
||
}
|
||
|
||
with self.assertRaises(fasr.FunAsrError):
|
||
fasr.fun_asr_result_to_srt(result)
|
||
|
||
def test_malformed_sentence_timestamp_raises_fun_asr_error(self):
|
||
result = {
|
||
"transcripts": [
|
||
{
|
||
"sentences": [
|
||
{
|
||
"begin_time": "bad",
|
||
"end_time": 1000,
|
||
"text": "坏时间",
|
||
"speaker_id": 0,
|
||
"words": [],
|
||
}
|
||
]
|
||
}
|
||
]
|
||
}
|
||
|
||
with self.assertRaises(fasr.FunAsrError):
|
||
fasr.fun_asr_result_to_srt(result)
|
||
|
||
|
||
class FunAsrServiceTests(unittest.TestCase):
|
||
def test_create_with_fun_asr_uses_expected_rest_flow(self):
|
||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||
local_file = Path(tmp_dir) / "audio.wav"
|
||
local_file.write_bytes(b"audio")
|
||
subtitle_file = Path(tmp_dir) / "out.srt"
|
||
session = FakeSession(OFFICIAL_SHAPE_RESULT)
|
||
|
||
result_path = fasr.create_with_fun_asr(
|
||
str(local_file),
|
||
subtitle_file=str(subtitle_file),
|
||
api_key="sk-test",
|
||
speaker_count=2,
|
||
poll_interval=0,
|
||
session=session,
|
||
)
|
||
|
||
self.assertEqual(str(subtitle_file), result_path)
|
||
self.assertTrue(subtitle_file.exists())
|
||
self.assertIn("说话人1:", subtitle_file.read_text(encoding="utf-8"))
|
||
|
||
policy_call = session.calls[0]
|
||
self.assertEqual("GET", policy_call[0])
|
||
self.assertEqual(fasr.UPLOAD_POLICY_URL, policy_call[1])
|
||
self.assertEqual({"action": "getPolicy", "model": "fun-asr"}, policy_call[2]["params"])
|
||
self.assertEqual("Bearer sk-test", policy_call[2]["headers"]["Authorization"])
|
||
|
||
upload_call = session.calls[1]
|
||
self.assertEqual("POST", upload_call[0])
|
||
self.assertEqual("https://dashscope-file-test.oss-cn-beijing.aliyuncs.com", upload_call[1])
|
||
upload_data = upload_call[2]["data"]
|
||
self.assertEqual("oss-ak", upload_data["OSSAccessKeyId"])
|
||
self.assertEqual("policy-token", upload_data["policy"])
|
||
self.assertEqual("signature-token", upload_data["Signature"])
|
||
self.assertEqual("dashscope-instant/test-dir/audio.wav", upload_data["key"])
|
||
self.assertEqual("200", upload_data["success_action_status"])
|
||
|
||
submit_call = session.calls[2]
|
||
self.assertEqual(fasr.TRANSCRIPTION_URL, submit_call[1])
|
||
headers = submit_call[2]["headers"]
|
||
self.assertEqual("enable", headers["X-DashScope-Async"])
|
||
self.assertEqual("enable", headers["X-DashScope-OssResourceResolve"])
|
||
payload = submit_call[2]["json"]
|
||
self.assertEqual("fun-asr", payload["model"])
|
||
self.assertEqual(["oss://dashscope-instant/test-dir/audio.wav"], payload["input"]["file_urls"])
|
||
self.assertTrue(payload["parameters"]["diarization_enabled"])
|
||
self.assertEqual(2, payload["parameters"]["speaker_count"])
|
||
|
||
poll_call = session.calls[3]
|
||
self.assertEqual("POST", poll_call[0])
|
||
self.assertTrue(poll_call[1].endswith("/api/v1/tasks/task-123"))
|
||
|
||
download_call = session.calls[4]
|
||
self.assertEqual(("GET", "https://result.example/transcription.json"), download_call[:2])
|
||
|
||
def test_upload_policy_size_validation_fails_before_upload(self):
|
||
policy = fasr.UploadPolicy(
|
||
upload_host="https://upload.example",
|
||
upload_dir="dashscope-instant/test",
|
||
policy="p",
|
||
signature="s",
|
||
oss_access_key_id="ak",
|
||
max_file_size_mb=0.000001,
|
||
)
|
||
with tempfile.NamedTemporaryFile() as f:
|
||
f.write(b"too-large")
|
||
f.flush()
|
||
with self.assertRaises(fasr.FunAsrError):
|
||
fasr.upload_to_temporary_oss(f.name, policy, session=FakeSession({}))
|
||
|
||
def test_failed_subtask_raises(self):
|
||
class FailedSession(FakeSession):
|
||
def post(self, url, **kwargs):
|
||
if url == fasr.TASK_URL_TEMPLATE.format(task_id="task-123"):
|
||
return FakeResponse(
|
||
{
|
||
"output": {
|
||
"task_status": "SUCCEEDED",
|
||
"results": [{"subtask_status": "FAILED"}],
|
||
}
|
||
}
|
||
)
|
||
return super().post(url, **kwargs)
|
||
|
||
with self.assertRaises(fasr.FunAsrError):
|
||
fasr.poll_transcription_task("sk-test", "task-123", poll_interval=0, session=FailedSession({}))
|
||
|
||
def test_missing_api_key_raises_before_request(self):
|
||
session = FakeSession(OFFICIAL_SHAPE_RESULT)
|
||
|
||
with self.assertRaises(fasr.FunAsrError):
|
||
fasr.request_upload_policy("", session=session)
|
||
|
||
self.assertEqual([], session.calls)
|
||
|
||
def test_upload_policy_http_error_raises(self):
|
||
class PolicyHttpErrorSession(FakeSession):
|
||
def get(self, url, **kwargs):
|
||
self.calls.append(("GET", url, kwargs))
|
||
return FakeResponse({}, status_code=403)
|
||
|
||
with self.assertRaises(fasr.FunAsrError):
|
||
fasr.request_upload_policy("sk-test", session=PolicyHttpErrorSession({}))
|
||
|
||
def test_malformed_upload_policy_raises(self):
|
||
class MalformedPolicySession(FakeSession):
|
||
def get(self, url, **kwargs):
|
||
self.calls.append(("GET", url, kwargs))
|
||
return FakeResponse({"data": {"policy": "missing-required-fields"}})
|
||
|
||
with self.assertRaises(fasr.FunAsrError):
|
||
fasr.request_upload_policy("sk-test", session=MalformedPolicySession({}))
|
||
|
||
def test_upload_http_failure_raises(self):
|
||
class UploadFailureSession(FakeSession):
|
||
def post(self, url, **kwargs):
|
||
self.calls.append(("POST", url, kwargs))
|
||
return FakeResponse({}, status_code=500)
|
||
|
||
policy = fasr.UploadPolicy(
|
||
upload_host="https://upload.example",
|
||
upload_dir="dashscope-instant/test",
|
||
policy="p",
|
||
signature="s",
|
||
oss_access_key_id="ak",
|
||
max_file_size_mb=1,
|
||
)
|
||
with tempfile.NamedTemporaryFile() as f:
|
||
f.write(b"audio")
|
||
f.flush()
|
||
with self.assertRaises(fasr.FunAsrError):
|
||
fasr.upload_to_temporary_oss(f.name, policy, session=UploadFailureSession({}))
|
||
|
||
def test_submit_failure_raises(self):
|
||
class SubmitFailureSession(FakeSession):
|
||
def post(self, url, **kwargs):
|
||
self.calls.append(("POST", url, kwargs))
|
||
return FakeResponse({}, status_code=500)
|
||
|
||
with self.assertRaises(fasr.FunAsrError):
|
||
fasr.submit_transcription_task("sk-test", "oss://file", session=SubmitFailureSession({}))
|
||
|
||
def test_poll_timeout_raises(self):
|
||
class PendingSession(FakeSession):
|
||
def post(self, url, **kwargs):
|
||
self.calls.append(("POST", url, kwargs))
|
||
return FakeResponse({"output": {"task_status": "RUNNING"}})
|
||
|
||
with self.assertRaises(fasr.FunAsrError):
|
||
fasr.poll_transcription_task("sk-test", "task-123", poll_interval=0, timeout=-1, session=PendingSession({}))
|
||
|
||
def test_task_failed_status_raises(self):
|
||
class FailedTaskSession(FakeSession):
|
||
def post(self, url, **kwargs):
|
||
self.calls.append(("POST", url, kwargs))
|
||
return FakeResponse({"output": {"task_status": "FAILED"}})
|
||
|
||
with self.assertRaises(fasr.FunAsrError):
|
||
fasr.poll_transcription_task("sk-test", "task-123", poll_interval=0, session=FailedTaskSession({}))
|
||
|
||
def test_missing_transcription_url_raises(self):
|
||
with self.assertRaises(fasr.FunAsrError):
|
||
fasr.download_transcription_result("", session=FakeSession({}))
|
||
|
||
def test_malformed_downloaded_json_raises(self):
|
||
class MalformedDownloadSession(FakeSession):
|
||
def get(self, url, **kwargs):
|
||
self.calls.append(("GET", url, kwargs))
|
||
return InvalidJsonResponse()
|
||
|
||
with self.assertRaises(fasr.FunAsrError):
|
||
fasr.download_transcription_result("https://result.example/bad.json", session=MalformedDownloadSession({}))
|
||
|
||
|
||
class FunAsrConfigTests(unittest.TestCase):
|
||
def test_save_config_persists_fun_asr_section(self):
|
||
original_config_file = cfg.config_file
|
||
original_fun_asr = cfg.fun_asr
|
||
try:
|
||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||
config_path = Path(tmp_dir) / "config.toml"
|
||
cfg.config_file = str(config_path)
|
||
cfg.fun_asr = {"api_key": "sk-local", "model": "fun-asr"}
|
||
cfg.save_config()
|
||
saved = tomllib.loads(config_path.read_text(encoding="utf-8"))
|
||
finally:
|
||
cfg.config_file = original_config_file
|
||
cfg.fun_asr = original_fun_asr
|
||
|
||
self.assertEqual("sk-local", saved["fun_asr"]["api_key"])
|
||
self.assertEqual("fun-asr", saved["fun_asr"]["model"])
|
||
|
||
def test_config_example_fun_asr_section_parses(self):
|
||
config_data = tomllib.loads(Path("config.example.toml").read_text(encoding="utf-8"))
|
||
self.assertEqual("fun-asr", config_data["fun_asr"]["model"])
|
||
self.assertIn("api_key", config_data["fun_asr"])
|
||
|
||
|
||
if __name__ == "__main__":
|
||
unittest.main()
|