mirror of
https://github.com/linyqh/NarratoAI.git
synced 2026-05-03 23:29:19 +00:00
Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c0b72ec603 | ||
|
|
99dd4193ae | ||
|
|
8c129790c7 | ||
|
|
de33c6d0bd | ||
|
|
852f5ae34c | ||
|
|
d45c1858c9 | ||
|
|
71dfc99839 | ||
|
|
be653c5748 | ||
|
|
d5c63cf4b4 | ||
|
|
e53156f4f2 | ||
|
|
abc9db22e5 | ||
|
|
4e2560651f | ||
|
|
a8b6a5bb6b | ||
|
|
d678bf62b1 | ||
|
|
ac63fea953 | ||
|
|
df034d104b | ||
|
|
ad02059e5d | ||
|
|
4d21c43b89 | ||
|
|
8201911b82 | ||
|
|
3d76bff442 | ||
|
|
40a48cc9ff | ||
|
|
c83841a2e0 | ||
|
|
f9539eac8c | ||
|
|
1d148370c5 | ||
|
|
093c8aa329 | ||
|
|
1057bd215c |
8
.gitignore
vendored
8
.gitignore
vendored
@ -39,9 +39,15 @@ bug清单.md
|
|||||||
task.md
|
task.md
|
||||||
.claude/*
|
.claude/*
|
||||||
.serena/*
|
.serena/*
|
||||||
|
.worktrees/
|
||||||
|
|
||||||
# OpenSpec: 忽略活动的变更提案,但保留归档和规范
|
# OpenSpec: 忽略活动的变更提案,但保留归档和规范
|
||||||
openspec/*
|
openspec/*
|
||||||
AGENTS.md
|
AGENTS.md
|
||||||
CLAUDE.md
|
CLAUDE.md
|
||||||
tests/*
|
tests/*
|
||||||
|
!tests/test_documentary_frame_analysis_service.py
|
||||||
|
!tests/test_video_processor_documentary_unittest.py
|
||||||
|
!tests/test_script_service_documentary_unittest.py
|
||||||
|
!tests/test_generate_narration_script_documentary_unittest.py
|
||||||
|
!tests/test_generate_script_docu_unittest.py
|
||||||
|
|||||||
@ -33,6 +33,7 @@ NarratoAI is an automated video narration tool that provides an all-in-one solut
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
## Latest News
|
## Latest News
|
||||||
|
- 2026.04.03 Released version 0.7.8, refactored the documentary frame-analysis pipeline with a shared service and improved extraction, caching, vision batching, and narration generation
|
||||||
- 2025.05.11 Released new version 0.6.0, supports **short drama commentary** and optimized editing process
|
- 2025.05.11 Released new version 0.6.0, supports **short drama commentary** and optimized editing process
|
||||||
- 2025.03.06 Released new version 0.5.2, supports DeepSeek R1 and DeepSeek V3 models for short drama mixing
|
- 2025.03.06 Released new version 0.5.2, supports DeepSeek R1 and DeepSeek V3 models for short drama mixing
|
||||||
- 2024.12.16 Released new version 0.3.9, supports Alibaba Qwen2-VL model for video understanding; supports short drama mixing
|
- 2024.12.16 Released new version 0.3.9, supports Alibaba Qwen2-VL model for video understanding; supports short drama mixing
|
||||||
|
|||||||
54
README.md
54
README.md
@ -1,38 +1,48 @@
|
|||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
<h1 align="center" style="font-size: 2cm;"> NarratoAI 😎📽️ </h1>
|
<h1 align="center"> NarratoAI 😎📽️ </h1>
|
||||||
<h3 align="center">一站式 AI 影视解说+自动化剪辑工具🎬🎞️ </h3>
|
<h3 align="center">一站式 AI 影视解说+自动化剪辑工具🎬🎞️ </h3>
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
📖 <a href="README-en.md">English</a> | 简体中文 | <a href="https://www.narratoai.cn">☁️ <b>云端版入口 (NarratoAI.cn)</b></a>
|
||||||
|
</p>
|
||||||
|
|
||||||
<h3>📖 <a href="README-en.md">English</a> | 简体中文 </h3>
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
<br>
|
||||||
|
|
||||||
|
> **🔥 隆重推荐:VibeCut 的新范式 —— [speclip.com](https://speclip.com)**
|
||||||
|
>
|
||||||
|
> **一个真正意义上的视频剪辑 Agent!像聊天(vibecoding)一样剪辑视频。**
|
||||||
|
> **[👉 点击立即免费下载 Speclip](https://speclip.com)**
|
||||||
|
|
||||||
[//]: # ( <a href="https://trendshift.io/repositories/8731" target="_blank"><img src="https://trendshift.io/api/badge/repositories/8731" alt="harry0703%2FNarratoAI | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>)
|
|
||||||
</div>
|
</div>
|
||||||
<br>
|
|
||||||
NarratoAI 是一个自动化影视解说工具,基于LLM实现文案撰写、自动化视频剪辑、配音和字幕生成的一站式流程,助力高效内容创作。
|
|
||||||
<br>
|
<br>
|
||||||
|
|
||||||
> **🔥 隆重推荐:VibeCut 的新范式 —— [Speclip](https://speclip.com) ,一个真正意义上的剪辑 Agent**
|
NarratoAI 是一款自动化影视解说工具,基于 LLM 实现文案撰写、自动化视频剪辑、配音和字幕生成的一站式流程,助力高效内容创作。支持本地部署开源版及 [云端托管版](https://www.narratoai.cn)。
|
||||||
|
|
||||||
[](https://github.com/linyqh/NarratoAI)
|
<br>
|
||||||
[](https://github.com/linyqh/NarratoAI/blob/main/LICENSE)
|
|
||||||
[](https://github.com/linyqh/NarratoAI/issues)
|
|
||||||
[](https://github.com/linyqh/NarratoAI/stargazers)
|
|
||||||
|
|
||||||
<a href="https://discord.com/invite/V2pbAqqQNb" target="_blank">💬 加入 discord 开源社区,获取项目动态和最新资讯。</a>
|
[](https://github.com/linyqh/NarratoAI/stargazers) [](https://github.com/linyqh/NarratoAI/issues) [](https://github.com/linyqh/NarratoAI)
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
|
<a href="https://github.com/linyqh/NarratoAI/wiki" target="_blank">💬 加入开源社群,获取项目动态和最新资讯</a>
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
<h2><a href="https://p9mf6rjv3c.feishu.cn/wiki/SP8swLLZki5WRWkhuFvc2CyInDg?from=from_copylink" target="_blank">🎉🎉🎉 官方文档 🎉🎉🎉</a> </h2>
|
<h2><a href="https://p9mf6rjv3c.feishu.cn/wiki/SP8swLLZki5WRWkhuFvc2CyInDg?from=from_copylink" target="_blank">🎉🎉🎉 官方文档 🎉🎉🎉</a> </h2>
|
||||||
<h3>首页</h3>
|
|
||||||
|
### 界面预览
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
## 许可证
|
## 许可证
|
||||||
本项目仅供学习和研究使用,不得商用。如需商业授权,请联系作者。
|
本项目仅供学习和研究使用,不得商用。如需商业授权,请联系作者。
|
||||||
|
|
||||||
## 最新资讯
|
## 最新资讯
|
||||||
|
- 2026.04.27 发布新版本 0.7.9,新增 **Fun-ASR一键转录字幕**
|
||||||
|
- 2026.04.03 发布新版本 0.7.8,重构纪录片逐帧分析链路,统一共享服务并优化抽帧、缓存、视觉并发与文案生成流程
|
||||||
- 2026.03.27 发布新版本 0.7.7,出于安全考虑,已移除 LiteLLM 依赖,统一使用 OpenAI 兼容请求链路
|
- 2026.03.27 发布新版本 0.7.7,出于安全考虑,已移除 LiteLLM 依赖,统一使用 OpenAI 兼容请求链路
|
||||||
- 2025.11.20 发布新版本 0.7.5,新增 [IndexTTS2](https://github.com/index-tts/index-tts) 语音克隆支持
|
- 2025.11.20 发布新版本 0.7.5,新增 [IndexTTS2](https://github.com/index-tts/index-tts) 语音克隆支持
|
||||||
- 2025.10.15 发布新版本 0.7.3,升级大模型供应商管理能力
|
- 2025.10.15 发布新版本 0.7.3,升级大模型供应商管理能力
|
||||||
@ -47,17 +57,7 @@ NarratoAI 是一个自动化影视解说工具,基于LLM实现文案撰写、
|
|||||||
- 2024.11.10 发布新版本 v0.3.5;优化视频剪辑流程,
|
- 2024.11.10 发布新版本 v0.3.5;优化视频剪辑流程,
|
||||||
|
|
||||||
## 重磅福利 🎉
|
## 重磅福利 🎉
|
||||||
> 1️⃣
|
|
||||||
> **开发者专属福利:一站式AI平台,注册即送体验金!**
|
|
||||||
>
|
|
||||||
> 还在为接入各种AI模型烦恼吗?向您推荐 302.AI,一个企业级的AI资源中心。一次接入,即可调用上百种AI模型,涵盖语言、图像、音视频等,按量付费,极大降低开发成本。
|
|
||||||
>
|
|
||||||
> 通过下方我的专属链接注册,**立获1美元免费体验金**,助您轻松开启AI开发之旅。
|
|
||||||
>
|
|
||||||
> **立即注册领取:** [https://share.302.ai/I9P6mP](https://share.302.ai/I9P6mP)
|
|
||||||
|
|
||||||
---
|
|
||||||
> 2️⃣
|
|
||||||
> 即日起全面支持硅基流动!注册即享2000万免费Token(价值16元平台配额),剪辑10分钟视频仅需0.1元!
|
> 即日起全面支持硅基流动!注册即享2000万免费Token(价值16元平台配额),剪辑10分钟视频仅需0.1元!
|
||||||
>
|
>
|
||||||
> 🔥 快速领福利:
|
> 🔥 快速领福利:
|
||||||
@ -96,7 +96,7 @@ _**1. NarratoAI 是一款完全免费的软件,近期在社交媒体(抖音,B
|
|||||||
- [x] 一键合并素材
|
- [x] 一键合并素材
|
||||||
- [x] 一键转录
|
- [x] 一键转录
|
||||||
- [x] 一键清理缓存
|
- [x] 一键清理缓存
|
||||||
- [ ] 支持导出剪映草稿
|
- [x] 支持导出剪映草稿
|
||||||
- [X] 支持短剧解说
|
- [X] 支持短剧解说
|
||||||
- [ ] 主角人脸匹配
|
- [ ] 主角人脸匹配
|
||||||
- [ ] 支持根据口播,文案,视频素材自动匹配
|
- [ ] 支持根据口播,文案,视频素材自动匹配
|
||||||
@ -168,7 +168,9 @@ streamlit run webui.py --server.maxUploadSize=2048
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
## 赞助
|
## 赞助
|
||||||
[](https://dartnode.com "Powered by DartNode - Free VPS for Open Source")
|
<a href="https://dartnode.com">
|
||||||
|
<img src="https://dartnode.com/_branding/white_color_full.png" alt="Powered by DartNode" style="background-color: #333; padding: 10px; border-radius: 4px;">
|
||||||
|
</a>
|
||||||
|
|
||||||
## 许可证 📝
|
## 许可证 📝
|
||||||
|
|
||||||
|
|||||||
@ -81,7 +81,9 @@ def save_config():
|
|||||||
_cfg["soulvoice"] = soulvoice
|
_cfg["soulvoice"] = soulvoice
|
||||||
_cfg["ui"] = ui
|
_cfg["ui"] = ui
|
||||||
_cfg["tts_qwen"] = tts_qwen
|
_cfg["tts_qwen"] = tts_qwen
|
||||||
|
_cfg["fun_asr"] = fun_asr
|
||||||
_cfg["indextts2"] = indextts2
|
_cfg["indextts2"] = indextts2
|
||||||
|
_cfg["doubaotts"] = doubaotts
|
||||||
f.write(toml.dumps(_cfg))
|
f.write(toml.dumps(_cfg))
|
||||||
|
|
||||||
|
|
||||||
@ -95,7 +97,9 @@ soulvoice = _cfg.get("soulvoice", {})
|
|||||||
ui = _cfg.get("ui", {})
|
ui = _cfg.get("ui", {})
|
||||||
frames = _cfg.get("frames", {})
|
frames = _cfg.get("frames", {})
|
||||||
tts_qwen = _cfg.get("tts_qwen", {})
|
tts_qwen = _cfg.get("tts_qwen", {})
|
||||||
|
fun_asr = _cfg.get("fun_asr", {})
|
||||||
indextts2 = _cfg.get("indextts2", {})
|
indextts2 = _cfg.get("indextts2", {})
|
||||||
|
doubaotts = _cfg.get("doubaotts", {})
|
||||||
|
|
||||||
hostname = socket.gethostname()
|
hostname = socket.gethostname()
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
"""Shared config defaults used by both bootstrap and WebUI fallbacks."""
|
"""Shared config defaults used by both bootstrap and WebUI fallbacks."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
DEFAULT_OPENAI_COMPATIBLE_BASE_URL = "https://api.siliconflow.cn/v1"
|
DEFAULT_OPENAI_COMPATIBLE_BASE_URL = "https://api.siliconflow.cn/v1"
|
||||||
DEFAULT_OPENAI_COMPATIBLE_PROVIDER = "openai"
|
DEFAULT_OPENAI_COMPATIBLE_PROVIDER = "openai"
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,10 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import tomllib
|
try:
|
||||||
|
import tomllib
|
||||||
|
except ModuleNotFoundError: # Python < 3.11
|
||||||
|
import tomli as tomllib
|
||||||
|
|
||||||
from app.config import config as cfg
|
from app.config import config as cfg
|
||||||
from app.config.defaults import (
|
from app.config.defaults import (
|
||||||
|
|||||||
@ -196,6 +196,7 @@ class VideoClipParams(BaseModel):
|
|||||||
tts_volume: Optional[float] = Field(default=AudioVolumeDefaults.TTS_VOLUME, description="解说语音音量(后处理)")
|
tts_volume: Optional[float] = Field(default=AudioVolumeDefaults.TTS_VOLUME, description="解说语音音量(后处理)")
|
||||||
original_volume: Optional[float] = Field(default=AudioVolumeDefaults.ORIGINAL_VOLUME, description="视频原声音量")
|
original_volume: Optional[float] = Field(default=AudioVolumeDefaults.ORIGINAL_VOLUME, description="视频原声音量")
|
||||||
bgm_volume: Optional[float] = Field(default=AudioVolumeDefaults.BGM_VOLUME, description="背景音乐音量")
|
bgm_volume: Optional[float] = Field(default=AudioVolumeDefaults.BGM_VOLUME, description="背景音乐音量")
|
||||||
|
draft_name: Optional[str] = Field(default="", description="剪映草稿名称")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
13
app/services/documentary/__init__.py
Normal file
13
app/services/documentary/__init__.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from app.services.documentary.frame_analysis_models import (
|
||||||
|
DocumentaryAnalysisConfig,
|
||||||
|
FrameBatchResult,
|
||||||
|
)
|
||||||
|
from app.services.documentary.frame_analysis_service import (
|
||||||
|
DocumentaryFrameAnalysisService,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DocumentaryAnalysisConfig",
|
||||||
|
"FrameBatchResult",
|
||||||
|
"DocumentaryFrameAnalysisService",
|
||||||
|
]
|
||||||
33
app/services/documentary/frame_analysis_models.py
Normal file
33
app/services/documentary/frame_analysis_models.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class DocumentaryAnalysisConfig:
|
||||||
|
video_path: str
|
||||||
|
frame_interval_seconds: float
|
||||||
|
vision_batch_size: int
|
||||||
|
vision_llm_provider: str
|
||||||
|
vision_model_name: str
|
||||||
|
custom_prompt: str = ""
|
||||||
|
max_concurrency: int = 2
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
if self.frame_interval_seconds <= 0:
|
||||||
|
raise ValueError("frame_interval_seconds must be > 0")
|
||||||
|
if self.vision_batch_size <= 0:
|
||||||
|
raise ValueError("vision_batch_size must be > 0")
|
||||||
|
if self.max_concurrency <= 0:
|
||||||
|
raise ValueError("max_concurrency must be > 0")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class FrameBatchResult:
|
||||||
|
batch_index: int
|
||||||
|
status: str
|
||||||
|
time_range: str
|
||||||
|
raw_response: str
|
||||||
|
frame_paths: list[str] = field(default_factory=list)
|
||||||
|
frame_observations: list[dict] = field(default_factory=list)
|
||||||
|
overall_activity_summary: str = ""
|
||||||
|
fallback_summary: str = ""
|
||||||
|
error_message: str = ""
|
||||||
761
app/services/documentary/frame_analysis_service.py
Normal file
761
app/services/documentary/frame_analysis_service.py
Normal file
@ -0,0 +1,761 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from app.config import config
|
||||||
|
from app.services.documentary.frame_analysis_models import FrameBatchResult
|
||||||
|
from app.services.generate_narration_script import generate_narration, parse_frame_analysis_to_markdown
|
||||||
|
from app.services.llm.migration_adapter import create_vision_analyzer
|
||||||
|
from app.utils import utils, video_processor
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentaryFrameAnalysisService:
|
||||||
|
PROMPT_TEMPLATE = """
|
||||||
|
我提供了 {frame_count} 张视频帧,它们按时间顺序排列,代表一个连续的视频片段。
|
||||||
|
首先,请详细描述每一帧的关键视觉信息(包含:主要内容、人物、动作和场景)。
|
||||||
|
然后,基于所有帧的分析,请用简洁的语言总结整个视频片段中发生的主要活动或事件流程。
|
||||||
|
请务必使用 JSON 格式输出。
|
||||||
|
JSON 必须包含以下键:
|
||||||
|
- frame_observations: 数组,且长度必须为 {frame_count}
|
||||||
|
- overall_activity_summary: 字符串,描述整个批次主要活动
|
||||||
|
示例结构:
|
||||||
|
{{
|
||||||
|
"frame_observations": [
|
||||||
|
{{"timestamp": "00:00:00,000", "observation": "画面描述"}}
|
||||||
|
],
|
||||||
|
"overall_activity_summary": "本批次主要活动总结"
|
||||||
|
}}
|
||||||
|
请务必不要遗漏视频帧,我提供了 {frame_count} 张视频帧,frame_observations 必须包含 {frame_count} 个元素
|
||||||
|
请只返回 JSON 字符串,不要附加解释文字。
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
async def generate_documentary_script(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
video_path: str,
|
||||||
|
video_theme: str = "",
|
||||||
|
custom_prompt: str = "",
|
||||||
|
frame_interval_input: int | float | None = None,
|
||||||
|
vision_batch_size: int | None = None,
|
||||||
|
vision_llm_provider: str | None = None,
|
||||||
|
progress_callback: Callable[[float, str], None] | None = None,
|
||||||
|
vision_api_key: str | None = None,
|
||||||
|
vision_model_name: str | None = None,
|
||||||
|
vision_base_url: str | None = None,
|
||||||
|
max_concurrency: int | None = None,
|
||||||
|
) -> list[dict]:
|
||||||
|
progress = progress_callback or (lambda _p, _m: None)
|
||||||
|
analysis_result = await self.analyze_video(
|
||||||
|
video_path=video_path,
|
||||||
|
video_theme=video_theme,
|
||||||
|
custom_prompt=custom_prompt,
|
||||||
|
frame_interval_input=frame_interval_input,
|
||||||
|
vision_batch_size=vision_batch_size,
|
||||||
|
vision_llm_provider=vision_llm_provider,
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
vision_api_key=vision_api_key,
|
||||||
|
vision_model_name=vision_model_name,
|
||||||
|
vision_base_url=vision_base_url,
|
||||||
|
max_concurrency=max_concurrency,
|
||||||
|
)
|
||||||
|
analysis_json_path = analysis_result["analysis_json_path"]
|
||||||
|
|
||||||
|
progress(80, "正在生成解说文案...")
|
||||||
|
text_provider = config.app.get("text_llm_provider", "openai").lower()
|
||||||
|
text_api_key = config.app.get(f"text_{text_provider}_api_key")
|
||||||
|
text_model = config.app.get(f"text_{text_provider}_model_name")
|
||||||
|
text_base_url = config.app.get(f"text_{text_provider}_base_url")
|
||||||
|
if not text_api_key or not text_model:
|
||||||
|
raise ValueError(
|
||||||
|
f"未配置 {text_provider} 的文本模型参数。"
|
||||||
|
f"请在设置中配置 text_{text_provider}_api_key 和 text_{text_provider}_model_name"
|
||||||
|
)
|
||||||
|
|
||||||
|
markdown_output = parse_frame_analysis_to_markdown(analysis_json_path)
|
||||||
|
narration_input = self._build_narration_input(
|
||||||
|
markdown_output=markdown_output,
|
||||||
|
video_theme=video_theme,
|
||||||
|
custom_prompt=custom_prompt,
|
||||||
|
)
|
||||||
|
narration_raw = generate_narration(
|
||||||
|
narration_input,
|
||||||
|
text_api_key,
|
||||||
|
base_url=text_base_url,
|
||||||
|
model=text_model,
|
||||||
|
)
|
||||||
|
narration_items = self._parse_narration_items(narration_raw)
|
||||||
|
|
||||||
|
final_script = [{**item, "OST": 2} for item in narration_items]
|
||||||
|
progress(100, "脚本生成完成")
|
||||||
|
return final_script
|
||||||
|
|
||||||
|
async def analyze_video(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
video_path: str,
|
||||||
|
video_theme: str = "",
|
||||||
|
custom_prompt: str = "",
|
||||||
|
frame_interval_input: int | float | None = None,
|
||||||
|
vision_batch_size: int | None = None,
|
||||||
|
vision_llm_provider: str | None = None,
|
||||||
|
progress_callback: Callable[[float, str], None] | None = None,
|
||||||
|
vision_api_key: str | None = None,
|
||||||
|
vision_model_name: str | None = None,
|
||||||
|
vision_base_url: str | None = None,
|
||||||
|
max_concurrency: int | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
progress = progress_callback or (lambda _p, _m: None)
|
||||||
|
|
||||||
|
if not video_path or not os.path.exists(video_path):
|
||||||
|
raise FileNotFoundError(f"视频文件不存在: {video_path}")
|
||||||
|
|
||||||
|
frame_interval_seconds = self._resolve_frame_interval(frame_interval_input)
|
||||||
|
batch_size = self._resolve_batch_size(vision_batch_size)
|
||||||
|
concurrency = self._resolve_max_concurrency(max_concurrency)
|
||||||
|
provider = (vision_llm_provider or config.app.get("vision_llm_provider", "openai")).lower()
|
||||||
|
|
||||||
|
api_key = vision_api_key if vision_api_key is not None else config.app.get(f"vision_{provider}_api_key")
|
||||||
|
model_name = (
|
||||||
|
vision_model_name if vision_model_name is not None else config.app.get(f"vision_{provider}_model_name")
|
||||||
|
)
|
||||||
|
base_url = vision_base_url if vision_base_url is not None else config.app.get(f"vision_{provider}_base_url", "")
|
||||||
|
if not api_key or not model_name:
|
||||||
|
raise ValueError(
|
||||||
|
f"未配置 {provider} 的 API Key 或模型名称。"
|
||||||
|
f"请在设置中配置 vision_{provider}_api_key 和 vision_{provider}_model_name"
|
||||||
|
)
|
||||||
|
|
||||||
|
progress(10, "正在提取关键帧...")
|
||||||
|
keyframe_files = self._load_or_extract_keyframes(video_path, frame_interval_seconds)
|
||||||
|
progress(25, f"关键帧准备完成,共 {len(keyframe_files)} 帧")
|
||||||
|
|
||||||
|
progress(30, "正在初始化视觉分析器...")
|
||||||
|
analyzer = create_vision_analyzer(
|
||||||
|
provider=provider,
|
||||||
|
api_key=api_key,
|
||||||
|
model=model_name,
|
||||||
|
base_url=base_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
batches = self._chunk_keyframes(keyframe_files, batch_size=batch_size)
|
||||||
|
if not batches:
|
||||||
|
raise RuntimeError("未能构建任何关键帧批次")
|
||||||
|
|
||||||
|
progress(40, f"正在分析关键帧,共 {len(batches)} 个批次...")
|
||||||
|
batch_results = await self._analyze_batches(
|
||||||
|
analyzer=analyzer,
|
||||||
|
batches=batches,
|
||||||
|
custom_prompt=custom_prompt,
|
||||||
|
video_theme=video_theme,
|
||||||
|
max_concurrency=concurrency,
|
||||||
|
progress_callback=progress,
|
||||||
|
)
|
||||||
|
|
||||||
|
progress(65, "正在整理分析结果...")
|
||||||
|
sorted_batches = self._sort_batch_results(batch_results)
|
||||||
|
artifact = self._build_analysis_artifact(
|
||||||
|
sorted_batches,
|
||||||
|
video_path=video_path,
|
||||||
|
frame_interval_seconds=frame_interval_seconds,
|
||||||
|
vision_batch_size=batch_size,
|
||||||
|
vision_llm_provider=provider,
|
||||||
|
vision_model_name=model_name,
|
||||||
|
max_concurrency=concurrency,
|
||||||
|
)
|
||||||
|
analysis_json_path = self._save_analysis_artifact(artifact)
|
||||||
|
video_clip_json = self._build_video_clip_json(sorted_batches)
|
||||||
|
|
||||||
|
progress(75, "逐帧分析完成")
|
||||||
|
return {
|
||||||
|
"analysis_json_path": analysis_json_path,
|
||||||
|
"analysis_artifact": artifact,
|
||||||
|
"video_clip_json": video_clip_json,
|
||||||
|
"keyframe_files": keyframe_files,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _parse_narration_items(self, narration_raw: str) -> list[dict[str, Any]]:
|
||||||
|
parsed = self._repair_narration_payload(narration_raw)
|
||||||
|
|
||||||
|
items: list[dict[str, Any]] = []
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
raw_items = parsed.get("items")
|
||||||
|
if isinstance(raw_items, list):
|
||||||
|
items = [item for item in raw_items if isinstance(item, dict)]
|
||||||
|
|
||||||
|
if not items:
|
||||||
|
raise ValueError("解说文案格式错误,无法解析JSON或缺少items字段")
|
||||||
|
|
||||||
|
return items
|
||||||
|
|
||||||
|
def _build_narration_input(self, *, markdown_output: str, video_theme: str, custom_prompt: str) -> str:
|
||||||
|
context_lines: list[str] = []
|
||||||
|
if (video_theme or "").strip():
|
||||||
|
context_lines.append(f"视频主题:{video_theme.strip()}")
|
||||||
|
if (custom_prompt or "").strip():
|
||||||
|
context_lines.append(f"补充创作要求:{custom_prompt.strip()}")
|
||||||
|
|
||||||
|
if not context_lines:
|
||||||
|
return markdown_output
|
||||||
|
|
||||||
|
context_block = "\n".join(f"- {line}" for line in context_lines)
|
||||||
|
return f"{markdown_output.rstrip()}\n\n## 创作上下文\n{context_block}\n"
|
||||||
|
|
||||||
|
def _repair_narration_payload(self, narration_raw: str) -> dict[str, Any] | None:
|
||||||
|
def load_json_candidate(payload: str) -> dict[str, Any] | None:
|
||||||
|
try:
|
||||||
|
parsed = json.loads(payload)
|
||||||
|
return parsed if isinstance(parsed, dict) else None
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
cleaned = (narration_raw or "").strip()
|
||||||
|
if not cleaned:
|
||||||
|
return None
|
||||||
|
|
||||||
|
candidates: list[str] = [cleaned]
|
||||||
|
candidates.append(cleaned.replace("{{", "{").replace("}}", "}"))
|
||||||
|
|
||||||
|
json_block = re.search(r"```json\s*(.*?)\s*```", cleaned, re.DOTALL)
|
||||||
|
if json_block:
|
||||||
|
candidates.append(json_block.group(1).strip())
|
||||||
|
|
||||||
|
start = cleaned.find("{")
|
||||||
|
end = cleaned.rfind("}")
|
||||||
|
if start >= 0 and end > start:
|
||||||
|
candidates.append(cleaned[start : end + 1])
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
parsed = load_json_candidate(candidate)
|
||||||
|
if parsed is not None:
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
fixed = cleaned.replace("{{", "{").replace("}}", "}")
|
||||||
|
fixed_start = fixed.find("{")
|
||||||
|
fixed_end = fixed.rfind("}")
|
||||||
|
if fixed_start >= 0 and fixed_end > fixed_start:
|
||||||
|
fixed = fixed[fixed_start : fixed_end + 1]
|
||||||
|
|
||||||
|
fixed = re.sub(r"^\s*#.*$", "", fixed, flags=re.MULTILINE)
|
||||||
|
fixed = re.sub(r"^\s*//.*$", "", fixed, flags=re.MULTILINE)
|
||||||
|
fixed = re.sub(r",\s*}", "}", fixed)
|
||||||
|
fixed = re.sub(r",\s*]", "]", fixed)
|
||||||
|
fixed = re.sub(r"'([^']*)'\s*:", r'"\1":', fixed)
|
||||||
|
fixed = re.sub(r'([{\[,]\s*)([A-Za-z_][\w\u4e00-\u9fff]*)(\s*:)', r'\1"\2"\3', fixed)
|
||||||
|
fixed = re.sub(r'""([^"]*?)""', r'"\1"', fixed)
|
||||||
|
|
||||||
|
return load_json_candidate(fixed)
|
||||||
|
|
||||||
|
def _resolve_frame_interval(self, frame_interval_input: int | float | None) -> float:
|
||||||
|
interval = frame_interval_input
|
||||||
|
if interval in (None, ""):
|
||||||
|
interval = config.frames.get("frame_interval_input", 3)
|
||||||
|
try:
|
||||||
|
value = float(interval)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
value = 3.0
|
||||||
|
if value <= 0:
|
||||||
|
raise ValueError("frame_interval_input must be > 0")
|
||||||
|
return value
|
||||||
|
|
||||||
|
def _resolve_batch_size(self, vision_batch_size: int | None) -> int:
|
||||||
|
size = vision_batch_size or config.frames.get("vision_batch_size", 10)
|
||||||
|
try:
|
||||||
|
value = int(size)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
value = 10
|
||||||
|
if value <= 0:
|
||||||
|
raise ValueError("vision_batch_size must be > 0")
|
||||||
|
return value
|
||||||
|
|
||||||
|
def _resolve_max_concurrency(self, max_concurrency: int | None) -> int:
|
||||||
|
value = max_concurrency if max_concurrency is not None else config.frames.get("vision_max_concurrency", 2)
|
||||||
|
try:
|
||||||
|
parsed = int(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
parsed = 1
|
||||||
|
return max(1, parsed)
|
||||||
|
|
||||||
|
def _load_or_extract_keyframes(self, video_path: str, frame_interval_seconds: float) -> list[str]:
|
||||||
|
keyframes_root = os.path.join(utils.temp_dir(), "keyframes")
|
||||||
|
os.makedirs(keyframes_root, exist_ok=True)
|
||||||
|
cache_key = self._build_keyframe_cache_key(video_path, frame_interval_seconds)
|
||||||
|
cache_dir = os.path.join(keyframes_root, cache_key)
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
|
||||||
|
cached_files = self._collect_keyframe_paths(cache_dir)
|
||||||
|
if cached_files:
|
||||||
|
logger.info(f"使用已缓存关键帧: {cache_dir}, 共 {len(cached_files)} 帧")
|
||||||
|
return cached_files
|
||||||
|
|
||||||
|
processor = video_processor.VideoProcessor(video_path)
|
||||||
|
extracted = processor.extract_frames_by_interval_with_fallback(
|
||||||
|
output_dir=cache_dir,
|
||||||
|
interval_seconds=frame_interval_seconds,
|
||||||
|
)
|
||||||
|
keyframe_files = sorted(str(path) for path in extracted if str(path).endswith(".jpg"))
|
||||||
|
if not keyframe_files:
|
||||||
|
keyframe_files = self._collect_keyframe_paths(cache_dir)
|
||||||
|
if not keyframe_files:
|
||||||
|
raise RuntimeError("未提取到任何关键帧")
|
||||||
|
|
||||||
|
logger.info(f"关键帧提取完成: {cache_dir}, 共 {len(keyframe_files)} 帧")
|
||||||
|
return keyframe_files
|
||||||
|
|
||||||
|
def _build_keyframe_cache_key(self, video_path: str, frame_interval_seconds: float) -> str:
|
||||||
|
try:
|
||||||
|
video_mtime = os.path.getmtime(video_path)
|
||||||
|
except OSError:
|
||||||
|
video_mtime = 0
|
||||||
|
|
||||||
|
legacy_prefix = utils.md5(f"{video_path}{video_mtime}")
|
||||||
|
payload = "|".join(
|
||||||
|
[
|
||||||
|
str(video_path),
|
||||||
|
str(video_mtime),
|
||||||
|
str(frame_interval_seconds),
|
||||||
|
"documentary-keyframes-v2",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return f"{legacy_prefix}_{utils.md5(payload)}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _collect_keyframe_paths(cache_dir: str) -> list[str]:
|
||||||
|
if not os.path.exists(cache_dir):
|
||||||
|
return []
|
||||||
|
return sorted(
|
||||||
|
os.path.join(cache_dir, name)
|
||||||
|
for name in os.listdir(cache_dir)
|
||||||
|
if re.fullmatch(r"keyframe_\d{6}_\d{9}\.jpg", name)
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _chunk_keyframes(keyframe_files: list[str], batch_size: int) -> list[list[str]]:
|
||||||
|
return [keyframe_files[index : index + batch_size] for index in range(0, len(keyframe_files), batch_size)]
|
||||||
|
|
||||||
|
async def _analyze_batches(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
analyzer: Any,
|
||||||
|
batches: list[list[str]],
|
||||||
|
custom_prompt: str,
|
||||||
|
video_theme: str,
|
||||||
|
max_concurrency: int,
|
||||||
|
progress_callback: Callable[[float, str], None],
|
||||||
|
) -> list[FrameBatchResult]:
|
||||||
|
semaphore = asyncio.Semaphore(max(1, max_concurrency))
|
||||||
|
total = len(batches)
|
||||||
|
done = 0
|
||||||
|
done_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
batch_time_ranges: list[str] = []
|
||||||
|
previous_batch_files: list[str] | None = None
|
||||||
|
for batch_files in batches:
|
||||||
|
_, _, time_range = self._get_batch_timestamps(batch_files, previous_batch_files)
|
||||||
|
batch_time_ranges.append(time_range)
|
||||||
|
previous_batch_files = batch_files
|
||||||
|
|
||||||
|
async def run_single(batch_index: int, frame_paths: list[str], time_range: str) -> FrameBatchResult:
|
||||||
|
nonlocal done
|
||||||
|
prompt = self._build_batch_prompt(
|
||||||
|
frame_count=len(frame_paths),
|
||||||
|
video_theme=video_theme,
|
||||||
|
custom_prompt=custom_prompt,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
async with semaphore:
|
||||||
|
raw_results = await analyzer.analyze_images(
|
||||||
|
images=frame_paths,
|
||||||
|
prompt=prompt,
|
||||||
|
batch_size=max(1, len(frame_paths)),
|
||||||
|
max_concurrency=1,
|
||||||
|
)
|
||||||
|
raw_response, error_message = self._extract_batch_response(raw_results)
|
||||||
|
if error_message:
|
||||||
|
return self._build_failed_batch_result(
|
||||||
|
batch_index=batch_index,
|
||||||
|
raw_response=raw_response,
|
||||||
|
error_message=error_message,
|
||||||
|
frame_paths=frame_paths,
|
||||||
|
time_range=time_range,
|
||||||
|
)
|
||||||
|
return self._parse_batch_response(
|
||||||
|
batch_index=batch_index,
|
||||||
|
raw_response=raw_response,
|
||||||
|
frame_paths=frame_paths,
|
||||||
|
time_range=time_range,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
return self._build_failed_batch_result(
|
||||||
|
batch_index=batch_index,
|
||||||
|
raw_response="",
|
||||||
|
error_message=str(exc),
|
||||||
|
frame_paths=frame_paths,
|
||||||
|
time_range=time_range,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
async with done_lock:
|
||||||
|
done += 1
|
||||||
|
progress = 40 + (done / max(1, total)) * 25
|
||||||
|
progress_callback(progress, f"正在分析关键帧批次 ({done}/{total})...")
|
||||||
|
|
||||||
|
tasks = [
|
||||||
|
run_single(batch_index=index, frame_paths=batch_files, time_range=batch_time_ranges[index])
|
||||||
|
for index, batch_files in enumerate(batches)
|
||||||
|
]
|
||||||
|
return await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
def _build_batch_prompt(self, *, frame_count: int, video_theme: str, custom_prompt: str) -> str:
|
||||||
|
prompt = self._build_analysis_prompt(frame_count=frame_count)
|
||||||
|
extra_lines: list[str] = []
|
||||||
|
if (video_theme or "").strip():
|
||||||
|
extra_lines.append(f"视频主题:{video_theme.strip()}")
|
||||||
|
if (custom_prompt or "").strip():
|
||||||
|
extra_lines.append(custom_prompt.strip())
|
||||||
|
if not extra_lines:
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
extras = "\n".join(f"- {line}" for line in extra_lines)
|
||||||
|
return f"{prompt}\n\n补充分析要求:\n{extras}"
|
||||||
|
|
||||||
|
def _extract_batch_response(self, raw_results: Any) -> tuple[str, str]:
|
||||||
|
if not raw_results:
|
||||||
|
return "", "Batch response is empty"
|
||||||
|
|
||||||
|
first_result = raw_results[0] if isinstance(raw_results, list) else raw_results
|
||||||
|
if isinstance(first_result, dict):
|
||||||
|
raw_response = str(first_result.get("response", "") or "")
|
||||||
|
error_message = str(first_result.get("error", "") or "")
|
||||||
|
if error_message:
|
||||||
|
if not raw_response:
|
||||||
|
raw_response = error_message
|
||||||
|
return raw_response, error_message
|
||||||
|
if not raw_response.strip():
|
||||||
|
return raw_response, "Batch response is empty"
|
||||||
|
return raw_response, ""
|
||||||
|
|
||||||
|
raw_response = str(first_result or "")
|
||||||
|
if not raw_response.strip():
|
||||||
|
return raw_response, "Batch response is empty"
|
||||||
|
return raw_response, ""
|
||||||
|
|
||||||
|
def _sort_batch_results(self, batch_results: list[FrameBatchResult]) -> list[FrameBatchResult]:
|
||||||
|
return sorted(batch_results, key=lambda item: (self._time_range_sort_key(item.time_range), item.batch_index))
|
||||||
|
|
||||||
|
def _build_analysis_artifact(
|
||||||
|
self,
|
||||||
|
batch_results: list[FrameBatchResult],
|
||||||
|
*,
|
||||||
|
video_path: str,
|
||||||
|
frame_interval_seconds: float,
|
||||||
|
vision_batch_size: int,
|
||||||
|
vision_llm_provider: str,
|
||||||
|
vision_model_name: str,
|
||||||
|
max_concurrency: int,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
sorted_batches = self._sort_batch_results(batch_results)
|
||||||
|
|
||||||
|
batch_dicts: list[dict[str, Any]] = []
|
||||||
|
frame_observations: list[dict[str, Any]] = []
|
||||||
|
overall_activity_summaries: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for batch in sorted_batches:
|
||||||
|
batch_payload = {
|
||||||
|
"batch_index": batch.batch_index,
|
||||||
|
"status": batch.status,
|
||||||
|
"time_range": batch.time_range,
|
||||||
|
"raw_response": batch.raw_response,
|
||||||
|
"frame_paths": list(batch.frame_paths),
|
||||||
|
"frame_observations": list(batch.frame_observations),
|
||||||
|
"overall_activity_summary": batch.overall_activity_summary,
|
||||||
|
"fallback_summary": batch.fallback_summary,
|
||||||
|
"error_message": batch.error_message,
|
||||||
|
}
|
||||||
|
batch_dicts.append(batch_payload)
|
||||||
|
|
||||||
|
for observation in batch.frame_observations:
|
||||||
|
observation_payload = dict(observation)
|
||||||
|
observation_payload["batch_index"] = batch.batch_index
|
||||||
|
observation_payload["time_range"] = batch.time_range
|
||||||
|
frame_observations.append(observation_payload)
|
||||||
|
|
||||||
|
summary_text = (batch.overall_activity_summary or batch.fallback_summary or "").strip()
|
||||||
|
if summary_text:
|
||||||
|
overall_activity_summaries.append(
|
||||||
|
{
|
||||||
|
"batch_index": batch.batch_index,
|
||||||
|
"time_range": batch.time_range,
|
||||||
|
"summary": summary_text,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"artifact_version": "documentary-frame-analysis-v2",
|
||||||
|
"generated_at": datetime.now().isoformat(),
|
||||||
|
"video_path": video_path,
|
||||||
|
"frame_interval_seconds": frame_interval_seconds,
|
||||||
|
"vision_batch_size": vision_batch_size,
|
||||||
|
"vision_llm_provider": vision_llm_provider,
|
||||||
|
"vision_model_name": vision_model_name,
|
||||||
|
"vision_max_concurrency": max_concurrency,
|
||||||
|
"batches": batch_dicts,
|
||||||
|
# 向后兼容旧解析器结构
|
||||||
|
"frame_observations": frame_observations,
|
||||||
|
"overall_activity_summaries": overall_activity_summaries,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _save_analysis_artifact(self, artifact: dict[str, Any]) -> str:
|
||||||
|
analysis_dir = os.path.join(utils.storage_dir(), "temp", "analysis")
|
||||||
|
os.makedirs(analysis_dir, exist_ok=True)
|
||||||
|
|
||||||
|
filename = f"frame_analysis_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||||||
|
file_path = os.path.join(analysis_dir, filename)
|
||||||
|
suffix = 1
|
||||||
|
while os.path.exists(file_path):
|
||||||
|
filename = f"frame_analysis_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{suffix:02d}.json"
|
||||||
|
file_path = os.path.join(analysis_dir, filename)
|
||||||
|
suffix += 1
|
||||||
|
|
||||||
|
with open(file_path, "w", encoding="utf-8") as fp:
|
||||||
|
json.dump(artifact, fp, ensure_ascii=False, indent=2)
|
||||||
|
logger.info(f"分析结果已保存到: {file_path}")
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
def _build_video_clip_json(self, batch_results: list[FrameBatchResult]) -> list[dict]:
|
||||||
|
clips: list[dict] = []
|
||||||
|
for batch in self._sort_batch_results(batch_results):
|
||||||
|
picture = self._build_batch_picture(batch)
|
||||||
|
clips.append(
|
||||||
|
{
|
||||||
|
"timestamp": batch.time_range,
|
||||||
|
"picture": picture,
|
||||||
|
"narration": "",
|
||||||
|
"OST": 2,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return clips
|
||||||
|
|
||||||
|
def _build_batch_picture(self, batch: FrameBatchResult) -> str:
|
||||||
|
summary = (batch.overall_activity_summary or "").strip()
|
||||||
|
if summary:
|
||||||
|
return summary
|
||||||
|
|
||||||
|
fallback = (batch.fallback_summary or "").strip()
|
||||||
|
if fallback:
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
observation_lines = []
|
||||||
|
for frame in batch.frame_observations:
|
||||||
|
timestamp = str(frame.get("timestamp", "") or "").strip()
|
||||||
|
observation = str(frame.get("observation", "") or "").strip()
|
||||||
|
if timestamp and observation:
|
||||||
|
observation_lines.append(f"{timestamp}: {observation}")
|
||||||
|
elif observation:
|
||||||
|
observation_lines.append(observation)
|
||||||
|
if observation_lines:
|
||||||
|
return " ".join(observation_lines)
|
||||||
|
|
||||||
|
raw_response = (batch.raw_response or "").strip()
|
||||||
|
if raw_response:
|
||||||
|
return raw_response[:200]
|
||||||
|
return "该批次分析失败,未返回可用描述。"
|
||||||
|
|
||||||
|
def _time_range_sort_key(self, time_range: str) -> tuple[int, str]:
|
||||||
|
start = (time_range or "").split("-", 1)[0].strip()
|
||||||
|
return self._timestamp_to_milliseconds(start), time_range
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _timestamp_to_milliseconds(timestamp: str) -> int:
|
||||||
|
text = (timestamp or "").strip()
|
||||||
|
try:
|
||||||
|
if "," in text:
|
||||||
|
time_part, ms_part = text.split(",", 1)
|
||||||
|
milliseconds = int(ms_part)
|
||||||
|
else:
|
||||||
|
time_part = text
|
||||||
|
milliseconds = 0
|
||||||
|
|
||||||
|
parts = [int(part) for part in time_part.split(":") if part]
|
||||||
|
while len(parts) < 3:
|
||||||
|
parts.insert(0, 0)
|
||||||
|
hours, minutes, seconds = parts[-3], parts[-2], parts[-1]
|
||||||
|
return ((hours * 3600 + minutes * 60 + seconds) * 1000) + milliseconds
|
||||||
|
except Exception:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def _get_batch_timestamps(
|
||||||
|
self,
|
||||||
|
batch_files: list[str],
|
||||||
|
prev_batch_files: list[str] | None = None,
|
||||||
|
) -> tuple[str, str, str]:
|
||||||
|
if not batch_files:
|
||||||
|
return "00:00:00,000", "00:00:00,000", "00:00:00,000-00:00:00,000"
|
||||||
|
|
||||||
|
if len(batch_files) == 1 and prev_batch_files:
|
||||||
|
first_frame = os.path.basename(prev_batch_files[-1])
|
||||||
|
last_frame = os.path.basename(batch_files[0])
|
||||||
|
else:
|
||||||
|
first_frame = os.path.basename(batch_files[0])
|
||||||
|
last_frame = os.path.basename(batch_files[-1])
|
||||||
|
|
||||||
|
first_timestamp = self._timestamp_from_keyframe_name(first_frame)
|
||||||
|
last_timestamp = self._timestamp_from_keyframe_name(last_frame)
|
||||||
|
return first_timestamp, last_timestamp, f"{first_timestamp}-{last_timestamp}"
|
||||||
|
|
||||||
|
def _timestamp_from_keyframe_name(self, filename: str) -> str:
|
||||||
|
match = re.search(r"keyframe_\d{6}_(\d{9})\.jpg$", filename)
|
||||||
|
if not match:
|
||||||
|
return "00:00:00,000"
|
||||||
|
token = match.group(1)
|
||||||
|
hours = int(token[0:2])
|
||||||
|
minutes = int(token[2:4])
|
||||||
|
seconds = int(token[4:6])
|
||||||
|
milliseconds = int(token[6:9])
|
||||||
|
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
|
||||||
|
|
||||||
|
def _build_analysis_prompt(self, frame_count: int) -> str:
|
||||||
|
return self.PROMPT_TEMPLATE.format(frame_count=frame_count)
|
||||||
|
|
||||||
|
def _build_failed_batch_result(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
batch_index: int,
|
||||||
|
raw_response: str,
|
||||||
|
error_message: str,
|
||||||
|
frame_paths: list[str],
|
||||||
|
time_range: str,
|
||||||
|
) -> FrameBatchResult:
|
||||||
|
fallback_summary = (raw_response or "").strip()[:200]
|
||||||
|
if not fallback_summary:
|
||||||
|
fallback_summary = f"Batch {batch_index} analysis failed: {error_message or 'unknown error'}"
|
||||||
|
|
||||||
|
return FrameBatchResult(
|
||||||
|
batch_index=batch_index,
|
||||||
|
status="failed",
|
||||||
|
time_range=time_range,
|
||||||
|
raw_response=raw_response,
|
||||||
|
frame_paths=list(frame_paths),
|
||||||
|
fallback_summary=fallback_summary,
|
||||||
|
error_message=error_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_cache_key(
|
||||||
|
self,
|
||||||
|
video_path: str,
|
||||||
|
interval_seconds: float,
|
||||||
|
prompt_version: str,
|
||||||
|
model_name: str,
|
||||||
|
batch_size: int,
|
||||||
|
max_concurrency: int,
|
||||||
|
) -> str:
|
||||||
|
try:
|
||||||
|
video_mtime = os.path.getmtime(video_path)
|
||||||
|
except OSError:
|
||||||
|
video_mtime = 0
|
||||||
|
|
||||||
|
legacy_prefix = utils.md5(f"{video_path}{video_mtime}")
|
||||||
|
|
||||||
|
payload = "|".join(
|
||||||
|
[
|
||||||
|
str(video_path),
|
||||||
|
str(video_mtime),
|
||||||
|
str(interval_seconds),
|
||||||
|
str(prompt_version),
|
||||||
|
str(model_name),
|
||||||
|
str(batch_size),
|
||||||
|
str(max_concurrency),
|
||||||
|
"documentary-frame-analysis-v2",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return f"{legacy_prefix}_{utils.md5(payload)}"
|
||||||
|
|
||||||
|
def _strip_code_fence(self, response_text: str) -> str:
|
||||||
|
cleaned = (response_text or "").strip()
|
||||||
|
cleaned = re.sub(r"^```[a-zA-Z0-9_-]*\s*", "", cleaned)
|
||||||
|
cleaned = re.sub(r"\s*```$", "", cleaned)
|
||||||
|
return cleaned.strip()
|
||||||
|
|
||||||
|
def _parse_batch_response(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
batch_index: int,
|
||||||
|
raw_response: str,
|
||||||
|
frame_paths: list[str],
|
||||||
|
time_range: str,
|
||||||
|
) -> FrameBatchResult:
|
||||||
|
try:
|
||||||
|
payload = json.loads(self._strip_code_fence(raw_response))
|
||||||
|
except Exception as exc:
|
||||||
|
return self._build_failed_batch_result(
|
||||||
|
batch_index=batch_index,
|
||||||
|
raw_response=raw_response,
|
||||||
|
error_message=str(exc),
|
||||||
|
frame_paths=frame_paths,
|
||||||
|
time_range=time_range,
|
||||||
|
)
|
||||||
|
|
||||||
|
validation_error = self._validate_batch_payload_contract(payload, expected_frame_count=len(frame_paths))
|
||||||
|
if validation_error:
|
||||||
|
return self._build_failed_batch_result(
|
||||||
|
batch_index=batch_index,
|
||||||
|
raw_response=raw_response,
|
||||||
|
error_message=validation_error,
|
||||||
|
frame_paths=frame_paths,
|
||||||
|
time_range=time_range,
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_observations = payload["frame_observations"]
|
||||||
|
|
||||||
|
frame_observations: list[dict] = []
|
||||||
|
for index, frame_path in enumerate(frame_paths):
|
||||||
|
entry = raw_observations[index] if index < len(raw_observations) else {}
|
||||||
|
if isinstance(entry, dict):
|
||||||
|
observation = str(entry.get("observation", "") or "")
|
||||||
|
timestamp = str(entry.get("timestamp", "") or "")
|
||||||
|
else:
|
||||||
|
observation = str(entry or "")
|
||||||
|
timestamp = ""
|
||||||
|
frame_observations.append(
|
||||||
|
{
|
||||||
|
"frame_path": frame_path,
|
||||||
|
"timestamp": timestamp,
|
||||||
|
"observation": observation,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_summary = payload.get("overall_activity_summary", "")
|
||||||
|
if isinstance(raw_summary, str):
|
||||||
|
summary = raw_summary
|
||||||
|
elif raw_summary is None:
|
||||||
|
summary = ""
|
||||||
|
else:
|
||||||
|
summary = str(raw_summary)
|
||||||
|
|
||||||
|
return FrameBatchResult(
|
||||||
|
batch_index=batch_index,
|
||||||
|
status="success",
|
||||||
|
time_range=time_range,
|
||||||
|
raw_response=raw_response,
|
||||||
|
frame_paths=list(frame_paths),
|
||||||
|
frame_observations=frame_observations,
|
||||||
|
overall_activity_summary=summary,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _validate_batch_payload_contract(self, payload: object, *, expected_frame_count: int) -> str:
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
return "Batch response JSON payload must be an object"
|
||||||
|
|
||||||
|
if "frame_observations" not in payload or not isinstance(payload["frame_observations"], list):
|
||||||
|
return "Batch response must include frame_observations as a list"
|
||||||
|
|
||||||
|
if len(payload["frame_observations"]) < expected_frame_count:
|
||||||
|
return (
|
||||||
|
"Batch response frame_observations length is shorter than provided frame_paths: "
|
||||||
|
f"{len(payload['frame_observations'])} < {expected_frame_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return ""
|
||||||
452
app/services/fun_asr_subtitle.py
Normal file
452
app/services/fun_asr_subtitle.py
Normal file
@ -0,0 +1,452 @@
|
|||||||
|
"""Aliyun Bailian Fun-ASR subtitle transcription helpers.
|
||||||
|
|
||||||
|
This module intentionally uses the REST API because the official Fun-ASR
|
||||||
|
recorded-file API supports temporary `oss://` resources only through REST.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from app.utils import utils
|
||||||
|
|
||||||
|
DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com"
|
||||||
|
UPLOAD_POLICY_URL = f"{DASHSCOPE_BASE_URL}/api/v1/uploads"
|
||||||
|
TRANSCRIPTION_URL = f"{DASHSCOPE_BASE_URL}/api/v1/services/audio/asr/transcription"
|
||||||
|
TASK_URL_TEMPLATE = f"{DASHSCOPE_BASE_URL}/api/v1/tasks/{{task_id}}"
|
||||||
|
MODEL_NAME = "fun-asr"
|
||||||
|
TERMINAL_FAILED_STATUSES = {"FAILED", "CANCELED", "UNKNOWN"}
|
||||||
|
PUNCTUATION_BREAKS = set(",。!?;,.!?;")
|
||||||
|
|
||||||
|
|
||||||
|
class FunAsrError(RuntimeError):
|
||||||
|
"""Raised for user-actionable Fun-ASR transcription failures."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UploadPolicy:
|
||||||
|
upload_host: str
|
||||||
|
upload_dir: str
|
||||||
|
policy: str
|
||||||
|
signature: str
|
||||||
|
oss_access_key_id: str
|
||||||
|
x_oss_object_acl: str = "private"
|
||||||
|
x_oss_forbid_overwrite: str = "true"
|
||||||
|
max_file_size_mb: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _auth_headers(api_key: str, extra: Optional[dict[str, str]] = None) -> dict[str, str]:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
if extra:
|
||||||
|
headers.update(extra)
|
||||||
|
return headers
|
||||||
|
|
||||||
|
|
||||||
|
def _raise_for_http(response: requests.Response, action: str) -> None:
|
||||||
|
try:
|
||||||
|
response.raise_for_status()
|
||||||
|
except Exception as exc: # requests may be mocked with generic exceptions
|
||||||
|
raise FunAsrError(f"{action}失败,请检查阿里百炼 API Key、网络或服务状态") from exc
|
||||||
|
|
||||||
|
|
||||||
|
def _json(response: requests.Response, action: str) -> dict[str, Any]:
|
||||||
|
_raise_for_http(response, action)
|
||||||
|
try:
|
||||||
|
data = response.json()
|
||||||
|
except Exception as exc:
|
||||||
|
raise FunAsrError(f"{action}返回了无效 JSON") from exc
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise FunAsrError(f"{action}返回格式无效")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _require_api_key(api_key: str) -> str:
|
||||||
|
api_key = (api_key or "").strip()
|
||||||
|
if not api_key:
|
||||||
|
raise FunAsrError("请先输入阿里百炼 API Key")
|
||||||
|
return api_key
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_upload_name(local_file: str) -> str:
|
||||||
|
name = os.path.basename(local_file).strip() or f"audio_{int(time.time())}.wav"
|
||||||
|
return name.replace("/", "_").replace("\\", "_")
|
||||||
|
|
||||||
|
|
||||||
|
def _session_get(session, url: str, **kwargs):
|
||||||
|
return session.get(url, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _session_post(session, url: str, **kwargs):
|
||||||
|
return session.post(url, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def request_upload_policy(api_key: str, model: str = MODEL_NAME, session=requests) -> UploadPolicy:
|
||||||
|
"""Request Bailian temporary-storage upload policy for the target model."""
|
||||||
|
api_key = _require_api_key(api_key)
|
||||||
|
response = _session_get(
|
||||||
|
session,
|
||||||
|
UPLOAD_POLICY_URL,
|
||||||
|
params={"action": "getPolicy", "model": model},
|
||||||
|
headers=_auth_headers(api_key),
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
data = _json(response, "获取临时存储上传凭证")
|
||||||
|
policy_data = data.get("data") or {}
|
||||||
|
required = ["upload_host", "upload_dir", "policy", "signature", "oss_access_key_id"]
|
||||||
|
missing = [field for field in required if not policy_data.get(field)]
|
||||||
|
if missing:
|
||||||
|
raise FunAsrError(f"临时存储上传凭证缺少字段: {', '.join(missing)}")
|
||||||
|
|
||||||
|
return UploadPolicy(
|
||||||
|
upload_host=str(policy_data["upload_host"]),
|
||||||
|
upload_dir=str(policy_data["upload_dir"]).rstrip("/"),
|
||||||
|
policy=str(policy_data["policy"]),
|
||||||
|
signature=str(policy_data["signature"]),
|
||||||
|
oss_access_key_id=str(policy_data["oss_access_key_id"]),
|
||||||
|
x_oss_object_acl=str(policy_data.get("x_oss_object_acl") or "private"),
|
||||||
|
x_oss_forbid_overwrite=str(policy_data.get("x_oss_forbid_overwrite") or "true"),
|
||||||
|
max_file_size_mb=policy_data.get("max_file_size_mb"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_file_size(local_file: str, policy: UploadPolicy) -> None:
|
||||||
|
if policy.max_file_size_mb is None:
|
||||||
|
return
|
||||||
|
max_bytes = float(policy.max_file_size_mb) * 1024 * 1024
|
||||||
|
size = os.path.getsize(local_file)
|
||||||
|
if size > max_bytes:
|
||||||
|
raise FunAsrError(
|
||||||
|
f"文件大小超过阿里百炼临时存储限制: {size / 1024 / 1024:.2f}MB > {float(policy.max_file_size_mb):.2f}MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def upload_to_temporary_oss(local_file: str, policy: UploadPolicy, session=requests) -> str:
|
||||||
|
"""Upload local file to temporary OSS and return `oss://...` URL."""
|
||||||
|
if not os.path.isfile(local_file):
|
||||||
|
raise FunAsrError(f"待转写文件不存在: {local_file}")
|
||||||
|
_validate_file_size(local_file, policy)
|
||||||
|
|
||||||
|
key = f"{policy.upload_dir}/{_safe_upload_name(local_file)}"
|
||||||
|
data = {
|
||||||
|
"OSSAccessKeyId": policy.oss_access_key_id,
|
||||||
|
"policy": policy.policy,
|
||||||
|
"Signature": policy.signature,
|
||||||
|
"key": key,
|
||||||
|
"x-oss-object-acl": policy.x_oss_object_acl,
|
||||||
|
"x-oss-forbid-overwrite": policy.x_oss_forbid_overwrite,
|
||||||
|
"success_action_status": "200",
|
||||||
|
}
|
||||||
|
with open(local_file, "rb") as file_obj:
|
||||||
|
files = {"file": (_safe_upload_name(local_file), file_obj)}
|
||||||
|
response = _session_post(session, policy.upload_host, data=data, files=files, timeout=120)
|
||||||
|
_raise_for_http(response, "上传文件到阿里百炼临时存储")
|
||||||
|
return f"oss://{key}"
|
||||||
|
|
||||||
|
|
||||||
|
def submit_transcription_task(
|
||||||
|
api_key: str,
|
||||||
|
oss_url: str,
|
||||||
|
speaker_count: Optional[int] = None,
|
||||||
|
model: str = MODEL_NAME,
|
||||||
|
session=requests,
|
||||||
|
) -> str:
|
||||||
|
"""Submit async Fun-ASR task and return task_id."""
|
||||||
|
api_key = _require_api_key(api_key)
|
||||||
|
parameters: dict[str, Any] = {"diarization_enabled": True}
|
||||||
|
if speaker_count:
|
||||||
|
parameters["speaker_count"] = int(speaker_count)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"input": {"file_urls": [oss_url]},
|
||||||
|
"parameters": parameters,
|
||||||
|
}
|
||||||
|
response = _session_post(
|
||||||
|
session,
|
||||||
|
TRANSCRIPTION_URL,
|
||||||
|
headers=_auth_headers(
|
||||||
|
api_key,
|
||||||
|
{
|
||||||
|
"X-DashScope-Async": "enable",
|
||||||
|
"X-DashScope-OssResourceResolve": "enable",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
json=payload,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
data = _json(response, "提交 Fun-ASR 转写任务")
|
||||||
|
task_id = ((data.get("output") or {}).get("task_id") or "").strip()
|
||||||
|
if not task_id:
|
||||||
|
raise FunAsrError("提交 Fun-ASR 转写任务失败:未返回 task_id")
|
||||||
|
return task_id
|
||||||
|
|
||||||
|
|
||||||
|
def poll_transcription_task(
|
||||||
|
api_key: str,
|
||||||
|
task_id: str,
|
||||||
|
poll_interval: float = 2.0,
|
||||||
|
timeout: float = 600.0,
|
||||||
|
session=requests,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Poll task until terminal status and return successful result item."""
|
||||||
|
api_key = _require_api_key(api_key)
|
||||||
|
deadline = time.time() + timeout
|
||||||
|
last_status = "PENDING"
|
||||||
|
while time.time() < deadline:
|
||||||
|
response = _session_post(
|
||||||
|
session,
|
||||||
|
TASK_URL_TEMPLATE.format(task_id=task_id),
|
||||||
|
headers=_auth_headers(api_key),
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
data = _json(response, "查询 Fun-ASR 转写任务")
|
||||||
|
output = data.get("output") or {}
|
||||||
|
last_status = str(output.get("task_status") or "").upper()
|
||||||
|
|
||||||
|
if last_status == "SUCCEEDED":
|
||||||
|
results = output.get("results") or []
|
||||||
|
for result in results:
|
||||||
|
subtask_status = str(result.get("subtask_status") or "").upper()
|
||||||
|
if subtask_status and subtask_status != "SUCCEEDED":
|
||||||
|
raise FunAsrError(f"Fun-ASR 子任务失败: {subtask_status}")
|
||||||
|
if not results:
|
||||||
|
raise FunAsrError("Fun-ASR 转写成功但未返回结果")
|
||||||
|
return results[0]
|
||||||
|
|
||||||
|
if last_status in TERMINAL_FAILED_STATUSES:
|
||||||
|
raise FunAsrError(f"Fun-ASR 转写任务失败: {last_status}")
|
||||||
|
|
||||||
|
time.sleep(poll_interval)
|
||||||
|
|
||||||
|
raise FunAsrError(f"Fun-ASR 转写任务超时,最后状态: {last_status}")
|
||||||
|
|
||||||
|
|
||||||
|
def download_transcription_result(transcription_url: str, session=requests) -> dict[str, Any]:
|
||||||
|
if not transcription_url:
|
||||||
|
raise FunAsrError("Fun-ASR 结果缺少 transcription_url")
|
||||||
|
response = _session_get(session, transcription_url, timeout=60)
|
||||||
|
return _json(response, "下载 Fun-ASR 转写结果")
|
||||||
|
|
||||||
|
|
||||||
|
def _ms_to_srt_time(ms: float) -> str:
|
||||||
|
total_ms = max(0, int(round(float(ms))))
|
||||||
|
hours = total_ms // 3_600_000
|
||||||
|
total_ms %= 3_600_000
|
||||||
|
minutes = total_ms // 60_000
|
||||||
|
total_ms %= 60_000
|
||||||
|
seconds = total_ms // 1_000
|
||||||
|
milliseconds = total_ms % 1_000
|
||||||
|
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
|
||||||
|
|
||||||
|
|
||||||
|
def _srt_block(index: int, start_ms: float, end_ms: float, text: str) -> str:
|
||||||
|
if end_ms <= start_ms:
|
||||||
|
end_ms = start_ms + 500
|
||||||
|
return f"{index}\n{_ms_to_srt_time(start_ms)} --> {_ms_to_srt_time(end_ms)}\n{text.strip()}\n"
|
||||||
|
|
||||||
|
|
||||||
|
def _timestamp_ms(value: Any, field_name: str) -> float:
|
||||||
|
try:
|
||||||
|
return float(value)
|
||||||
|
except (TypeError, ValueError) as exc:
|
||||||
|
raise FunAsrError(f"Fun-ASR 转写结果时间戳无效: {field_name}={value!r}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
def _speaker_prefix(speaker_id: Any) -> str:
|
||||||
|
if speaker_id is None or speaker_id == "":
|
||||||
|
return ""
|
||||||
|
try:
|
||||||
|
label = int(speaker_id) + 1
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
label = str(speaker_id)
|
||||||
|
return f"说话人{label}: "
|
||||||
|
|
||||||
|
|
||||||
|
def _iter_sentences(result_json: dict[str, Any]):
|
||||||
|
transcripts = result_json.get("transcripts")
|
||||||
|
if transcripts is None and "sentences" in result_json:
|
||||||
|
transcripts = [{"sentences": result_json.get("sentences") or []}]
|
||||||
|
if not transcripts:
|
||||||
|
raise FunAsrError("Fun-ASR 转写结果为空:未找到 transcripts")
|
||||||
|
for transcript in transcripts:
|
||||||
|
for sentence in transcript.get("sentences") or []:
|
||||||
|
yield sentence
|
||||||
|
|
||||||
|
|
||||||
|
def _word_text(word: dict[str, Any]) -> str:
|
||||||
|
text = str(word.get("text") or word.get("word") or "")
|
||||||
|
punctuation = str(word.get("punctuation") or "")
|
||||||
|
if punctuation and not text.endswith(punctuation):
|
||||||
|
text += punctuation
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def _flush_block(blocks: list[dict[str, Any]], current: dict[str, Any]) -> None:
|
||||||
|
text = current.get("text", "").strip()
|
||||||
|
if text:
|
||||||
|
blocks.append(current.copy())
|
||||||
|
|
||||||
|
|
||||||
|
def _blocks_from_words(sentence: dict[str, Any], max_chars: int, max_duration: float) -> list[dict[str, Any]]:
|
||||||
|
words = sentence.get("words") or []
|
||||||
|
blocks: list[dict[str, Any]] = []
|
||||||
|
current: Optional[dict[str, Any]] = None
|
||||||
|
max_duration_ms = max_duration * 1000
|
||||||
|
sentence_speaker = sentence.get("speaker_id")
|
||||||
|
|
||||||
|
for word in words:
|
||||||
|
text = _word_text(word)
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
start = word.get("begin_time", word.get("start_time"))
|
||||||
|
end = word.get("end_time")
|
||||||
|
if start is None or end is None:
|
||||||
|
continue
|
||||||
|
speaker_id = word.get("speaker_id", sentence_speaker)
|
||||||
|
start_ms = _timestamp_ms(start, "word.begin_time")
|
||||||
|
end_ms = _timestamp_ms(end, "word.end_time")
|
||||||
|
|
||||||
|
if current is None:
|
||||||
|
current = {"start": start_ms, "end": end_ms, "text": text, "speaker_id": speaker_id}
|
||||||
|
else:
|
||||||
|
should_split_before = (
|
||||||
|
speaker_id != current.get("speaker_id")
|
||||||
|
or len(current["text"] + text) > max_chars
|
||||||
|
or (end_ms - current["start"]) > max_duration_ms
|
||||||
|
)
|
||||||
|
if should_split_before:
|
||||||
|
_flush_block(blocks, current)
|
||||||
|
current = {"start": start_ms, "end": end_ms, "text": text, "speaker_id": speaker_id}
|
||||||
|
else:
|
||||||
|
current["text"] += text
|
||||||
|
current["end"] = end_ms
|
||||||
|
|
||||||
|
if current and text[-1:] in PUNCTUATION_BREAKS:
|
||||||
|
_flush_block(blocks, current)
|
||||||
|
current = None
|
||||||
|
|
||||||
|
if current:
|
||||||
|
_flush_block(blocks, current)
|
||||||
|
return blocks
|
||||||
|
|
||||||
|
|
||||||
|
def _split_text(text: str, max_chars: int) -> list[str]:
|
||||||
|
chunks: list[str] = []
|
||||||
|
current = ""
|
||||||
|
for char in text:
|
||||||
|
current += char
|
||||||
|
if char in PUNCTUATION_BREAKS or len(current) >= max_chars:
|
||||||
|
chunks.append(current.strip())
|
||||||
|
current = ""
|
||||||
|
if current.strip():
|
||||||
|
chunks.append(current.strip())
|
||||||
|
return [chunk for chunk in chunks if chunk]
|
||||||
|
|
||||||
|
|
||||||
|
def _blocks_from_sentence(sentence: dict[str, Any], max_chars: int) -> list[dict[str, Any]]:
|
||||||
|
text = str(sentence.get("text") or "").strip()
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
start = sentence.get("begin_time", 0)
|
||||||
|
end = sentence.get("end_time")
|
||||||
|
start_ms = _timestamp_ms(start, "sentence.begin_time")
|
||||||
|
end_ms = _timestamp_ms(end, "sentence.end_time") if end is not None else start_ms + 500
|
||||||
|
chunks = _split_text(text, max_chars)
|
||||||
|
if not chunks:
|
||||||
|
return []
|
||||||
|
duration = max(500.0, end_ms - start_ms)
|
||||||
|
total_chars = max(1, sum(len(chunk) for chunk in chunks))
|
||||||
|
cursor = start_ms
|
||||||
|
blocks: list[dict[str, Any]] = []
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
if i == len(chunks) - 1:
|
||||||
|
chunk_end = end_ms
|
||||||
|
else:
|
||||||
|
chunk_end = cursor + duration * (len(chunk) / total_chars)
|
||||||
|
blocks.append(
|
||||||
|
{
|
||||||
|
"start": cursor,
|
||||||
|
"end": max(cursor + 200, chunk_end),
|
||||||
|
"text": chunk,
|
||||||
|
"speaker_id": sentence.get("speaker_id"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
cursor = chunk_end
|
||||||
|
return blocks
|
||||||
|
|
||||||
|
|
||||||
|
def fun_asr_result_to_srt(result_json: dict[str, Any], max_chars: int = 20, max_duration: float = 3.5) -> str:
|
||||||
|
"""Convert downloaded Fun-ASR JSON into fine-grained SRT.
|
||||||
|
|
||||||
|
Official downloaded schema is `transcripts[*].sentences[*].words[*]`.
|
||||||
|
Fun-ASR timestamps are milliseconds.
|
||||||
|
"""
|
||||||
|
blocks: list[dict[str, Any]] = []
|
||||||
|
for sentence in _iter_sentences(result_json):
|
||||||
|
sentence_blocks = _blocks_from_words(sentence, max_chars, max_duration)
|
||||||
|
if not sentence_blocks:
|
||||||
|
sentence_blocks = _blocks_from_sentence(sentence, max_chars)
|
||||||
|
blocks.extend(sentence_blocks)
|
||||||
|
|
||||||
|
if not blocks:
|
||||||
|
raise FunAsrError("Fun-ASR 转写结果为空:未找到可用字幕内容")
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
for index, block in enumerate(blocks, start=1):
|
||||||
|
text = f"{_speaker_prefix(block.get('speaker_id'))}{block['text']}"
|
||||||
|
lines.append(_srt_block(index, block["start"], block["end"], text))
|
||||||
|
return "\n".join(lines).rstrip() + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
def write_srt_file(srt_content: str, subtitle_file: str = "") -> str:
|
||||||
|
if not subtitle_file:
|
||||||
|
subtitle_file = os.path.join(utils.subtitle_dir(), f"fun_asr_{int(time.time())}.srt")
|
||||||
|
parent = os.path.dirname(subtitle_file)
|
||||||
|
if parent:
|
||||||
|
os.makedirs(parent, exist_ok=True)
|
||||||
|
with open(subtitle_file, "w", encoding="utf-8") as f:
|
||||||
|
f.write(srt_content)
|
||||||
|
return subtitle_file
|
||||||
|
|
||||||
|
|
||||||
|
def create_with_fun_asr(
|
||||||
|
local_file: str,
|
||||||
|
subtitle_file: str = "",
|
||||||
|
api_key: str = "",
|
||||||
|
speaker_count: Optional[int] = None,
|
||||||
|
poll_interval: float = 2.0,
|
||||||
|
timeout: float = 600.0,
|
||||||
|
session=requests,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Upload local media to Bailian temporary storage and create a Fun-ASR SRT file."""
|
||||||
|
api_key = _require_api_key(api_key)
|
||||||
|
try:
|
||||||
|
policy = request_upload_policy(api_key, session=session)
|
||||||
|
oss_url = upload_to_temporary_oss(local_file, policy, session=session)
|
||||||
|
task_id = submit_transcription_task(api_key, oss_url, speaker_count=speaker_count, session=session)
|
||||||
|
task_result = poll_transcription_task(
|
||||||
|
api_key,
|
||||||
|
task_id,
|
||||||
|
poll_interval=poll_interval,
|
||||||
|
timeout=timeout,
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
transcription_url = task_result.get("transcription_url")
|
||||||
|
result_json = download_transcription_result(transcription_url, session=session)
|
||||||
|
srt_content = fun_asr_result_to_srt(result_json)
|
||||||
|
output_file = write_srt_file(srt_content, subtitle_file)
|
||||||
|
logger.info(f"Fun-ASR 字幕文件已生成: {output_file}")
|
||||||
|
return output_file
|
||||||
|
except FunAsrError:
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
raise FunAsrError("Fun-ASR 字幕转写失败,请检查文件、网络或阿里百炼服务状态") from exc
|
||||||
@ -38,46 +38,90 @@ def parse_frame_analysis_to_markdown(json_file_path):
|
|||||||
with open(json_file_path, 'r', encoding='utf-8') as file:
|
with open(json_file_path, 'r', encoding='utf-8') as file:
|
||||||
data = json.load(file)
|
data = json.load(file)
|
||||||
|
|
||||||
# 初始化Markdown字符串
|
def time_to_milliseconds(time_text):
|
||||||
|
time_text = (time_text or "").strip()
|
||||||
|
if not time_text:
|
||||||
|
return 0
|
||||||
|
try:
|
||||||
|
if "," in time_text:
|
||||||
|
hhmmss, ms = time_text.split(",", 1)
|
||||||
|
milliseconds = int(ms)
|
||||||
|
else:
|
||||||
|
hhmmss = time_text
|
||||||
|
milliseconds = 0
|
||||||
|
|
||||||
|
parts = [int(part) for part in hhmmss.split(":") if part]
|
||||||
|
while len(parts) < 3:
|
||||||
|
parts.insert(0, 0)
|
||||||
|
hours, minutes, seconds = parts[-3], parts[-2], parts[-1]
|
||||||
|
return ((hours * 3600 + minutes * 60 + seconds) * 1000) + milliseconds
|
||||||
|
except Exception:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def batch_sort_key(batch):
|
||||||
|
time_range = batch.get("time_range", "")
|
||||||
|
start = time_range.split("-", 1)[0].strip()
|
||||||
|
return time_to_milliseconds(start), batch.get("batch_index", 0)
|
||||||
|
|
||||||
markdown = ""
|
markdown = ""
|
||||||
|
|
||||||
# 获取总结和帧观察数据
|
# 新结构:按批次保存完整分析产物
|
||||||
|
if isinstance(data.get("batches"), list):
|
||||||
|
ordered_batches = sorted(data.get("batches", []), key=batch_sort_key)
|
||||||
|
|
||||||
|
for i, batch in enumerate(ordered_batches, 1):
|
||||||
|
time_range = batch.get("time_range", "")
|
||||||
|
summary = (
|
||||||
|
batch.get("overall_activity_summary")
|
||||||
|
or batch.get("summary")
|
||||||
|
or batch.get("fallback_summary")
|
||||||
|
or ""
|
||||||
|
)
|
||||||
|
observations = batch.get("frame_observations") or batch.get("observations") or []
|
||||||
|
|
||||||
|
markdown += f"## 片段 {i}\n"
|
||||||
|
markdown += f"- 时间范围:{time_range}\n"
|
||||||
|
markdown += f"- 片段描述:{summary}\n" if summary else "- 片段描述:\n"
|
||||||
|
markdown += "- 详细描述:\n"
|
||||||
|
|
||||||
|
for frame in observations:
|
||||||
|
timestamp = frame.get("timestamp", "")
|
||||||
|
observation = frame.get("observation", "")
|
||||||
|
markdown += f" - {timestamp}: {observation}\n" if observation else f" - {timestamp}: \n"
|
||||||
|
|
||||||
|
markdown += "\n"
|
||||||
|
|
||||||
|
return markdown
|
||||||
|
|
||||||
|
# 兼容旧结构
|
||||||
summaries = data.get('overall_activity_summaries', [])
|
summaries = data.get('overall_activity_summaries', [])
|
||||||
frame_observations = data.get('frame_observations', [])
|
frame_observations = data.get('frame_observations', [])
|
||||||
|
|
||||||
# 按批次组织数据
|
|
||||||
batch_frames = {}
|
batch_frames = {}
|
||||||
for frame in frame_observations:
|
for frame in frame_observations:
|
||||||
batch_index = frame.get('batch_index')
|
batch_index = frame.get('batch_index')
|
||||||
if batch_index not in batch_frames:
|
if batch_index not in batch_frames:
|
||||||
batch_frames[batch_index] = []
|
batch_frames[batch_index] = []
|
||||||
batch_frames[batch_index].append(frame)
|
batch_frames[batch_index].append(frame)
|
||||||
|
|
||||||
# 生成Markdown内容
|
|
||||||
for i, summary in enumerate(summaries, 1):
|
for i, summary in enumerate(summaries, 1):
|
||||||
batch_index = summary.get('batch_index')
|
batch_index = summary.get('batch_index')
|
||||||
time_range = summary.get('time_range', '')
|
time_range = summary.get('time_range', '')
|
||||||
batch_summary = summary.get('summary', '')
|
batch_summary = summary.get('summary', '')
|
||||||
|
|
||||||
markdown += f"## 片段 {i}\n"
|
markdown += f"## 片段 {i}\n"
|
||||||
markdown += f"- 时间范围:{time_range}\n"
|
markdown += f"- 时间范围:{time_range}\n"
|
||||||
|
|
||||||
# 添加片段描述
|
|
||||||
markdown += f"- 片段描述:{batch_summary}\n" if batch_summary else f"- 片段描述:\n"
|
markdown += f"- 片段描述:{batch_summary}\n" if batch_summary else f"- 片段描述:\n"
|
||||||
|
|
||||||
markdown += "- 详细描述:\n"
|
markdown += "- 详细描述:\n"
|
||||||
|
|
||||||
# 添加该批次的帧观察详情
|
|
||||||
frames = batch_frames.get(batch_index, [])
|
frames = batch_frames.get(batch_index, [])
|
||||||
for frame in frames:
|
for frame in frames:
|
||||||
timestamp = frame.get('timestamp', '')
|
timestamp = frame.get('timestamp', '')
|
||||||
observation = frame.get('observation', '')
|
observation = frame.get('observation', '')
|
||||||
|
|
||||||
# 直接使用原始文本,不进行分割
|
|
||||||
markdown += f" - {timestamp}: {observation}\n" if observation else f" - {timestamp}: \n"
|
markdown += f" - {timestamp}: {observation}\n" if observation else f" - {timestamp}: \n"
|
||||||
|
|
||||||
markdown += "\n"
|
markdown += "\n"
|
||||||
|
|
||||||
return markdown
|
return markdown
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
241
app/services/jianying_task.py
Normal file
241
app/services/jianying_task.py
Normal file
@ -0,0 +1,241 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
from os import path
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from app.config import config
|
||||||
|
from app.models import const
|
||||||
|
from app.models.schema import VideoClipParams
|
||||||
|
from app.services import voice, clip_video, update_script
|
||||||
|
from app.services import state as sm
|
||||||
|
from app.utils import utils
|
||||||
|
|
||||||
|
|
||||||
|
def get_audio_duration_ffprobe(audio_file: str) -> float:
|
||||||
|
"""
|
||||||
|
使用ffprobe获取音频文件的精确时长(秒)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_file: 音频文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: 音频时长(秒),精确到微秒
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cmd = [
|
||||||
|
'ffprobe',
|
||||||
|
'-v', 'error',
|
||||||
|
'-show_entries', 'format=duration',
|
||||||
|
'-of', 'csv=p=0',
|
||||||
|
audio_file
|
||||||
|
]
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
||||||
|
duration = float(result.stdout.strip())
|
||||||
|
logger.debug(f"使用ffprobe获取音频时长: {duration:.6f}秒")
|
||||||
|
return duration
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
logger.error(f"ffprobe执行失败: {e.stderr}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取音频时长失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def start_export_jianying_draft(task_id: str, params: VideoClipParams):
|
||||||
|
"""
|
||||||
|
导出到剪映草稿的后台任务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: 任务ID
|
||||||
|
params: 视频参数
|
||||||
|
"""
|
||||||
|
logger.info(f"\n\n## 开始导出到剪映草稿任务: {task_id}")
|
||||||
|
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=0)
|
||||||
|
|
||||||
|
"""
|
||||||
|
1. 加载剪辑脚本
|
||||||
|
"""
|
||||||
|
logger.info("\n\n## 1. 加载视频脚本")
|
||||||
|
video_script_path = path.join(params.video_clip_json_path)
|
||||||
|
|
||||||
|
if path.exists(video_script_path):
|
||||||
|
try:
|
||||||
|
with open(video_script_path, "r", encoding="utf-8") as f:
|
||||||
|
list_script = json.load(f)
|
||||||
|
video_list = [i['narration'] for i in list_script]
|
||||||
|
video_ost = [i['OST'] for i in list_script]
|
||||||
|
time_list = [i['timestamp'] for i in list_script]
|
||||||
|
|
||||||
|
video_script = " ".join(video_list)
|
||||||
|
logger.debug(f"解说完整脚本: \n{video_script}")
|
||||||
|
logger.debug(f"解说 OST 列表: \n{video_ost}")
|
||||||
|
logger.debug(f"解说时间戳列表: \n{time_list}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"无法读取视频json脚本,请检查脚本格式是否正确")
|
||||||
|
raise ValueError("无法读取视频json脚本,请检查脚本格式是否正确")
|
||||||
|
else:
|
||||||
|
logger.error(f"解说脚本文件不存在: {video_script_path},请先点击【保存脚本】按钮保存脚本后再生成视频")
|
||||||
|
raise ValueError("解说脚本文件不存在!请先点击【保存脚本】按钮保存脚本后再生成视频。")
|
||||||
|
|
||||||
|
"""
|
||||||
|
2. 使用 TTS 生成音频素材
|
||||||
|
"""
|
||||||
|
logger.info("\n\n## 2. 根据OST设置生成音频列表")
|
||||||
|
tts_segments = [
|
||||||
|
segment for segment in list_script
|
||||||
|
if segment['OST'] in [0, 2]
|
||||||
|
]
|
||||||
|
logger.debug(f"需要生成TTS的片段数: {len(tts_segments)}")
|
||||||
|
|
||||||
|
tts_results = voice.tts_multiple(
|
||||||
|
task_id=task_id,
|
||||||
|
list_script=tts_segments, # 只传入需要TTS的片段
|
||||||
|
tts_engine=params.tts_engine,
|
||||||
|
voice_name=params.voice_name,
|
||||||
|
voice_rate=params.voice_rate,
|
||||||
|
voice_pitch=params.voice_pitch,
|
||||||
|
)
|
||||||
|
|
||||||
|
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20)
|
||||||
|
|
||||||
|
"""
|
||||||
|
3. 统一视频裁剪 - 基于OST类型的差异化裁剪策略
|
||||||
|
"""
|
||||||
|
logger.info("\n\n## 3. 统一视频裁剪(基于OST类型)")
|
||||||
|
video_clip_result = clip_video.clip_video_unified(
|
||||||
|
video_origin_path=params.video_origin_path,
|
||||||
|
script_list=list_script,
|
||||||
|
tts_results=tts_results
|
||||||
|
)
|
||||||
|
|
||||||
|
tts_clip_result = {tts_result['_id']: tts_result['audio_file'] for tts_result in tts_results}
|
||||||
|
subclip_clip_result = {
|
||||||
|
tts_result['_id']: tts_result['subtitle_file'] for tts_result in tts_results
|
||||||
|
}
|
||||||
|
new_script_list = update_script.update_script_timestamps(list_script, video_clip_result, tts_clip_result, subclip_clip_result)
|
||||||
|
|
||||||
|
logger.info(f"统一裁剪完成,处理了 {len(video_clip_result)} 个视频片段")
|
||||||
|
|
||||||
|
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=60)
|
||||||
|
|
||||||
|
"""
|
||||||
|
4. 导出到剪映草稿
|
||||||
|
"""
|
||||||
|
logger.info("\n\n## 4. 导出到剪映草稿")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pyJianYingDraft
|
||||||
|
from pyJianYingDraft import DraftFolder, VideoSegment, AudioSegment, trange, TrackType
|
||||||
|
jianying_draft_path = config.ui.get("jianying_draft_path", "")
|
||||||
|
if not jianying_draft_path:
|
||||||
|
raise ValueError("剪映草稿路径未配置")
|
||||||
|
|
||||||
|
# 创建DraftFolder实例
|
||||||
|
draft_folder = DraftFolder(jianying_draft_path)
|
||||||
|
|
||||||
|
# 使用从参数中获取的草稿名称,如果为空则使用默认名称
|
||||||
|
draft_name = getattr(params, 'draft_name', "")
|
||||||
|
logger.debug(f"从params获取的草稿名称: '{draft_name}' (类型: {type(draft_name)})")
|
||||||
|
if not draft_name:
|
||||||
|
draft_name = f"NarratoAI_{int(time.time())}"
|
||||||
|
logger.debug(f"使用默认草稿名称: '{draft_name}'")
|
||||||
|
|
||||||
|
# 创建新草稿
|
||||||
|
script = draft_folder.create_draft(draft_name, 1920, 1080)
|
||||||
|
|
||||||
|
# 添加视频轨道和音频轨道
|
||||||
|
script.add_track(TrackType.video, '视频轨道')
|
||||||
|
script.add_track(TrackType.audio, '音频轨道')
|
||||||
|
|
||||||
|
# 处理脚本数据
|
||||||
|
current_time = 0
|
||||||
|
output_dir = utils.task_dir(task_id)
|
||||||
|
|
||||||
|
for item in new_script_list:
|
||||||
|
# 获取时间信息
|
||||||
|
start_time = float(item.get('start_time', 0.0))
|
||||||
|
duration = float(item.get('duration', 0.0))
|
||||||
|
timestamp = item.get('timestamp', '')
|
||||||
|
|
||||||
|
logger.info(f"处理片段: OST={item['OST']}, start_time={start_time}, duration={duration}, timestamp={timestamp}")
|
||||||
|
|
||||||
|
# 生成音频文件路径
|
||||||
|
audio_file = ""
|
||||||
|
if timestamp:
|
||||||
|
timestamp_formatted = timestamp.replace(':', '_')
|
||||||
|
audio_file = os.path.join(
|
||||||
|
output_dir,
|
||||||
|
f"audio_{timestamp_formatted}.mp3"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查是否有裁剪后的视频文件
|
||||||
|
video_file = item.get('video', '')
|
||||||
|
if video_file and not os.path.exists(video_file):
|
||||||
|
video_file = ""
|
||||||
|
|
||||||
|
# 添加视频片段
|
||||||
|
if video_file:
|
||||||
|
# 使用裁剪后的视频文件
|
||||||
|
# 对于裁剪后的视频,target_timerange的第二个参数是持续时间
|
||||||
|
video_segment = VideoSegment(
|
||||||
|
video_file,
|
||||||
|
trange(f"{current_time}s", f"{duration}s")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 使用原始视频文件
|
||||||
|
# source_timerange是从原始视频中截取的部分
|
||||||
|
# target_timerange是片段在时间轴上的位置
|
||||||
|
video_segment = VideoSegment(
|
||||||
|
params.video_origin_path,
|
||||||
|
trange(f"{current_time}s", f"{duration}s"),
|
||||||
|
source_timerange=trange(f"{start_time}s", f"{duration}s")
|
||||||
|
)
|
||||||
|
script.add_segment(video_segment, '视频轨道')
|
||||||
|
|
||||||
|
# 处理音频
|
||||||
|
if item['OST'] in [0, 2]: # 需要TTS的片段
|
||||||
|
if os.path.exists(audio_file):
|
||||||
|
# 使用ffprobe获取精确的音频时长,避免因TTS引擎差异导致时长不匹配
|
||||||
|
actual_audio_duration = get_audio_duration_ffprobe(audio_file)
|
||||||
|
logger.info(f"音频文件实际时长: {actual_audio_duration:.6f}秒, 脚本时长(视频): {duration:.3f}秒")
|
||||||
|
|
||||||
|
# 使用音频实际时长和视频时长中的较小值,确保不超过素材时长
|
||||||
|
# 当TTS语速调整时,音频可能比视频长或短,取较小值可以避免超出素材
|
||||||
|
safe_duration = min(actual_audio_duration, duration)
|
||||||
|
logger.info(f"使用时长: {safe_duration:.6f}秒 (取音频和视频时长的较小值)")
|
||||||
|
|
||||||
|
audio_segment = AudioSegment(
|
||||||
|
audio_file,
|
||||||
|
trange(f"{current_time}s", f"{safe_duration}s")
|
||||||
|
)
|
||||||
|
script.add_segment(audio_segment, '音频轨道')
|
||||||
|
else:
|
||||||
|
logger.warning(f"音频文件不存在: {audio_file}")
|
||||||
|
# OST=1的片段保留原声,不需要添加额外音频
|
||||||
|
|
||||||
|
# 更新当前时间
|
||||||
|
current_time += duration
|
||||||
|
|
||||||
|
# 保存草稿
|
||||||
|
script.save()
|
||||||
|
|
||||||
|
draft_path = os.path.join(jianying_draft_path, draft_name)
|
||||||
|
|
||||||
|
logger.success(f"成功导出到剪映草稿: {draft_name}")
|
||||||
|
logger.info(f"草稿已保存到: {draft_path}")
|
||||||
|
|
||||||
|
# 更新任务状态
|
||||||
|
sm.state.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, draft_path=draft_path, draft_name=draft_name)
|
||||||
|
|
||||||
|
return {"draft_path": draft_path, "draft_name": draft_name}
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"导入pyJianYingDraft失败: {e}")
|
||||||
|
raise ImportError(f"pyJianYingDraft库导入失败: {e}\n请确保已正确安装该库")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"导出到剪映草稿失败: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(f"错误详情: {traceback.format_exc()}")
|
||||||
|
raise Exception(f"导出到剪映草稿失败: {e}")
|
||||||
@ -108,6 +108,7 @@ class VisionModelProvider(BaseLLMProvider):
|
|||||||
images: List[Union[str, Path, PIL.Image.Image]],
|
images: List[Union[str, Path, PIL.Image.Image]],
|
||||||
prompt: str,
|
prompt: str,
|
||||||
batch_size: int = 10,
|
batch_size: int = 10,
|
||||||
|
max_concurrency: int = 1,
|
||||||
**kwargs) -> List[str]:
|
**kwargs) -> List[str]:
|
||||||
"""
|
"""
|
||||||
分析图片并返回结果
|
分析图片并返回结果
|
||||||
@ -116,6 +117,7 @@ class VisionModelProvider(BaseLLMProvider):
|
|||||||
images: 图片路径列表或PIL图片对象列表
|
images: 图片路径列表或PIL图片对象列表
|
||||||
prompt: 分析提示词
|
prompt: 分析提示词
|
||||||
batch_size: 批处理大小
|
batch_size: 批处理大小
|
||||||
|
max_concurrency: 最大并发批次数(实现支持时生效)
|
||||||
**kwargs: 其他参数
|
**kwargs: 其他参数
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|||||||
@ -5,7 +5,6 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
from typing import List, Dict, Any, Optional, Union
|
from typing import List, Dict, Any, Optional, Union
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
@ -13,6 +12,7 @@ from loguru import logger
|
|||||||
|
|
||||||
from .unified_service import UnifiedLLMService
|
from .unified_service import UnifiedLLMService
|
||||||
from .exceptions import LLMServiceError
|
from .exceptions import LLMServiceError
|
||||||
|
from .manager import LLMServiceManager
|
||||||
# 导入新的提示词管理系统
|
# 导入新的提示词管理系统
|
||||||
from app.services.prompts import PromptManager
|
from app.services.prompts import PromptManager
|
||||||
|
|
||||||
@ -110,41 +110,11 @@ class LegacyLLMAdapter:
|
|||||||
temperature=1.5,
|
temperature=1.5,
|
||||||
response_format="json"
|
response_format="json"
|
||||||
)
|
)
|
||||||
|
return result if isinstance(result, str) else str(result)
|
||||||
# 使用增强的JSON解析器
|
|
||||||
from webui.tools.generate_short_summary import parse_and_fix_json
|
|
||||||
parsed_result = parse_and_fix_json(result)
|
|
||||||
|
|
||||||
if not parsed_result:
|
|
||||||
logger.error("无法解析LLM返回的JSON数据")
|
|
||||||
# 返回一个基本的JSON结构而不是错误字符串
|
|
||||||
return json.dumps({
|
|
||||||
"items": [
|
|
||||||
{
|
|
||||||
"_id": 1,
|
|
||||||
"timestamp": "00:00:00-00:00:10",
|
|
||||||
"picture": "解析失败,请检查LLM输出",
|
|
||||||
"narration": "解说文案生成失败,请重试"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}, ensure_ascii=False)
|
|
||||||
|
|
||||||
# 确保返回的是JSON字符串
|
|
||||||
return json.dumps(parsed_result, ensure_ascii=False)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"生成解说文案失败: {str(e)}")
|
logger.error(f"生成解说文案失败: {str(e)}")
|
||||||
# 返回一个基本的JSON结构而不是错误字符串
|
raise
|
||||||
return json.dumps({
|
|
||||||
"items": [
|
|
||||||
{
|
|
||||||
"_id": 1,
|
|
||||||
"timestamp": "00:00:00-00:00:10",
|
|
||||||
"picture": "生成失败",
|
|
||||||
"narration": f"解说文案生成失败: {str(e)}"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}, ensure_ascii=False)
|
|
||||||
|
|
||||||
|
|
||||||
class VisionAnalyzerAdapter:
|
class VisionAnalyzerAdapter:
|
||||||
@ -155,11 +125,29 @@ class VisionAnalyzerAdapter:
|
|||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.model = model
|
self.model = model
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
|
|
||||||
|
def _build_provider_with_explicit_settings(self):
|
||||||
|
provider_name = (self.provider or "").lower()
|
||||||
|
if not LLMServiceManager.is_registered():
|
||||||
|
from .providers import register_all_providers
|
||||||
|
|
||||||
|
register_all_providers()
|
||||||
|
|
||||||
|
provider_class = LLMServiceManager._vision_providers.get(provider_name)
|
||||||
|
if provider_class is None:
|
||||||
|
raise LLMServiceError(f"视觉模型提供商未注册: {provider_name}")
|
||||||
|
|
||||||
|
return provider_class(
|
||||||
|
api_key=self.api_key,
|
||||||
|
model_name=self.model,
|
||||||
|
base_url=self.base_url,
|
||||||
|
)
|
||||||
|
|
||||||
async def analyze_images(self,
|
async def analyze_images(self,
|
||||||
images: List[Union[str, Path, PIL.Image.Image]],
|
images: List[Union[str, Path, PIL.Image.Image]],
|
||||||
prompt: str,
|
prompt: str,
|
||||||
batch_size: int = 10) -> List[Dict[str, Any]]:
|
batch_size: int = 10,
|
||||||
|
max_concurrency: int = 1) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
分析图片 - 兼容原有接口
|
分析图片 - 兼容原有接口
|
||||||
|
|
||||||
@ -167,17 +155,20 @@ class VisionAnalyzerAdapter:
|
|||||||
images: 图片列表
|
images: 图片列表
|
||||||
prompt: 分析提示词
|
prompt: 分析提示词
|
||||||
batch_size: 批处理大小
|
batch_size: 批处理大小
|
||||||
|
max_concurrency: 最大并发批次数
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
分析结果列表,格式与旧实现兼容
|
分析结果列表,格式与旧实现兼容
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 使用统一服务分析图片
|
provider = self._build_provider_with_explicit_settings()
|
||||||
results = await UnifiedLLMService.analyze_images(
|
results = await provider.analyze_images(
|
||||||
images=images,
|
images=images,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
provider=self.provider,
|
batch_size=batch_size,
|
||||||
batch_size=batch_size
|
max_concurrency=max_concurrency,
|
||||||
|
api_key=self.api_key,
|
||||||
|
api_base=self.base_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 转换为旧格式以保持向后兼容性
|
# 转换为旧格式以保持向后兼容性
|
||||||
|
|||||||
@ -4,6 +4,7 @@ OpenAI 兼容提供商实现
|
|||||||
使用 OpenAI 官方 SDK 调用 OpenAI 兼容接口,支持文本和视觉模型。
|
使用 OpenAI 官方 SDK 调用 OpenAI 兼容接口,支持文本和视觉模型。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import io
|
import io
|
||||||
import base64
|
import base64
|
||||||
import re
|
import re
|
||||||
@ -96,24 +97,35 @@ class OpenAICompatibleVisionProvider(_OpenAICompatibleBase, VisionModelProvider)
|
|||||||
images: List[Union[str, Path, PIL.Image.Image]],
|
images: List[Union[str, Path, PIL.Image.Image]],
|
||||||
prompt: str,
|
prompt: str,
|
||||||
batch_size: int = 10,
|
batch_size: int = 10,
|
||||||
|
max_concurrency: int = 1,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
logger.info(f"开始使用 OpenAI 兼容接口 ({self.model_name}) 分析 {len(images)} 张图片")
|
logger.info(f"开始使用 OpenAI 兼容接口 ({self.model_name}) 分析 {len(images)} 张图片")
|
||||||
|
|
||||||
processed_images = self._prepare_images(images)
|
processed_images = self._prepare_images(images)
|
||||||
results: List[str] = []
|
if not processed_images:
|
||||||
|
return []
|
||||||
|
|
||||||
for i in range(0, len(processed_images), batch_size):
|
bounded_concurrency = max(1, int(max_concurrency))
|
||||||
batch = processed_images[i : i + batch_size]
|
semaphore = asyncio.Semaphore(bounded_concurrency)
|
||||||
logger.info(f"处理第 {i // batch_size + 1} 批,共 {len(batch)} 张图片")
|
batches = [
|
||||||
try:
|
(index // batch_size, processed_images[index : index + batch_size])
|
||||||
result = await self._analyze_batch(batch, prompt, **kwargs)
|
for index in range(0, len(processed_images), batch_size)
|
||||||
results.append(result)
|
]
|
||||||
except Exception as exc:
|
|
||||||
logger.error(f"批次 {i // batch_size + 1} 处理失败: {exc}")
|
|
||||||
results.append(f"批次处理失败: {exc}")
|
|
||||||
|
|
||||||
return results
|
async def run_batch(batch_index: int, batch: List[PIL.Image.Image]) -> tuple[int, str]:
|
||||||
|
logger.info(f"处理第 {batch_index + 1} 批,共 {len(batch)} 张图片")
|
||||||
|
async with semaphore:
|
||||||
|
try:
|
||||||
|
result = await self._analyze_batch(batch, prompt, **kwargs)
|
||||||
|
return batch_index, result
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"批次 {batch_index + 1} 处理失败: {exc}")
|
||||||
|
return batch_index, f"批次处理失败: {exc}"
|
||||||
|
|
||||||
|
completed = await asyncio.gather(*(run_batch(index, batch) for index, batch in batches))
|
||||||
|
completed.sort(key=lambda item: item[0])
|
||||||
|
return [result for _, result in completed]
|
||||||
|
|
||||||
async def _analyze_batch(self, batch: List[PIL.Image.Image], prompt: str, **kwargs) -> str:
|
async def _analyze_batch(self, batch: List[PIL.Image.Image], prompt: str, **kwargs) -> str:
|
||||||
content = [{"type": "text", "text": prompt}]
|
content = [{"type": "text", "text": prompt}]
|
||||||
|
|||||||
@ -1,10 +1,14 @@
|
|||||||
"""OpenAI 兼容 provider 的最小回归测试。"""
|
"""OpenAI 兼容 provider 的最小回归测试。"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import unittest
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.services.llm.base import TextModelProvider
|
from app.services.llm.base import TextModelProvider
|
||||||
from app.services.llm.manager import LLMServiceManager
|
from app.services.llm.manager import LLMServiceManager
|
||||||
|
from app.services.llm.migration_adapter import LegacyLLMAdapter, VisionAnalyzerAdapter
|
||||||
|
from app.services.llm.openai_compatible_provider import OpenAICompatibleVisionProvider
|
||||||
from app.services.llm.providers import register_all_providers
|
from app.services.llm.providers import register_all_providers
|
||||||
|
|
||||||
|
|
||||||
@ -63,5 +67,128 @@ class OpenAICompatManagerTests(unittest.TestCase):
|
|||||||
self.assertEqual("https://new.example/v1", provider.base_url)
|
self.assertEqual("https://new.example/v1", provider.base_url)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAICompatVisionConcurrencyTests(unittest.IsolatedAsyncioTestCase):
|
||||||
|
async def test_analyze_images_keeps_batch_order_when_running_concurrently(self):
|
||||||
|
provider = OpenAICompatibleVisionProvider(api_key="k", model_name="m")
|
||||||
|
provider._prepare_images = lambda images: list(images)
|
||||||
|
|
||||||
|
async def fake_analyze_batch(batch, prompt, **kwargs):
|
||||||
|
delays = {"a": 0.03, "c": 0.01, "e": 0.0}
|
||||||
|
await asyncio.sleep(delays[batch[0]])
|
||||||
|
return f"batch-{batch[0]}"
|
||||||
|
|
||||||
|
provider._analyze_batch = fake_analyze_batch
|
||||||
|
|
||||||
|
result = await provider.analyze_images(
|
||||||
|
images=["a", "b", "c", "d", "e", "f"],
|
||||||
|
prompt="prompt",
|
||||||
|
batch_size=2,
|
||||||
|
max_concurrency=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(["batch-a", "batch-c", "batch-e"], result)
|
||||||
|
|
||||||
|
async def test_analyze_images_respects_max_concurrency_limit(self):
|
||||||
|
provider = OpenAICompatibleVisionProvider(api_key="k", model_name="m")
|
||||||
|
provider._prepare_images = lambda images: list(images)
|
||||||
|
|
||||||
|
in_flight = 0
|
||||||
|
max_in_flight = 0
|
||||||
|
|
||||||
|
async def fake_analyze_batch(batch, prompt, **kwargs):
|
||||||
|
nonlocal in_flight, max_in_flight
|
||||||
|
in_flight += 1
|
||||||
|
max_in_flight = max(max_in_flight, in_flight)
|
||||||
|
await asyncio.sleep(0.02)
|
||||||
|
in_flight -= 1
|
||||||
|
return f"batch-{batch[0]}"
|
||||||
|
|
||||||
|
provider._analyze_batch = fake_analyze_batch
|
||||||
|
|
||||||
|
result = await provider.analyze_images(
|
||||||
|
images=["a", "b", "c", "d", "e", "f"],
|
||||||
|
prompt="prompt",
|
||||||
|
batch_size=1,
|
||||||
|
max_concurrency=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(6, len(result))
|
||||||
|
self.assertEqual(2, max_in_flight)
|
||||||
|
|
||||||
|
|
||||||
|
class ExplicitVisionAdapterSettingsTests(unittest.IsolatedAsyncioTestCase):
|
||||||
|
class _CapturingVisionProvider:
|
||||||
|
last_init: tuple[str, str, str | None] | None = None
|
||||||
|
last_call_kwargs: dict | None = None
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, model_name: str, base_url: str | None = None):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.model_name = model_name
|
||||||
|
self.base_url = base_url
|
||||||
|
ExplicitVisionAdapterSettingsTests._CapturingVisionProvider.last_init = (api_key, model_name, base_url)
|
||||||
|
|
||||||
|
async def analyze_images(self, images, prompt, batch_size=10, max_concurrency=1, **kwargs):
|
||||||
|
ExplicitVisionAdapterSettingsTests._CapturingVisionProvider.last_call_kwargs = dict(kwargs)
|
||||||
|
return [f"{self.model_name}|{self.api_key}|{self.base_url}"]
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
_reset_manager_state()
|
||||||
|
self._original_app = dict(config.app)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
_reset_manager_state()
|
||||||
|
config.app.clear()
|
||||||
|
config.app.update(self._original_app)
|
||||||
|
|
||||||
|
async def test_adapter_uses_explicit_settings_instead_of_global_config(self):
|
||||||
|
LLMServiceManager.register_vision_provider("openai", self._CapturingVisionProvider)
|
||||||
|
config.app["vision_openai_api_key"] = "config-key"
|
||||||
|
config.app["vision_openai_model_name"] = "config-model"
|
||||||
|
config.app["vision_openai_base_url"] = "https://config.example/v1"
|
||||||
|
|
||||||
|
adapter = VisionAnalyzerAdapter(
|
||||||
|
provider="openai",
|
||||||
|
api_key="explicit-key",
|
||||||
|
model="explicit-model",
|
||||||
|
base_url="https://explicit.example/v1",
|
||||||
|
)
|
||||||
|
result = await adapter.analyze_images(
|
||||||
|
images=["/tmp/keyframe_000001_000000100.jpg"],
|
||||||
|
prompt="描述画面",
|
||||||
|
batch_size=1,
|
||||||
|
max_concurrency=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
("explicit-key", "explicit-model", "https://explicit.example/v1"),
|
||||||
|
self._CapturingVisionProvider.last_init,
|
||||||
|
)
|
||||||
|
self.assertEqual("explicit-key", self._CapturingVisionProvider.last_call_kwargs["api_key"])
|
||||||
|
self.assertEqual("https://explicit.example/v1", self._CapturingVisionProvider.last_call_kwargs["api_base"])
|
||||||
|
self.assertEqual("explicit-model|explicit-key|https://explicit.example/v1", result[0]["response"])
|
||||||
|
|
||||||
|
|
||||||
|
class LegacyNarrationAdapterBehaviorTests(unittest.TestCase):
|
||||||
|
def test_generate_narration_returns_raw_unrecoverable_payload_without_fabrication(self):
|
||||||
|
raw_payload = "not-json-at-all ::: ???"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.services.llm.migration_adapter.PromptManager.get_prompt",
|
||||||
|
return_value="prompt",
|
||||||
|
), patch(
|
||||||
|
"app.services.llm.migration_adapter._run_async_safely",
|
||||||
|
return_value=raw_payload,
|
||||||
|
):
|
||||||
|
result = LegacyLLMAdapter.generate_narration(
|
||||||
|
markdown_content="markdown",
|
||||||
|
api_key="test-key",
|
||||||
|
base_url="https://example.com/v1",
|
||||||
|
model="test-model",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(raw_payload, result)
|
||||||
|
self.assertNotIn('"items"', result)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@ -1,324 +1,40 @@
|
|||||||
import os
|
from typing import Any, Callable
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import asyncio
|
|
||||||
import requests
|
|
||||||
from app.utils import video_processor
|
|
||||||
from loguru import logger
|
|
||||||
from typing import List, Dict, Any, Callable
|
|
||||||
|
|
||||||
from app.utils import utils, gemini_analyzer, video_processor
|
from loguru import logger
|
||||||
from app.utils.script_generator import ScriptProcessor
|
|
||||||
from app.config import config
|
from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService
|
||||||
|
|
||||||
|
|
||||||
class ScriptGenerator:
|
class ScriptGenerator:
|
||||||
def __init__(self):
|
def __init__(self, documentary_service: DocumentaryFrameAnalysisService | None = None):
|
||||||
self.temp_dir = utils.temp_dir()
|
self.documentary_service = documentary_service or DocumentaryFrameAnalysisService()
|
||||||
self.keyframes_dir = os.path.join(self.temp_dir, "keyframes")
|
|
||||||
|
|
||||||
async def generate_script(
|
async def generate_script(
|
||||||
self,
|
self,
|
||||||
video_path: str,
|
video_path: str,
|
||||||
video_theme: str = "",
|
video_theme: str = "",
|
||||||
custom_prompt: str = "",
|
custom_prompt: str = "",
|
||||||
frame_interval_input: int = 5,
|
frame_interval_input: int | None = None,
|
||||||
skip_seconds: int = 0,
|
skip_seconds: int = 0,
|
||||||
threshold: int = 30,
|
threshold: int = 30,
|
||||||
vision_batch_size: int = 5,
|
vision_batch_size: int | None = None,
|
||||||
vision_llm_provider: str = "gemini",
|
vision_llm_provider: str | None = None,
|
||||||
progress_callback: Callable[[float, str], None] = None
|
progress_callback: Callable[[float, str], None] | None = None,
|
||||||
) -> List[Dict[Any, Any]]:
|
) -> list[dict[Any, Any]]:
|
||||||
"""
|
callback = progress_callback or (lambda _p, _m: None)
|
||||||
生成视频脚本的核心逻辑
|
if skip_seconds != 0 or threshold != 30:
|
||||||
|
logger.warning(
|
||||||
Args:
|
"ScriptGenerator documentary path received "
|
||||||
video_path: 视频文件路径
|
f"skip_seconds={skip_seconds} threshold={threshold}; "
|
||||||
video_theme: 视频主题
|
"the shared documentary frame pipeline does not currently apply these parameters."
|
||||||
custom_prompt: 自定义提示词
|
|
||||||
skip_seconds: 跳过开始的秒数
|
|
||||||
threshold: 差异<EFBFBD><EFBFBD><EFBFBD>值
|
|
||||||
vision_batch_size: 视觉处理批次大小
|
|
||||||
vision_llm_provider: 视觉模型提供商
|
|
||||||
progress_callback: 进度回调函数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict]: 生成的视频脚本
|
|
||||||
"""
|
|
||||||
if progress_callback is None:
|
|
||||||
progress_callback = lambda p, m: None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 提取关键帧
|
|
||||||
progress_callback(10, "正在提取关键帧...")
|
|
||||||
keyframe_files = await self._extract_keyframes(
|
|
||||||
video_path,
|
|
||||||
skip_seconds,
|
|
||||||
threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
# 使用统一的 LLM 接口(支持所有 provider)
|
|
||||||
script = await self._process_with_llm(
|
|
||||||
keyframe_files,
|
|
||||||
video_theme,
|
|
||||||
custom_prompt,
|
|
||||||
vision_batch_size,
|
|
||||||
vision_llm_provider,
|
|
||||||
progress_callback
|
|
||||||
)
|
|
||||||
|
|
||||||
return json.loads(script) if isinstance(script, str) else script
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("Generate script failed")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def _extract_keyframes(
|
|
||||||
self,
|
|
||||||
video_path: str,
|
|
||||||
skip_seconds: int,
|
|
||||||
threshold: int
|
|
||||||
) -> List[str]:
|
|
||||||
"""提取视频关键帧"""
|
|
||||||
video_hash = utils.md5(video_path + str(os.path.getmtime(video_path)))
|
|
||||||
video_keyframes_dir = os.path.join(self.keyframes_dir, video_hash)
|
|
||||||
|
|
||||||
# 检查缓存
|
|
||||||
keyframe_files = []
|
|
||||||
if os.path.exists(video_keyframes_dir):
|
|
||||||
for filename in sorted(os.listdir(video_keyframes_dir)):
|
|
||||||
if filename.endswith('.jpg'):
|
|
||||||
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
|
|
||||||
|
|
||||||
if keyframe_files:
|
|
||||||
logger.info(f"Using cached keyframes: {video_keyframes_dir}")
|
|
||||||
return keyframe_files
|
|
||||||
|
|
||||||
# 提取新的关键帧
|
|
||||||
os.makedirs(video_keyframes_dir, exist_ok=True)
|
|
||||||
|
|
||||||
try:
|
|
||||||
processor = video_processor.VideoProcessor(video_path)
|
|
||||||
processor.process_video_pipeline(
|
|
||||||
output_dir=video_keyframes_dir,
|
|
||||||
skip_seconds=skip_seconds,
|
|
||||||
threshold=threshold
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for filename in sorted(os.listdir(video_keyframes_dir)):
|
return await self.documentary_service.generate_documentary_script(
|
||||||
if filename.endswith('.jpg'):
|
video_path=video_path,
|
||||||
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
|
video_theme=video_theme,
|
||||||
|
custom_prompt=custom_prompt,
|
||||||
return keyframe_files
|
frame_interval_input=frame_interval_input,
|
||||||
|
vision_batch_size=vision_batch_size,
|
||||||
except Exception as e:
|
vision_llm_provider=vision_llm_provider,
|
||||||
if os.path.exists(video_keyframes_dir):
|
progress_callback=callback,
|
||||||
import shutil
|
|
||||||
shutil.rmtree(video_keyframes_dir)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def _process_with_llm(
|
|
||||||
self,
|
|
||||||
keyframe_files: List[str],
|
|
||||||
video_theme: str,
|
|
||||||
custom_prompt: str,
|
|
||||||
vision_batch_size: int,
|
|
||||||
vision_llm_provider: str,
|
|
||||||
progress_callback: Callable[[float, str], None]
|
|
||||||
) -> str:
|
|
||||||
"""使用统一 LLM 接口处理视频帧"""
|
|
||||||
progress_callback(30, "正在初始化视觉分析器...")
|
|
||||||
|
|
||||||
# 使用新的 LLM 迁移适配器(支持所有 provider)
|
|
||||||
from app.services.llm.migration_adapter import create_vision_analyzer
|
|
||||||
|
|
||||||
# 获取配置
|
|
||||||
text_provider = config.app.get('text_llm_provider', 'openai').lower()
|
|
||||||
vision_api_key = config.app.get(f'vision_{vision_llm_provider}_api_key')
|
|
||||||
vision_model = config.app.get(f'vision_{vision_llm_provider}_model_name')
|
|
||||||
vision_base_url = config.app.get(f'vision_{vision_llm_provider}_base_url')
|
|
||||||
|
|
||||||
if not vision_api_key or not vision_model:
|
|
||||||
raise ValueError(f"未配置 {vision_llm_provider} API Key 或者模型")
|
|
||||||
|
|
||||||
# 创建统一的视觉分析器
|
|
||||||
analyzer = create_vision_analyzer(
|
|
||||||
provider=vision_llm_provider,
|
|
||||||
api_key=vision_api_key,
|
|
||||||
model=vision_model,
|
|
||||||
base_url=vision_base_url
|
|
||||||
)
|
)
|
||||||
|
|
||||||
progress_callback(40, "正在分析关键帧...")
|
|
||||||
|
|
||||||
# 执行异步分析
|
|
||||||
results = await analyzer.analyze_images(
|
|
||||||
images=keyframe_files,
|
|
||||||
prompt=config.app.get('vision_analysis_prompt'),
|
|
||||||
batch_size=vision_batch_size
|
|
||||||
)
|
|
||||||
|
|
||||||
progress_callback(60, "正在整理分析结果...")
|
|
||||||
|
|
||||||
# 合并所有批次的分析结果
|
|
||||||
frame_analysis = ""
|
|
||||||
prev_batch_files = None
|
|
||||||
|
|
||||||
for result in results:
|
|
||||||
if 'error' in result:
|
|
||||||
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
batch_files = self._get_batch_files(keyframe_files, result, vision_batch_size)
|
|
||||||
first_timestamp, last_timestamp, _ = self._get_batch_timestamps(batch_files, prev_batch_files)
|
|
||||||
|
|
||||||
# 添加带时间戳的分<E79A84><E58886>结果
|
|
||||||
frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n"
|
|
||||||
frame_analysis += result['response']
|
|
||||||
frame_analysis += "\n"
|
|
||||||
|
|
||||||
prev_batch_files = batch_files
|
|
||||||
|
|
||||||
if not frame_analysis.strip():
|
|
||||||
raise Exception("未能生成有效的帧分析结果")
|
|
||||||
|
|
||||||
progress_callback(70, "正在生成脚本...")
|
|
||||||
|
|
||||||
# 构建帧内容列表
|
|
||||||
frame_content_list = []
|
|
||||||
prev_batch_files = None
|
|
||||||
|
|
||||||
for result in results:
|
|
||||||
if 'error' in result:
|
|
||||||
continue
|
|
||||||
|
|
||||||
batch_files = self._get_batch_files(keyframe_files, result, vision_batch_size)
|
|
||||||
_, _, timestamp_range = self._get_batch_timestamps(batch_files, prev_batch_files)
|
|
||||||
|
|
||||||
frame_content = {
|
|
||||||
"timestamp": timestamp_range,
|
|
||||||
"picture": result['response'],
|
|
||||||
"narration": "",
|
|
||||||
"OST": 2
|
|
||||||
}
|
|
||||||
frame_content_list.append(frame_content)
|
|
||||||
prev_batch_files = batch_files
|
|
||||||
|
|
||||||
if not frame_content_list:
|
|
||||||
raise Exception("没有有效的帧内容可以处理")
|
|
||||||
|
|
||||||
progress_callback(90, "正在生成文案...")
|
|
||||||
|
|
||||||
# 获取文本生<E69CAC><E7949F>配置
|
|
||||||
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
|
|
||||||
text_api_key = config.app.get(f'text_{text_provider}_api_key')
|
|
||||||
text_model = config.app.get(f'text_{text_provider}_model_name')
|
|
||||||
text_base_url = config.app.get(f'text_{text_provider}_base_url')
|
|
||||||
|
|
||||||
# 根据提供商类型选择合适的处理器
|
|
||||||
if text_provider == 'gemini(openai)':
|
|
||||||
# 使用OpenAI兼容的Gemini代理
|
|
||||||
from app.utils.script_generator import GeminiOpenAIGenerator
|
|
||||||
generator = GeminiOpenAIGenerator(
|
|
||||||
model_name=text_model,
|
|
||||||
api_key=text_api_key,
|
|
||||||
prompt=custom_prompt,
|
|
||||||
base_url=text_base_url
|
|
||||||
)
|
|
||||||
processor = ScriptProcessor(
|
|
||||||
model_name=text_model,
|
|
||||||
api_key=text_api_key,
|
|
||||||
base_url=text_base_url,
|
|
||||||
prompt=custom_prompt,
|
|
||||||
video_theme=video_theme
|
|
||||||
)
|
|
||||||
processor.generator = generator
|
|
||||||
else:
|
|
||||||
# 使用标准处理器(包括原生Gemini)
|
|
||||||
processor = ScriptProcessor(
|
|
||||||
model_name=text_model,
|
|
||||||
api_key=text_api_key,
|
|
||||||
base_url=text_base_url,
|
|
||||||
prompt=custom_prompt,
|
|
||||||
video_theme=video_theme
|
|
||||||
)
|
|
||||||
|
|
||||||
return processor.process_frames(frame_content_list)
|
|
||||||
|
|
||||||
def _get_batch_files(
|
|
||||||
self,
|
|
||||||
keyframe_files: List[str],
|
|
||||||
result: Dict[str, Any],
|
|
||||||
batch_size: int
|
|
||||||
) -> List[str]:
|
|
||||||
"""获取当前批次的图片文件"""
|
|
||||||
batch_start = result['batch_index'] * batch_size
|
|
||||||
batch_end = min(batch_start + batch_size, len(keyframe_files))
|
|
||||||
return keyframe_files[batch_start:batch_end]
|
|
||||||
|
|
||||||
def _get_batch_timestamps(
|
|
||||||
self,
|
|
||||||
batch_files: List[str],
|
|
||||||
prev_batch_files: List[str] = None
|
|
||||||
) -> tuple[str, str, str]:
|
|
||||||
"""获取一批文件的时间戳范围,支持毫秒级精度"""
|
|
||||||
if not batch_files:
|
|
||||||
logger.warning("Empty batch files")
|
|
||||||
return "00:00:00,000", "00:00:00,000", "00:00:00,000-00:00:00,000"
|
|
||||||
|
|
||||||
if len(batch_files) == 1 and prev_batch_files and len(prev_batch_files) > 0:
|
|
||||||
first_frame = os.path.basename(prev_batch_files[-1])
|
|
||||||
last_frame = os.path.basename(batch_files[0])
|
|
||||||
else:
|
|
||||||
first_frame = os.path.basename(batch_files[0])
|
|
||||||
last_frame = os.path.basename(batch_files[-1])
|
|
||||||
|
|
||||||
first_time = first_frame.split('_')[2].replace('.jpg', '')
|
|
||||||
last_time = last_frame.split('_')[2].replace('.jpg', '')
|
|
||||||
|
|
||||||
def format_timestamp(time_str: str) -> str:
|
|
||||||
"""将时间字符串转换为 HH:MM:SS,mmm 格式"""
|
|
||||||
try:
|
|
||||||
if len(time_str) < 4:
|
|
||||||
logger.warning(f"Invalid timestamp format: {time_str}")
|
|
||||||
return "00:00:00,000"
|
|
||||||
|
|
||||||
# 处理毫秒部分
|
|
||||||
if ',' in time_str:
|
|
||||||
time_part, ms_part = time_str.split(',')
|
|
||||||
ms = int(ms_part)
|
|
||||||
else:
|
|
||||||
time_part = time_str
|
|
||||||
ms = 0
|
|
||||||
|
|
||||||
# 处理时分秒
|
|
||||||
parts = time_part.split(':')
|
|
||||||
if len(parts) == 3: # HH:MM:SS
|
|
||||||
h, m, s = map(int, parts)
|
|
||||||
elif len(parts) == 2: # MM:SS
|
|
||||||
h = 0
|
|
||||||
m, s = map(int, parts)
|
|
||||||
else: # SS
|
|
||||||
h = 0
|
|
||||||
m = 0
|
|
||||||
s = int(parts[0])
|
|
||||||
|
|
||||||
# 处理进位
|
|
||||||
if s >= 60:
|
|
||||||
m += s // 60
|
|
||||||
s = s % 60
|
|
||||||
if m >= 60:
|
|
||||||
h += m // 60
|
|
||||||
m = m % 60
|
|
||||||
|
|
||||||
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"时间戳格式转换错误 {time_str}: {str(e)}")
|
|
||||||
return "00:00:00,000"
|
|
||||||
|
|
||||||
first_timestamp = format_timestamp(first_time)
|
|
||||||
last_timestamp = format_timestamp(last_time)
|
|
||||||
timestamp_range = f"{first_timestamp}-{last_timestamp}"
|
|
||||||
|
|
||||||
return first_timestamp, last_timestamp, timestamp_range
|
|
||||||
|
|||||||
403
app/services/test_fun_asr_subtitle_unittest.py
Normal file
403
app/services/test_fun_asr_subtitle_unittest.py
Normal file
@ -0,0 +1,403 @@
|
|||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
try:
|
||||||
|
import tomllib
|
||||||
|
except ModuleNotFoundError: # Python < 3.11
|
||||||
|
import tomli as tomllib
|
||||||
|
|
||||||
|
from app.config import config as cfg
|
||||||
|
from app.services import fun_asr_subtitle as fasr
|
||||||
|
|
||||||
|
|
||||||
|
class FakeResponse:
|
||||||
|
def __init__(self, payload=None, status_code=200):
|
||||||
|
self.payload = payload or {}
|
||||||
|
self.status_code = status_code
|
||||||
|
|
||||||
|
def json(self):
|
||||||
|
return self.payload
|
||||||
|
|
||||||
|
def raise_for_status(self):
|
||||||
|
if self.status_code >= 400:
|
||||||
|
raise RuntimeError(f"HTTP {self.status_code}")
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidJsonResponse(FakeResponse):
|
||||||
|
def json(self):
|
||||||
|
raise ValueError("invalid json")
|
||||||
|
|
||||||
|
|
||||||
|
class FakeSession:
|
||||||
|
def __init__(self, local_result):
|
||||||
|
self.calls = []
|
||||||
|
self.local_result = local_result
|
||||||
|
|
||||||
|
def get(self, url, **kwargs):
|
||||||
|
self.calls.append(("GET", url, kwargs))
|
||||||
|
if url == fasr.UPLOAD_POLICY_URL:
|
||||||
|
return FakeResponse(
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"policy": "policy-token",
|
||||||
|
"signature": "signature-token",
|
||||||
|
"upload_dir": "dashscope-instant/test-dir",
|
||||||
|
"upload_host": "https://dashscope-file-test.oss-cn-beijing.aliyuncs.com",
|
||||||
|
"oss_access_key_id": "oss-ak",
|
||||||
|
"x_oss_object_acl": "private",
|
||||||
|
"x_oss_forbid_overwrite": "true",
|
||||||
|
"max_file_size_mb": 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if url == "https://result.example/transcription.json":
|
||||||
|
return FakeResponse(self.local_result)
|
||||||
|
return FakeResponse({}, 404)
|
||||||
|
|
||||||
|
def post(self, url, **kwargs):
|
||||||
|
self.calls.append(("POST", url, kwargs))
|
||||||
|
if url == "https://dashscope-file-test.oss-cn-beijing.aliyuncs.com":
|
||||||
|
return FakeResponse({})
|
||||||
|
if url == fasr.TRANSCRIPTION_URL:
|
||||||
|
return FakeResponse({"output": {"task_status": "PENDING", "task_id": "task-123"}})
|
||||||
|
if url == fasr.TASK_URL_TEMPLATE.format(task_id="task-123"):
|
||||||
|
return FakeResponse(
|
||||||
|
{
|
||||||
|
"output": {
|
||||||
|
"task_status": "SUCCEEDED",
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"file_url": "oss://dashscope-instant/test-dir/audio.wav",
|
||||||
|
"transcription_url": "https://result.example/transcription.json",
|
||||||
|
"subtask_status": "SUCCEEDED",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return FakeResponse({}, 404)
|
||||||
|
|
||||||
|
|
||||||
|
OFFICIAL_SHAPE_RESULT = {
|
||||||
|
"transcripts": [
|
||||||
|
{
|
||||||
|
"sentences": [
|
||||||
|
{
|
||||||
|
"begin_time": 0,
|
||||||
|
"end_time": 3600,
|
||||||
|
"text": "你好欢迎观看今天的内容",
|
||||||
|
"speaker_id": 0,
|
||||||
|
"words": [
|
||||||
|
{"begin_time": 0, "end_time": 400, "text": "你好", "punctuation": ","},
|
||||||
|
{"begin_time": 400, "end_time": 900, "text": "欢迎", "punctuation": ""},
|
||||||
|
{"begin_time": 900, "end_time": 1300, "text": "观看", "punctuation": ""},
|
||||||
|
{"begin_time": 1300, "end_time": 1800, "text": "今天", "punctuation": ""},
|
||||||
|
{"begin_time": 1800, "end_time": 2400, "text": "的内容", "punctuation": "。"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class FunAsrSrtConversionTests(unittest.TestCase):
|
||||||
|
def test_official_shape_words_convert_ms_and_speaker_label(self):
|
||||||
|
srt = fasr.fun_asr_result_to_srt(OFFICIAL_SHAPE_RESULT, max_chars=20, max_duration=3.5)
|
||||||
|
|
||||||
|
self.assertIn("1\n00:00:00,000 --> 00:00:00,400\n说话人1: 你好,", srt)
|
||||||
|
self.assertIn("2\n00:00:00,400 --> 00:00:02,400\n说话人1: 欢迎观看今天的内容。", srt)
|
||||||
|
self.assertNotIn("00:06:40,000", srt, "milliseconds must not be treated as seconds")
|
||||||
|
|
||||||
|
def test_long_word_sequence_splits_into_fine_blocks(self):
|
||||||
|
result = {
|
||||||
|
"transcripts": [
|
||||||
|
{
|
||||||
|
"sentences": [
|
||||||
|
{
|
||||||
|
"begin_time": 0,
|
||||||
|
"end_time": 6000,
|
||||||
|
"speaker_id": 1,
|
||||||
|
"words": [
|
||||||
|
{"begin_time": i * 500, "end_time": (i + 1) * 500, "text": token, "punctuation": ""}
|
||||||
|
for i, token in enumerate(["这是", "一个", "很长", "字幕", "需要", "拆分"])
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
srt = fasr.fun_asr_result_to_srt(result, max_chars=4, max_duration=10)
|
||||||
|
|
||||||
|
self.assertGreaterEqual(srt.count("\n说话人2:"), 3)
|
||||||
|
self.assertIn("1\n00:00:00,000", srt)
|
||||||
|
|
||||||
|
def test_sentence_fallback_uses_ms_without_zero_duration(self):
|
||||||
|
result = {
|
||||||
|
"transcripts": [
|
||||||
|
{
|
||||||
|
"sentences": [
|
||||||
|
{
|
||||||
|
"begin_time": 1000,
|
||||||
|
"end_time": 3000,
|
||||||
|
"text": "没有词级时间戳也可以拆分。",
|
||||||
|
"speaker_id": 0,
|
||||||
|
"words": [],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
srt = fasr.fun_asr_result_to_srt(result, max_chars=5)
|
||||||
|
|
||||||
|
self.assertIn("00:00:01,000", srt)
|
||||||
|
self.assertIn("说话人1:", srt)
|
||||||
|
self.assertNotIn("--> 00:00:01,000\n", srt)
|
||||||
|
|
||||||
|
def test_empty_result_raises_clear_error(self):
|
||||||
|
with self.assertRaises(fasr.FunAsrError):
|
||||||
|
fasr.fun_asr_result_to_srt({"transcripts": []})
|
||||||
|
|
||||||
|
def test_malformed_word_timestamp_raises_fun_asr_error(self):
|
||||||
|
result = {
|
||||||
|
"transcripts": [
|
||||||
|
{
|
||||||
|
"sentences": [
|
||||||
|
{
|
||||||
|
"begin_time": 0,
|
||||||
|
"end_time": 1000,
|
||||||
|
"speaker_id": 0,
|
||||||
|
"words": [
|
||||||
|
{"begin_time": "bad", "end_time": 500, "text": "坏时间", "punctuation": ""}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
with self.assertRaises(fasr.FunAsrError):
|
||||||
|
fasr.fun_asr_result_to_srt(result)
|
||||||
|
|
||||||
|
def test_malformed_sentence_timestamp_raises_fun_asr_error(self):
|
||||||
|
result = {
|
||||||
|
"transcripts": [
|
||||||
|
{
|
||||||
|
"sentences": [
|
||||||
|
{
|
||||||
|
"begin_time": "bad",
|
||||||
|
"end_time": 1000,
|
||||||
|
"text": "坏时间",
|
||||||
|
"speaker_id": 0,
|
||||||
|
"words": [],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
with self.assertRaises(fasr.FunAsrError):
|
||||||
|
fasr.fun_asr_result_to_srt(result)
|
||||||
|
|
||||||
|
|
||||||
|
class FunAsrServiceTests(unittest.TestCase):
|
||||||
|
def test_create_with_fun_asr_uses_expected_rest_flow(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
local_file = Path(tmp_dir) / "audio.wav"
|
||||||
|
local_file.write_bytes(b"audio")
|
||||||
|
subtitle_file = Path(tmp_dir) / "out.srt"
|
||||||
|
session = FakeSession(OFFICIAL_SHAPE_RESULT)
|
||||||
|
|
||||||
|
result_path = fasr.create_with_fun_asr(
|
||||||
|
str(local_file),
|
||||||
|
subtitle_file=str(subtitle_file),
|
||||||
|
api_key="sk-test",
|
||||||
|
speaker_count=2,
|
||||||
|
poll_interval=0,
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(str(subtitle_file), result_path)
|
||||||
|
self.assertTrue(subtitle_file.exists())
|
||||||
|
self.assertIn("说话人1:", subtitle_file.read_text(encoding="utf-8"))
|
||||||
|
|
||||||
|
policy_call = session.calls[0]
|
||||||
|
self.assertEqual("GET", policy_call[0])
|
||||||
|
self.assertEqual(fasr.UPLOAD_POLICY_URL, policy_call[1])
|
||||||
|
self.assertEqual({"action": "getPolicy", "model": "fun-asr"}, policy_call[2]["params"])
|
||||||
|
self.assertEqual("Bearer sk-test", policy_call[2]["headers"]["Authorization"])
|
||||||
|
|
||||||
|
upload_call = session.calls[1]
|
||||||
|
self.assertEqual("POST", upload_call[0])
|
||||||
|
self.assertEqual("https://dashscope-file-test.oss-cn-beijing.aliyuncs.com", upload_call[1])
|
||||||
|
upload_data = upload_call[2]["data"]
|
||||||
|
self.assertEqual("oss-ak", upload_data["OSSAccessKeyId"])
|
||||||
|
self.assertEqual("policy-token", upload_data["policy"])
|
||||||
|
self.assertEqual("signature-token", upload_data["Signature"])
|
||||||
|
self.assertEqual("dashscope-instant/test-dir/audio.wav", upload_data["key"])
|
||||||
|
self.assertEqual("200", upload_data["success_action_status"])
|
||||||
|
|
||||||
|
submit_call = session.calls[2]
|
||||||
|
self.assertEqual(fasr.TRANSCRIPTION_URL, submit_call[1])
|
||||||
|
headers = submit_call[2]["headers"]
|
||||||
|
self.assertEqual("enable", headers["X-DashScope-Async"])
|
||||||
|
self.assertEqual("enable", headers["X-DashScope-OssResourceResolve"])
|
||||||
|
payload = submit_call[2]["json"]
|
||||||
|
self.assertEqual("fun-asr", payload["model"])
|
||||||
|
self.assertEqual(["oss://dashscope-instant/test-dir/audio.wav"], payload["input"]["file_urls"])
|
||||||
|
self.assertTrue(payload["parameters"]["diarization_enabled"])
|
||||||
|
self.assertEqual(2, payload["parameters"]["speaker_count"])
|
||||||
|
|
||||||
|
poll_call = session.calls[3]
|
||||||
|
self.assertEqual("POST", poll_call[0])
|
||||||
|
self.assertTrue(poll_call[1].endswith("/api/v1/tasks/task-123"))
|
||||||
|
|
||||||
|
download_call = session.calls[4]
|
||||||
|
self.assertEqual(("GET", "https://result.example/transcription.json"), download_call[:2])
|
||||||
|
|
||||||
|
def test_upload_policy_size_validation_fails_before_upload(self):
|
||||||
|
policy = fasr.UploadPolicy(
|
||||||
|
upload_host="https://upload.example",
|
||||||
|
upload_dir="dashscope-instant/test",
|
||||||
|
policy="p",
|
||||||
|
signature="s",
|
||||||
|
oss_access_key_id="ak",
|
||||||
|
max_file_size_mb=0.000001,
|
||||||
|
)
|
||||||
|
with tempfile.NamedTemporaryFile() as f:
|
||||||
|
f.write(b"too-large")
|
||||||
|
f.flush()
|
||||||
|
with self.assertRaises(fasr.FunAsrError):
|
||||||
|
fasr.upload_to_temporary_oss(f.name, policy, session=FakeSession({}))
|
||||||
|
|
||||||
|
def test_failed_subtask_raises(self):
|
||||||
|
class FailedSession(FakeSession):
|
||||||
|
def post(self, url, **kwargs):
|
||||||
|
if url == fasr.TASK_URL_TEMPLATE.format(task_id="task-123"):
|
||||||
|
return FakeResponse(
|
||||||
|
{
|
||||||
|
"output": {
|
||||||
|
"task_status": "SUCCEEDED",
|
||||||
|
"results": [{"subtask_status": "FAILED"}],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return super().post(url, **kwargs)
|
||||||
|
|
||||||
|
with self.assertRaises(fasr.FunAsrError):
|
||||||
|
fasr.poll_transcription_task("sk-test", "task-123", poll_interval=0, session=FailedSession({}))
|
||||||
|
|
||||||
|
def test_missing_api_key_raises_before_request(self):
|
||||||
|
session = FakeSession(OFFICIAL_SHAPE_RESULT)
|
||||||
|
|
||||||
|
with self.assertRaises(fasr.FunAsrError):
|
||||||
|
fasr.request_upload_policy("", session=session)
|
||||||
|
|
||||||
|
self.assertEqual([], session.calls)
|
||||||
|
|
||||||
|
def test_upload_policy_http_error_raises(self):
|
||||||
|
class PolicyHttpErrorSession(FakeSession):
|
||||||
|
def get(self, url, **kwargs):
|
||||||
|
self.calls.append(("GET", url, kwargs))
|
||||||
|
return FakeResponse({}, status_code=403)
|
||||||
|
|
||||||
|
with self.assertRaises(fasr.FunAsrError):
|
||||||
|
fasr.request_upload_policy("sk-test", session=PolicyHttpErrorSession({}))
|
||||||
|
|
||||||
|
def test_malformed_upload_policy_raises(self):
|
||||||
|
class MalformedPolicySession(FakeSession):
|
||||||
|
def get(self, url, **kwargs):
|
||||||
|
self.calls.append(("GET", url, kwargs))
|
||||||
|
return FakeResponse({"data": {"policy": "missing-required-fields"}})
|
||||||
|
|
||||||
|
with self.assertRaises(fasr.FunAsrError):
|
||||||
|
fasr.request_upload_policy("sk-test", session=MalformedPolicySession({}))
|
||||||
|
|
||||||
|
def test_upload_http_failure_raises(self):
|
||||||
|
class UploadFailureSession(FakeSession):
|
||||||
|
def post(self, url, **kwargs):
|
||||||
|
self.calls.append(("POST", url, kwargs))
|
||||||
|
return FakeResponse({}, status_code=500)
|
||||||
|
|
||||||
|
policy = fasr.UploadPolicy(
|
||||||
|
upload_host="https://upload.example",
|
||||||
|
upload_dir="dashscope-instant/test",
|
||||||
|
policy="p",
|
||||||
|
signature="s",
|
||||||
|
oss_access_key_id="ak",
|
||||||
|
max_file_size_mb=1,
|
||||||
|
)
|
||||||
|
with tempfile.NamedTemporaryFile() as f:
|
||||||
|
f.write(b"audio")
|
||||||
|
f.flush()
|
||||||
|
with self.assertRaises(fasr.FunAsrError):
|
||||||
|
fasr.upload_to_temporary_oss(f.name, policy, session=UploadFailureSession({}))
|
||||||
|
|
||||||
|
def test_submit_failure_raises(self):
|
||||||
|
class SubmitFailureSession(FakeSession):
|
||||||
|
def post(self, url, **kwargs):
|
||||||
|
self.calls.append(("POST", url, kwargs))
|
||||||
|
return FakeResponse({}, status_code=500)
|
||||||
|
|
||||||
|
with self.assertRaises(fasr.FunAsrError):
|
||||||
|
fasr.submit_transcription_task("sk-test", "oss://file", session=SubmitFailureSession({}))
|
||||||
|
|
||||||
|
def test_poll_timeout_raises(self):
|
||||||
|
class PendingSession(FakeSession):
|
||||||
|
def post(self, url, **kwargs):
|
||||||
|
self.calls.append(("POST", url, kwargs))
|
||||||
|
return FakeResponse({"output": {"task_status": "RUNNING"}})
|
||||||
|
|
||||||
|
with self.assertRaises(fasr.FunAsrError):
|
||||||
|
fasr.poll_transcription_task("sk-test", "task-123", poll_interval=0, timeout=-1, session=PendingSession({}))
|
||||||
|
|
||||||
|
def test_task_failed_status_raises(self):
|
||||||
|
class FailedTaskSession(FakeSession):
|
||||||
|
def post(self, url, **kwargs):
|
||||||
|
self.calls.append(("POST", url, kwargs))
|
||||||
|
return FakeResponse({"output": {"task_status": "FAILED"}})
|
||||||
|
|
||||||
|
with self.assertRaises(fasr.FunAsrError):
|
||||||
|
fasr.poll_transcription_task("sk-test", "task-123", poll_interval=0, session=FailedTaskSession({}))
|
||||||
|
|
||||||
|
def test_missing_transcription_url_raises(self):
|
||||||
|
with self.assertRaises(fasr.FunAsrError):
|
||||||
|
fasr.download_transcription_result("", session=FakeSession({}))
|
||||||
|
|
||||||
|
def test_malformed_downloaded_json_raises(self):
|
||||||
|
class MalformedDownloadSession(FakeSession):
|
||||||
|
def get(self, url, **kwargs):
|
||||||
|
self.calls.append(("GET", url, kwargs))
|
||||||
|
return InvalidJsonResponse()
|
||||||
|
|
||||||
|
with self.assertRaises(fasr.FunAsrError):
|
||||||
|
fasr.download_transcription_result("https://result.example/bad.json", session=MalformedDownloadSession({}))
|
||||||
|
|
||||||
|
|
||||||
|
class FunAsrConfigTests(unittest.TestCase):
|
||||||
|
def test_save_config_persists_fun_asr_section(self):
|
||||||
|
original_config_file = cfg.config_file
|
||||||
|
original_fun_asr = cfg.fun_asr
|
||||||
|
try:
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
config_path = Path(tmp_dir) / "config.toml"
|
||||||
|
cfg.config_file = str(config_path)
|
||||||
|
cfg.fun_asr = {"api_key": "sk-local", "model": "fun-asr"}
|
||||||
|
cfg.save_config()
|
||||||
|
saved = tomllib.loads(config_path.read_text(encoding="utf-8"))
|
||||||
|
finally:
|
||||||
|
cfg.config_file = original_config_file
|
||||||
|
cfg.fun_asr = original_fun_asr
|
||||||
|
|
||||||
|
self.assertEqual("sk-local", saved["fun_asr"]["api_key"])
|
||||||
|
self.assertEqual("fun-asr", saved["fun_asr"]["model"])
|
||||||
|
|
||||||
|
def test_config_example_fun_asr_section_parses(self):
|
||||||
|
config_data = tomllib.loads(Path("config.example.toml").read_text(encoding="utf-8"))
|
||||||
|
self.assertEqual("fun-asr", config_data["fun_asr"]["model"])
|
||||||
|
self.assertIn("api_key", config_data["fun_asr"])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@ -1116,6 +1116,125 @@ def should_use_azure_speech_services(voice_name: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def doubaotts_tts(text: str, voice_name: str, voice_file: str, speed: float = 1.0) -> Union[SubMaker, None]:
|
||||||
|
"""
|
||||||
|
使用豆包语音 TTS 生成语音
|
||||||
|
"""
|
||||||
|
# 读取配置
|
||||||
|
doubaotts_cfg = getattr(config, "doubaotts", {}) or {}
|
||||||
|
appid = doubaotts_cfg.get("appid", "")
|
||||||
|
token = doubaotts_cfg.get("token", "")
|
||||||
|
ak = doubaotts_cfg.get("ak", "")
|
||||||
|
sk = doubaotts_cfg.get("sk", "")
|
||||||
|
cluster = doubaotts_cfg.get("cluster", "volcano_tts")
|
||||||
|
|
||||||
|
if not appid or not token:
|
||||||
|
logger.error("豆包语音 TTS 配置未完成")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 准备参数
|
||||||
|
voice_type = voice_name
|
||||||
|
safe_speed = float(max(0.2, min(3.0, speed)))
|
||||||
|
text = text.strip()
|
||||||
|
|
||||||
|
# 构建请求参数
|
||||||
|
import uuid
|
||||||
|
reqid = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# 获取高级参数
|
||||||
|
volume = doubaotts_cfg.get("volume", 1.0)
|
||||||
|
pitch = doubaotts_cfg.get("pitch", 1.0)
|
||||||
|
silence_duration = doubaotts_cfg.get("silence_duration", 0.125)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"app": {
|
||||||
|
"appid": appid,
|
||||||
|
"token": token,
|
||||||
|
"cluster": cluster
|
||||||
|
},
|
||||||
|
"user": {
|
||||||
|
"uid": "NarratoAI"
|
||||||
|
},
|
||||||
|
"audio": {
|
||||||
|
"voice_type": voice_type,
|
||||||
|
"encoding": "mp3",
|
||||||
|
"rate": 24000,
|
||||||
|
"speed_ratio": safe_speed,
|
||||||
|
"volume_ratio": float(volume),
|
||||||
|
"pitch_ratio": float(pitch)
|
||||||
|
},
|
||||||
|
"request": {
|
||||||
|
"reqid": reqid,
|
||||||
|
"text": text,
|
||||||
|
"text_type": "plain",
|
||||||
|
"operation": "query"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果设置了句尾静音时长,添加到请求参数中
|
||||||
|
if silence_duration > 0:
|
||||||
|
payload["audio"]["silence_duration"] = float(silence_duration)
|
||||||
|
|
||||||
|
# API 地址
|
||||||
|
url = "https://openspeech.bytedance.com/api/v1/tts"
|
||||||
|
|
||||||
|
# 构建请求头(使用Bearer Token认证)
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer;{token}"
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
try:
|
||||||
|
logger.info(f"=== 豆包语音 TTS 请求参数 (第 {i+1} 次调用) ===")
|
||||||
|
|
||||||
|
# 发送请求
|
||||||
|
import requests
|
||||||
|
# 处理代理设置
|
||||||
|
proxies = None
|
||||||
|
proxy_enabled = config.proxy.get("enabled", False)
|
||||||
|
if proxy_enabled:
|
||||||
|
proxy_url = config.proxy.get("https", config.proxy.get("http", ""))
|
||||||
|
if proxy_url:
|
||||||
|
proxies = {"https": proxy_url, "http": proxy_url}
|
||||||
|
response = requests.post(url, json=payload, headers=headers, proxies=proxies, timeout=60)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
result = response.json()
|
||||||
|
if result.get("code") == 3000:
|
||||||
|
# 成功
|
||||||
|
audio_data = result.get("data", "")
|
||||||
|
if audio_data:
|
||||||
|
# 解码 base64 音频数据
|
||||||
|
import base64
|
||||||
|
audio_bytes = base64.b64decode(audio_data)
|
||||||
|
|
||||||
|
# 写入文件
|
||||||
|
with open(voice_file, "wb") as f:
|
||||||
|
f.write(audio_bytes)
|
||||||
|
|
||||||
|
logger.success(f"豆包语音 TTS 合成成功: {voice_file}")
|
||||||
|
|
||||||
|
# 创建 SubMaker 对象(简化版,不包含时间戳)
|
||||||
|
sub_maker = new_sub_maker()
|
||||||
|
return sub_maker
|
||||||
|
else:
|
||||||
|
logger.error("豆包语音 TTS 响应中无音频数据")
|
||||||
|
else:
|
||||||
|
logger.error(f"豆包语音 TTS 失败: {result.get('message', '未知错误')}")
|
||||||
|
else:
|
||||||
|
logger.error(f"豆包语音 TTS API 请求失败: {response.status_code}, {response.text}")
|
||||||
|
|
||||||
|
if i < 2:
|
||||||
|
time.sleep(1)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"豆包语音 TTS 错误: {str(e)}")
|
||||||
|
if i < 2:
|
||||||
|
time.sleep(3)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def tts(
|
def tts(
|
||||||
text: str, voice_name: str, voice_rate: float, voice_pitch: float, voice_file: str, tts_engine: str
|
text: str, voice_name: str, voice_rate: float, voice_pitch: float, voice_file: str, tts_engine: str
|
||||||
) -> Union[SubMaker, None]:
|
) -> Union[SubMaker, None]:
|
||||||
@ -1147,6 +1266,10 @@ def tts(
|
|||||||
if tts_engine == "indextts2":
|
if tts_engine == "indextts2":
|
||||||
logger.info("分发到 IndexTTS2")
|
logger.info("分发到 IndexTTS2")
|
||||||
return indextts2_tts(text, voice_name, voice_file, speed=voice_rate)
|
return indextts2_tts(text, voice_name, voice_file, speed=voice_rate)
|
||||||
|
|
||||||
|
if tts_engine == "doubaotts":
|
||||||
|
logger.info("分发到豆包语音 TTS")
|
||||||
|
return doubaotts_tts(text, voice_name, voice_file, speed=voice_rate)
|
||||||
|
|
||||||
# Fallback for unknown engine - default to azure v1
|
# Fallback for unknown engine - default to azure v1
|
||||||
logger.warning(f"未知的 TTS 引擎: '{tts_engine}', 将默认使用 Edge TTS (Azure V1)。")
|
logger.warning(f"未知的 TTS 引擎: '{tts_engine}', 将默认使用 Edge TTS (Azure V1)。")
|
||||||
@ -1606,8 +1729,8 @@ def tts_multiple(task_id: str, list_script: list, voice_name: str, voice_rate: f
|
|||||||
f"或者使用其他 tts 引擎")
|
f"或者使用其他 tts 引擎")
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
# SoulVoice、Qwen3、IndexTTS2 引擎不生成字幕文件
|
# SoulVoice、Qwen3、IndexTTS2、豆包语音 引擎不生成字幕文件
|
||||||
if is_soulvoice_voice(voice_name) or is_qwen_engine(tts_engine) or tts_engine == "indextts2":
|
if is_soulvoice_voice(voice_name) or is_qwen_engine(tts_engine) or tts_engine == "indextts2" or tts_engine == "doubaotts":
|
||||||
# 获取实际音频文件的时长
|
# 获取实际音频文件的时长
|
||||||
duration = get_audio_duration_from_file(audio_file)
|
duration = get_audio_duration_from_file(audio_file)
|
||||||
if duration <= 0:
|
if duration <= 0:
|
||||||
@ -1615,8 +1738,27 @@ def tts_multiple(task_id: str, list_script: list, voice_name: str, voice_rate: f
|
|||||||
duration = get_audio_duration(sub_maker)
|
duration = get_audio_duration(sub_maker)
|
||||||
if duration <= 0:
|
if duration <= 0:
|
||||||
# 最后的 fallback,基于文本长度估算
|
# 最后的 fallback,基于文本长度估算
|
||||||
duration = max(1.0, len(text) / 3.0)
|
# 对于英文文本,使用更准确的估算方法
|
||||||
logger.warning(f"无法获取音频时长,使用文本估算: {duration:.2f}秒")
|
# 英文平均语速约为每分钟150-180个单词,即每秒2.5-3个单词
|
||||||
|
# 对于中文文本,约为每秒3-4字
|
||||||
|
import re
|
||||||
|
# 计算英文单词数
|
||||||
|
english_words = len(re.findall(r'\b\w+\b', text))
|
||||||
|
# 计算中文字符数
|
||||||
|
chinese_chars = len(re.findall(r'[\u4e00-\u9fa5]', text))
|
||||||
|
|
||||||
|
if english_words > chinese_chars:
|
||||||
|
# 主要是英文文本
|
||||||
|
# 假设平均每个单词需要0.35秒
|
||||||
|
estimated_duration = max(1.0, english_words * 0.35)
|
||||||
|
else:
|
||||||
|
# 主要是中文文本
|
||||||
|
# 假设平均每个汉字需要0.3秒
|
||||||
|
estimated_duration = max(1.0, chinese_chars * 0.3)
|
||||||
|
|
||||||
|
# 确保估算时长合理
|
||||||
|
duration = max(1.0, estimated_duration)
|
||||||
|
logger.warning(f"无法获取音频时长,使用文本估算: {duration:.2f}秒 (英文单词: {english_words}, 中文字符: {chinese_chars})")
|
||||||
# 不创建字幕文件
|
# 不创建字幕文件
|
||||||
subtitle_file = ""
|
subtitle_file = ""
|
||||||
else:
|
else:
|
||||||
@ -1658,8 +1800,6 @@ def get_audio_duration_from_file(audio_file: str) -> float:
|
|||||||
# 但实际文件还包含头部信息,所以调整系数
|
# 但实际文件还包含头部信息,所以调整系数
|
||||||
estimated_duration = max(1.0, file_size / 20000) # 调整为更保守的估算
|
estimated_duration = max(1.0, file_size / 20000) # 调整为更保守的估算
|
||||||
|
|
||||||
# 对于中文语音,根据文本长度进行二次校正
|
|
||||||
# 一般中文语音速度约为 3-4 字/秒
|
|
||||||
logger.warning(f"使用文件大小估算音频时长: {estimated_duration:.2f}秒")
|
logger.warning(f"使用文件大小估算音频时长: {estimated_duration:.2f}秒")
|
||||||
return estimated_duration
|
return estimated_duration
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -570,29 +570,39 @@ def temp_dir(sub_dir: str = ""):
|
|||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
def clear_keyframes_cache(video_path: str = None):
|
def clear_keyframes_cache(video_path: str = None, cache_scope: str = "keyframes"):
|
||||||
"""
|
"""
|
||||||
清理关键帧缓存
|
清理关键帧缓存
|
||||||
Args:
|
Args:
|
||||||
video_path: 视频文件路径,如果指定则只清理该视频的缓存
|
video_path: 视频文件路径,如果指定则只清理该视频的缓存
|
||||||
|
cache_scope: 缓存作用域目录,默认 keyframes
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
keyframes_dir = os.path.join(temp_dir(), "keyframes")
|
cache_dir = os.path.join(temp_dir(), cache_scope)
|
||||||
if not os.path.exists(keyframes_dir):
|
if not os.path.exists(cache_dir):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
|
||||||
if video_path:
|
if video_path:
|
||||||
# 理指定视频的缓存
|
# 清理指定视频的缓存(兼容前缀扩展键)
|
||||||
video_hash = md5(video_path + str(os.path.getmtime(video_path)))
|
try:
|
||||||
video_keyframes_dir = os.path.join(keyframes_dir, video_hash)
|
video_mtime = os.path.getmtime(video_path)
|
||||||
if os.path.exists(video_keyframes_dir):
|
except OSError:
|
||||||
import shutil
|
video_mtime = 0
|
||||||
shutil.rmtree(video_keyframes_dir)
|
video_hash = md5(video_path + str(video_mtime))
|
||||||
logger.info(f"已清理视频关键帧缓存: {video_path}")
|
for entry in os.listdir(cache_dir):
|
||||||
|
if not entry.startswith(video_hash):
|
||||||
|
continue
|
||||||
|
target_path = os.path.join(cache_dir, entry)
|
||||||
|
if os.path.isdir(target_path):
|
||||||
|
shutil.rmtree(target_path)
|
||||||
|
else:
|
||||||
|
os.remove(target_path)
|
||||||
|
logger.info(f"已清理视频关键帧缓存: {video_path}")
|
||||||
else:
|
else:
|
||||||
# 清理所有缓存
|
# 清理所有缓存
|
||||||
import shutil
|
shutil.rmtree(cache_dir)
|
||||||
shutil.rmtree(keyframes_dir)
|
|
||||||
logger.info("已清理所有关键帧缓存")
|
logger.info("已清理所有关键帧缓存")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -185,6 +185,95 @@ class VideoProcessor:
|
|||||||
|
|
||||||
return frame_numbers
|
return frame_numbers
|
||||||
|
|
||||||
|
def extract_frames_by_interval_with_fallback(self, output_dir: str, interval_seconds: float = 5.0) -> List[str]:
|
||||||
|
"""
|
||||||
|
先尝试单次 ffmpeg 快路径抽帧,失败时回退到高兼容方案。
|
||||||
|
"""
|
||||||
|
if interval_seconds <= 0:
|
||||||
|
raise ValueError("interval_seconds must be > 0")
|
||||||
|
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self._extract_frames_fast_path(output_dir, interval_seconds=interval_seconds)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"快路径抽帧失败,回退到兼容模式: {exc}")
|
||||||
|
self._cleanup_fast_path_artifacts(output_dir)
|
||||||
|
self.extract_frames_by_interval_ultra_compatible(output_dir, interval_seconds=interval_seconds)
|
||||||
|
return self._collect_extracted_frame_paths(output_dir)
|
||||||
|
|
||||||
|
def _extract_frames_fast_path(self, output_dir: str, interval_seconds: float = 5.0) -> List[str]:
|
||||||
|
"""
|
||||||
|
使用单次 ffmpeg 命令按固定间隔抽帧,随后重命名为既有 keyframe 约定格式。
|
||||||
|
"""
|
||||||
|
if interval_seconds <= 0:
|
||||||
|
raise ValueError("interval_seconds must be > 0")
|
||||||
|
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
raw_pattern = os.path.join(output_dir, "fastframe_%06d.jpg")
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-hide_banner",
|
||||||
|
"-loglevel",
|
||||||
|
"error",
|
||||||
|
"-i",
|
||||||
|
self.video_path,
|
||||||
|
"-vf",
|
||||||
|
f"fps=1/{interval_seconds}",
|
||||||
|
"-q:v",
|
||||||
|
"2",
|
||||||
|
"-start_number",
|
||||||
|
"0",
|
||||||
|
"-y",
|
||||||
|
raw_pattern,
|
||||||
|
]
|
||||||
|
subprocess.run(cmd, capture_output=True, text=True, check=True, timeout=120)
|
||||||
|
|
||||||
|
raw_files = sorted(
|
||||||
|
filename
|
||||||
|
for filename in os.listdir(output_dir)
|
||||||
|
if re.fullmatch(r"fastframe_\d{6}\.jpg", filename)
|
||||||
|
)
|
||||||
|
if not raw_files:
|
||||||
|
raise RuntimeError("Fast-path extraction produced no frames")
|
||||||
|
|
||||||
|
renamed_files: List[str] = []
|
||||||
|
for index, filename in enumerate(raw_files):
|
||||||
|
timestamp = index * interval_seconds
|
||||||
|
frame_number = int(timestamp * self.fps)
|
||||||
|
token = self._format_timestamp_token(timestamp)
|
||||||
|
source_path = os.path.join(output_dir, filename)
|
||||||
|
target_path = os.path.join(output_dir, f"keyframe_{frame_number:06d}_{token}.jpg")
|
||||||
|
os.replace(source_path, target_path)
|
||||||
|
renamed_files.append(target_path)
|
||||||
|
|
||||||
|
return renamed_files
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _format_timestamp_token(timestamp: float) -> str:
|
||||||
|
hours = int(timestamp // 3600)
|
||||||
|
minutes = int((timestamp % 3600) // 60)
|
||||||
|
seconds = int(timestamp % 60)
|
||||||
|
milliseconds = int((timestamp % 1) * 1000)
|
||||||
|
return f"{hours:02d}{minutes:02d}{seconds:02d}{milliseconds:03d}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _collect_extracted_frame_paths(output_dir: str) -> List[str]:
|
||||||
|
return sorted(
|
||||||
|
os.path.join(output_dir, name)
|
||||||
|
for name in os.listdir(output_dir)
|
||||||
|
if re.fullmatch(r"keyframe_\d{6}_\d{9}\.jpg", name)
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _cleanup_fast_path_artifacts(output_dir: str) -> None:
|
||||||
|
for name in os.listdir(output_dir):
|
||||||
|
if not re.fullmatch(r"fastframe_\d{6}\.jpg", name):
|
||||||
|
continue
|
||||||
|
artifact_path = os.path.join(output_dir, name)
|
||||||
|
if os.path.isfile(artifact_path):
|
||||||
|
os.remove(artifact_path)
|
||||||
|
|
||||||
def _extract_single_frame_optimized(self, timestamp: float, output_path: str,
|
def _extract_single_frame_optimized(self, timestamp: float, output_path: str,
|
||||||
use_hw_accel: bool, hwaccel_type: str) -> bool:
|
use_hw_accel: bool, hwaccel_type: str) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
[app]
|
[app]
|
||||||
project_version="0.7.6"
|
project_version="0.7.8"
|
||||||
|
|
||||||
# LLM API 超时配置(秒)
|
# LLM API 超时配置(秒)
|
||||||
llm_vision_timeout = 120 # 视觉模型基础超时时间
|
llm_vision_timeout = 120 # 视觉模型基础超时时间
|
||||||
@ -93,6 +93,12 @@
|
|||||||
# 访问 https://bailian.console.aliyun.com/?tab=model#/api-key 获取你的 API 密钥
|
# 访问 https://bailian.console.aliyun.com/?tab=model#/api-key 获取你的 API 密钥
|
||||||
api_key = ""
|
api_key = ""
|
||||||
model_name = "qwen3-tts-flash"
|
model_name = "qwen3-tts-flash"
|
||||||
|
|
||||||
|
[fun_asr]
|
||||||
|
# 阿里百炼 Fun-ASR 字幕转录配置
|
||||||
|
# 访问 https://bailian.console.aliyun.com/?tab=model#/api-key 获取你的 API 密钥
|
||||||
|
api_key = ""
|
||||||
|
model = "fun-asr"
|
||||||
|
|
||||||
[indextts2]
|
[indextts2]
|
||||||
# IndexTTS2 语音克隆配置
|
# IndexTTS2 语音克隆配置
|
||||||
@ -114,9 +120,25 @@
|
|||||||
do_sample = true
|
do_sample = true
|
||||||
num_beams = 3
|
num_beams = 3
|
||||||
repetition_penalty = 10.0
|
repetition_penalty = 10.0
|
||||||
|
[doubaotts]
|
||||||
|
# 豆包语音 TTS 配置
|
||||||
|
# 申请流程:
|
||||||
|
# 1. 打开 https://console.volcengine.com/iam/keymanage 新建 Access Key 和 Secret Key
|
||||||
|
# 2. 打开 https://www.volcengine.com/product/voice-tech 点击立即使用
|
||||||
|
# 3. 在 API 服务中心找到音频生成下面的语音合成,获取 APPID 和 Token
|
||||||
|
ak = ""
|
||||||
|
sk = ""
|
||||||
|
appid = ""
|
||||||
|
token = ""
|
||||||
|
cluster = "volcano_tts"
|
||||||
|
|
||||||
|
# 高级参数
|
||||||
|
volume = 1.0
|
||||||
|
pitch = 1.0
|
||||||
|
silence_duration = 0.125
|
||||||
|
|
||||||
[ui]
|
[ui]
|
||||||
# TTS引擎选择 (edge_tts, azure_speech, soulvoice, tencent_tts, tts_qwen)
|
# TTS引擎选择 (edge_tts, azure_speech, soulvoice, tencent_tts, tts_qwen, doubaotts)
|
||||||
tts_engine = "edge_tts"
|
tts_engine = "edge_tts"
|
||||||
|
|
||||||
# Edge TTS 配置
|
# Edge TTS 配置
|
||||||
@ -130,6 +152,10 @@
|
|||||||
azure_volume = 80
|
azure_volume = 80
|
||||||
azure_rate = 1.0
|
azure_rate = 1.0
|
||||||
azure_pitch = 0
|
azure_pitch = 0
|
||||||
|
|
||||||
|
# 豆包语音 TTS 配置
|
||||||
|
doubaotts_voice_type = "BV700_V2_streaming"
|
||||||
|
doubaotts_rate = 1.0
|
||||||
|
|
||||||
##########################################
|
##########################################
|
||||||
# 代理和网络配置
|
# 代理和网络配置
|
||||||
@ -152,3 +178,6 @@
|
|||||||
|
|
||||||
# 大模型单次处理的关键帧数量
|
# 大模型单次处理的关键帧数量
|
||||||
vision_batch_size = 10
|
vision_batch_size = 10
|
||||||
|
|
||||||
|
# 视觉批处理最大并发批次数(OpenAI 兼容 provider)
|
||||||
|
vision_max_concurrency = 2
|
||||||
|
|||||||
11
conftest.py
Normal file
11
conftest.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
"""Pytest collection rules for the repository.
|
||||||
|
|
||||||
|
These files are executable smoke-check scripts that live next to the LLM
|
||||||
|
implementation for convenience. They require live credentials or manual
|
||||||
|
execution semantics, so keep them out of the default automated test suite.
|
||||||
|
"""
|
||||||
|
|
||||||
|
collect_ignore = [
|
||||||
|
"app/services/llm/test_llm_service.py",
|
||||||
|
"app/services/llm/test_openai_compatible_integration.py",
|
||||||
|
]
|
||||||
@ -1 +1 @@
|
|||||||
0.7.7
|
0.7.9
|
||||||
@ -35,3 +35,6 @@ tenacity>=9.0.0
|
|||||||
# torch>=2.0.0
|
# torch>=2.0.0
|
||||||
# torchvision>=0.15.0
|
# torchvision>=0.15.0
|
||||||
# torchaudio>=2.0.0
|
# torchaudio>=2.0.0
|
||||||
|
|
||||||
|
# 剪映草稿导出依赖
|
||||||
|
pyJianYingDraft>=0.1.0
|
||||||
|
|||||||
275
tests/test_documentary_frame_analysis_service.py
Normal file
275
tests/test_documentary_frame_analysis_service.py
Normal file
@ -0,0 +1,275 @@
|
|||||||
|
import unittest
|
||||||
|
import os
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from app.services.documentary.frame_analysis_models import DocumentaryAnalysisConfig
|
||||||
|
from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService
|
||||||
|
from app.utils import utils
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentaryFrameAnalysisServiceTests(unittest.TestCase):
|
||||||
|
def test_build_analysis_prompt_formats_real_frame_count(self):
|
||||||
|
service = DocumentaryFrameAnalysisService()
|
||||||
|
|
||||||
|
prompt = service._build_analysis_prompt(frame_count=3)
|
||||||
|
|
||||||
|
self.assertIn("我提供了 3 张视频帧", prompt)
|
||||||
|
self.assertNotIn("%s", prompt)
|
||||||
|
self.assertIn("frame_observations", prompt)
|
||||||
|
self.assertIn("overall_activity_summary", prompt)
|
||||||
|
|
||||||
|
def test_parse_failed_batch_keeps_raw_response_and_time_range(self):
|
||||||
|
service = DocumentaryFrameAnalysisService()
|
||||||
|
|
||||||
|
batch = service._build_failed_batch_result(
|
||||||
|
batch_index=2,
|
||||||
|
raw_response="not-json",
|
||||||
|
error_message="JSON decode failed",
|
||||||
|
frame_paths=["/tmp/keyframe_000000_000000000.jpg"],
|
||||||
|
time_range="00:00:00,000-00:00:03,000",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual("failed", batch.status)
|
||||||
|
self.assertEqual("not-json", batch.raw_response)
|
||||||
|
self.assertEqual("00:00:00,000-00:00:03,000", batch.time_range)
|
||||||
|
self.assertTrue(batch.fallback_summary)
|
||||||
|
|
||||||
|
def test_parse_failed_batch_uses_non_empty_fallback_when_raw_response_is_empty(self):
|
||||||
|
service = DocumentaryFrameAnalysisService()
|
||||||
|
|
||||||
|
batch = service._build_failed_batch_result(
|
||||||
|
batch_index=3,
|
||||||
|
raw_response="",
|
||||||
|
error_message="Empty model response",
|
||||||
|
frame_paths=["/tmp/keyframe_000001_000001000.jpg"],
|
||||||
|
time_range="00:00:03,000-00:00:06,000",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual("failed", batch.status)
|
||||||
|
self.assertEqual("", batch.raw_response)
|
||||||
|
self.assertTrue(batch.fallback_summary)
|
||||||
|
|
||||||
|
def test_failed_batch_result_uses_prompt_contract_field_names(self):
|
||||||
|
service = DocumentaryFrameAnalysisService()
|
||||||
|
|
||||||
|
batch = service._build_failed_batch_result(
|
||||||
|
batch_index=4,
|
||||||
|
raw_response="not-json",
|
||||||
|
error_message="JSON decode failed",
|
||||||
|
frame_paths=["/tmp/keyframe_000002_000002000.jpg"],
|
||||||
|
time_range="00:00:06,000-00:00:09,000",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual([], batch.frame_observations)
|
||||||
|
self.assertEqual("", batch.overall_activity_summary)
|
||||||
|
self.assertFalse(hasattr(batch, "observations"))
|
||||||
|
self.assertFalse(hasattr(batch, "summary"))
|
||||||
|
|
||||||
|
def test_parse_batch_returns_failed_result_when_json_is_invalid(self):
|
||||||
|
service = DocumentaryFrameAnalysisService()
|
||||||
|
|
||||||
|
batch = service._parse_batch_response(
|
||||||
|
batch_index=0,
|
||||||
|
raw_response="plain text",
|
||||||
|
frame_paths=["/tmp/keyframe_000000_000000000.jpg"],
|
||||||
|
time_range="00:00:00,000-00:00:03,000",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual("failed", batch.status)
|
||||||
|
self.assertEqual("plain text", batch.raw_response)
|
||||||
|
self.assertEqual(["/tmp/keyframe_000000_000000000.jpg"], batch.frame_paths)
|
||||||
|
self.assertEqual([], batch.frame_observations)
|
||||||
|
self.assertEqual("", batch.overall_activity_summary)
|
||||||
|
|
||||||
|
def test_parse_batch_returns_failed_result_for_empty_json_object(self):
|
||||||
|
service = DocumentaryFrameAnalysisService()
|
||||||
|
|
||||||
|
batch = service._parse_batch_response(
|
||||||
|
batch_index=0,
|
||||||
|
raw_response="{}",
|
||||||
|
frame_paths=["/tmp/keyframe_000000_000000000.jpg"],
|
||||||
|
time_range="00:00:00,000-00:00:03,000",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual("failed", batch.status)
|
||||||
|
self.assertEqual("{}", batch.raw_response)
|
||||||
|
self.assertIn("frame_observations", batch.error_message)
|
||||||
|
|
||||||
|
def test_parse_batch_returns_failed_result_when_observations_are_too_short(self):
|
||||||
|
service = DocumentaryFrameAnalysisService()
|
||||||
|
raw_response = """
|
||||||
|
{
|
||||||
|
"frame_observations": [
|
||||||
|
{"observation": "第一帧画面"}
|
||||||
|
],
|
||||||
|
"overall_activity_summary": "只有一条帧观察"
|
||||||
|
}
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
batch = service._parse_batch_response(
|
||||||
|
batch_index=1,
|
||||||
|
raw_response=raw_response,
|
||||||
|
frame_paths=[
|
||||||
|
"/tmp/keyframe_000000_000000000.jpg",
|
||||||
|
"/tmp/keyframe_000075_000003000.jpg",
|
||||||
|
],
|
||||||
|
time_range="00:00:00,000-00:00:06,000",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual("failed", batch.status)
|
||||||
|
self.assertEqual(raw_response, batch.raw_response)
|
||||||
|
self.assertIn("frame_observations", batch.error_message)
|
||||||
|
|
||||||
|
def test_parse_batch_parses_code_fenced_json_into_structured_result(self):
|
||||||
|
service = DocumentaryFrameAnalysisService()
|
||||||
|
raw_response = """```json
|
||||||
|
{
|
||||||
|
"frame_observations": [
|
||||||
|
{"observation": "第一帧画面"},
|
||||||
|
{"observation": "第二帧画面"}
|
||||||
|
],
|
||||||
|
"overall_activity_summary": "人物从房间走到街道"
|
||||||
|
}
|
||||||
|
```"""
|
||||||
|
|
||||||
|
batch = service._parse_batch_response(
|
||||||
|
batch_index=1,
|
||||||
|
raw_response=raw_response,
|
||||||
|
frame_paths=[
|
||||||
|
"/tmp/keyframe_000000_000000000.jpg",
|
||||||
|
"/tmp/keyframe_000075_000003000.jpg",
|
||||||
|
],
|
||||||
|
time_range="00:00:00,000-00:00:06,000",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual("success", batch.status)
|
||||||
|
self.assertEqual(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"frame_path": "/tmp/keyframe_000000_000000000.jpg",
|
||||||
|
"timestamp": "",
|
||||||
|
"observation": "第一帧画面",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"frame_path": "/tmp/keyframe_000075_000003000.jpg",
|
||||||
|
"timestamp": "",
|
||||||
|
"observation": "第二帧画面",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
batch.frame_observations,
|
||||||
|
)
|
||||||
|
self.assertEqual("人物从房间走到街道", batch.overall_activity_summary)
|
||||||
|
self.assertEqual("", batch.fallback_summary)
|
||||||
|
|
||||||
|
def test_parse_batch_preserves_frames_when_summary_is_missing(self):
|
||||||
|
service = DocumentaryFrameAnalysisService()
|
||||||
|
raw_response = """
|
||||||
|
{
|
||||||
|
"frame_observations": [
|
||||||
|
{"observation": "第一帧画面"},
|
||||||
|
{"observation": "第二帧画面"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
batch = service._parse_batch_response(
|
||||||
|
batch_index=2,
|
||||||
|
raw_response=raw_response,
|
||||||
|
frame_paths=[
|
||||||
|
"/tmp/keyframe_000000_000000000.jpg",
|
||||||
|
"/tmp/keyframe_000075_000003000.jpg",
|
||||||
|
],
|
||||||
|
time_range="00:00:00,000-00:00:06,000",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual("success", batch.status)
|
||||||
|
self.assertEqual(2, len(batch.frame_observations))
|
||||||
|
self.assertEqual("", batch.overall_activity_summary)
|
||||||
|
|
||||||
|
def test_cache_key_changes_when_interval_changes(self):
|
||||||
|
service = DocumentaryFrameAnalysisService()
|
||||||
|
|
||||||
|
with patch("app.services.documentary.frame_analysis_service.os.path.getmtime", return_value=100.0):
|
||||||
|
key_a = service._build_cache_key("video.mp4", 3.0, "prompt-v1", "model-a", 10, 2)
|
||||||
|
key_b = service._build_cache_key("video.mp4", 5.0, "prompt-v1", "model-a", 10, 2)
|
||||||
|
|
||||||
|
self.assertNotEqual(key_a, key_b)
|
||||||
|
|
||||||
|
def test_cache_key_changes_when_model_changes(self):
|
||||||
|
service = DocumentaryFrameAnalysisService()
|
||||||
|
|
||||||
|
with patch("app.services.documentary.frame_analysis_service.os.path.getmtime", return_value=100.0):
|
||||||
|
key_a = service._build_cache_key("video.mp4", 3.0, "prompt-v1", "model-a", 10, 2)
|
||||||
|
key_b = service._build_cache_key("video.mp4", 3.0, "prompt-v1", "model-b", 10, 2)
|
||||||
|
|
||||||
|
self.assertNotEqual(key_a, key_b)
|
||||||
|
|
||||||
|
def test_cache_key_starts_with_legacy_video_hash_prefix(self):
|
||||||
|
service = DocumentaryFrameAnalysisService()
|
||||||
|
|
||||||
|
with patch("app.services.documentary.frame_analysis_service.os.path.getmtime", return_value=123.0):
|
||||||
|
key = service._build_cache_key("video.mp4", 3.0, "prompt-v1", "model-a", 10, 2)
|
||||||
|
|
||||||
|
expected_prefix = utils.md5("video.mp4" + "123.0")
|
||||||
|
self.assertTrue(key.startswith(expected_prefix))
|
||||||
|
|
||||||
|
def test_clear_keyframes_cache_respects_scope_and_prefix_match(self):
|
||||||
|
with TemporaryDirectory() as temp_root:
|
||||||
|
service = DocumentaryFrameAnalysisService()
|
||||||
|
analysis_dir = os.path.join(temp_root, "analysis")
|
||||||
|
os.makedirs(analysis_dir, exist_ok=True)
|
||||||
|
|
||||||
|
with patch("app.services.documentary.frame_analysis_service.os.path.getmtime", return_value=123.0):
|
||||||
|
target_key_a = service._build_cache_key("video.mp4", 3.0, "prompt-v1", "model-a", 10, 2)
|
||||||
|
target_key_b = service._build_cache_key("video.mp4", 5.0, "prompt-v1", "model-a", 10, 2)
|
||||||
|
keep_key = service._build_cache_key("other.mp4", 3.0, "prompt-v1", "model-a", 10, 2)
|
||||||
|
|
||||||
|
target_dir_a = os.path.join(analysis_dir, target_key_a)
|
||||||
|
target_dir_b = os.path.join(analysis_dir, target_key_b)
|
||||||
|
keep_dir = os.path.join(analysis_dir, keep_key)
|
||||||
|
os.makedirs(target_dir_a, exist_ok=True)
|
||||||
|
os.makedirs(target_dir_b, exist_ok=True)
|
||||||
|
os.makedirs(keep_dir, exist_ok=True)
|
||||||
|
|
||||||
|
with patch("app.utils.utils.temp_dir", return_value=temp_root), patch(
|
||||||
|
"app.utils.utils.os.path.getmtime", return_value=123.0
|
||||||
|
):
|
||||||
|
utils.clear_keyframes_cache(video_path="video.mp4", cache_scope="analysis")
|
||||||
|
|
||||||
|
self.assertFalse(os.path.exists(target_dir_a))
|
||||||
|
self.assertFalse(os.path.exists(target_dir_b))
|
||||||
|
self.assertTrue(os.path.exists(keep_dir))
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentaryAnalysisConfigTests(unittest.TestCase):
|
||||||
|
def test_config_rejects_non_positive_frame_interval(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
DocumentaryAnalysisConfig(
|
||||||
|
video_path="/tmp/demo.mp4",
|
||||||
|
frame_interval_seconds=0,
|
||||||
|
vision_batch_size=5,
|
||||||
|
vision_llm_provider="openai",
|
||||||
|
vision_model_name="gpt-4o-mini",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_config_rejects_non_positive_batch_size(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
DocumentaryAnalysisConfig(
|
||||||
|
video_path="/tmp/demo.mp4",
|
||||||
|
frame_interval_seconds=5,
|
||||||
|
vision_batch_size=0,
|
||||||
|
vision_llm_provider="openai",
|
||||||
|
vision_model_name="gpt-4o-mini",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_config_rejects_non_positive_max_concurrency(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
DocumentaryAnalysisConfig(
|
||||||
|
video_path="/tmp/demo.mp4",
|
||||||
|
frame_interval_seconds=5,
|
||||||
|
vision_batch_size=5,
|
||||||
|
vision_llm_provider="openai",
|
||||||
|
vision_model_name="gpt-4o-mini",
|
||||||
|
max_concurrency=0,
|
||||||
|
)
|
||||||
58
tests/test_generate_narration_script_documentary_unittest.py
Normal file
58
tests/test_generate_narration_script_documentary_unittest.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
import json
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
|
from app.services.generate_narration_script import parse_frame_analysis_to_markdown
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateNarrationMarkdownTests(unittest.TestCase):
|
||||||
|
def test_markdown_keeps_batches_without_summary_and_sorts_by_time(self):
|
||||||
|
artifact = {
|
||||||
|
"batches": [
|
||||||
|
{
|
||||||
|
"batch_index": 1,
|
||||||
|
"time_range": "00:00:03,000-00:00:06,000",
|
||||||
|
"overall_activity_summary": "人物转身跑向远处",
|
||||||
|
"fallback_summary": "",
|
||||||
|
"frame_observations": [
|
||||||
|
{
|
||||||
|
"timestamp": "00:00:03,000",
|
||||||
|
"observation": "人物突然回头",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"batch_index": 0,
|
||||||
|
"time_range": "00:00:00,000-00:00:03,000",
|
||||||
|
"overall_activity_summary": "",
|
||||||
|
"fallback_summary": "原始响应回退摘要",
|
||||||
|
"frame_observations": [
|
||||||
|
{
|
||||||
|
"timestamp": "00:00:00,000",
|
||||||
|
"observation": "镜头里有一只猫",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
with TemporaryDirectory() as temp_dir:
|
||||||
|
analysis_path = Path(temp_dir) / "frame-analysis.json"
|
||||||
|
analysis_path.write_text(json.dumps(artifact, ensure_ascii=False), encoding="utf-8")
|
||||||
|
markdown = parse_frame_analysis_to_markdown(str(analysis_path))
|
||||||
|
|
||||||
|
first_range_index = markdown.find("00:00:00,000-00:00:03,000")
|
||||||
|
second_range_index = markdown.find("00:00:03,000-00:00:06,000")
|
||||||
|
|
||||||
|
self.assertIn("原始响应回退摘要", markdown)
|
||||||
|
self.assertIn("镜头里有一只猫", markdown)
|
||||||
|
self.assertIn("人物转身跑向远处", markdown)
|
||||||
|
self.assertIn("人物突然回头", markdown)
|
||||||
|
self.assertNotEqual(-1, first_range_index)
|
||||||
|
self.assertNotEqual(-1, second_range_index)
|
||||||
|
self.assertLess(first_range_index, second_range_index)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
19
tests/test_generate_script_docu_unittest.py
Normal file
19
tests/test_generate_script_docu_unittest.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from webui.tools.generate_script_docu import _normalize_progress_value
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateScriptDocuProgressTests(unittest.TestCase):
|
||||||
|
def test_normalize_progress_rounds_percentage_float_to_valid_streamlit_int(self):
|
||||||
|
self.assertEqual(43, _normalize_progress_value(43.125))
|
||||||
|
|
||||||
|
def test_normalize_progress_converts_ratio_float_to_percentage_int(self):
|
||||||
|
self.assertEqual(43, _normalize_progress_value(0.43125))
|
||||||
|
|
||||||
|
def test_normalize_progress_clamps_out_of_range_values(self):
|
||||||
|
self.assertEqual(0, _normalize_progress_value(-5))
|
||||||
|
self.assertEqual(100, _normalize_progress_value(101))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
316
tests/test_script_service_documentary_unittest.py
Normal file
316
tests/test_script_service_documentary_unittest.py
Normal file
@ -0,0 +1,316 @@
|
|||||||
|
import json
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService
|
||||||
|
from app.services.script_service import ScriptGenerator
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptGeneratorDocumentaryTests(unittest.IsolatedAsyncioTestCase):
|
||||||
|
async def test_generate_script_forwards_explicit_values_to_shared_service(self):
|
||||||
|
expected_script = [
|
||||||
|
{
|
||||||
|
"timestamp": "00:00:00,000-00:00:03,000",
|
||||||
|
"picture": "批次描述",
|
||||||
|
"narration": "这里是解说词",
|
||||||
|
"OST": 2,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
callback = lambda _percent, _message: None
|
||||||
|
|
||||||
|
with patch("app.services.script_service.DocumentaryFrameAnalysisService") as service_cls:
|
||||||
|
service = service_cls.return_value
|
||||||
|
service.generate_documentary_script = AsyncMock(return_value=expected_script)
|
||||||
|
generator = ScriptGenerator()
|
||||||
|
|
||||||
|
result = await generator.generate_script(
|
||||||
|
video_path="demo.mp4",
|
||||||
|
video_theme="荒野生存",
|
||||||
|
custom_prompt="请聚焦生存动作",
|
||||||
|
frame_interval_input=3,
|
||||||
|
vision_batch_size=6,
|
||||||
|
vision_llm_provider="openai",
|
||||||
|
progress_callback=callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(expected_script, result)
|
||||||
|
self.assertTrue(result[0]["narration"])
|
||||||
|
service.generate_documentary_script.assert_awaited_once()
|
||||||
|
called_kwargs = service.generate_documentary_script.await_args.kwargs
|
||||||
|
self.assertEqual("demo.mp4", called_kwargs["video_path"])
|
||||||
|
self.assertEqual(3, called_kwargs["frame_interval_input"])
|
||||||
|
self.assertEqual(6, called_kwargs["vision_batch_size"])
|
||||||
|
self.assertEqual("openai", called_kwargs["vision_llm_provider"])
|
||||||
|
self.assertEqual("荒野生存", called_kwargs["video_theme"])
|
||||||
|
self.assertEqual("请聚焦生存动作", called_kwargs["custom_prompt"])
|
||||||
|
self.assertIs(called_kwargs["progress_callback"], callback)
|
||||||
|
|
||||||
|
async def test_generate_script_forwards_unset_values_as_none(self):
|
||||||
|
expected_script = [
|
||||||
|
{
|
||||||
|
"timestamp": "00:00:00,000-00:00:03,000",
|
||||||
|
"picture": "批次描述",
|
||||||
|
"narration": "这里是解说词",
|
||||||
|
"OST": 2,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
with patch("app.services.script_service.DocumentaryFrameAnalysisService") as service_cls:
|
||||||
|
service = service_cls.return_value
|
||||||
|
service.generate_documentary_script = AsyncMock(return_value=expected_script)
|
||||||
|
generator = ScriptGenerator()
|
||||||
|
|
||||||
|
await generator.generate_script(video_path="demo.mp4")
|
||||||
|
|
||||||
|
called_kwargs = service.generate_documentary_script.await_args.kwargs
|
||||||
|
self.assertIsNone(called_kwargs["frame_interval_input"])
|
||||||
|
self.assertIsNone(called_kwargs["vision_batch_size"])
|
||||||
|
self.assertIsNone(called_kwargs["vision_llm_provider"])
|
||||||
|
|
||||||
|
async def test_generate_script_warns_when_skip_seconds_or_threshold_are_non_default(self):
|
||||||
|
expected_script = [
|
||||||
|
{
|
||||||
|
"timestamp": "00:00:00,000-00:00:03,000",
|
||||||
|
"picture": "批次描述",
|
||||||
|
"narration": "这里是解说词",
|
||||||
|
"OST": 2,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
with patch("app.services.script_service.DocumentaryFrameAnalysisService") as service_cls, patch(
|
||||||
|
"app.services.script_service.logger.warning"
|
||||||
|
) as warning:
|
||||||
|
service = service_cls.return_value
|
||||||
|
service.generate_documentary_script = AsyncMock(return_value=expected_script)
|
||||||
|
generator = ScriptGenerator()
|
||||||
|
await generator.generate_script(
|
||||||
|
video_path="demo.mp4",
|
||||||
|
skip_seconds=2,
|
||||||
|
threshold=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
warning.assert_called_once()
|
||||||
|
warning_message = warning.call_args.args[0]
|
||||||
|
self.assertIn("skip_seconds", warning_message)
|
||||||
|
self.assertIn("threshold", warning_message)
|
||||||
|
self.assertIn("does not currently apply", warning_message)
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentaryFrameAnalysisServiceScriptGenerationTests(unittest.IsolatedAsyncioTestCase):
|
||||||
|
async def test_generate_documentary_script_returns_final_narrated_items(self):
|
||||||
|
service = DocumentaryFrameAnalysisService()
|
||||||
|
analysis_payload = {
|
||||||
|
"batches": [
|
||||||
|
{
|
||||||
|
"batch_index": 0,
|
||||||
|
"time_range": "00:00:00,000-00:00:03,000",
|
||||||
|
"overall_activity_summary": "",
|
||||||
|
"fallback_summary": "回退摘要",
|
||||||
|
"frame_observations": [
|
||||||
|
{"timestamp": "00:00:00,000", "observation": "镜头里有一只猫"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
with TemporaryDirectory() as temp_dir:
|
||||||
|
analysis_path = Path(temp_dir) / "frame_analysis_test.json"
|
||||||
|
analysis_path.write_text(json.dumps(analysis_payload, ensure_ascii=False), encoding="utf-8")
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
DocumentaryFrameAnalysisService,
|
||||||
|
"analyze_video",
|
||||||
|
AsyncMock(return_value={"analysis_json_path": str(analysis_path)}),
|
||||||
|
), patch.dict(
|
||||||
|
"app.services.documentary.frame_analysis_service.config.app",
|
||||||
|
{
|
||||||
|
"text_llm_provider": "openai",
|
||||||
|
"text_openai_api_key": "test-key",
|
||||||
|
"text_openai_model_name": "test-model",
|
||||||
|
"text_openai_base_url": "https://example.com/v1",
|
||||||
|
},
|
||||||
|
), patch(
|
||||||
|
"app.services.documentary.frame_analysis_service.generate_narration",
|
||||||
|
return_value='{"items":[{"timestamp":"00:00:00,000-00:00:03,000","picture":"镜头里有一只猫","narration":"一只猫警觉地望向镜头。"}]}',
|
||||||
|
):
|
||||||
|
result = await service.generate_documentary_script(video_path="demo.mp4")
|
||||||
|
|
||||||
|
self.assertEqual(1, len(result))
|
||||||
|
self.assertEqual("00:00:00,000-00:00:03,000", result[0]["timestamp"])
|
||||||
|
self.assertEqual("镜头里有一只猫", result[0]["picture"])
|
||||||
|
self.assertEqual("一只猫警觉地望向镜头。", result[0]["narration"])
|
||||||
|
self.assertEqual(2, result[0]["OST"])
|
||||||
|
|
||||||
|
async def test_generate_documentary_script_raises_when_narration_json_is_malformed(self):
|
||||||
|
service = DocumentaryFrameAnalysisService()
|
||||||
|
analysis_payload = {
|
||||||
|
"batches": [
|
||||||
|
{
|
||||||
|
"batch_index": 0,
|
||||||
|
"time_range": "00:00:00,000-00:00:03,000",
|
||||||
|
"overall_activity_summary": "测试摘要",
|
||||||
|
"fallback_summary": "",
|
||||||
|
"frame_observations": [
|
||||||
|
{"timestamp": "00:00:00,000", "observation": "镜头里有一只猫"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
with TemporaryDirectory() as temp_dir:
|
||||||
|
analysis_path = Path(temp_dir) / "frame_analysis_test.json"
|
||||||
|
analysis_path.write_text(json.dumps(analysis_payload, ensure_ascii=False), encoding="utf-8")
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
DocumentaryFrameAnalysisService,
|
||||||
|
"analyze_video",
|
||||||
|
AsyncMock(return_value={"analysis_json_path": str(analysis_path)}),
|
||||||
|
), patch.dict(
|
||||||
|
"app.services.documentary.frame_analysis_service.config.app",
|
||||||
|
{
|
||||||
|
"text_llm_provider": "openai",
|
||||||
|
"text_openai_api_key": "test-key",
|
||||||
|
"text_openai_model_name": "test-model",
|
||||||
|
"text_openai_base_url": "https://example.com/v1",
|
||||||
|
},
|
||||||
|
), patch(
|
||||||
|
"app.services.documentary.frame_analysis_service.generate_narration",
|
||||||
|
return_value="malformed narration payload",
|
||||||
|
):
|
||||||
|
with self.assertRaises(Exception) as ctx:
|
||||||
|
await service.generate_documentary_script(video_path="demo.mp4")
|
||||||
|
|
||||||
|
self.assertIn("解说文案格式错误", str(ctx.exception))
|
||||||
|
self.assertIn("items", str(ctx.exception))
|
||||||
|
|
||||||
|
def test_parse_narration_items_recovers_from_common_json_damage(self):
|
||||||
|
service = DocumentaryFrameAnalysisService()
|
||||||
|
damaged_payload = """
|
||||||
|
解释文字
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"items": [
|
||||||
|
{{
|
||||||
|
"timestamp": "00:00:00,000-00:00:03,000",
|
||||||
|
"picture": "镜头里有一只猫",
|
||||||
|
"narration": "一只猫警觉地望向镜头。",
|
||||||
|
}},
|
||||||
|
],
|
||||||
|
}}
|
||||||
|
```
|
||||||
|
补充文字
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
parsed_items = service._parse_narration_items(damaged_payload)
|
||||||
|
|
||||||
|
self.assertEqual(1, len(parsed_items))
|
||||||
|
self.assertEqual("00:00:00,000-00:00:03,000", parsed_items[0]["timestamp"])
|
||||||
|
self.assertEqual("镜头里有一只猫", parsed_items[0]["picture"])
|
||||||
|
self.assertEqual("一只猫警觉地望向镜头。", parsed_items[0]["narration"])
|
||||||
|
|
||||||
|
def test_parse_narration_items_raises_for_unrecoverable_payload(self):
|
||||||
|
service = DocumentaryFrameAnalysisService()
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError) as ctx:
|
||||||
|
service._parse_narration_items("not-json-at-all ::: ???")
|
||||||
|
|
||||||
|
self.assertIn("解说文案格式错误", str(ctx.exception))
|
||||||
|
self.assertIn("items", str(ctx.exception))
|
||||||
|
|
||||||
|
async def test_generate_documentary_script_includes_theme_and_custom_prompt_for_narration(self):
|
||||||
|
service = DocumentaryFrameAnalysisService()
|
||||||
|
analysis_payload = {
|
||||||
|
"batches": [
|
||||||
|
{
|
||||||
|
"batch_index": 0,
|
||||||
|
"time_range": "00:00:00,000-00:00:03,000",
|
||||||
|
"overall_activity_summary": "测试摘要",
|
||||||
|
"fallback_summary": "",
|
||||||
|
"frame_observations": [
|
||||||
|
{"timestamp": "00:00:00,000", "observation": "镜头里有一只猫"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
with TemporaryDirectory() as temp_dir:
|
||||||
|
analysis_path = Path(temp_dir) / "frame_analysis_test.json"
|
||||||
|
analysis_path.write_text(json.dumps(analysis_payload, ensure_ascii=False), encoding="utf-8")
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
DocumentaryFrameAnalysisService,
|
||||||
|
"analyze_video",
|
||||||
|
AsyncMock(return_value={"analysis_json_path": str(analysis_path)}),
|
||||||
|
), patch.dict(
|
||||||
|
"app.services.documentary.frame_analysis_service.config.app",
|
||||||
|
{
|
||||||
|
"text_llm_provider": "openai",
|
||||||
|
"text_openai_api_key": "test-key",
|
||||||
|
"text_openai_model_name": "test-model",
|
||||||
|
"text_openai_base_url": "https://example.com/v1",
|
||||||
|
},
|
||||||
|
), patch(
|
||||||
|
"app.services.documentary.frame_analysis_service.generate_narration",
|
||||||
|
return_value='{"items":[{"timestamp":"00:00:00,000-00:00:03,000","picture":"镜头里有一只猫","narration":"一只猫警觉地望向镜头。"}]}',
|
||||||
|
) as mocked_generate:
|
||||||
|
await service.generate_documentary_script(
|
||||||
|
video_path="demo.mp4",
|
||||||
|
video_theme="野生动物纪录片",
|
||||||
|
custom_prompt="重点描述危险信号",
|
||||||
|
)
|
||||||
|
|
||||||
|
narration_input = mocked_generate.call_args.args[0]
|
||||||
|
self.assertIn("## 创作上下文", narration_input)
|
||||||
|
self.assertIn("视频主题:野生动物纪录片", narration_input)
|
||||||
|
self.assertIn("补充创作要求:重点描述危险信号", narration_input)
|
||||||
|
|
||||||
|
async def test_analyze_video_forwards_explicit_empty_base_url_without_config_fallback(self):
|
||||||
|
service = DocumentaryFrameAnalysisService()
|
||||||
|
|
||||||
|
with patch.dict(
|
||||||
|
"app.services.documentary.frame_analysis_service.config.app",
|
||||||
|
{
|
||||||
|
"vision_llm_provider": "openai",
|
||||||
|
"vision_openai_api_key": "config-key",
|
||||||
|
"vision_openai_model_name": "config-model",
|
||||||
|
"vision_openai_base_url": "https://config.example/v1",
|
||||||
|
},
|
||||||
|
), patch(
|
||||||
|
"app.services.documentary.frame_analysis_service.os.path.exists",
|
||||||
|
return_value=True,
|
||||||
|
), patch.object(
|
||||||
|
service,
|
||||||
|
"_load_or_extract_keyframes",
|
||||||
|
return_value=["/tmp/keyframe_000001_000000100.jpg"],
|
||||||
|
), patch.object(
|
||||||
|
service,
|
||||||
|
"_analyze_batches",
|
||||||
|
AsyncMock(return_value=[]),
|
||||||
|
), patch.object(
|
||||||
|
service,
|
||||||
|
"_save_analysis_artifact",
|
||||||
|
return_value="/tmp/frame_analysis_test.json",
|
||||||
|
), patch.object(
|
||||||
|
service,
|
||||||
|
"_build_video_clip_json",
|
||||||
|
return_value=[],
|
||||||
|
), patch(
|
||||||
|
"app.services.documentary.frame_analysis_service.create_vision_analyzer",
|
||||||
|
return_value=object(),
|
||||||
|
) as mocked_create_analyzer:
|
||||||
|
await service.analyze_video(
|
||||||
|
video_path="/tmp/demo.mp4",
|
||||||
|
vision_api_key="explicit-key",
|
||||||
|
vision_model_name="explicit-model",
|
||||||
|
vision_base_url="",
|
||||||
|
)
|
||||||
|
|
||||||
|
called_kwargs = mocked_create_analyzer.call_args.kwargs
|
||||||
|
self.assertEqual("openai", called_kwargs["provider"])
|
||||||
|
self.assertEqual("explicit-key", called_kwargs["api_key"])
|
||||||
|
self.assertEqual("explicit-model", called_kwargs["model"])
|
||||||
|
self.assertEqual("", called_kwargs["base_url"])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
91
tests/test_video_processor_documentary_unittest.py
Normal file
91
tests/test_video_processor_documentary_unittest.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from app.utils.video_processor import VideoProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class VideoProcessorDocumentaryTests(unittest.TestCase):
|
||||||
|
@patch.object(VideoProcessor, "_extract_frames_fast_path", return_value=["a.jpg"])
|
||||||
|
def test_extract_frames_by_interval_prefers_fast_path(self, fast_path):
|
||||||
|
processor = VideoProcessor.__new__(VideoProcessor)
|
||||||
|
processor.video_path = "demo.mp4"
|
||||||
|
processor.duration = 6.0
|
||||||
|
processor.fps = 25.0
|
||||||
|
|
||||||
|
result = processor.extract_frames_by_interval_with_fallback("/tmp/out", interval_seconds=3.0)
|
||||||
|
|
||||||
|
self.assertEqual(["a.jpg"], result)
|
||||||
|
fast_path.assert_called_once_with("/tmp/out", interval_seconds=3.0)
|
||||||
|
|
||||||
|
def test_extract_frames_by_interval_falls_back_to_ultra_compatible(self):
|
||||||
|
processor = VideoProcessor.__new__(VideoProcessor)
|
||||||
|
processor.video_path = "demo.mp4"
|
||||||
|
processor.duration = 6.0
|
||||||
|
processor.fps = 25.0
|
||||||
|
|
||||||
|
with TemporaryDirectory() as output_dir:
|
||||||
|
expected_frame_path = os.path.join(output_dir, "keyframe_000000_000000000.jpg")
|
||||||
|
|
||||||
|
def ultra_compatible_fallback(self, output_dir_arg, interval_seconds=5.0):
|
||||||
|
with open(expected_frame_path, "wb") as frame_file:
|
||||||
|
frame_file.write(b"frame")
|
||||||
|
return [0]
|
||||||
|
|
||||||
|
with patch.object(VideoProcessor, "_extract_frames_fast_path", side_effect=RuntimeError("fast path failed")) as fast_path, patch.object(
|
||||||
|
VideoProcessor,
|
||||||
|
"extract_frames_by_interval_ultra_compatible",
|
||||||
|
side_effect=ultra_compatible_fallback,
|
||||||
|
autospec=True,
|
||||||
|
) as fallback:
|
||||||
|
result = processor.extract_frames_by_interval_with_fallback(output_dir, interval_seconds=3.0)
|
||||||
|
|
||||||
|
self.assertEqual([expected_frame_path], result)
|
||||||
|
fast_path.assert_called_once_with(output_dir, interval_seconds=3.0)
|
||||||
|
fallback.assert_called_once_with(processor, output_dir, interval_seconds=3.0)
|
||||||
|
|
||||||
|
def test_extract_frames_by_interval_rejects_non_positive_interval(self):
|
||||||
|
processor = VideoProcessor.__new__(VideoProcessor)
|
||||||
|
processor.video_path = "demo.mp4"
|
||||||
|
processor.duration = 6.0
|
||||||
|
processor.fps = 25.0
|
||||||
|
|
||||||
|
with patch.object(VideoProcessor, "extract_frames_by_interval_ultra_compatible", autospec=True) as fallback:
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor.extract_frames_by_interval_with_fallback("/tmp/out", interval_seconds=0)
|
||||||
|
|
||||||
|
fallback.assert_not_called()
|
||||||
|
|
||||||
|
def test_extract_frames_by_interval_fallback_cleans_partial_fast_path_artifacts(self):
|
||||||
|
processor = VideoProcessor.__new__(VideoProcessor)
|
||||||
|
processor.video_path = "demo.mp4"
|
||||||
|
processor.duration = 6.0
|
||||||
|
processor.fps = 25.0
|
||||||
|
|
||||||
|
with TemporaryDirectory() as output_dir:
|
||||||
|
stale_fastframe = os.path.join(output_dir, "fastframe_000000.jpg")
|
||||||
|
expected_keyframe = os.path.join(output_dir, "keyframe_000000_000000000.jpg")
|
||||||
|
|
||||||
|
def fast_path_with_partial_output(_output_dir, interval_seconds=5.0):
|
||||||
|
with open(stale_fastframe, "wb") as frame_file:
|
||||||
|
frame_file.write(b"stale")
|
||||||
|
raise RuntimeError("simulated fast-path failure")
|
||||||
|
|
||||||
|
def ultra_compatible_fallback(self, output_dir_arg, interval_seconds=5.0):
|
||||||
|
with open(expected_keyframe, "wb") as frame_file:
|
||||||
|
frame_file.write(b"frame")
|
||||||
|
return [0]
|
||||||
|
|
||||||
|
with patch.object(VideoProcessor, "_extract_frames_fast_path", side_effect=fast_path_with_partial_output) as fast_path, patch.object(
|
||||||
|
VideoProcessor,
|
||||||
|
"extract_frames_by_interval_ultra_compatible",
|
||||||
|
side_effect=ultra_compatible_fallback,
|
||||||
|
autospec=True,
|
||||||
|
) as fallback:
|
||||||
|
result = processor.extract_frames_by_interval_with_fallback(output_dir, interval_seconds=3.0)
|
||||||
|
|
||||||
|
self.assertEqual([expected_keyframe], result)
|
||||||
|
self.assertFalse(os.path.exists(stale_fastframe))
|
||||||
|
fast_path.assert_called_once_with(output_dir, interval_seconds=3.0)
|
||||||
|
fallback.assert_called_once_with(processor, output_dir, interval_seconds=3.0)
|
||||||
141
webui.py
141
webui.py
@ -1,6 +1,7 @@
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from webui.components import basic_settings, video_settings, audio_settings, subtitle_settings, script_settings, \
|
from webui.components import basic_settings, video_settings, audio_settings, subtitle_settings, script_settings, \
|
||||||
@ -221,6 +222,145 @@ def render_generate_button():
|
|||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
|
||||||
|
def get_voice_name_for_tts_engine(tts_engine: str) -> str:
|
||||||
|
"""根据TTS引擎获取用户选择的音色"""
|
||||||
|
if tts_engine == 'doubaotts':
|
||||||
|
return st.session_state.get('voice_name', config.ui.get('doubaotts_voice_type', 'BV700_streaming'))
|
||||||
|
elif tts_engine == 'azure_speech':
|
||||||
|
return st.session_state.get('voice_name', config.ui.get('azure_voice_name', 'zh-CN-XiaoxiaoMultilingualNeural'))
|
||||||
|
else:
|
||||||
|
return st.session_state.get('voice_name', config.ui.get('edge_voice_name', 'zh-CN-XiaoxiaoNeural-Female'))
|
||||||
|
|
||||||
|
|
||||||
|
def get_jianying_export_params() -> VideoClipParams:
|
||||||
|
"""获取导出到剪映草稿的参数"""
|
||||||
|
tts_engine = st.session_state.get('tts_engine', 'azure')
|
||||||
|
voice_name = get_voice_name_for_tts_engine(tts_engine)
|
||||||
|
voice_rate = st.session_state.get('voice_rate', 1.0)
|
||||||
|
voice_pitch = st.session_state.get('voice_pitch', 1.0)
|
||||||
|
|
||||||
|
return VideoClipParams(
|
||||||
|
video_clip_json_path=st.session_state['video_clip_json_path'],
|
||||||
|
video_origin_path=st.session_state['video_origin_path'],
|
||||||
|
tts_engine=tts_engine,
|
||||||
|
voice_name=voice_name,
|
||||||
|
voice_rate=voice_rate,
|
||||||
|
voice_pitch=voice_pitch,
|
||||||
|
n_threads=config.app.get('n_threads', 4),
|
||||||
|
video_aspect=VideoAspect.landscape,
|
||||||
|
subtitle_enabled=st.session_state.get('subtitle_enabled', False),
|
||||||
|
font_name=st.session_state.get('font_name', 'Microsoft YaHei'),
|
||||||
|
font_size=st.session_state.get('font_size', 24),
|
||||||
|
text_fore_color=st.session_state.get('text_fore_color', '#FFFFFF'),
|
||||||
|
subtitle_position=st.session_state.get('subtitle_position', 'bottom'),
|
||||||
|
custom_position=st.session_state.get('custom_position', 70.0),
|
||||||
|
tts_volume=st.session_state.get('tts_volume', 1.0),
|
||||||
|
original_volume=st.session_state.get('original_volume', 0.7),
|
||||||
|
bgm_volume=st.session_state.get('bgm_volume', 0.3),
|
||||||
|
draft_name=st.session_state.get('draft_name_input', f"NarratoAI_{int(time.time())}")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def render_export_jianying_button():
|
||||||
|
"""渲染导出到剪映草稿按钮和处理逻辑"""
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
# 初始化session state
|
||||||
|
if 'show_jianying_export_form' not in st.session_state:
|
||||||
|
st.session_state['show_jianying_export_form'] = False
|
||||||
|
if 'jianying_export_result' not in st.session_state:
|
||||||
|
st.session_state['jianying_export_result'] = None
|
||||||
|
if 'jianying_export_error' not in st.session_state:
|
||||||
|
st.session_state['jianying_export_error'] = None
|
||||||
|
|
||||||
|
if st.button("📤 导出到剪映草稿", use_container_width=True, type="secondary"):
|
||||||
|
config.save_config()
|
||||||
|
|
||||||
|
if not st.session_state.get('video_clip_json_path'):
|
||||||
|
st.error("脚本文件不能为空")
|
||||||
|
return
|
||||||
|
if not st.session_state.get('video_origin_path'):
|
||||||
|
st.error("视频文件不能为空")
|
||||||
|
return
|
||||||
|
|
||||||
|
jianying_draft_path = config.ui.get("jianying_draft_path", "")
|
||||||
|
if not jianying_draft_path:
|
||||||
|
st.error("请在基础设置中配置剪映草稿地址")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not os.path.exists(jianying_draft_path):
|
||||||
|
st.error(f"剪映草稿文件夹不存在: {jianying_draft_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 显示导出表单
|
||||||
|
st.session_state['show_jianying_export_form'] = True
|
||||||
|
st.session_state['jianying_export_result'] = None
|
||||||
|
st.session_state['jianying_export_error'] = None
|
||||||
|
|
||||||
|
# 显示导出表单
|
||||||
|
if st.session_state['show_jianying_export_form']:
|
||||||
|
st.markdown("---")
|
||||||
|
st.subheader("导出到剪映草稿")
|
||||||
|
|
||||||
|
draft_name = st.text_input(
|
||||||
|
"请输入剪映草稿名称",
|
||||||
|
value=f"NarratoAI_{int(time.time())}",
|
||||||
|
key="draft_name_input"
|
||||||
|
)
|
||||||
|
|
||||||
|
if st.button("确认导出", key="confirm_export"):
|
||||||
|
if not draft_name:
|
||||||
|
st.error("请输入草稿名称")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 创建任务ID
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
st.session_state['task_id'] = task_id
|
||||||
|
|
||||||
|
# 构建参数
|
||||||
|
try:
|
||||||
|
params = get_jianying_export_params()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"构建参数失败: {e}")
|
||||||
|
st.error(f"参数构建失败: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
with st.spinner("正在导出到剪映草稿,请稍候..."):
|
||||||
|
try:
|
||||||
|
from app.services import jianying_task
|
||||||
|
|
||||||
|
# 调用导出到剪映草稿的任务
|
||||||
|
result = jianying_task.start_export_jianying_draft(task_id, params)
|
||||||
|
|
||||||
|
# 记录日志
|
||||||
|
logger.info(f"成功导出到剪映草稿: {result['draft_name']}")
|
||||||
|
logger.info(f"草稿已保存到: {result['draft_path']}")
|
||||||
|
|
||||||
|
# 保存结果到session state
|
||||||
|
st.session_state['jianying_export_result'] = result
|
||||||
|
st.session_state['jianying_export_error'] = None
|
||||||
|
st.session_state['show_jianying_export_form'] = False
|
||||||
|
|
||||||
|
st.success(f"✅ 成功导出到剪映草稿: {result['draft_name']}")
|
||||||
|
st.info(f"📁 草稿已保存到: {result['draft_path']}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"导出到剪映草稿失败: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(f"错误详情: {traceback.format_exc()}")
|
||||||
|
st.session_state['jianying_export_error'] = str(e)
|
||||||
|
st.session_state['jianying_export_result'] = None
|
||||||
|
st.error(f"❌ 导出到剪映草稿失败: {e}")
|
||||||
|
|
||||||
|
if st.button("取消", key="cancel_export"):
|
||||||
|
st.session_state['show_jianying_export_form'] = False
|
||||||
|
st.session_state['jianying_export_result'] = None
|
||||||
|
st.session_state['jianying_export_error'] = None
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""主函数"""
|
"""主函数"""
|
||||||
@ -285,6 +425,7 @@ def main():
|
|||||||
|
|
||||||
# 放到最后渲染生成按钮和处理逻辑
|
# 放到最后渲染生成按钮和处理逻辑
|
||||||
render_generate_button()
|
render_generate_button()
|
||||||
|
render_export_jianying_button()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -26,7 +26,8 @@ def get_tts_engine_options():
|
|||||||
"azure_speech": "Azure Speech Services",
|
"azure_speech": "Azure Speech Services",
|
||||||
"tencent_tts": "腾讯云 TTS",
|
"tencent_tts": "腾讯云 TTS",
|
||||||
"qwen3_tts": "通义千问 Qwen3 TTS",
|
"qwen3_tts": "通义千问 Qwen3 TTS",
|
||||||
"indextts2": "IndexTTS2 语音克隆"
|
"indextts2": "IndexTTS2 语音克隆",
|
||||||
|
"doubaotts": "豆包语音 TTS"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -62,6 +63,12 @@ def get_tts_engine_descriptions():
|
|||||||
"features": "零样本语音克隆,上传参考音频即可合成相同音色的语音,需要本地或私有部署",
|
"features": "零样本语音克隆,上传参考音频即可合成相同音色的语音,需要本地或私有部署",
|
||||||
"use_case": "下载地址:https://pan.quark.cn/s/0767c9bcefd5",
|
"use_case": "下载地址:https://pan.quark.cn/s/0767c9bcefd5",
|
||||||
"registration": None
|
"registration": None
|
||||||
|
},
|
||||||
|
"doubaotts": {
|
||||||
|
"title": "豆包语音 TTS",
|
||||||
|
"features": "火山引擎豆包语音合成,支持多种音色和情感,国内访问速度快",
|
||||||
|
"use_case": "需要高质量中文语音合成的用户",
|
||||||
|
"registration": "https://www.volcengine.com/product/voice-tech"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -147,6 +154,8 @@ def render_tts_settings(tr):
|
|||||||
render_qwen3_tts_settings(tr)
|
render_qwen3_tts_settings(tr)
|
||||||
elif selected_engine == "indextts2":
|
elif selected_engine == "indextts2":
|
||||||
render_indextts2_tts_settings(tr)
|
render_indextts2_tts_settings(tr)
|
||||||
|
elif selected_engine == "doubaotts":
|
||||||
|
render_doubaotts_settings(tr)
|
||||||
|
|
||||||
# 4. 试听功能
|
# 4. 试听功能
|
||||||
render_voice_preview_new(tr, selected_engine)
|
render_voice_preview_new(tr, selected_engine)
|
||||||
@ -703,6 +712,250 @@ def render_indextts2_tts_settings(tr):
|
|||||||
config.ui["voice_name"] = f"indextts2:{reference_audio}"
|
config.ui["voice_name"] = f"indextts2:{reference_audio}"
|
||||||
|
|
||||||
|
|
||||||
|
def render_doubaotts_settings(tr):
|
||||||
|
"""渲染豆包语音 TTS 设置"""
|
||||||
|
# AK 输入
|
||||||
|
ak = st.text_input(
|
||||||
|
"Access Key",
|
||||||
|
value=config.doubaotts.get("ak", ""),
|
||||||
|
help="火山引擎 Access Key"
|
||||||
|
)
|
||||||
|
|
||||||
|
# SK 输入
|
||||||
|
sk = st.text_input(
|
||||||
|
"Secret Key",
|
||||||
|
value=config.doubaotts.get("sk", ""),
|
||||||
|
type="password",
|
||||||
|
help="火山引擎 Secret Key"
|
||||||
|
)
|
||||||
|
|
||||||
|
# AppID 输入
|
||||||
|
appid = st.text_input(
|
||||||
|
"AppID",
|
||||||
|
value=config.doubaotts.get("appid", ""),
|
||||||
|
help="豆包语音应用 AppID"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Token 输入
|
||||||
|
token = st.text_input(
|
||||||
|
"Token",
|
||||||
|
value=config.doubaotts.get("token", ""),
|
||||||
|
type="password",
|
||||||
|
help="豆包语音应用 Token"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 集群配置
|
||||||
|
cluster = st.text_input(
|
||||||
|
"集群",
|
||||||
|
value=config.doubaotts.get("cluster", "volcano_tts"),
|
||||||
|
help="业务集群,标准音色使用 volcano_tts"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 音色选择
|
||||||
|
# 在线音色列表(从文档中提取)
|
||||||
|
voice_options = {
|
||||||
|
"BV700_V2_streaming": "灿灿 2.0",
|
||||||
|
"BV705_streaming": "炀炀",
|
||||||
|
"BV701_V2_streaming": "擎苍 2.0",
|
||||||
|
"BV001_V2_streaming": "通用女声 2.0",
|
||||||
|
"BV700_streaming": "灿灿",
|
||||||
|
"BV406_V2_streaming": "超自然音色-梓梓2.0",
|
||||||
|
"BV406_streaming": "超自然音色-梓梓",
|
||||||
|
"BV407_V2_streaming": "超自然音色-燃燃2.0",
|
||||||
|
"BV407_streaming": "超自然音色-燃燃",
|
||||||
|
"BV001_streaming": "通用女声",
|
||||||
|
"BV002_streaming": "通用男声",
|
||||||
|
"BV701_streaming": "擎苍",
|
||||||
|
"BV123_streaming": "阳光青年",
|
||||||
|
"BV120_streaming": "反卷青年",
|
||||||
|
"BV119_streaming": "通用赘婿",
|
||||||
|
"BV115_streaming": "古风少御",
|
||||||
|
"BV107_streaming": "霸气青叔",
|
||||||
|
"BV100_streaming": "质朴青年",
|
||||||
|
"BV104_streaming": "温柔淑女",
|
||||||
|
"BV004_streaming": "开朗青年",
|
||||||
|
"BV113_streaming": "甜宠少御",
|
||||||
|
"BV102_streaming": "儒雅青年",
|
||||||
|
"BV405_streaming": "甜美小源",
|
||||||
|
"BV007_streaming": "亲切女声",
|
||||||
|
"BV009_streaming": "知性女声",
|
||||||
|
"BV419_streaming": "诚诚",
|
||||||
|
"BV415_streaming": "童童",
|
||||||
|
"BV008_streaming": "亲切男声",
|
||||||
|
"BV408_streaming": "译制片男声",
|
||||||
|
"BV426_streaming": "懒小羊",
|
||||||
|
"BV428_streaming": "清新文艺女声",
|
||||||
|
"BV403_streaming": "鸡汤女声",
|
||||||
|
"BV158_streaming": "智慧老者",
|
||||||
|
"BV157_streaming": "慈爱姥姥",
|
||||||
|
"BR001_streaming": "说唱小哥",
|
||||||
|
"BV410_streaming": "活力解说男",
|
||||||
|
"BV411_streaming": "影视解说小帅",
|
||||||
|
"BV437_streaming": "解说小帅-多情感",
|
||||||
|
"BV412_streaming": "影视解说小美",
|
||||||
|
"BV159_streaming": "纨绔青年",
|
||||||
|
"BV418_streaming": "直播一姐",
|
||||||
|
"BV142_streaming": "沉稳解说男",
|
||||||
|
"BV143_streaming": "潇洒青年",
|
||||||
|
"BV056_streaming": "阳光男声",
|
||||||
|
"BV005_streaming": "活泼女声",
|
||||||
|
"BV064_streaming": "小萝莉",
|
||||||
|
"BV051_streaming": "奶气萌娃",
|
||||||
|
"BV063_streaming": "动漫海绵",
|
||||||
|
"BV417_streaming": "动漫海星",
|
||||||
|
"BV050_streaming": "动漫小新",
|
||||||
|
"BV061_streaming": "天才童声",
|
||||||
|
"BV401_streaming": "促销男声",
|
||||||
|
"BV402_streaming": "促销女声",
|
||||||
|
"BV006_streaming": "磁性男声",
|
||||||
|
"BV011_streaming": "新闻女声",
|
||||||
|
"BV012_streaming": "新闻男声",
|
||||||
|
"BV034_streaming": "知性姐姐-双语",
|
||||||
|
"BV033_streaming": "温柔小哥",
|
||||||
|
"BV511_streaming": "慵懒女声-Ava",
|
||||||
|
"BV505_streaming": "议论女声-Alicia",
|
||||||
|
"BV138_streaming": "情感女声-Lawrence",
|
||||||
|
"BV027_streaming": "美式女声-Amelia",
|
||||||
|
"BV502_streaming": "讲述女声-Amanda",
|
||||||
|
"BV503_streaming": "活力女声-Ariana",
|
||||||
|
"BV504_streaming": "活力男声-Jackson",
|
||||||
|
"BV421_streaming": "天才少女",
|
||||||
|
"BV702_streaming": "Stefan",
|
||||||
|
"BV506_streaming": "天真萌娃-Lily",
|
||||||
|
"BV040_streaming": "亲切女声-Anna",
|
||||||
|
"BV516_streaming": "澳洲男声-Henry",
|
||||||
|
"BV520_streaming": "元气少女",
|
||||||
|
"BV521_streaming": "萌系少女",
|
||||||
|
"BV522_streaming": "气质女声",
|
||||||
|
"BV524_streaming": "日语男声",
|
||||||
|
"BV531_streaming": "活力男声Carlos(巴西地区)",
|
||||||
|
"BV530_streaming": "活力女声(巴西地区)",
|
||||||
|
"BV065_streaming": "气质御姐(墨西哥地区)",
|
||||||
|
"BV021_streaming": "东北老铁",
|
||||||
|
"BV020_streaming": "东北丫头",
|
||||||
|
"BV704_streaming": "方言灿灿",
|
||||||
|
"BV210_streaming": "西安佟掌柜",
|
||||||
|
"BV217_streaming": "沪上阿姐",
|
||||||
|
"BV213_streaming": "广西表哥",
|
||||||
|
"BV025_streaming": "甜美台妹",
|
||||||
|
"BV227_streaming": "台普男声",
|
||||||
|
"BV026_streaming": "港剧男神",
|
||||||
|
"BV424_streaming": "广东女仔",
|
||||||
|
"BV212_streaming": "相声演员",
|
||||||
|
"BV019_streaming": "重庆小伙",
|
||||||
|
"BV221_streaming": "四川甜妹儿",
|
||||||
|
"BV423_streaming": "重庆幺妹儿",
|
||||||
|
"BV214_streaming": "乡村企业家",
|
||||||
|
"BV226_streaming": "湖南妹坨",
|
||||||
|
"BV216_streaming": "长沙靓女"
|
||||||
|
}
|
||||||
|
|
||||||
|
saved_voice_type = config.ui.get("doubaotts_voice_type", "BV700_streaming")
|
||||||
|
if saved_voice_type not in voice_options:
|
||||||
|
voice_options[saved_voice_type] = f"自定义音色 ({saved_voice_type})"
|
||||||
|
|
||||||
|
selected_voice_display = st.selectbox(
|
||||||
|
"音色选择",
|
||||||
|
options=list(voice_options.values()),
|
||||||
|
index=list(voice_options.keys()).index(saved_voice_type) if saved_voice_type in voice_options else 0,
|
||||||
|
help="选择豆包语音 TTS 音色"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取实际的音色ID
|
||||||
|
voice_type = list(voice_options.keys())[
|
||||||
|
list(voice_options.values()).index(selected_voice_display)
|
||||||
|
]
|
||||||
|
|
||||||
|
# 高级参数折叠面板
|
||||||
|
with st.expander("🔧 高级参数", expanded=False):
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
# 语速调节
|
||||||
|
voice_rate = st.slider(
|
||||||
|
"语速调节",
|
||||||
|
min_value=0.2,
|
||||||
|
max_value=3.0,
|
||||||
|
value=config.ui.get("doubaotts_rate", 1.0),
|
||||||
|
step=0.1,
|
||||||
|
help="调节语音速度 (0.2-3.0)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 音量调节
|
||||||
|
voice_volume = st.slider(
|
||||||
|
"音量调节",
|
||||||
|
min_value=0.1,
|
||||||
|
max_value=2.0,
|
||||||
|
value=config.doubaotts.get("volume", 1.0),
|
||||||
|
step=0.1,
|
||||||
|
help="调节语音音量 (0.1-2.0)"
|
||||||
|
)
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
# 音高调节
|
||||||
|
voice_pitch = st.slider(
|
||||||
|
"音高调节",
|
||||||
|
min_value=0.5,
|
||||||
|
max_value=1.5,
|
||||||
|
value=config.doubaotts.get("pitch", 1.0),
|
||||||
|
step=0.1,
|
||||||
|
help="调节语音音高 (0.5-1.5)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 句尾静音时长
|
||||||
|
silence_duration = st.slider(
|
||||||
|
"句尾静音时长 (秒)",
|
||||||
|
min_value=0.0,
|
||||||
|
max_value=2.0,
|
||||||
|
value=config.doubaotts.get("silence_duration", 0.125),
|
||||||
|
step=0.05,
|
||||||
|
help="调节句尾静音时长 (0.0-2.0秒)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 显示API Key申请流程
|
||||||
|
with st.expander("💡 豆包语音 TTS API Key申请流程", expanded=False):
|
||||||
|
st.write("**申请步骤:**")
|
||||||
|
st.write("1. 打开 [https://console.volcengine.com/iam/keymanage](https://console.volcengine.com/iam/keymanage)")
|
||||||
|
st.write("2. 新建 Access Key 和 Secret Key")
|
||||||
|
st.write("3. 打开 [https://www.volcengine.com/product/voice-tech](https://www.volcengine.com/product/voice-tech)")
|
||||||
|
st.write("4. 点击立即使用")
|
||||||
|
st.write("5. 在最左边的API服务中心找到音频生成下面的语音合成(注意:是语音合成,不是语音合成大模型)")
|
||||||
|
st.write("6. 翻到最下面获取 APPID 和 Access Token")
|
||||||
|
|
||||||
|
st.write("")
|
||||||
|
st.info("💡 请将获取到的 Access Key、Secret Key、AppID 和 Token 填写到上方的配置中")
|
||||||
|
|
||||||
|
# 保存配置
|
||||||
|
config.doubaotts["ak"] = ak
|
||||||
|
config.doubaotts["sk"] = sk
|
||||||
|
config.doubaotts["appid"] = appid
|
||||||
|
config.doubaotts["token"] = token
|
||||||
|
config.doubaotts["cluster"] = cluster
|
||||||
|
config.doubaotts["volume"] = voice_volume
|
||||||
|
config.doubaotts["pitch"] = voice_pitch
|
||||||
|
config.doubaotts["silence_duration"] = silence_duration
|
||||||
|
config.ui["doubaotts_voice_type"] = voice_type
|
||||||
|
config.ui["doubaotts_rate"] = voice_rate
|
||||||
|
config.ui["voice_name"] = voice_type # 兼容性
|
||||||
|
st.session_state['voice_rate'] = voice_rate # 确保语速参数被保存到session state
|
||||||
|
|
||||||
|
# 显示配置状态
|
||||||
|
if ak and sk and appid and token:
|
||||||
|
st.success("✅ 豆包语音 TTS 配置已设置")
|
||||||
|
else:
|
||||||
|
missing = []
|
||||||
|
if not ak:
|
||||||
|
missing.append("Access Key")
|
||||||
|
if not sk:
|
||||||
|
missing.append("Secret Key")
|
||||||
|
if not appid:
|
||||||
|
missing.append("AppID")
|
||||||
|
if not token:
|
||||||
|
missing.append("Token")
|
||||||
|
if missing:
|
||||||
|
st.warning(f"⚠️ 请配置: {', '.join(missing)}")
|
||||||
|
|
||||||
|
|
||||||
def render_voice_preview_new(tr, selected_engine):
|
def render_voice_preview_new(tr, selected_engine):
|
||||||
"""渲染新的语音试听功能"""
|
"""渲染新的语音试听功能"""
|
||||||
if st.button("🎵 试听语音合成", use_container_width=True):
|
if st.button("🎵 试听语音合成", use_container_width=True):
|
||||||
@ -746,6 +999,11 @@ def render_voice_preview_new(tr, selected_engine):
|
|||||||
voice_name = f"indextts2:{reference_audio}"
|
voice_name = f"indextts2:{reference_audio}"
|
||||||
voice_rate = 1.0 # IndexTTS2 不支持速度调节
|
voice_rate = 1.0 # IndexTTS2 不支持速度调节
|
||||||
voice_pitch = 1.0 # IndexTTS2 不支持音调调节
|
voice_pitch = 1.0 # IndexTTS2 不支持音调调节
|
||||||
|
elif selected_engine == "doubaotts":
|
||||||
|
voice_type = config.ui.get("doubaotts_voice_type", "BV700_streaming")
|
||||||
|
voice_name = voice_type
|
||||||
|
voice_rate = config.ui.get("doubaotts_rate", 1.0)
|
||||||
|
voice_pitch = 1.0 # 豆包语音 TTS 不支持音调调节
|
||||||
|
|
||||||
if not voice_name:
|
if not voice_name:
|
||||||
st.error("请先配置语音设置")
|
st.error("请先配置语音设置")
|
||||||
|
|||||||
@ -217,6 +217,15 @@ def render_proxy_settings(tr):
|
|||||||
config.proxy["http"] = ""
|
config.proxy["http"] = ""
|
||||||
config.proxy["https"] = ""
|
config.proxy["https"] = ""
|
||||||
|
|
||||||
|
# 剪映草稿地址设置
|
||||||
|
st.subheader("剪映草稿设置")
|
||||||
|
jianying_draft_path = st.text_input(
|
||||||
|
"剪映草稿文件夹路径",
|
||||||
|
value=config.ui.get("jianying_draft_path", ""),
|
||||||
|
help="剪映草稿文件夹路径,例如:C:\\Users\\用户名\\Documents\\JianyingPro Drafts"
|
||||||
|
)
|
||||||
|
config.ui["jianying_draft_path"] = jianying_draft_path
|
||||||
|
|
||||||
|
|
||||||
def test_vision_model_connection(api_key, base_url, model_name, provider, tr):
|
def test_vision_model_connection(api_key, base_url, model_name, provider, tr):
|
||||||
"""测试视觉模型连接
|
"""测试视觉模型连接
|
||||||
|
|||||||
@ -327,6 +327,8 @@ def short_drama_summary(tr):
|
|||||||
# 检查是否已经处理过字幕文件
|
# 检查是否已经处理过字幕文件
|
||||||
if 'subtitle_file_processed' not in st.session_state:
|
if 'subtitle_file_processed' not in st.session_state:
|
||||||
st.session_state['subtitle_file_processed'] = False
|
st.session_state['subtitle_file_processed'] = False
|
||||||
|
|
||||||
|
render_fun_asr_transcription(tr)
|
||||||
|
|
||||||
subtitle_file = st.file_uploader(
|
subtitle_file = st.file_uploader(
|
||||||
tr("上传字幕文件"),
|
tr("上传字幕文件"),
|
||||||
@ -401,6 +403,95 @@ def short_drama_summary(tr):
|
|||||||
return video_theme
|
return video_theme
|
||||||
|
|
||||||
|
|
||||||
|
def render_fun_asr_transcription(tr):
|
||||||
|
"""使用阿里百炼 Fun-ASR 从本地音视频转写生成字幕。"""
|
||||||
|
def clear_fun_asr_subtitle_state():
|
||||||
|
st.session_state['subtitle_path'] = None
|
||||||
|
st.session_state['subtitle_content'] = None
|
||||||
|
st.session_state['subtitle_file_processed'] = False
|
||||||
|
|
||||||
|
with st.expander("阿里百炼 Fun-ASR 字幕转录", expanded=False):
|
||||||
|
st.caption("上传本地音频/视频后,将自动上传到阿里百炼临时存储并通过 fun-asr 生成 SRT 字幕。")
|
||||||
|
st.markdown(
|
||||||
|
"API Key 获取地址:"
|
||||||
|
"[https://bailian.console.aliyun.com/?tab=model#/api-key]"
|
||||||
|
"(https://bailian.console.aliyun.com/?tab=model#/api-key)"
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key = st.text_input(
|
||||||
|
"阿里百炼 API Key",
|
||||||
|
value=config.fun_asr.get("api_key", ""),
|
||||||
|
type="password",
|
||||||
|
help="请输入你自己的阿里百炼 API Key;保存配置后会写入本地 config.toml",
|
||||||
|
key="fun_asr_api_key",
|
||||||
|
)
|
||||||
|
uploaded_media = st.file_uploader(
|
||||||
|
"上传需要转录的音频/视频",
|
||||||
|
type=[
|
||||||
|
"aac", "amr", "avi", "flac", "flv", "m4a", "mkv", "mov",
|
||||||
|
"mp3", "mp4", "mpeg", "ogg", "opus", "wav", "webm", "wma", "wmv",
|
||||||
|
],
|
||||||
|
accept_multiple_files=False,
|
||||||
|
key="fun_asr_media_uploader",
|
||||||
|
)
|
||||||
|
|
||||||
|
if st.button("转写生成字幕", key="fun_asr_transcribe"):
|
||||||
|
if not api_key.strip():
|
||||||
|
clear_fun_asr_subtitle_state()
|
||||||
|
st.error("请先输入阿里百炼 API Key")
|
||||||
|
return
|
||||||
|
if uploaded_media is None:
|
||||||
|
clear_fun_asr_subtitle_state()
|
||||||
|
st.error("请先上传需要转录的音频或视频文件")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
clear_fun_asr_subtitle_state()
|
||||||
|
from app.services import fun_asr_subtitle
|
||||||
|
|
||||||
|
config.fun_asr["api_key"] = api_key.strip()
|
||||||
|
config.fun_asr["model"] = "fun-asr"
|
||||||
|
config.save_config()
|
||||||
|
|
||||||
|
temp_dir = utils.temp_dir("fun_asr")
|
||||||
|
safe_filename = os.path.basename(uploaded_media.name)
|
||||||
|
media_path = os.path.join(temp_dir, safe_filename)
|
||||||
|
file_name, file_extension = os.path.splitext(safe_filename)
|
||||||
|
if os.path.exists(media_path):
|
||||||
|
timestamp = time.strftime("%Y%m%d%H%M%S")
|
||||||
|
media_path = os.path.join(temp_dir, f"{file_name}_{timestamp}{file_extension}")
|
||||||
|
|
||||||
|
with open(media_path, "wb") as f:
|
||||||
|
f.write(uploaded_media.getbuffer())
|
||||||
|
|
||||||
|
subtitle_name = f"{os.path.splitext(os.path.basename(media_path))[0]}_fun_asr.srt"
|
||||||
|
subtitle_path = os.path.join(utils.subtitle_dir(), subtitle_name)
|
||||||
|
|
||||||
|
with st.spinner("正在使用阿里百炼 Fun-ASR 转写字幕,请稍候..."):
|
||||||
|
generated_path = fun_asr_subtitle.create_with_fun_asr(
|
||||||
|
local_file=media_path,
|
||||||
|
subtitle_file=subtitle_path,
|
||||||
|
api_key=api_key.strip(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not generated_path or not os.path.exists(generated_path):
|
||||||
|
clear_fun_asr_subtitle_state()
|
||||||
|
st.error("Fun-ASR 转写失败:未生成字幕文件")
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(generated_path, "r", encoding="utf-8") as f:
|
||||||
|
subtitle_content = f.read()
|
||||||
|
|
||||||
|
st.session_state['subtitle_path'] = generated_path
|
||||||
|
st.session_state['subtitle_content'] = subtitle_content
|
||||||
|
st.session_state['subtitle_file_processed'] = True
|
||||||
|
st.success(f"字幕转写成功: {os.path.basename(generated_path)}")
|
||||||
|
except Exception as e:
|
||||||
|
clear_fun_asr_subtitle_state()
|
||||||
|
logger.error(f"Fun-ASR 字幕转写失败: {traceback.format_exc()}")
|
||||||
|
st.error(f"Fun-ASR 字幕转写失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def render_script_buttons(tr, params):
|
def render_script_buttons(tr, params):
|
||||||
"""渲染脚本操作按钮"""
|
"""渲染脚本操作按钮"""
|
||||||
# 获取当前选择的脚本类型
|
# 获取当前选择的脚本类型
|
||||||
|
|||||||
@ -1,21 +1,32 @@
|
|||||||
# 纪录片脚本生成
|
# 纪录片脚本生成
|
||||||
import os
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import asyncio
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.utils import utils, video_processor
|
from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService
|
||||||
from webui.tools.base import create_vision_analyzer, get_batch_files, get_batch_timestamps
|
|
||||||
|
|
||||||
|
def _normalize_progress_value(progress: float | int) -> int:
|
||||||
|
"""Normalize mixed progress inputs to Streamlit's 0-100 integer range."""
|
||||||
|
try:
|
||||||
|
value = float(progress)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if 0.0 <= value <= 1.0:
|
||||||
|
value *= 100
|
||||||
|
|
||||||
|
return max(0, min(100, int(round(value))))
|
||||||
|
|
||||||
|
|
||||||
def generate_script_docu(params):
|
def generate_script_docu(params):
|
||||||
"""
|
"""
|
||||||
生成 纪录片 视频脚本
|
生成纪录片视频脚本。
|
||||||
要求: 原视频无字幕无配音
|
要求: 原视频无字幕无配音
|
||||||
适合场景: 纪录片、动物搞笑解说、荒野建造等
|
适合场景: 纪录片、动物搞笑解说、荒野建造等
|
||||||
"""
|
"""
|
||||||
@ -23,419 +34,72 @@ def generate_script_docu(params):
|
|||||||
status_text = st.empty()
|
status_text = st.empty()
|
||||||
|
|
||||||
def update_progress(progress: float, message: str = ""):
|
def update_progress(progress: float, message: str = ""):
|
||||||
progress_bar.progress(progress)
|
normalized_progress = _normalize_progress_value(progress)
|
||||||
|
progress_bar.progress(normalized_progress)
|
||||||
if message:
|
if message:
|
||||||
status_text.text(f"🎬 {message}")
|
status_text.text(f"🎬 {message}")
|
||||||
else:
|
else:
|
||||||
status_text.text(f"📊 进度: {progress}%")
|
status_text.text(f"📊 进度: {normalized_progress}%")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with st.spinner("正在生成脚本..."):
|
with st.spinner("正在生成脚本..."):
|
||||||
if not params.video_origin_path:
|
if not params.video_origin_path:
|
||||||
st.error("请先选择视频文件")
|
st.error("请先选择视频文件")
|
||||||
return
|
return
|
||||||
"""
|
|
||||||
1. 提取键帧
|
|
||||||
"""
|
|
||||||
update_progress(10, "正在提取关键帧...")
|
|
||||||
|
|
||||||
# 创建临时目录用于存储关键帧
|
|
||||||
keyframes_dir = os.path.join(utils.temp_dir(), "keyframes")
|
|
||||||
video_hash = utils.md5(params.video_origin_path + str(os.path.getmtime(params.video_origin_path)))
|
|
||||||
video_keyframes_dir = os.path.join(keyframes_dir, video_hash)
|
|
||||||
|
|
||||||
# 检查是否已经提取过关键帧
|
|
||||||
keyframe_files = []
|
|
||||||
if os.path.exists(video_keyframes_dir):
|
|
||||||
# 取已有的关键帧文件
|
|
||||||
for filename in sorted(os.listdir(video_keyframes_dir)):
|
|
||||||
if filename.endswith('.jpg'):
|
|
||||||
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
|
|
||||||
|
|
||||||
if keyframe_files:
|
|
||||||
logger.info(f"使用已缓存的关键帧: {video_keyframes_dir}")
|
|
||||||
st.info(f"✅ 使用已缓存关键帧,共 {len(keyframe_files)} 帧")
|
|
||||||
update_progress(20, f"使用已缓存关键帧,共 {len(keyframe_files)} 帧")
|
|
||||||
|
|
||||||
# 如果没有缓存的关键帧,则进行提取
|
|
||||||
if not keyframe_files:
|
|
||||||
try:
|
|
||||||
# 确保目录存在
|
|
||||||
os.makedirs(video_keyframes_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# 初始化视频处理器
|
|
||||||
processor = video_processor.VideoProcessor(params.video_origin_path)
|
|
||||||
|
|
||||||
# 显示视频信息
|
|
||||||
st.info(f"📹 视频信息: {processor.width}x{processor.height}, {processor.fps:.1f}fps, {processor.duration:.1f}秒")
|
|
||||||
|
|
||||||
# 处理视频并提取关键帧 - 直接使用超级兼容性方案
|
|
||||||
update_progress(15, "正在提取关键帧(使用超级兼容性方案)...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 使用优化的关键帧提取方法
|
|
||||||
processor.extract_frames_by_interval_ultra_compatible(
|
|
||||||
output_dir=video_keyframes_dir,
|
|
||||||
interval_seconds=st.session_state.get('frame_interval_input'),
|
|
||||||
)
|
|
||||||
except Exception as extract_error:
|
|
||||||
logger.error(f"关键帧提取失败: {extract_error}")
|
|
||||||
|
|
||||||
# 提供详细的错误信息和解决建议
|
|
||||||
error_msg = str(extract_error)
|
|
||||||
if "权限" in error_msg or "permission" in error_msg.lower():
|
|
||||||
suggestion = "建议:检查输出目录权限,或更换输出位置"
|
|
||||||
elif "空间" in error_msg or "space" in error_msg.lower():
|
|
||||||
suggestion = "建议:检查磁盘空间是否足够"
|
|
||||||
else:
|
|
||||||
suggestion = "建议:检查视频文件是否损坏,或尝试转换为标准格式"
|
|
||||||
|
|
||||||
raise Exception(f"关键帧提取失败: {error_msg}\n{suggestion}")
|
|
||||||
|
|
||||||
# 获取所有关键文件路径
|
|
||||||
for filename in sorted(os.listdir(video_keyframes_dir)):
|
|
||||||
if filename.endswith('.jpg'):
|
|
||||||
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
|
|
||||||
|
|
||||||
if not keyframe_files:
|
|
||||||
# 检查目录中是否有其他文件
|
|
||||||
all_files = os.listdir(video_keyframes_dir)
|
|
||||||
logger.error(f"关键帧目录内容: {all_files}")
|
|
||||||
raise Exception("未提取到任何关键帧文件,请检查视频文件格式")
|
|
||||||
|
|
||||||
update_progress(20, f"关键帧提取完成,共 {len(keyframe_files)} 帧")
|
|
||||||
st.success(f"✅ 成功提取 {len(keyframe_files)} 个关键帧")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# 如果提取失败,清理创建的目录
|
|
||||||
try:
|
|
||||||
if os.path.exists(video_keyframes_dir):
|
|
||||||
import shutil
|
|
||||||
shutil.rmtree(video_keyframes_dir)
|
|
||||||
except Exception as cleanup_err:
|
|
||||||
logger.error(f"清理失败的关键帧目录时出错: {cleanup_err}")
|
|
||||||
|
|
||||||
raise Exception(f"关键帧提取失败: {str(e)}")
|
|
||||||
|
|
||||||
"""
|
|
||||||
2. 视觉分析(批量分析每一帧)
|
|
||||||
"""
|
|
||||||
# 最佳实践:使用 get() 的默认值参数 + 从 config 获取备用值
|
|
||||||
vision_llm_provider = (
|
vision_llm_provider = (
|
||||||
st.session_state.get('vision_llm_provider') or
|
st.session_state.get("vision_llm_provider") or config.app.get("vision_llm_provider", "openai")
|
||||||
config.app.get('vision_llm_provider', 'openai')
|
|
||||||
).lower()
|
).lower()
|
||||||
|
vision_api_key = (
|
||||||
logger.info(f"使用 {vision_llm_provider.upper()} 进行视觉分析")
|
st.session_state.get(f"vision_{vision_llm_provider}_api_key")
|
||||||
|
or config.app.get(f"vision_{vision_llm_provider}_api_key")
|
||||||
try:
|
)
|
||||||
# ===================初始化视觉分析器===================
|
vision_model = (
|
||||||
update_progress(30, "正在初始化视觉分析器...")
|
st.session_state.get(f"vision_{vision_llm_provider}_model_name")
|
||||||
|
or config.app.get(f"vision_{vision_llm_provider}_model_name")
|
||||||
# 使用统一的配置键格式获取配置(支持所有 provider)
|
)
|
||||||
vision_api_key = (
|
vision_base_url = (
|
||||||
st.session_state.get(f'vision_{vision_llm_provider}_api_key') or
|
st.session_state.get(f"vision_{vision_llm_provider}_base_url")
|
||||||
config.app.get(f'vision_{vision_llm_provider}_api_key')
|
or config.app.get(f"vision_{vision_llm_provider}_base_url", "")
|
||||||
)
|
)
|
||||||
vision_model = (
|
if not vision_api_key or not vision_model:
|
||||||
st.session_state.get(f'vision_{vision_llm_provider}_model_name') or
|
raise ValueError(
|
||||||
config.app.get(f'vision_{vision_llm_provider}_model_name')
|
f"未配置 {vision_llm_provider} 的 API Key 或模型名称。"
|
||||||
)
|
f"请在设置页面配置 vision_{vision_llm_provider}_api_key 和 vision_{vision_llm_provider}_model_name"
|
||||||
vision_base_url = (
|
|
||||||
st.session_state.get(f'vision_{vision_llm_provider}_base_url') or
|
|
||||||
config.app.get(f'vision_{vision_llm_provider}_base_url', '')
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 验证必需配置
|
frame_interval_input = st.session_state.get("frame_interval_input") or config.frames.get(
|
||||||
if not vision_api_key or not vision_model:
|
"frame_interval_input", 3
|
||||||
raise ValueError(
|
)
|
||||||
f"未配置 {vision_llm_provider} 的 API Key 或模型名称。"
|
vision_batch_size = st.session_state.get("vision_batch_size") or config.frames.get("vision_batch_size", 10)
|
||||||
f"请在设置页面配置 vision_{vision_llm_provider}_api_key 和 vision_{vision_llm_provider}_model_name"
|
vision_max_concurrency = st.session_state.get("vision_max_concurrency") or config.frames.get(
|
||||||
)
|
"vision_max_concurrency", 2
|
||||||
|
)
|
||||||
|
|
||||||
# 创建视觉分析器实例(使用统一接口)
|
update_progress(10, "正在提取关键帧...")
|
||||||
llm_params = {
|
service = DocumentaryFrameAnalysisService()
|
||||||
"vision_provider": vision_llm_provider,
|
script_items = asyncio.run(
|
||||||
"vision_api_key": vision_api_key,
|
service.generate_documentary_script(
|
||||||
"vision_model_name": vision_model,
|
video_path=params.video_origin_path,
|
||||||
"vision_base_url": vision_base_url,
|
video_theme=st.session_state.get("video_theme", ""),
|
||||||
}
|
custom_prompt=st.session_state.get("custom_prompt", ""),
|
||||||
|
frame_interval_input=frame_interval_input,
|
||||||
logger.debug(f"视觉分析器配置: provider={vision_llm_provider}, model={vision_model}")
|
vision_batch_size=vision_batch_size,
|
||||||
|
vision_llm_provider=vision_llm_provider,
|
||||||
analyzer = create_vision_analyzer(
|
progress_callback=update_progress,
|
||||||
provider=vision_llm_provider,
|
vision_api_key=vision_api_key,
|
||||||
api_key=vision_api_key,
|
vision_model_name=vision_model,
|
||||||
model=vision_model,
|
vision_base_url=vision_base_url,
|
||||||
base_url=vision_base_url
|
max_concurrency=vision_max_concurrency,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
update_progress(40, "正在分析关键帧...")
|
logger.info(f"纪录片解说脚本生成完成,共 {len(script_items)} 个片段")
|
||||||
|
script = json.dumps(script_items, ensure_ascii=False, indent=2)
|
||||||
# ===================创建异步事件循环===================
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
# 执行异步分析
|
|
||||||
vision_batch_size = st.session_state.get('vision_batch_size') or config.frames.get("vision_batch_size")
|
|
||||||
vision_analysis_prompt = """
|
|
||||||
我提供了 %s 张视频帧,它们按时间顺序排列,代表一个连续的视频片段。请仔细分析每一帧的内容,并关注帧与帧之间的变化,以理解整个片段的活动。
|
|
||||||
|
|
||||||
首先,请详细描述每一帧的关键视觉信息(包含:主要内容、人物、动作和场景)。
|
|
||||||
然后,基于所有帧的分析,请用**简洁的语言**总结整个视频片段中发生的主要活动或事件流程。
|
|
||||||
|
|
||||||
请务必使用 JSON 格式输出你的结果。JSON 结构应如下:
|
|
||||||
{
|
|
||||||
"frame_observations": [
|
|
||||||
{
|
|
||||||
"frame_number": 1, // 或其他标识帧的方式
|
|
||||||
"observation": "描述每张视频帧中的主要内容、人物、动作和场景。"
|
|
||||||
},
|
|
||||||
// ... 更多帧的观察 ...
|
|
||||||
],
|
|
||||||
"overall_activity_summary": "在这里填写你总结的整个片段的主要活动,保持简洁。"
|
|
||||||
}
|
|
||||||
|
|
||||||
请务必不要遗漏视频帧,我提供了 %s 张视频帧,frame_observations 必须包含 %s 个元素
|
|
||||||
|
|
||||||
请只返回 JSON 字符串,不要包含任何其他解释性文字。
|
|
||||||
"""
|
|
||||||
results = loop.run_until_complete(
|
|
||||||
analyzer.analyze_images(
|
|
||||||
images=keyframe_files,
|
|
||||||
prompt=vision_analysis_prompt,
|
|
||||||
batch_size=vision_batch_size
|
|
||||||
)
|
|
||||||
)
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
"""
|
|
||||||
3. 处理分析结果(格式化为 json 数据)
|
|
||||||
"""
|
|
||||||
# ===================处理分析结果===================
|
|
||||||
update_progress(60, "正在整理分析结果...")
|
|
||||||
|
|
||||||
# 合并所有批次的分析结果
|
|
||||||
frame_analysis = ""
|
|
||||||
merged_frame_observations = [] # 合并所有批次的帧观察
|
|
||||||
overall_activity_summaries = [] # 合并所有批次的整体总结
|
|
||||||
prev_batch_files = None
|
|
||||||
frame_counter = 1 # 初始化帧计数器,用于给所有帧分配连续的序号
|
|
||||||
|
|
||||||
# 确保分析目录存在
|
|
||||||
analysis_dir = os.path.join(utils.storage_dir(), "temp", "analysis")
|
|
||||||
os.makedirs(analysis_dir, exist_ok=True)
|
|
||||||
origin_res = os.path.join(analysis_dir, "frame_analysis.json")
|
|
||||||
with open(origin_res, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(results, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
# 开始处理
|
|
||||||
for result in results:
|
|
||||||
if 'error' in result:
|
|
||||||
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 获取当前批次的文件列表
|
|
||||||
batch_files = get_batch_files(keyframe_files, result, vision_batch_size)
|
|
||||||
|
|
||||||
# 获取批次的时间戳范围
|
|
||||||
first_timestamp, last_timestamp, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files)
|
|
||||||
|
|
||||||
# 解析响应中的JSON数据
|
|
||||||
response_text = result['response']
|
|
||||||
try:
|
|
||||||
# 处理可能包含```json```格式的响应
|
|
||||||
if "```json" in response_text:
|
|
||||||
json_content = response_text.split("```json")[1].split("```")[0].strip()
|
|
||||||
elif "```" in response_text:
|
|
||||||
json_content = response_text.split("```")[1].split("```")[0].strip()
|
|
||||||
else:
|
|
||||||
json_content = response_text.strip()
|
|
||||||
|
|
||||||
response_data = json.loads(json_content)
|
|
||||||
|
|
||||||
# 提取frame_observations和overall_activity_summary
|
|
||||||
if "frame_observations" in response_data:
|
|
||||||
frame_obs = response_data["frame_observations"]
|
|
||||||
overall_summary = response_data.get("overall_activity_summary", "")
|
|
||||||
|
|
||||||
# 添加时间戳信息到每个帧观察
|
|
||||||
for i, obs in enumerate(frame_obs):
|
|
||||||
if i < len(batch_files):
|
|
||||||
# 从文件名中提取时间戳
|
|
||||||
file_path = batch_files[i]
|
|
||||||
file_name = os.path.basename(file_path)
|
|
||||||
# 提取时间戳字符串 (格式如: keyframe_000675_000027000.jpg)
|
|
||||||
# 格式解析: keyframe_帧序号_毫秒时间戳.jpg
|
|
||||||
timestamp_parts = file_name.split('_')
|
|
||||||
if len(timestamp_parts) >= 3:
|
|
||||||
timestamp_str = timestamp_parts[-1].split('.')[0]
|
|
||||||
try:
|
|
||||||
# 修正时间戳解析逻辑
|
|
||||||
# 格式为000100000,表示00:01:00,000,即1分钟
|
|
||||||
# 需要按照对应位数进行解析:
|
|
||||||
# 前两位是小时,中间两位是分钟,后面是秒和毫秒
|
|
||||||
if len(timestamp_str) >= 9: # 确保格式正确
|
|
||||||
hours = int(timestamp_str[0:2])
|
|
||||||
minutes = int(timestamp_str[2:4])
|
|
||||||
seconds = int(timestamp_str[4:6])
|
|
||||||
milliseconds = int(timestamp_str[6:9])
|
|
||||||
|
|
||||||
# 计算总秒数
|
|
||||||
timestamp_seconds = hours * 3600 + minutes * 60 + seconds + milliseconds / 1000
|
|
||||||
formatted_time = utils.format_time(timestamp_seconds) # 格式化时间戳
|
|
||||||
else:
|
|
||||||
# 兼容旧的解析方式
|
|
||||||
timestamp_seconds = int(timestamp_str) / 1000 # 转换为秒
|
|
||||||
formatted_time = utils.format_time(timestamp_seconds) # 格式化时间戳
|
|
||||||
except ValueError:
|
|
||||||
logger.warning(f"无法解析时间戳: {timestamp_str}")
|
|
||||||
timestamp_seconds = 0
|
|
||||||
formatted_time = "00:00:00,000"
|
|
||||||
else:
|
|
||||||
logger.warning(f"文件名格式不符合预期: {file_name}")
|
|
||||||
timestamp_seconds = 0
|
|
||||||
formatted_time = "00:00:00,000"
|
|
||||||
|
|
||||||
# 添加额外信息到帧观察
|
|
||||||
obs["frame_path"] = file_path
|
|
||||||
obs["timestamp"] = formatted_time
|
|
||||||
obs["timestamp_seconds"] = timestamp_seconds
|
|
||||||
obs["batch_index"] = result['batch_index']
|
|
||||||
|
|
||||||
# 使用全局递增的帧计数器替换原始的frame_number
|
|
||||||
if "frame_number" in obs:
|
|
||||||
obs["original_frame_number"] = obs["frame_number"] # 保留原始编号作为参考
|
|
||||||
obs["frame_number"] = frame_counter # 赋值连续的帧编号
|
|
||||||
frame_counter += 1 # 增加帧计数器
|
|
||||||
|
|
||||||
# 添加到合并列表
|
|
||||||
merged_frame_observations.append(obs)
|
|
||||||
|
|
||||||
# 添加批次整体总结信息
|
|
||||||
if overall_summary:
|
|
||||||
# 从文件名中提取时间戳数值
|
|
||||||
first_time_str = first_timestamp.split('_')[-1].split('.')[0]
|
|
||||||
last_time_str = last_timestamp.split('_')[-1].split('.')[0]
|
|
||||||
|
|
||||||
# 转换为毫秒并计算持续时间(秒)
|
|
||||||
try:
|
|
||||||
# 修正解析逻辑,与上面相同的方式解析时间戳
|
|
||||||
if len(first_time_str) >= 9 and len(last_time_str) >= 9:
|
|
||||||
# 解析第一个时间戳
|
|
||||||
first_hours = int(first_time_str[0:2])
|
|
||||||
first_minutes = int(first_time_str[2:4])
|
|
||||||
first_seconds = int(first_time_str[4:6])
|
|
||||||
first_ms = int(first_time_str[6:9])
|
|
||||||
first_time_seconds = first_hours * 3600 + first_minutes * 60 + first_seconds + first_ms / 1000
|
|
||||||
|
|
||||||
# 解析第二个时间戳
|
|
||||||
last_hours = int(last_time_str[0:2])
|
|
||||||
last_minutes = int(last_time_str[2:4])
|
|
||||||
last_seconds = int(last_time_str[4:6])
|
|
||||||
last_ms = int(last_time_str[6:9])
|
|
||||||
last_time_seconds = last_hours * 3600 + last_minutes * 60 + last_seconds + last_ms / 1000
|
|
||||||
|
|
||||||
batch_duration = last_time_seconds - first_time_seconds
|
|
||||||
else:
|
|
||||||
# 兼容旧的解析方式
|
|
||||||
first_time_ms = int(first_time_str)
|
|
||||||
last_time_ms = int(last_time_str)
|
|
||||||
batch_duration = (last_time_ms - first_time_ms) / 1000
|
|
||||||
except ValueError:
|
|
||||||
# 使用 utils.time_to_seconds 函数处理格式化的时间戳
|
|
||||||
first_time_seconds = utils.time_to_seconds(first_time_str.replace('_', ':').replace('-', ','))
|
|
||||||
last_time_seconds = utils.time_to_seconds(last_time_str.replace('_', ':').replace('-', ','))
|
|
||||||
batch_duration = last_time_seconds - first_time_seconds
|
|
||||||
|
|
||||||
overall_activity_summaries.append({
|
|
||||||
"batch_index": result['batch_index'],
|
|
||||||
"time_range": f"{first_timestamp}-{last_timestamp}",
|
|
||||||
"duration_seconds": batch_duration,
|
|
||||||
"summary": overall_summary
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"解析批次 {result['batch_index']} 的响应数据失败: {str(e)}")
|
|
||||||
# 添加原始响应作为回退
|
|
||||||
frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n"
|
|
||||||
frame_analysis += response_text
|
|
||||||
frame_analysis += "\n"
|
|
||||||
|
|
||||||
# 更新上一个批次的文件
|
|
||||||
prev_batch_files = batch_files
|
|
||||||
|
|
||||||
# 将合并后的结果转为JSON字符串
|
|
||||||
merged_results = {
|
|
||||||
"frame_observations": merged_frame_observations,
|
|
||||||
"overall_activity_summaries": overall_activity_summaries
|
|
||||||
}
|
|
||||||
|
|
||||||
# 使用当前时间创建文件名
|
|
||||||
now = datetime.now()
|
|
||||||
timestamp_str = now.strftime("%Y%m%d_%H%M")
|
|
||||||
|
|
||||||
# 保存完整的分析结果为JSON
|
|
||||||
analysis_filename = f"frame_analysis_{timestamp_str}.json"
|
|
||||||
analysis_json_path = os.path.join(analysis_dir, analysis_filename)
|
|
||||||
with open(analysis_json_path, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(merged_results, f, ensure_ascii=False, indent=2)
|
|
||||||
logger.info(f"分析结果已保存到: {analysis_json_path}")
|
|
||||||
|
|
||||||
"""
|
|
||||||
4. 生成文案
|
|
||||||
"""
|
|
||||||
logger.info("开始生成解说文案")
|
|
||||||
update_progress(80, "正在生成解说文案...")
|
|
||||||
from app.services.generate_narration_script import parse_frame_analysis_to_markdown, generate_narration
|
|
||||||
# 从配置中获取文本生成相关配置
|
|
||||||
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
|
|
||||||
text_api_key = config.app.get(f'text_{text_provider}_api_key')
|
|
||||||
text_model = config.app.get(f'text_{text_provider}_model_name')
|
|
||||||
text_base_url = config.app.get(f'text_{text_provider}_base_url')
|
|
||||||
llm_params.update({
|
|
||||||
"text_provider": text_provider,
|
|
||||||
"text_api_key": text_api_key,
|
|
||||||
"text_model_name": text_model,
|
|
||||||
"text_base_url": text_base_url
|
|
||||||
})
|
|
||||||
# 整理帧分析数据
|
|
||||||
markdown_output = parse_frame_analysis_to_markdown(analysis_json_path)
|
|
||||||
|
|
||||||
# 生成解说文案
|
|
||||||
narration = generate_narration(
|
|
||||||
markdown_output,
|
|
||||||
text_api_key,
|
|
||||||
base_url=text_base_url,
|
|
||||||
model=text_model
|
|
||||||
)
|
|
||||||
|
|
||||||
# 使用增强的JSON解析器
|
|
||||||
from webui.tools.generate_short_summary import parse_and_fix_json
|
|
||||||
narration_data = parse_and_fix_json(narration)
|
|
||||||
|
|
||||||
if not narration_data or 'items' not in narration_data:
|
|
||||||
logger.error(f"解说文案JSON解析失败,原始内容: {narration[:200]}...")
|
|
||||||
raise Exception("解说文案格式错误,无法解析JSON或缺少items字段")
|
|
||||||
|
|
||||||
narration_dict = narration_data['items']
|
|
||||||
# 为 narration_dict 中每个 item 新增一个 OST: 2 的字段, 代表保留原声和配音
|
|
||||||
narration_dict = [{**item, "OST": 2} for item in narration_dict]
|
|
||||||
logger.info(f"解说文案生成完成,共 {len(narration_dict)} 个片段")
|
|
||||||
# 结果转换为JSON字符串
|
|
||||||
script = json.dumps(narration_dict, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"大模型处理过程中发生错误\n{traceback.format_exc()}")
|
|
||||||
raise Exception(f"分析失败: {str(e)}")
|
|
||||||
|
|
||||||
if script is None:
|
|
||||||
st.error("生成脚本失败,请检查日志")
|
|
||||||
st.stop()
|
|
||||||
logger.info(f"纪录片解说脚本生成完成")
|
|
||||||
if isinstance(script, list):
|
if isinstance(script, list):
|
||||||
st.session_state['video_clip_json'] = script
|
st.session_state["video_clip_json"] = script
|
||||||
elif isinstance(script, str):
|
elif isinstance(script, str):
|
||||||
st.session_state['video_clip_json'] = json.loads(script)
|
st.session_state["video_clip_json"] = json.loads(script)
|
||||||
update_progress(100, "脚本生成完成")
|
update_progress(100, "脚本生成完成")
|
||||||
|
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user