Compare commits

...

26 Commits
v0.7.7 ... main

Author SHA1 Message Date
viccy
c0b72ec603 chore: 更新项目版本至0.7.9并优化README内容
- 将项目版本从0.7.8更新至0.7.9
- 优化README.md的排版和结构,提升可读性
- 更新功能列表和最新资讯,新增对0.7.9版本的说明
- 移除过时的推广内容,更新赞助商标识
2026-04-27 18:51:49 +08:00
viccy
99dd4193ae feat(字幕): 新增阿里百炼 Fun-ASR 音视频字幕转录功能
- 在 WebUI 中增加 Fun-ASR 转录界面,支持上传多种音视频格式并生成 SRT 字幕
- 新增 `app/services/fun_asr_subtitle.py` 服务模块,实现完整的 REST API 调用流程,包括获取上传凭证、文件上传、提交任务、轮询结果和 SRT 格式转换
- 在配置文件中增加 `[fun_asr]` 配置段,支持保存 API Key
- 添加完整的单元测试,覆盖核心转换逻辑和服务流程
- 为兼容 Python 3.11 以下版本,将 `tomllib` 导入改为尝试导入并回退到 `tomli`
- 在 `defaults.py` 中添加 `from __future__ import annotations` 以支持类型注解
2026-04-27 18:15:54 +08:00
viccy
8c129790c7
Merge pull request #237 from aw123456dew/feature/doubao-tts
add doubao tts
2026-04-08 15:14:10 +08:00
viccy
de33c6d0bd
Merge pull request #238 from aw123456dew/feature/export-jianying-draft
add export jianying draft feature
2026-04-08 15:13:02 +08:00
aw123456dew
852f5ae34c fix: jianying draft export failure due to floating-point precision in audio duration 2026-04-07 17:13:43 +08:00
aw123456dew
d45c1858c9 add export jianying draft feature 2026-04-07 11:33:12 +08:00
aw123456dew
71dfc99839 add doubao tts 2026-04-07 09:10:50 +08:00
viccy
be653c5748
Merge pull request #236 from linyqh/codex/refactor-documentary-frame-analysis-pipeline
refactor(documentary): centralize frame analysis pipeline
2026-04-03 13:16:02 +08:00
linyq
d5c63cf4b4 chore: bump version to 0.7.8 2026-04-03 13:09:26 +08:00
linyq
e53156f4f2 fix(documentary): normalize streamlit progress values 2026-04-03 12:57:24 +08:00
linyq
abc9db22e5 Fix documentary narration parsing and explicit vision overrides 2026-04-03 12:04:09 +08:00
linyq
4e2560651f fix(documentary): restore narration repair and explicit vision overrides 2026-04-03 11:29:27 +08:00
linyq
a8b6a5bb6b fix(documentary): fail on malformed narration payload 2026-04-03 02:45:33 +08:00
linyq
d678bf62b1 fix(documentary): centralize final script generation in shared service 2026-04-03 02:38:54 +08:00
linyq
ac63fea953 refactor(documentary): route adapters through shared analysis service 2026-04-03 02:24:30 +08:00
linyq
df034d104b fix(documentary): keep frames when batch summary is missing 2026-04-03 02:09:02 +08:00
linyq
ad02059e5d fix(documentary): validate batch response contract before success 2026-04-03 02:04:21 +08:00
linyq
4d21c43b89 feat(documentary): preserve failed batches and add vision concurrency 2026-04-03 01:54:47 +08:00
linyq
8201911b82 fix(documentary): harden fast-path fallback and cache key prefix 2026-04-03 01:42:43 +08:00
linyq
3d76bff442 perf(documentary): add fast frame extraction and cache keys 2026-04-03 01:30:51 +08:00
linyq
40a48cc9ff fix(documentary): align batch result fields with prompt contract 2026-04-03 01:23:05 +08:00
linyq
c83841a2e0 chore(gitignore): restore minimal tests ignore exception 2026-04-03 01:18:39 +08:00
linyq
f9539eac8c fix(documentary): tighten prompt contract and config guards 2026-04-03 01:14:41 +08:00
linyq
1d148370c5 feat(documentary): add shared frame analysis contract 2026-04-03 00:55:19 +08:00
linyq
093c8aa329 test: ignore manual llm smoke scripts in pytest 2026-04-03 00:47:30 +08:00
linyq
1057bd215c chore: ignore local worktrees 2026-04-03 00:36:46 +08:00
36 changed files with 3844 additions and 826 deletions

8
.gitignore vendored
View File

@ -39,9 +39,15 @@ bug清单.md
task.md
.claude/*
.serena/*
.worktrees/
# OpenSpec: 忽略活动的变更提案,但保留归档和规范
openspec/*
AGENTS.md
CLAUDE.md
tests/*
tests/*
!tests/test_documentary_frame_analysis_service.py
!tests/test_video_processor_documentary_unittest.py
!tests/test_script_service_documentary_unittest.py
!tests/test_generate_narration_script_documentary_unittest.py
!tests/test_generate_script_docu_unittest.py

View File

@ -33,6 +33,7 @@ NarratoAI is an automated video narration tool that provides an all-in-one solut
</div>
## Latest News
- 2026.04.03 Released version 0.7.8, refactored the documentary frame-analysis pipeline with a shared service and improved extraction, caching, vision batching, and narration generation
- 2025.05.11 Released new version 0.6.0, supports **short drama commentary** and optimized editing process
- 2025.03.06 Released new version 0.5.2, supports DeepSeek R1 and DeepSeek V3 models for short drama mixing
- 2024.12.16 Released new version 0.3.9, supports Alibaba Qwen2-VL model for video understanding; supports short drama mixing

View File

@ -1,38 +1,48 @@
<div align="center">
<h1 align="center" style="font-size: 2cm;"> NarratoAI 😎📽️ </h1>
<h1 align="center"> NarratoAI 😎📽️ </h1>
<h3 align="center">一站式 AI 影视解说+自动化剪辑工具🎬🎞️ </h3>
<p align="center">
📖 <a href="README-en.md">English</a> | 简体中文 | <a href="https://www.narratoai.cn">☁️ <b>云端版入口 (NarratoAI.cn)</b></a>
</p>
<h3>📖 <a href="README-en.md">English</a> | 简体中文 </h3>
<div align="center">
<br>
> **🔥 隆重推荐VibeCut 的新范式 —— [speclip.com](https://speclip.com)**
>
> **一个真正意义上的视频剪辑 Agent像聊天(vibecoding)一样剪辑视频。**
> **[👉 点击立即免费下载 Speclip](https://speclip.com)**
[//]: # ( <a href="https://trendshift.io/repositories/8731" target="_blank"><img src="https://trendshift.io/api/badge/repositories/8731" alt="harry0703%2FNarratoAI | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>)
</div>
<br>
NarratoAI 是一个自动化影视解说工具基于LLM实现文案撰写、自动化视频剪辑、配音和字幕生成的一站式流程助力高效内容创作。
<br>
> **🔥 隆重推荐VibeCut 的新范式 —— [Speclip](https://speclip.com) ,一个真正意义上的剪辑 Agent[👉 点击免费下载](https://speclip.com)**
NarratoAI 是一款自动化影视解说工具,基于 LLM 实现文案撰写、自动化视频剪辑、配音和字幕生成的一站式流程,助力高效内容创作。支持本地部署开源版及 [云端托管版](https://www.narratoai.cn)。
[![madewithlove](https://img.shields.io/badge/made_with-%E2%9D%A4-red?style=for-the-badge&labelColor=orange)](https://github.com/linyqh/NarratoAI)
[![GitHub license](https://img.shields.io/github/license/linyqh/NarratoAI?style=for-the-badge)](https://github.com/linyqh/NarratoAI/blob/main/LICENSE)
[![GitHub issues](https://img.shields.io/github/issues/linyqh/NarratoAI?style=for-the-badge)](https://github.com/linyqh/NarratoAI/issues)
[![GitHub stars](https://img.shields.io/github/stars/linyqh/NarratoAI?style=for-the-badge)](https://github.com/linyqh/NarratoAI/stargazers)
<br>
<a href="https://discord.com/invite/V2pbAqqQNb" target="_blank">💬 加入 discord 开源社区,获取项目动态和最新资讯。</a>
[![GitHub stars](https://img.shields.io/github/stars/linyqh/NarratoAI?style=for-the-badge)](https://github.com/linyqh/NarratoAI/stargazers) [![GitHub issues](https://img.shields.io/github/issues/linyqh/NarratoAI?style=for-the-badge)](https://github.com/linyqh/NarratoAI/issues) [![madewithlove](https://img.shields.io/badge/made_with-%E2%9D%A4-red?style=for-the-badge&labelColor=orange)](https://github.com/linyqh/NarratoAI)
<br>
<a href="https://github.com/linyqh/NarratoAI/wiki" target="_blank">💬 加入开源社群,获取项目动态和最新资讯</a>
<br>
<h2><a href="https://p9mf6rjv3c.feishu.cn/wiki/SP8swLLZki5WRWkhuFvc2CyInDg?from=from_copylink" target="_blank">🎉🎉🎉 官方文档 🎉🎉🎉</a> </h2>
<h3>首页</h3>
### 界面预览
![](docs/index-zh.png)
</div>
## 许可证
本项目仅供学习和研究使用,不得商用。如需商业授权,请联系作者。
## 最新资讯
- 2026.04.27 发布新版本 0.7.9,新增 **Fun-ASR一键转录字幕**
- 2026.04.03 发布新版本 0.7.8,重构纪录片逐帧分析链路,统一共享服务并优化抽帧、缓存、视觉并发与文案生成流程
- 2026.03.27 发布新版本 0.7.7,出于安全考虑,已移除 LiteLLM 依赖,统一使用 OpenAI 兼容请求链路
- 2025.11.20 发布新版本 0.7.5,新增 [IndexTTS2](https://github.com/index-tts/index-tts) 语音克隆支持
- 2025.10.15 发布新版本 0.7.3,升级大模型供应商管理能力
@ -47,17 +57,7 @@ NarratoAI 是一个自动化影视解说工具基于LLM实现文案撰写、
- 2024.11.10 发布新版本 v0.3.5;优化视频剪辑流程,
## 重磅福利 🎉
> 1
> **开发者专属福利一站式AI平台注册即送体验金**
>
> 还在为接入各种AI模型烦恼吗向您推荐 302.AI一个企业级的AI资源中心。一次接入即可调用上百种AI模型涵盖语言、图像、音视频等按量付费极大降低开发成本。
>
> 通过下方我的专属链接注册,**立获1美元免费体验金**助您轻松开启AI开发之旅。
>
> **立即注册领取:** [https://share.302.ai/I9P6mP](https://share.302.ai/I9P6mP)
---
> 2
> 即日起全面支持硅基流动注册即享2000万免费Token价值16元平台配额剪辑10分钟视频仅需0.1元!
>
> 🔥 快速领福利:
@ -96,7 +96,7 @@ _**1. NarratoAI 是一款完全免费的软件,近期在社交媒体(抖音,B
- [x] 一键合并素材
- [x] 一键转录
- [x] 一键清理缓存
- [ ] 支持导出剪映草稿
- [x] 支持导出剪映草稿
- [X] 支持短剧解说
- [ ] 主角人脸匹配
- [ ] 支持根据口播,文案,视频素材自动匹配
@ -168,7 +168,9 @@ streamlit run webui.py --server.maxUploadSize=2048
</div>
## 赞助
[![Powered by DartNode](https://dartnode.com/branding/DN-Open-Source-sm.png)](https://dartnode.com "Powered by DartNode - Free VPS for Open Source")
<a href="https://dartnode.com">
<img src="https://dartnode.com/_branding/white_color_full.png" alt="Powered by DartNode" style="background-color: #333; padding: 10px; border-radius: 4px;">
</a>
## 许可证 📝

View File

@ -81,7 +81,9 @@ def save_config():
_cfg["soulvoice"] = soulvoice
_cfg["ui"] = ui
_cfg["tts_qwen"] = tts_qwen
_cfg["fun_asr"] = fun_asr
_cfg["indextts2"] = indextts2
_cfg["doubaotts"] = doubaotts
f.write(toml.dumps(_cfg))
@ -95,7 +97,9 @@ soulvoice = _cfg.get("soulvoice", {})
ui = _cfg.get("ui", {})
frames = _cfg.get("frames", {})
tts_qwen = _cfg.get("tts_qwen", {})
fun_asr = _cfg.get("fun_asr", {})
indextts2 = _cfg.get("indextts2", {})
doubaotts = _cfg.get("doubaotts", {})
hostname = socket.gethostname()

View File

@ -1,5 +1,7 @@
"""Shared config defaults used by both bootstrap and WebUI fallbacks."""
from __future__ import annotations
DEFAULT_OPENAI_COMPATIBLE_BASE_URL = "https://api.siliconflow.cn/v1"
DEFAULT_OPENAI_COMPATIBLE_PROVIDER = "openai"

View File

@ -2,7 +2,10 @@ import tempfile
import unittest
from pathlib import Path
import tomllib
try:
import tomllib
except ModuleNotFoundError: # Python < 3.11
import tomli as tomllib
from app.config import config as cfg
from app.config.defaults import (

View File

@ -196,6 +196,7 @@ class VideoClipParams(BaseModel):
tts_volume: Optional[float] = Field(default=AudioVolumeDefaults.TTS_VOLUME, description="解说语音音量(后处理)")
original_volume: Optional[float] = Field(default=AudioVolumeDefaults.ORIGINAL_VOLUME, description="视频原声音量")
bgm_volume: Optional[float] = Field(default=AudioVolumeDefaults.BGM_VOLUME, description="背景音乐音量")
draft_name: Optional[str] = Field(default="", description="剪映草稿名称")

View File

@ -0,0 +1,13 @@
from app.services.documentary.frame_analysis_models import (
DocumentaryAnalysisConfig,
FrameBatchResult,
)
from app.services.documentary.frame_analysis_service import (
DocumentaryFrameAnalysisService,
)
__all__ = [
"DocumentaryAnalysisConfig",
"FrameBatchResult",
"DocumentaryFrameAnalysisService",
]

View File

@ -0,0 +1,33 @@
from dataclasses import dataclass, field
@dataclass(slots=True)
class DocumentaryAnalysisConfig:
video_path: str
frame_interval_seconds: float
vision_batch_size: int
vision_llm_provider: str
vision_model_name: str
custom_prompt: str = ""
max_concurrency: int = 2
def __post_init__(self) -> None:
if self.frame_interval_seconds <= 0:
raise ValueError("frame_interval_seconds must be > 0")
if self.vision_batch_size <= 0:
raise ValueError("vision_batch_size must be > 0")
if self.max_concurrency <= 0:
raise ValueError("max_concurrency must be > 0")
@dataclass(slots=True)
class FrameBatchResult:
batch_index: int
status: str
time_range: str
raw_response: str
frame_paths: list[str] = field(default_factory=list)
frame_observations: list[dict] = field(default_factory=list)
overall_activity_summary: str = ""
fallback_summary: str = ""
error_message: str = ""

View File

@ -0,0 +1,761 @@
import asyncio
import json
import os
import re
from datetime import datetime
from typing import Any, Callable
from loguru import logger
from app.config import config
from app.services.documentary.frame_analysis_models import FrameBatchResult
from app.services.generate_narration_script import generate_narration, parse_frame_analysis_to_markdown
from app.services.llm.migration_adapter import create_vision_analyzer
from app.utils import utils, video_processor
class DocumentaryFrameAnalysisService:
PROMPT_TEMPLATE = """
我提供了 {frame_count} 张视频帧它们按时间顺序排列代表一个连续的视频片段
首先请详细描述每一帧的关键视觉信息包含主要内容人物动作和场景
然后基于所有帧的分析请用简洁的语言总结整个视频片段中发生的主要活动或事件流程
请务必使用 JSON 格式输出
JSON 必须包含以下键
- frame_observations: 数组且长度必须为 {frame_count}
- overall_activity_summary: 字符串描述整个批次主要活动
示例结构
{{
"frame_observations": [
{{"timestamp": "00:00:00,000", "observation": "画面描述"}}
],
"overall_activity_summary": "本批次主要活动总结"
}}
请务必不要遗漏视频帧我提供了 {frame_count} 张视频帧frame_observations 必须包含 {frame_count} 个元素
请只返回 JSON 字符串不要附加解释文字
""".strip()
async def generate_documentary_script(
self,
*,
video_path: str,
video_theme: str = "",
custom_prompt: str = "",
frame_interval_input: int | float | None = None,
vision_batch_size: int | None = None,
vision_llm_provider: str | None = None,
progress_callback: Callable[[float, str], None] | None = None,
vision_api_key: str | None = None,
vision_model_name: str | None = None,
vision_base_url: str | None = None,
max_concurrency: int | None = None,
) -> list[dict]:
progress = progress_callback or (lambda _p, _m: None)
analysis_result = await self.analyze_video(
video_path=video_path,
video_theme=video_theme,
custom_prompt=custom_prompt,
frame_interval_input=frame_interval_input,
vision_batch_size=vision_batch_size,
vision_llm_provider=vision_llm_provider,
progress_callback=progress_callback,
vision_api_key=vision_api_key,
vision_model_name=vision_model_name,
vision_base_url=vision_base_url,
max_concurrency=max_concurrency,
)
analysis_json_path = analysis_result["analysis_json_path"]
progress(80, "正在生成解说文案...")
text_provider = config.app.get("text_llm_provider", "openai").lower()
text_api_key = config.app.get(f"text_{text_provider}_api_key")
text_model = config.app.get(f"text_{text_provider}_model_name")
text_base_url = config.app.get(f"text_{text_provider}_base_url")
if not text_api_key or not text_model:
raise ValueError(
f"未配置 {text_provider} 的文本模型参数。"
f"请在设置中配置 text_{text_provider}_api_key 和 text_{text_provider}_model_name"
)
markdown_output = parse_frame_analysis_to_markdown(analysis_json_path)
narration_input = self._build_narration_input(
markdown_output=markdown_output,
video_theme=video_theme,
custom_prompt=custom_prompt,
)
narration_raw = generate_narration(
narration_input,
text_api_key,
base_url=text_base_url,
model=text_model,
)
narration_items = self._parse_narration_items(narration_raw)
final_script = [{**item, "OST": 2} for item in narration_items]
progress(100, "脚本生成完成")
return final_script
async def analyze_video(
self,
*,
video_path: str,
video_theme: str = "",
custom_prompt: str = "",
frame_interval_input: int | float | None = None,
vision_batch_size: int | None = None,
vision_llm_provider: str | None = None,
progress_callback: Callable[[float, str], None] | None = None,
vision_api_key: str | None = None,
vision_model_name: str | None = None,
vision_base_url: str | None = None,
max_concurrency: int | None = None,
) -> dict[str, Any]:
progress = progress_callback or (lambda _p, _m: None)
if not video_path or not os.path.exists(video_path):
raise FileNotFoundError(f"视频文件不存在: {video_path}")
frame_interval_seconds = self._resolve_frame_interval(frame_interval_input)
batch_size = self._resolve_batch_size(vision_batch_size)
concurrency = self._resolve_max_concurrency(max_concurrency)
provider = (vision_llm_provider or config.app.get("vision_llm_provider", "openai")).lower()
api_key = vision_api_key if vision_api_key is not None else config.app.get(f"vision_{provider}_api_key")
model_name = (
vision_model_name if vision_model_name is not None else config.app.get(f"vision_{provider}_model_name")
)
base_url = vision_base_url if vision_base_url is not None else config.app.get(f"vision_{provider}_base_url", "")
if not api_key or not model_name:
raise ValueError(
f"未配置 {provider} 的 API Key 或模型名称。"
f"请在设置中配置 vision_{provider}_api_key 和 vision_{provider}_model_name"
)
progress(10, "正在提取关键帧...")
keyframe_files = self._load_or_extract_keyframes(video_path, frame_interval_seconds)
progress(25, f"关键帧准备完成,共 {len(keyframe_files)}")
progress(30, "正在初始化视觉分析器...")
analyzer = create_vision_analyzer(
provider=provider,
api_key=api_key,
model=model_name,
base_url=base_url,
)
batches = self._chunk_keyframes(keyframe_files, batch_size=batch_size)
if not batches:
raise RuntimeError("未能构建任何关键帧批次")
progress(40, f"正在分析关键帧,共 {len(batches)} 个批次...")
batch_results = await self._analyze_batches(
analyzer=analyzer,
batches=batches,
custom_prompt=custom_prompt,
video_theme=video_theme,
max_concurrency=concurrency,
progress_callback=progress,
)
progress(65, "正在整理分析结果...")
sorted_batches = self._sort_batch_results(batch_results)
artifact = self._build_analysis_artifact(
sorted_batches,
video_path=video_path,
frame_interval_seconds=frame_interval_seconds,
vision_batch_size=batch_size,
vision_llm_provider=provider,
vision_model_name=model_name,
max_concurrency=concurrency,
)
analysis_json_path = self._save_analysis_artifact(artifact)
video_clip_json = self._build_video_clip_json(sorted_batches)
progress(75, "逐帧分析完成")
return {
"analysis_json_path": analysis_json_path,
"analysis_artifact": artifact,
"video_clip_json": video_clip_json,
"keyframe_files": keyframe_files,
}
def _parse_narration_items(self, narration_raw: str) -> list[dict[str, Any]]:
parsed = self._repair_narration_payload(narration_raw)
items: list[dict[str, Any]] = []
if isinstance(parsed, dict):
raw_items = parsed.get("items")
if isinstance(raw_items, list):
items = [item for item in raw_items if isinstance(item, dict)]
if not items:
raise ValueError("解说文案格式错误无法解析JSON或缺少items字段")
return items
def _build_narration_input(self, *, markdown_output: str, video_theme: str, custom_prompt: str) -> str:
context_lines: list[str] = []
if (video_theme or "").strip():
context_lines.append(f"视频主题:{video_theme.strip()}")
if (custom_prompt or "").strip():
context_lines.append(f"补充创作要求:{custom_prompt.strip()}")
if not context_lines:
return markdown_output
context_block = "\n".join(f"- {line}" for line in context_lines)
return f"{markdown_output.rstrip()}\n\n## 创作上下文\n{context_block}\n"
def _repair_narration_payload(self, narration_raw: str) -> dict[str, Any] | None:
def load_json_candidate(payload: str) -> dict[str, Any] | None:
try:
parsed = json.loads(payload)
return parsed if isinstance(parsed, dict) else None
except Exception:
return None
cleaned = (narration_raw or "").strip()
if not cleaned:
return None
candidates: list[str] = [cleaned]
candidates.append(cleaned.replace("{{", "{").replace("}}", "}"))
json_block = re.search(r"```json\s*(.*?)\s*```", cleaned, re.DOTALL)
if json_block:
candidates.append(json_block.group(1).strip())
start = cleaned.find("{")
end = cleaned.rfind("}")
if start >= 0 and end > start:
candidates.append(cleaned[start : end + 1])
for candidate in candidates:
parsed = load_json_candidate(candidate)
if parsed is not None:
return parsed
fixed = cleaned.replace("{{", "{").replace("}}", "}")
fixed_start = fixed.find("{")
fixed_end = fixed.rfind("}")
if fixed_start >= 0 and fixed_end > fixed_start:
fixed = fixed[fixed_start : fixed_end + 1]
fixed = re.sub(r"^\s*#.*$", "", fixed, flags=re.MULTILINE)
fixed = re.sub(r"^\s*//.*$", "", fixed, flags=re.MULTILINE)
fixed = re.sub(r",\s*}", "}", fixed)
fixed = re.sub(r",\s*]", "]", fixed)
fixed = re.sub(r"'([^']*)'\s*:", r'"\1":', fixed)
fixed = re.sub(r'([{\[,]\s*)([A-Za-z_][\w\u4e00-\u9fff]*)(\s*:)', r'\1"\2"\3', fixed)
fixed = re.sub(r'""([^"]*?)""', r'"\1"', fixed)
return load_json_candidate(fixed)
def _resolve_frame_interval(self, frame_interval_input: int | float | None) -> float:
interval = frame_interval_input
if interval in (None, ""):
interval = config.frames.get("frame_interval_input", 3)
try:
value = float(interval)
except (TypeError, ValueError):
value = 3.0
if value <= 0:
raise ValueError("frame_interval_input must be > 0")
return value
def _resolve_batch_size(self, vision_batch_size: int | None) -> int:
size = vision_batch_size or config.frames.get("vision_batch_size", 10)
try:
value = int(size)
except (TypeError, ValueError):
value = 10
if value <= 0:
raise ValueError("vision_batch_size must be > 0")
return value
def _resolve_max_concurrency(self, max_concurrency: int | None) -> int:
value = max_concurrency if max_concurrency is not None else config.frames.get("vision_max_concurrency", 2)
try:
parsed = int(value)
except (TypeError, ValueError):
parsed = 1
return max(1, parsed)
def _load_or_extract_keyframes(self, video_path: str, frame_interval_seconds: float) -> list[str]:
keyframes_root = os.path.join(utils.temp_dir(), "keyframes")
os.makedirs(keyframes_root, exist_ok=True)
cache_key = self._build_keyframe_cache_key(video_path, frame_interval_seconds)
cache_dir = os.path.join(keyframes_root, cache_key)
os.makedirs(cache_dir, exist_ok=True)
cached_files = self._collect_keyframe_paths(cache_dir)
if cached_files:
logger.info(f"使用已缓存关键帧: {cache_dir}, 共 {len(cached_files)}")
return cached_files
processor = video_processor.VideoProcessor(video_path)
extracted = processor.extract_frames_by_interval_with_fallback(
output_dir=cache_dir,
interval_seconds=frame_interval_seconds,
)
keyframe_files = sorted(str(path) for path in extracted if str(path).endswith(".jpg"))
if not keyframe_files:
keyframe_files = self._collect_keyframe_paths(cache_dir)
if not keyframe_files:
raise RuntimeError("未提取到任何关键帧")
logger.info(f"关键帧提取完成: {cache_dir}, 共 {len(keyframe_files)}")
return keyframe_files
def _build_keyframe_cache_key(self, video_path: str, frame_interval_seconds: float) -> str:
try:
video_mtime = os.path.getmtime(video_path)
except OSError:
video_mtime = 0
legacy_prefix = utils.md5(f"{video_path}{video_mtime}")
payload = "|".join(
[
str(video_path),
str(video_mtime),
str(frame_interval_seconds),
"documentary-keyframes-v2",
]
)
return f"{legacy_prefix}_{utils.md5(payload)}"
@staticmethod
def _collect_keyframe_paths(cache_dir: str) -> list[str]:
if not os.path.exists(cache_dir):
return []
return sorted(
os.path.join(cache_dir, name)
for name in os.listdir(cache_dir)
if re.fullmatch(r"keyframe_\d{6}_\d{9}\.jpg", name)
)
@staticmethod
def _chunk_keyframes(keyframe_files: list[str], batch_size: int) -> list[list[str]]:
return [keyframe_files[index : index + batch_size] for index in range(0, len(keyframe_files), batch_size)]
async def _analyze_batches(
self,
*,
analyzer: Any,
batches: list[list[str]],
custom_prompt: str,
video_theme: str,
max_concurrency: int,
progress_callback: Callable[[float, str], None],
) -> list[FrameBatchResult]:
semaphore = asyncio.Semaphore(max(1, max_concurrency))
total = len(batches)
done = 0
done_lock = asyncio.Lock()
batch_time_ranges: list[str] = []
previous_batch_files: list[str] | None = None
for batch_files in batches:
_, _, time_range = self._get_batch_timestamps(batch_files, previous_batch_files)
batch_time_ranges.append(time_range)
previous_batch_files = batch_files
async def run_single(batch_index: int, frame_paths: list[str], time_range: str) -> FrameBatchResult:
nonlocal done
prompt = self._build_batch_prompt(
frame_count=len(frame_paths),
video_theme=video_theme,
custom_prompt=custom_prompt,
)
try:
async with semaphore:
raw_results = await analyzer.analyze_images(
images=frame_paths,
prompt=prompt,
batch_size=max(1, len(frame_paths)),
max_concurrency=1,
)
raw_response, error_message = self._extract_batch_response(raw_results)
if error_message:
return self._build_failed_batch_result(
batch_index=batch_index,
raw_response=raw_response,
error_message=error_message,
frame_paths=frame_paths,
time_range=time_range,
)
return self._parse_batch_response(
batch_index=batch_index,
raw_response=raw_response,
frame_paths=frame_paths,
time_range=time_range,
)
except Exception as exc:
return self._build_failed_batch_result(
batch_index=batch_index,
raw_response="",
error_message=str(exc),
frame_paths=frame_paths,
time_range=time_range,
)
finally:
async with done_lock:
done += 1
progress = 40 + (done / max(1, total)) * 25
progress_callback(progress, f"正在分析关键帧批次 ({done}/{total})...")
tasks = [
run_single(batch_index=index, frame_paths=batch_files, time_range=batch_time_ranges[index])
for index, batch_files in enumerate(batches)
]
return await asyncio.gather(*tasks)
def _build_batch_prompt(self, *, frame_count: int, video_theme: str, custom_prompt: str) -> str:
prompt = self._build_analysis_prompt(frame_count=frame_count)
extra_lines: list[str] = []
if (video_theme or "").strip():
extra_lines.append(f"视频主题:{video_theme.strip()}")
if (custom_prompt or "").strip():
extra_lines.append(custom_prompt.strip())
if not extra_lines:
return prompt
extras = "\n".join(f"- {line}" for line in extra_lines)
return f"{prompt}\n\n补充分析要求:\n{extras}"
def _extract_batch_response(self, raw_results: Any) -> tuple[str, str]:
if not raw_results:
return "", "Batch response is empty"
first_result = raw_results[0] if isinstance(raw_results, list) else raw_results
if isinstance(first_result, dict):
raw_response = str(first_result.get("response", "") or "")
error_message = str(first_result.get("error", "") or "")
if error_message:
if not raw_response:
raw_response = error_message
return raw_response, error_message
if not raw_response.strip():
return raw_response, "Batch response is empty"
return raw_response, ""
raw_response = str(first_result or "")
if not raw_response.strip():
return raw_response, "Batch response is empty"
return raw_response, ""
def _sort_batch_results(self, batch_results: list[FrameBatchResult]) -> list[FrameBatchResult]:
return sorted(batch_results, key=lambda item: (self._time_range_sort_key(item.time_range), item.batch_index))
def _build_analysis_artifact(
self,
batch_results: list[FrameBatchResult],
*,
video_path: str,
frame_interval_seconds: float,
vision_batch_size: int,
vision_llm_provider: str,
vision_model_name: str,
max_concurrency: int,
) -> dict[str, Any]:
sorted_batches = self._sort_batch_results(batch_results)
batch_dicts: list[dict[str, Any]] = []
frame_observations: list[dict[str, Any]] = []
overall_activity_summaries: list[dict[str, Any]] = []
for batch in sorted_batches:
batch_payload = {
"batch_index": batch.batch_index,
"status": batch.status,
"time_range": batch.time_range,
"raw_response": batch.raw_response,
"frame_paths": list(batch.frame_paths),
"frame_observations": list(batch.frame_observations),
"overall_activity_summary": batch.overall_activity_summary,
"fallback_summary": batch.fallback_summary,
"error_message": batch.error_message,
}
batch_dicts.append(batch_payload)
for observation in batch.frame_observations:
observation_payload = dict(observation)
observation_payload["batch_index"] = batch.batch_index
observation_payload["time_range"] = batch.time_range
frame_observations.append(observation_payload)
summary_text = (batch.overall_activity_summary or batch.fallback_summary or "").strip()
if summary_text:
overall_activity_summaries.append(
{
"batch_index": batch.batch_index,
"time_range": batch.time_range,
"summary": summary_text,
}
)
return {
"artifact_version": "documentary-frame-analysis-v2",
"generated_at": datetime.now().isoformat(),
"video_path": video_path,
"frame_interval_seconds": frame_interval_seconds,
"vision_batch_size": vision_batch_size,
"vision_llm_provider": vision_llm_provider,
"vision_model_name": vision_model_name,
"vision_max_concurrency": max_concurrency,
"batches": batch_dicts,
# 向后兼容旧解析器结构
"frame_observations": frame_observations,
"overall_activity_summaries": overall_activity_summaries,
}
def _save_analysis_artifact(self, artifact: dict[str, Any]) -> str:
analysis_dir = os.path.join(utils.storage_dir(), "temp", "analysis")
os.makedirs(analysis_dir, exist_ok=True)
filename = f"frame_analysis_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
file_path = os.path.join(analysis_dir, filename)
suffix = 1
while os.path.exists(file_path):
filename = f"frame_analysis_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{suffix:02d}.json"
file_path = os.path.join(analysis_dir, filename)
suffix += 1
with open(file_path, "w", encoding="utf-8") as fp:
json.dump(artifact, fp, ensure_ascii=False, indent=2)
logger.info(f"分析结果已保存到: {file_path}")
return file_path
def _build_video_clip_json(self, batch_results: list[FrameBatchResult]) -> list[dict]:
clips: list[dict] = []
for batch in self._sort_batch_results(batch_results):
picture = self._build_batch_picture(batch)
clips.append(
{
"timestamp": batch.time_range,
"picture": picture,
"narration": "",
"OST": 2,
}
)
return clips
def _build_batch_picture(self, batch: FrameBatchResult) -> str:
summary = (batch.overall_activity_summary or "").strip()
if summary:
return summary
fallback = (batch.fallback_summary or "").strip()
if fallback:
return fallback
observation_lines = []
for frame in batch.frame_observations:
timestamp = str(frame.get("timestamp", "") or "").strip()
observation = str(frame.get("observation", "") or "").strip()
if timestamp and observation:
observation_lines.append(f"{timestamp}: {observation}")
elif observation:
observation_lines.append(observation)
if observation_lines:
return " ".join(observation_lines)
raw_response = (batch.raw_response or "").strip()
if raw_response:
return raw_response[:200]
return "该批次分析失败,未返回可用描述。"
def _time_range_sort_key(self, time_range: str) -> tuple[int, str]:
start = (time_range or "").split("-", 1)[0].strip()
return self._timestamp_to_milliseconds(start), time_range
@staticmethod
def _timestamp_to_milliseconds(timestamp: str) -> int:
text = (timestamp or "").strip()
try:
if "," in text:
time_part, ms_part = text.split(",", 1)
milliseconds = int(ms_part)
else:
time_part = text
milliseconds = 0
parts = [int(part) for part in time_part.split(":") if part]
while len(parts) < 3:
parts.insert(0, 0)
hours, minutes, seconds = parts[-3], parts[-2], parts[-1]
return ((hours * 3600 + minutes * 60 + seconds) * 1000) + milliseconds
except Exception:
return 0
def _get_batch_timestamps(
self,
batch_files: list[str],
prev_batch_files: list[str] | None = None,
) -> tuple[str, str, str]:
if not batch_files:
return "00:00:00,000", "00:00:00,000", "00:00:00,000-00:00:00,000"
if len(batch_files) == 1 and prev_batch_files:
first_frame = os.path.basename(prev_batch_files[-1])
last_frame = os.path.basename(batch_files[0])
else:
first_frame = os.path.basename(batch_files[0])
last_frame = os.path.basename(batch_files[-1])
first_timestamp = self._timestamp_from_keyframe_name(first_frame)
last_timestamp = self._timestamp_from_keyframe_name(last_frame)
return first_timestamp, last_timestamp, f"{first_timestamp}-{last_timestamp}"
def _timestamp_from_keyframe_name(self, filename: str) -> str:
match = re.search(r"keyframe_\d{6}_(\d{9})\.jpg$", filename)
if not match:
return "00:00:00,000"
token = match.group(1)
hours = int(token[0:2])
minutes = int(token[2:4])
seconds = int(token[4:6])
milliseconds = int(token[6:9])
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
def _build_analysis_prompt(self, frame_count: int) -> str:
return self.PROMPT_TEMPLATE.format(frame_count=frame_count)
def _build_failed_batch_result(
self,
*,
batch_index: int,
raw_response: str,
error_message: str,
frame_paths: list[str],
time_range: str,
) -> FrameBatchResult:
fallback_summary = (raw_response or "").strip()[:200]
if not fallback_summary:
fallback_summary = f"Batch {batch_index} analysis failed: {error_message or 'unknown error'}"
return FrameBatchResult(
batch_index=batch_index,
status="failed",
time_range=time_range,
raw_response=raw_response,
frame_paths=list(frame_paths),
fallback_summary=fallback_summary,
error_message=error_message,
)
def _build_cache_key(
self,
video_path: str,
interval_seconds: float,
prompt_version: str,
model_name: str,
batch_size: int,
max_concurrency: int,
) -> str:
try:
video_mtime = os.path.getmtime(video_path)
except OSError:
video_mtime = 0
legacy_prefix = utils.md5(f"{video_path}{video_mtime}")
payload = "|".join(
[
str(video_path),
str(video_mtime),
str(interval_seconds),
str(prompt_version),
str(model_name),
str(batch_size),
str(max_concurrency),
"documentary-frame-analysis-v2",
]
)
return f"{legacy_prefix}_{utils.md5(payload)}"
def _strip_code_fence(self, response_text: str) -> str:
cleaned = (response_text or "").strip()
cleaned = re.sub(r"^```[a-zA-Z0-9_-]*\s*", "", cleaned)
cleaned = re.sub(r"\s*```$", "", cleaned)
return cleaned.strip()
def _parse_batch_response(
self,
*,
batch_index: int,
raw_response: str,
frame_paths: list[str],
time_range: str,
) -> FrameBatchResult:
try:
payload = json.loads(self._strip_code_fence(raw_response))
except Exception as exc:
return self._build_failed_batch_result(
batch_index=batch_index,
raw_response=raw_response,
error_message=str(exc),
frame_paths=frame_paths,
time_range=time_range,
)
validation_error = self._validate_batch_payload_contract(payload, expected_frame_count=len(frame_paths))
if validation_error:
return self._build_failed_batch_result(
batch_index=batch_index,
raw_response=raw_response,
error_message=validation_error,
frame_paths=frame_paths,
time_range=time_range,
)
raw_observations = payload["frame_observations"]
frame_observations: list[dict] = []
for index, frame_path in enumerate(frame_paths):
entry = raw_observations[index] if index < len(raw_observations) else {}
if isinstance(entry, dict):
observation = str(entry.get("observation", "") or "")
timestamp = str(entry.get("timestamp", "") or "")
else:
observation = str(entry or "")
timestamp = ""
frame_observations.append(
{
"frame_path": frame_path,
"timestamp": timestamp,
"observation": observation,
}
)
raw_summary = payload.get("overall_activity_summary", "")
if isinstance(raw_summary, str):
summary = raw_summary
elif raw_summary is None:
summary = ""
else:
summary = str(raw_summary)
return FrameBatchResult(
batch_index=batch_index,
status="success",
time_range=time_range,
raw_response=raw_response,
frame_paths=list(frame_paths),
frame_observations=frame_observations,
overall_activity_summary=summary,
)
def _validate_batch_payload_contract(self, payload: object, *, expected_frame_count: int) -> str:
if not isinstance(payload, dict):
return "Batch response JSON payload must be an object"
if "frame_observations" not in payload or not isinstance(payload["frame_observations"], list):
return "Batch response must include frame_observations as a list"
if len(payload["frame_observations"]) < expected_frame_count:
return (
"Batch response frame_observations length is shorter than provided frame_paths: "
f"{len(payload['frame_observations'])} < {expected_frame_count}"
)
return ""

View File

@ -0,0 +1,452 @@
"""Aliyun Bailian Fun-ASR subtitle transcription helpers.
This module intentionally uses the REST API because the official Fun-ASR
recorded-file API supports temporary `oss://` resources only through REST.
"""
from __future__ import annotations
import os
import time
from dataclasses import dataclass
from typing import Any, Optional
import requests
from loguru import logger
from app.utils import utils
DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com"
UPLOAD_POLICY_URL = f"{DASHSCOPE_BASE_URL}/api/v1/uploads"
TRANSCRIPTION_URL = f"{DASHSCOPE_BASE_URL}/api/v1/services/audio/asr/transcription"
TASK_URL_TEMPLATE = f"{DASHSCOPE_BASE_URL}/api/v1/tasks/{{task_id}}"
MODEL_NAME = "fun-asr"
TERMINAL_FAILED_STATUSES = {"FAILED", "CANCELED", "UNKNOWN"}
PUNCTUATION_BREAKS = set(",。!?;,.!?;")
class FunAsrError(RuntimeError):
"""Raised for user-actionable Fun-ASR transcription failures."""
@dataclass
class UploadPolicy:
upload_host: str
upload_dir: str
policy: str
signature: str
oss_access_key_id: str
x_oss_object_acl: str = "private"
x_oss_forbid_overwrite: str = "true"
max_file_size_mb: Optional[float] = None
def _auth_headers(api_key: str, extra: Optional[dict[str, str]] = None) -> dict[str, str]:
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
if extra:
headers.update(extra)
return headers
def _raise_for_http(response: requests.Response, action: str) -> None:
try:
response.raise_for_status()
except Exception as exc: # requests may be mocked with generic exceptions
raise FunAsrError(f"{action}失败,请检查阿里百炼 API Key、网络或服务状态") from exc
def _json(response: requests.Response, action: str) -> dict[str, Any]:
_raise_for_http(response, action)
try:
data = response.json()
except Exception as exc:
raise FunAsrError(f"{action}返回了无效 JSON") from exc
if not isinstance(data, dict):
raise FunAsrError(f"{action}返回格式无效")
return data
def _require_api_key(api_key: str) -> str:
api_key = (api_key or "").strip()
if not api_key:
raise FunAsrError("请先输入阿里百炼 API Key")
return api_key
def _safe_upload_name(local_file: str) -> str:
name = os.path.basename(local_file).strip() or f"audio_{int(time.time())}.wav"
return name.replace("/", "_").replace("\\", "_")
def _session_get(session, url: str, **kwargs):
return session.get(url, **kwargs)
def _session_post(session, url: str, **kwargs):
return session.post(url, **kwargs)
def request_upload_policy(api_key: str, model: str = MODEL_NAME, session=requests) -> UploadPolicy:
"""Request Bailian temporary-storage upload policy for the target model."""
api_key = _require_api_key(api_key)
response = _session_get(
session,
UPLOAD_POLICY_URL,
params={"action": "getPolicy", "model": model},
headers=_auth_headers(api_key),
timeout=30,
)
data = _json(response, "获取临时存储上传凭证")
policy_data = data.get("data") or {}
required = ["upload_host", "upload_dir", "policy", "signature", "oss_access_key_id"]
missing = [field for field in required if not policy_data.get(field)]
if missing:
raise FunAsrError(f"临时存储上传凭证缺少字段: {', '.join(missing)}")
return UploadPolicy(
upload_host=str(policy_data["upload_host"]),
upload_dir=str(policy_data["upload_dir"]).rstrip("/"),
policy=str(policy_data["policy"]),
signature=str(policy_data["signature"]),
oss_access_key_id=str(policy_data["oss_access_key_id"]),
x_oss_object_acl=str(policy_data.get("x_oss_object_acl") or "private"),
x_oss_forbid_overwrite=str(policy_data.get("x_oss_forbid_overwrite") or "true"),
max_file_size_mb=policy_data.get("max_file_size_mb"),
)
def _validate_file_size(local_file: str, policy: UploadPolicy) -> None:
if policy.max_file_size_mb is None:
return
max_bytes = float(policy.max_file_size_mb) * 1024 * 1024
size = os.path.getsize(local_file)
if size > max_bytes:
raise FunAsrError(
f"文件大小超过阿里百炼临时存储限制: {size / 1024 / 1024:.2f}MB > {float(policy.max_file_size_mb):.2f}MB"
)
def upload_to_temporary_oss(local_file: str, policy: UploadPolicy, session=requests) -> str:
"""Upload local file to temporary OSS and return `oss://...` URL."""
if not os.path.isfile(local_file):
raise FunAsrError(f"待转写文件不存在: {local_file}")
_validate_file_size(local_file, policy)
key = f"{policy.upload_dir}/{_safe_upload_name(local_file)}"
data = {
"OSSAccessKeyId": policy.oss_access_key_id,
"policy": policy.policy,
"Signature": policy.signature,
"key": key,
"x-oss-object-acl": policy.x_oss_object_acl,
"x-oss-forbid-overwrite": policy.x_oss_forbid_overwrite,
"success_action_status": "200",
}
with open(local_file, "rb") as file_obj:
files = {"file": (_safe_upload_name(local_file), file_obj)}
response = _session_post(session, policy.upload_host, data=data, files=files, timeout=120)
_raise_for_http(response, "上传文件到阿里百炼临时存储")
return f"oss://{key}"
def submit_transcription_task(
api_key: str,
oss_url: str,
speaker_count: Optional[int] = None,
model: str = MODEL_NAME,
session=requests,
) -> str:
"""Submit async Fun-ASR task and return task_id."""
api_key = _require_api_key(api_key)
parameters: dict[str, Any] = {"diarization_enabled": True}
if speaker_count:
parameters["speaker_count"] = int(speaker_count)
payload = {
"model": model,
"input": {"file_urls": [oss_url]},
"parameters": parameters,
}
response = _session_post(
session,
TRANSCRIPTION_URL,
headers=_auth_headers(
api_key,
{
"X-DashScope-Async": "enable",
"X-DashScope-OssResourceResolve": "enable",
},
),
json=payload,
timeout=30,
)
data = _json(response, "提交 Fun-ASR 转写任务")
task_id = ((data.get("output") or {}).get("task_id") or "").strip()
if not task_id:
raise FunAsrError("提交 Fun-ASR 转写任务失败:未返回 task_id")
return task_id
def poll_transcription_task(
api_key: str,
task_id: str,
poll_interval: float = 2.0,
timeout: float = 600.0,
session=requests,
) -> dict[str, Any]:
"""Poll task until terminal status and return successful result item."""
api_key = _require_api_key(api_key)
deadline = time.time() + timeout
last_status = "PENDING"
while time.time() < deadline:
response = _session_post(
session,
TASK_URL_TEMPLATE.format(task_id=task_id),
headers=_auth_headers(api_key),
timeout=30,
)
data = _json(response, "查询 Fun-ASR 转写任务")
output = data.get("output") or {}
last_status = str(output.get("task_status") or "").upper()
if last_status == "SUCCEEDED":
results = output.get("results") or []
for result in results:
subtask_status = str(result.get("subtask_status") or "").upper()
if subtask_status and subtask_status != "SUCCEEDED":
raise FunAsrError(f"Fun-ASR 子任务失败: {subtask_status}")
if not results:
raise FunAsrError("Fun-ASR 转写成功但未返回结果")
return results[0]
if last_status in TERMINAL_FAILED_STATUSES:
raise FunAsrError(f"Fun-ASR 转写任务失败: {last_status}")
time.sleep(poll_interval)
raise FunAsrError(f"Fun-ASR 转写任务超时,最后状态: {last_status}")
def download_transcription_result(transcription_url: str, session=requests) -> dict[str, Any]:
if not transcription_url:
raise FunAsrError("Fun-ASR 结果缺少 transcription_url")
response = _session_get(session, transcription_url, timeout=60)
return _json(response, "下载 Fun-ASR 转写结果")
def _ms_to_srt_time(ms: float) -> str:
total_ms = max(0, int(round(float(ms))))
hours = total_ms // 3_600_000
total_ms %= 3_600_000
minutes = total_ms // 60_000
total_ms %= 60_000
seconds = total_ms // 1_000
milliseconds = total_ms % 1_000
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
def _srt_block(index: int, start_ms: float, end_ms: float, text: str) -> str:
if end_ms <= start_ms:
end_ms = start_ms + 500
return f"{index}\n{_ms_to_srt_time(start_ms)} --> {_ms_to_srt_time(end_ms)}\n{text.strip()}\n"
def _timestamp_ms(value: Any, field_name: str) -> float:
try:
return float(value)
except (TypeError, ValueError) as exc:
raise FunAsrError(f"Fun-ASR 转写结果时间戳无效: {field_name}={value!r}") from exc
def _speaker_prefix(speaker_id: Any) -> str:
if speaker_id is None or speaker_id == "":
return ""
try:
label = int(speaker_id) + 1
except (TypeError, ValueError):
label = str(speaker_id)
return f"说话人{label}: "
def _iter_sentences(result_json: dict[str, Any]):
transcripts = result_json.get("transcripts")
if transcripts is None and "sentences" in result_json:
transcripts = [{"sentences": result_json.get("sentences") or []}]
if not transcripts:
raise FunAsrError("Fun-ASR 转写结果为空:未找到 transcripts")
for transcript in transcripts:
for sentence in transcript.get("sentences") or []:
yield sentence
def _word_text(word: dict[str, Any]) -> str:
text = str(word.get("text") or word.get("word") or "")
punctuation = str(word.get("punctuation") or "")
if punctuation and not text.endswith(punctuation):
text += punctuation
return text
def _flush_block(blocks: list[dict[str, Any]], current: dict[str, Any]) -> None:
text = current.get("text", "").strip()
if text:
blocks.append(current.copy())
def _blocks_from_words(sentence: dict[str, Any], max_chars: int, max_duration: float) -> list[dict[str, Any]]:
words = sentence.get("words") or []
blocks: list[dict[str, Any]] = []
current: Optional[dict[str, Any]] = None
max_duration_ms = max_duration * 1000
sentence_speaker = sentence.get("speaker_id")
for word in words:
text = _word_text(word)
if not text:
continue
start = word.get("begin_time", word.get("start_time"))
end = word.get("end_time")
if start is None or end is None:
continue
speaker_id = word.get("speaker_id", sentence_speaker)
start_ms = _timestamp_ms(start, "word.begin_time")
end_ms = _timestamp_ms(end, "word.end_time")
if current is None:
current = {"start": start_ms, "end": end_ms, "text": text, "speaker_id": speaker_id}
else:
should_split_before = (
speaker_id != current.get("speaker_id")
or len(current["text"] + text) > max_chars
or (end_ms - current["start"]) > max_duration_ms
)
if should_split_before:
_flush_block(blocks, current)
current = {"start": start_ms, "end": end_ms, "text": text, "speaker_id": speaker_id}
else:
current["text"] += text
current["end"] = end_ms
if current and text[-1:] in PUNCTUATION_BREAKS:
_flush_block(blocks, current)
current = None
if current:
_flush_block(blocks, current)
return blocks
def _split_text(text: str, max_chars: int) -> list[str]:
chunks: list[str] = []
current = ""
for char in text:
current += char
if char in PUNCTUATION_BREAKS or len(current) >= max_chars:
chunks.append(current.strip())
current = ""
if current.strip():
chunks.append(current.strip())
return [chunk for chunk in chunks if chunk]
def _blocks_from_sentence(sentence: dict[str, Any], max_chars: int) -> list[dict[str, Any]]:
text = str(sentence.get("text") or "").strip()
if not text:
return []
start = sentence.get("begin_time", 0)
end = sentence.get("end_time")
start_ms = _timestamp_ms(start, "sentence.begin_time")
end_ms = _timestamp_ms(end, "sentence.end_time") if end is not None else start_ms + 500
chunks = _split_text(text, max_chars)
if not chunks:
return []
duration = max(500.0, end_ms - start_ms)
total_chars = max(1, sum(len(chunk) for chunk in chunks))
cursor = start_ms
blocks: list[dict[str, Any]] = []
for i, chunk in enumerate(chunks):
if i == len(chunks) - 1:
chunk_end = end_ms
else:
chunk_end = cursor + duration * (len(chunk) / total_chars)
blocks.append(
{
"start": cursor,
"end": max(cursor + 200, chunk_end),
"text": chunk,
"speaker_id": sentence.get("speaker_id"),
}
)
cursor = chunk_end
return blocks
def fun_asr_result_to_srt(result_json: dict[str, Any], max_chars: int = 20, max_duration: float = 3.5) -> str:
"""Convert downloaded Fun-ASR JSON into fine-grained SRT.
Official downloaded schema is `transcripts[*].sentences[*].words[*]`.
Fun-ASR timestamps are milliseconds.
"""
blocks: list[dict[str, Any]] = []
for sentence in _iter_sentences(result_json):
sentence_blocks = _blocks_from_words(sentence, max_chars, max_duration)
if not sentence_blocks:
sentence_blocks = _blocks_from_sentence(sentence, max_chars)
blocks.extend(sentence_blocks)
if not blocks:
raise FunAsrError("Fun-ASR 转写结果为空:未找到可用字幕内容")
lines = []
for index, block in enumerate(blocks, start=1):
text = f"{_speaker_prefix(block.get('speaker_id'))}{block['text']}"
lines.append(_srt_block(index, block["start"], block["end"], text))
return "\n".join(lines).rstrip() + "\n"
def write_srt_file(srt_content: str, subtitle_file: str = "") -> str:
if not subtitle_file:
subtitle_file = os.path.join(utils.subtitle_dir(), f"fun_asr_{int(time.time())}.srt")
parent = os.path.dirname(subtitle_file)
if parent:
os.makedirs(parent, exist_ok=True)
with open(subtitle_file, "w", encoding="utf-8") as f:
f.write(srt_content)
return subtitle_file
def create_with_fun_asr(
local_file: str,
subtitle_file: str = "",
api_key: str = "",
speaker_count: Optional[int] = None,
poll_interval: float = 2.0,
timeout: float = 600.0,
session=requests,
) -> Optional[str]:
"""Upload local media to Bailian temporary storage and create a Fun-ASR SRT file."""
api_key = _require_api_key(api_key)
try:
policy = request_upload_policy(api_key, session=session)
oss_url = upload_to_temporary_oss(local_file, policy, session=session)
task_id = submit_transcription_task(api_key, oss_url, speaker_count=speaker_count, session=session)
task_result = poll_transcription_task(
api_key,
task_id,
poll_interval=poll_interval,
timeout=timeout,
session=session,
)
transcription_url = task_result.get("transcription_url")
result_json = download_transcription_result(transcription_url, session=session)
srt_content = fun_asr_result_to_srt(result_json)
output_file = write_srt_file(srt_content, subtitle_file)
logger.info(f"Fun-ASR 字幕文件已生成: {output_file}")
return output_file
except FunAsrError:
raise
except Exception as exc:
raise FunAsrError("Fun-ASR 字幕转写失败,请检查文件、网络或阿里百炼服务状态") from exc

View File

@ -38,46 +38,90 @@ def parse_frame_analysis_to_markdown(json_file_path):
with open(json_file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
# 初始化Markdown字符串
def time_to_milliseconds(time_text):
time_text = (time_text or "").strip()
if not time_text:
return 0
try:
if "," in time_text:
hhmmss, ms = time_text.split(",", 1)
milliseconds = int(ms)
else:
hhmmss = time_text
milliseconds = 0
parts = [int(part) for part in hhmmss.split(":") if part]
while len(parts) < 3:
parts.insert(0, 0)
hours, minutes, seconds = parts[-3], parts[-2], parts[-1]
return ((hours * 3600 + minutes * 60 + seconds) * 1000) + milliseconds
except Exception:
return 0
def batch_sort_key(batch):
time_range = batch.get("time_range", "")
start = time_range.split("-", 1)[0].strip()
return time_to_milliseconds(start), batch.get("batch_index", 0)
markdown = ""
# 获取总结和帧观察数据
# 新结构:按批次保存完整分析产物
if isinstance(data.get("batches"), list):
ordered_batches = sorted(data.get("batches", []), key=batch_sort_key)
for i, batch in enumerate(ordered_batches, 1):
time_range = batch.get("time_range", "")
summary = (
batch.get("overall_activity_summary")
or batch.get("summary")
or batch.get("fallback_summary")
or ""
)
observations = batch.get("frame_observations") or batch.get("observations") or []
markdown += f"## 片段 {i}\n"
markdown += f"- 时间范围:{time_range}\n"
markdown += f"- 片段描述:{summary}\n" if summary else "- 片段描述:\n"
markdown += "- 详细描述:\n"
for frame in observations:
timestamp = frame.get("timestamp", "")
observation = frame.get("observation", "")
markdown += f" - {timestamp}: {observation}\n" if observation else f" - {timestamp}: \n"
markdown += "\n"
return markdown
# 兼容旧结构
summaries = data.get('overall_activity_summaries', [])
frame_observations = data.get('frame_observations', [])
# 按批次组织数据
batch_frames = {}
for frame in frame_observations:
batch_index = frame.get('batch_index')
if batch_index not in batch_frames:
batch_frames[batch_index] = []
batch_frames[batch_index].append(frame)
# 生成Markdown内容
for i, summary in enumerate(summaries, 1):
batch_index = summary.get('batch_index')
time_range = summary.get('time_range', '')
batch_summary = summary.get('summary', '')
markdown += f"## 片段 {i}\n"
markdown += f"- 时间范围:{time_range}\n"
# 添加片段描述
markdown += f"- 片段描述:{batch_summary}\n" if batch_summary else f"- 片段描述:\n"
markdown += "- 详细描述:\n"
# 添加该批次的帧观察详情
frames = batch_frames.get(batch_index, [])
for frame in frames:
timestamp = frame.get('timestamp', '')
observation = frame.get('observation', '')
# 直接使用原始文本,不进行分割
markdown += f" - {timestamp}: {observation}\n" if observation else f" - {timestamp}: \n"
markdown += "\n"
return markdown
except Exception as e:

View File

@ -0,0 +1,241 @@
import json
import os
import subprocess
import time
from os import path
from loguru import logger
from app.config import config
from app.models import const
from app.models.schema import VideoClipParams
from app.services import voice, clip_video, update_script
from app.services import state as sm
from app.utils import utils
def get_audio_duration_ffprobe(audio_file: str) -> float:
"""
使用ffprobe获取音频文件的精确时长
Args:
audio_file: 音频文件路径
Returns:
float: 音频时长精确到微秒
"""
try:
cmd = [
'ffprobe',
'-v', 'error',
'-show_entries', 'format=duration',
'-of', 'csv=p=0',
audio_file
]
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
duration = float(result.stdout.strip())
logger.debug(f"使用ffprobe获取音频时长: {duration:.6f}")
return duration
except subprocess.CalledProcessError as e:
logger.error(f"ffprobe执行失败: {e.stderr}")
raise
except Exception as e:
logger.error(f"获取音频时长失败: {str(e)}")
raise
def start_export_jianying_draft(task_id: str, params: VideoClipParams):
"""
导出到剪映草稿的后台任务
Args:
task_id: 任务ID
params: 视频参数
"""
logger.info(f"\n\n## 开始导出到剪映草稿任务: {task_id}")
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=0)
"""
1. 加载剪辑脚本
"""
logger.info("\n\n## 1. 加载视频脚本")
video_script_path = path.join(params.video_clip_json_path)
if path.exists(video_script_path):
try:
with open(video_script_path, "r", encoding="utf-8") as f:
list_script = json.load(f)
video_list = [i['narration'] for i in list_script]
video_ost = [i['OST'] for i in list_script]
time_list = [i['timestamp'] for i in list_script]
video_script = " ".join(video_list)
logger.debug(f"解说完整脚本: \n{video_script}")
logger.debug(f"解说 OST 列表: \n{video_ost}")
logger.debug(f"解说时间戳列表: \n{time_list}")
except Exception as e:
logger.error(f"无法读取视频json脚本请检查脚本格式是否正确")
raise ValueError("无法读取视频json脚本请检查脚本格式是否正确")
else:
logger.error(f"解说脚本文件不存在: {video_script_path},请先点击【保存脚本】按钮保存脚本后再生成视频")
raise ValueError("解说脚本文件不存在!请先点击【保存脚本】按钮保存脚本后再生成视频。")
"""
2. 使用 TTS 生成音频素材
"""
logger.info("\n\n## 2. 根据OST设置生成音频列表")
tts_segments = [
segment for segment in list_script
if segment['OST'] in [0, 2]
]
logger.debug(f"需要生成TTS的片段数: {len(tts_segments)}")
tts_results = voice.tts_multiple(
task_id=task_id,
list_script=tts_segments, # 只传入需要TTS的片段
tts_engine=params.tts_engine,
voice_name=params.voice_name,
voice_rate=params.voice_rate,
voice_pitch=params.voice_pitch,
)
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20)
"""
3. 统一视频裁剪 - 基于OST类型的差异化裁剪策略
"""
logger.info("\n\n## 3. 统一视频裁剪基于OST类型")
video_clip_result = clip_video.clip_video_unified(
video_origin_path=params.video_origin_path,
script_list=list_script,
tts_results=tts_results
)
tts_clip_result = {tts_result['_id']: tts_result['audio_file'] for tts_result in tts_results}
subclip_clip_result = {
tts_result['_id']: tts_result['subtitle_file'] for tts_result in tts_results
}
new_script_list = update_script.update_script_timestamps(list_script, video_clip_result, tts_clip_result, subclip_clip_result)
logger.info(f"统一裁剪完成,处理了 {len(video_clip_result)} 个视频片段")
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=60)
"""
4. 导出到剪映草稿
"""
logger.info("\n\n## 4. 导出到剪映草稿")
try:
import pyJianYingDraft
from pyJianYingDraft import DraftFolder, VideoSegment, AudioSegment, trange, TrackType
jianying_draft_path = config.ui.get("jianying_draft_path", "")
if not jianying_draft_path:
raise ValueError("剪映草稿路径未配置")
# 创建DraftFolder实例
draft_folder = DraftFolder(jianying_draft_path)
# 使用从参数中获取的草稿名称,如果为空则使用默认名称
draft_name = getattr(params, 'draft_name', "")
logger.debug(f"从params获取的草稿名称: '{draft_name}' (类型: {type(draft_name)})")
if not draft_name:
draft_name = f"NarratoAI_{int(time.time())}"
logger.debug(f"使用默认草稿名称: '{draft_name}'")
# 创建新草稿
script = draft_folder.create_draft(draft_name, 1920, 1080)
# 添加视频轨道和音频轨道
script.add_track(TrackType.video, '视频轨道')
script.add_track(TrackType.audio, '音频轨道')
# 处理脚本数据
current_time = 0
output_dir = utils.task_dir(task_id)
for item in new_script_list:
# 获取时间信息
start_time = float(item.get('start_time', 0.0))
duration = float(item.get('duration', 0.0))
timestamp = item.get('timestamp', '')
logger.info(f"处理片段: OST={item['OST']}, start_time={start_time}, duration={duration}, timestamp={timestamp}")
# 生成音频文件路径
audio_file = ""
if timestamp:
timestamp_formatted = timestamp.replace(':', '_')
audio_file = os.path.join(
output_dir,
f"audio_{timestamp_formatted}.mp3"
)
# 检查是否有裁剪后的视频文件
video_file = item.get('video', '')
if video_file and not os.path.exists(video_file):
video_file = ""
# 添加视频片段
if video_file:
# 使用裁剪后的视频文件
# 对于裁剪后的视频target_timerange的第二个参数是持续时间
video_segment = VideoSegment(
video_file,
trange(f"{current_time}s", f"{duration}s")
)
else:
# 使用原始视频文件
# source_timerange是从原始视频中截取的部分
# target_timerange是片段在时间轴上的位置
video_segment = VideoSegment(
params.video_origin_path,
trange(f"{current_time}s", f"{duration}s"),
source_timerange=trange(f"{start_time}s", f"{duration}s")
)
script.add_segment(video_segment, '视频轨道')
# 处理音频
if item['OST'] in [0, 2]: # 需要TTS的片段
if os.path.exists(audio_file):
# 使用ffprobe获取精确的音频时长避免因TTS引擎差异导致时长不匹配
actual_audio_duration = get_audio_duration_ffprobe(audio_file)
logger.info(f"音频文件实际时长: {actual_audio_duration:.6f}秒, 脚本时长(视频): {duration:.3f}")
# 使用音频实际时长和视频时长中的较小值,确保不超过素材时长
# 当TTS语速调整时音频可能比视频长或短取较小值可以避免超出素材
safe_duration = min(actual_audio_duration, duration)
logger.info(f"使用时长: {safe_duration:.6f}秒 (取音频和视频时长的较小值)")
audio_segment = AudioSegment(
audio_file,
trange(f"{current_time}s", f"{safe_duration}s")
)
script.add_segment(audio_segment, '音频轨道')
else:
logger.warning(f"音频文件不存在: {audio_file}")
# OST=1的片段保留原声不需要添加额外音频
# 更新当前时间
current_time += duration
# 保存草稿
script.save()
draft_path = os.path.join(jianying_draft_path, draft_name)
logger.success(f"成功导出到剪映草稿: {draft_name}")
logger.info(f"草稿已保存到: {draft_path}")
# 更新任务状态
sm.state.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, draft_path=draft_path, draft_name=draft_name)
return {"draft_path": draft_path, "draft_name": draft_name}
except ImportError as e:
logger.error(f"导入pyJianYingDraft失败: {e}")
raise ImportError(f"pyJianYingDraft库导入失败: {e}\n请确保已正确安装该库")
except Exception as e:
logger.error(f"导出到剪映草稿失败: {e}")
import traceback
logger.error(f"错误详情: {traceback.format_exc()}")
raise Exception(f"导出到剪映草稿失败: {e}")

View File

@ -108,6 +108,7 @@ class VisionModelProvider(BaseLLMProvider):
images: List[Union[str, Path, PIL.Image.Image]],
prompt: str,
batch_size: int = 10,
max_concurrency: int = 1,
**kwargs) -> List[str]:
"""
分析图片并返回结果
@ -116,6 +117,7 @@ class VisionModelProvider(BaseLLMProvider):
images: 图片路径列表或PIL图片对象列表
prompt: 分析提示词
batch_size: 批处理大小
max_concurrency: 最大并发批次数实现支持时生效
**kwargs: 其他参数
Returns:

View File

@ -5,7 +5,6 @@
"""
import asyncio
import json
from typing import List, Dict, Any, Optional, Union
from pathlib import Path
import PIL.Image
@ -13,6 +12,7 @@ from loguru import logger
from .unified_service import UnifiedLLMService
from .exceptions import LLMServiceError
from .manager import LLMServiceManager
# 导入新的提示词管理系统
from app.services.prompts import PromptManager
@ -110,41 +110,11 @@ class LegacyLLMAdapter:
temperature=1.5,
response_format="json"
)
# 使用增强的JSON解析器
from webui.tools.generate_short_summary import parse_and_fix_json
parsed_result = parse_and_fix_json(result)
if not parsed_result:
logger.error("无法解析LLM返回的JSON数据")
# 返回一个基本的JSON结构而不是错误字符串
return json.dumps({
"items": [
{
"_id": 1,
"timestamp": "00:00:00-00:00:10",
"picture": "解析失败请检查LLM输出",
"narration": "解说文案生成失败,请重试"
}
]
}, ensure_ascii=False)
# 确保返回的是JSON字符串
return json.dumps(parsed_result, ensure_ascii=False)
return result if isinstance(result, str) else str(result)
except Exception as e:
logger.error(f"生成解说文案失败: {str(e)}")
# 返回一个基本的JSON结构而不是错误字符串
return json.dumps({
"items": [
{
"_id": 1,
"timestamp": "00:00:00-00:00:10",
"picture": "生成失败",
"narration": f"解说文案生成失败: {str(e)}"
}
]
}, ensure_ascii=False)
raise
class VisionAnalyzerAdapter:
@ -155,11 +125,29 @@ class VisionAnalyzerAdapter:
self.api_key = api_key
self.model = model
self.base_url = base_url
def _build_provider_with_explicit_settings(self):
provider_name = (self.provider or "").lower()
if not LLMServiceManager.is_registered():
from .providers import register_all_providers
register_all_providers()
provider_class = LLMServiceManager._vision_providers.get(provider_name)
if provider_class is None:
raise LLMServiceError(f"视觉模型提供商未注册: {provider_name}")
return provider_class(
api_key=self.api_key,
model_name=self.model,
base_url=self.base_url,
)
async def analyze_images(self,
images: List[Union[str, Path, PIL.Image.Image]],
prompt: str,
batch_size: int = 10) -> List[Dict[str, Any]]:
batch_size: int = 10,
max_concurrency: int = 1) -> List[Dict[str, Any]]:
"""
分析图片 - 兼容原有接口
@ -167,17 +155,20 @@ class VisionAnalyzerAdapter:
images: 图片列表
prompt: 分析提示词
batch_size: 批处理大小
max_concurrency: 最大并发批次数
Returns:
分析结果列表格式与旧实现兼容
"""
try:
# 使用统一服务分析图片
results = await UnifiedLLMService.analyze_images(
provider = self._build_provider_with_explicit_settings()
results = await provider.analyze_images(
images=images,
prompt=prompt,
provider=self.provider,
batch_size=batch_size
batch_size=batch_size,
max_concurrency=max_concurrency,
api_key=self.api_key,
api_base=self.base_url,
)
# 转换为旧格式以保持向后兼容性

View File

@ -4,6 +4,7 @@ OpenAI 兼容提供商实现
使用 OpenAI 官方 SDK 调用 OpenAI 兼容接口支持文本和视觉模型
"""
import asyncio
import io
import base64
import re
@ -96,24 +97,35 @@ class OpenAICompatibleVisionProvider(_OpenAICompatibleBase, VisionModelProvider)
images: List[Union[str, Path, PIL.Image.Image]],
prompt: str,
batch_size: int = 10,
max_concurrency: int = 1,
**kwargs,
) -> List[str]:
logger.info(f"开始使用 OpenAI 兼容接口 ({self.model_name}) 分析 {len(images)} 张图片")
processed_images = self._prepare_images(images)
results: List[str] = []
if not processed_images:
return []
for i in range(0, len(processed_images), batch_size):
batch = processed_images[i : i + batch_size]
logger.info(f"处理第 {i // batch_size + 1} 批,共 {len(batch)} 张图片")
try:
result = await self._analyze_batch(batch, prompt, **kwargs)
results.append(result)
except Exception as exc:
logger.error(f"批次 {i // batch_size + 1} 处理失败: {exc}")
results.append(f"批次处理失败: {exc}")
bounded_concurrency = max(1, int(max_concurrency))
semaphore = asyncio.Semaphore(bounded_concurrency)
batches = [
(index // batch_size, processed_images[index : index + batch_size])
for index in range(0, len(processed_images), batch_size)
]
return results
async def run_batch(batch_index: int, batch: List[PIL.Image.Image]) -> tuple[int, str]:
logger.info(f"处理第 {batch_index + 1} 批,共 {len(batch)} 张图片")
async with semaphore:
try:
result = await self._analyze_batch(batch, prompt, **kwargs)
return batch_index, result
except Exception as exc:
logger.error(f"批次 {batch_index + 1} 处理失败: {exc}")
return batch_index, f"批次处理失败: {exc}"
completed = await asyncio.gather(*(run_batch(index, batch) for index, batch in batches))
completed.sort(key=lambda item: item[0])
return [result for _, result in completed]
async def _analyze_batch(self, batch: List[PIL.Image.Image], prompt: str, **kwargs) -> str:
content = [{"type": "text", "text": prompt}]

View File

@ -1,10 +1,14 @@
"""OpenAI 兼容 provider 的最小回归测试。"""
import asyncio
import unittest
from unittest.mock import patch
from app.config import config
from app.services.llm.base import TextModelProvider
from app.services.llm.manager import LLMServiceManager
from app.services.llm.migration_adapter import LegacyLLMAdapter, VisionAnalyzerAdapter
from app.services.llm.openai_compatible_provider import OpenAICompatibleVisionProvider
from app.services.llm.providers import register_all_providers
@ -63,5 +67,128 @@ class OpenAICompatManagerTests(unittest.TestCase):
self.assertEqual("https://new.example/v1", provider.base_url)
class OpenAICompatVisionConcurrencyTests(unittest.IsolatedAsyncioTestCase):
async def test_analyze_images_keeps_batch_order_when_running_concurrently(self):
provider = OpenAICompatibleVisionProvider(api_key="k", model_name="m")
provider._prepare_images = lambda images: list(images)
async def fake_analyze_batch(batch, prompt, **kwargs):
delays = {"a": 0.03, "c": 0.01, "e": 0.0}
await asyncio.sleep(delays[batch[0]])
return f"batch-{batch[0]}"
provider._analyze_batch = fake_analyze_batch
result = await provider.analyze_images(
images=["a", "b", "c", "d", "e", "f"],
prompt="prompt",
batch_size=2,
max_concurrency=2,
)
self.assertEqual(["batch-a", "batch-c", "batch-e"], result)
async def test_analyze_images_respects_max_concurrency_limit(self):
provider = OpenAICompatibleVisionProvider(api_key="k", model_name="m")
provider._prepare_images = lambda images: list(images)
in_flight = 0
max_in_flight = 0
async def fake_analyze_batch(batch, prompt, **kwargs):
nonlocal in_flight, max_in_flight
in_flight += 1
max_in_flight = max(max_in_flight, in_flight)
await asyncio.sleep(0.02)
in_flight -= 1
return f"batch-{batch[0]}"
provider._analyze_batch = fake_analyze_batch
result = await provider.analyze_images(
images=["a", "b", "c", "d", "e", "f"],
prompt="prompt",
batch_size=1,
max_concurrency=2,
)
self.assertEqual(6, len(result))
self.assertEqual(2, max_in_flight)
class ExplicitVisionAdapterSettingsTests(unittest.IsolatedAsyncioTestCase):
class _CapturingVisionProvider:
last_init: tuple[str, str, str | None] | None = None
last_call_kwargs: dict | None = None
def __init__(self, api_key: str, model_name: str, base_url: str | None = None):
self.api_key = api_key
self.model_name = model_name
self.base_url = base_url
ExplicitVisionAdapterSettingsTests._CapturingVisionProvider.last_init = (api_key, model_name, base_url)
async def analyze_images(self, images, prompt, batch_size=10, max_concurrency=1, **kwargs):
ExplicitVisionAdapterSettingsTests._CapturingVisionProvider.last_call_kwargs = dict(kwargs)
return [f"{self.model_name}|{self.api_key}|{self.base_url}"]
def setUp(self):
_reset_manager_state()
self._original_app = dict(config.app)
def tearDown(self):
_reset_manager_state()
config.app.clear()
config.app.update(self._original_app)
async def test_adapter_uses_explicit_settings_instead_of_global_config(self):
LLMServiceManager.register_vision_provider("openai", self._CapturingVisionProvider)
config.app["vision_openai_api_key"] = "config-key"
config.app["vision_openai_model_name"] = "config-model"
config.app["vision_openai_base_url"] = "https://config.example/v1"
adapter = VisionAnalyzerAdapter(
provider="openai",
api_key="explicit-key",
model="explicit-model",
base_url="https://explicit.example/v1",
)
result = await adapter.analyze_images(
images=["/tmp/keyframe_000001_000000100.jpg"],
prompt="描述画面",
batch_size=1,
max_concurrency=1,
)
self.assertEqual(
("explicit-key", "explicit-model", "https://explicit.example/v1"),
self._CapturingVisionProvider.last_init,
)
self.assertEqual("explicit-key", self._CapturingVisionProvider.last_call_kwargs["api_key"])
self.assertEqual("https://explicit.example/v1", self._CapturingVisionProvider.last_call_kwargs["api_base"])
self.assertEqual("explicit-model|explicit-key|https://explicit.example/v1", result[0]["response"])
class LegacyNarrationAdapterBehaviorTests(unittest.TestCase):
def test_generate_narration_returns_raw_unrecoverable_payload_without_fabrication(self):
raw_payload = "not-json-at-all ::: ???"
with patch(
"app.services.llm.migration_adapter.PromptManager.get_prompt",
return_value="prompt",
), patch(
"app.services.llm.migration_adapter._run_async_safely",
return_value=raw_payload,
):
result = LegacyLLMAdapter.generate_narration(
markdown_content="markdown",
api_key="test-key",
base_url="https://example.com/v1",
model="test-model",
)
self.assertEqual(raw_payload, result)
self.assertNotIn('"items"', result)
if __name__ == "__main__":
unittest.main()

View File

@ -1,324 +1,40 @@
import os
import json
import time
import asyncio
import requests
from app.utils import video_processor
from loguru import logger
from typing import List, Dict, Any, Callable
from typing import Any, Callable
from app.utils import utils, gemini_analyzer, video_processor
from app.utils.script_generator import ScriptProcessor
from app.config import config
from loguru import logger
from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService
class ScriptGenerator:
def __init__(self):
self.temp_dir = utils.temp_dir()
self.keyframes_dir = os.path.join(self.temp_dir, "keyframes")
def __init__(self, documentary_service: DocumentaryFrameAnalysisService | None = None):
self.documentary_service = documentary_service or DocumentaryFrameAnalysisService()
async def generate_script(
self,
video_path: str,
video_theme: str = "",
custom_prompt: str = "",
frame_interval_input: int = 5,
frame_interval_input: int | None = None,
skip_seconds: int = 0,
threshold: int = 30,
vision_batch_size: int = 5,
vision_llm_provider: str = "gemini",
progress_callback: Callable[[float, str], None] = None
) -> List[Dict[Any, Any]]:
"""
生成视频脚本的核心逻辑
Args:
video_path: 视频文件路径
video_theme: 视频主题
custom_prompt: 自定义提示词
skip_seconds: 跳过开始的秒数
threshold: 差异<EFBFBD><EFBFBD><EFBFBD>
vision_batch_size: 视觉处理批次大小
vision_llm_provider: 视觉模型提供商
progress_callback: 进度回调函数
Returns:
List[Dict]: 生成的视频脚本
"""
if progress_callback is None:
progress_callback = lambda p, m: None
try:
# 提取关键帧
progress_callback(10, "正在提取关键帧...")
keyframe_files = await self._extract_keyframes(
video_path,
skip_seconds,
threshold
)
# 使用统一的 LLM 接口(支持所有 provider
script = await self._process_with_llm(
keyframe_files,
video_theme,
custom_prompt,
vision_batch_size,
vision_llm_provider,
progress_callback
)
return json.loads(script) if isinstance(script, str) else script
except Exception as e:
logger.exception("Generate script failed")
raise
async def _extract_keyframes(
self,
video_path: str,
skip_seconds: int,
threshold: int
) -> List[str]:
"""提取视频关键帧"""
video_hash = utils.md5(video_path + str(os.path.getmtime(video_path)))
video_keyframes_dir = os.path.join(self.keyframes_dir, video_hash)
# 检查缓存
keyframe_files = []
if os.path.exists(video_keyframes_dir):
for filename in sorted(os.listdir(video_keyframes_dir)):
if filename.endswith('.jpg'):
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
if keyframe_files:
logger.info(f"Using cached keyframes: {video_keyframes_dir}")
return keyframe_files
# 提取新的关键帧
os.makedirs(video_keyframes_dir, exist_ok=True)
try:
processor = video_processor.VideoProcessor(video_path)
processor.process_video_pipeline(
output_dir=video_keyframes_dir,
skip_seconds=skip_seconds,
threshold=threshold
vision_batch_size: int | None = None,
vision_llm_provider: str | None = None,
progress_callback: Callable[[float, str], None] | None = None,
) -> list[dict[Any, Any]]:
callback = progress_callback or (lambda _p, _m: None)
if skip_seconds != 0 or threshold != 30:
logger.warning(
"ScriptGenerator documentary path received "
f"skip_seconds={skip_seconds} threshold={threshold}; "
"the shared documentary frame pipeline does not currently apply these parameters."
)
for filename in sorted(os.listdir(video_keyframes_dir)):
if filename.endswith('.jpg'):
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
return keyframe_files
except Exception as e:
if os.path.exists(video_keyframes_dir):
import shutil
shutil.rmtree(video_keyframes_dir)
raise
async def _process_with_llm(
self,
keyframe_files: List[str],
video_theme: str,
custom_prompt: str,
vision_batch_size: int,
vision_llm_provider: str,
progress_callback: Callable[[float, str], None]
) -> str:
"""使用统一 LLM 接口处理视频帧"""
progress_callback(30, "正在初始化视觉分析器...")
# 使用新的 LLM 迁移适配器(支持所有 provider
from app.services.llm.migration_adapter import create_vision_analyzer
# 获取配置
text_provider = config.app.get('text_llm_provider', 'openai').lower()
vision_api_key = config.app.get(f'vision_{vision_llm_provider}_api_key')
vision_model = config.app.get(f'vision_{vision_llm_provider}_model_name')
vision_base_url = config.app.get(f'vision_{vision_llm_provider}_base_url')
if not vision_api_key or not vision_model:
raise ValueError(f"未配置 {vision_llm_provider} API Key 或者模型")
# 创建统一的视觉分析器
analyzer = create_vision_analyzer(
provider=vision_llm_provider,
api_key=vision_api_key,
model=vision_model,
base_url=vision_base_url
return await self.documentary_service.generate_documentary_script(
video_path=video_path,
video_theme=video_theme,
custom_prompt=custom_prompt,
frame_interval_input=frame_interval_input,
vision_batch_size=vision_batch_size,
vision_llm_provider=vision_llm_provider,
progress_callback=callback,
)
progress_callback(40, "正在分析关键帧...")
# 执行异步分析
results = await analyzer.analyze_images(
images=keyframe_files,
prompt=config.app.get('vision_analysis_prompt'),
batch_size=vision_batch_size
)
progress_callback(60, "正在整理分析结果...")
# 合并所有批次的分析结果
frame_analysis = ""
prev_batch_files = None
for result in results:
if 'error' in result:
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
continue
batch_files = self._get_batch_files(keyframe_files, result, vision_batch_size)
first_timestamp, last_timestamp, _ = self._get_batch_timestamps(batch_files, prev_batch_files)
# 添加带时间戳的分<E79A84><E58886>结果
frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n"
frame_analysis += result['response']
frame_analysis += "\n"
prev_batch_files = batch_files
if not frame_analysis.strip():
raise Exception("未能生成有效的帧分析结果")
progress_callback(70, "正在生成脚本...")
# 构建帧内容列表
frame_content_list = []
prev_batch_files = None
for result in results:
if 'error' in result:
continue
batch_files = self._get_batch_files(keyframe_files, result, vision_batch_size)
_, _, timestamp_range = self._get_batch_timestamps(batch_files, prev_batch_files)
frame_content = {
"timestamp": timestamp_range,
"picture": result['response'],
"narration": "",
"OST": 2
}
frame_content_list.append(frame_content)
prev_batch_files = batch_files
if not frame_content_list:
raise Exception("没有有效的帧内容可以处理")
progress_callback(90, "正在生成文案...")
# 获取文本生<E69CAC><E7949F>配置
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
text_api_key = config.app.get(f'text_{text_provider}_api_key')
text_model = config.app.get(f'text_{text_provider}_model_name')
text_base_url = config.app.get(f'text_{text_provider}_base_url')
# 根据提供商类型选择合适的处理器
if text_provider == 'gemini(openai)':
# 使用OpenAI兼容的Gemini代理
from app.utils.script_generator import GeminiOpenAIGenerator
generator = GeminiOpenAIGenerator(
model_name=text_model,
api_key=text_api_key,
prompt=custom_prompt,
base_url=text_base_url
)
processor = ScriptProcessor(
model_name=text_model,
api_key=text_api_key,
base_url=text_base_url,
prompt=custom_prompt,
video_theme=video_theme
)
processor.generator = generator
else:
# 使用标准处理器包括原生Gemini
processor = ScriptProcessor(
model_name=text_model,
api_key=text_api_key,
base_url=text_base_url,
prompt=custom_prompt,
video_theme=video_theme
)
return processor.process_frames(frame_content_list)
def _get_batch_files(
self,
keyframe_files: List[str],
result: Dict[str, Any],
batch_size: int
) -> List[str]:
"""获取当前批次的图片文件"""
batch_start = result['batch_index'] * batch_size
batch_end = min(batch_start + batch_size, len(keyframe_files))
return keyframe_files[batch_start:batch_end]
def _get_batch_timestamps(
self,
batch_files: List[str],
prev_batch_files: List[str] = None
) -> tuple[str, str, str]:
"""获取一批文件的时间戳范围,支持毫秒级精度"""
if not batch_files:
logger.warning("Empty batch files")
return "00:00:00,000", "00:00:00,000", "00:00:00,000-00:00:00,000"
if len(batch_files) == 1 and prev_batch_files and len(prev_batch_files) > 0:
first_frame = os.path.basename(prev_batch_files[-1])
last_frame = os.path.basename(batch_files[0])
else:
first_frame = os.path.basename(batch_files[0])
last_frame = os.path.basename(batch_files[-1])
first_time = first_frame.split('_')[2].replace('.jpg', '')
last_time = last_frame.split('_')[2].replace('.jpg', '')
def format_timestamp(time_str: str) -> str:
"""将时间字符串转换为 HH:MM:SS,mmm 格式"""
try:
if len(time_str) < 4:
logger.warning(f"Invalid timestamp format: {time_str}")
return "00:00:00,000"
# 处理毫秒部分
if ',' in time_str:
time_part, ms_part = time_str.split(',')
ms = int(ms_part)
else:
time_part = time_str
ms = 0
# 处理时分秒
parts = time_part.split(':')
if len(parts) == 3: # HH:MM:SS
h, m, s = map(int, parts)
elif len(parts) == 2: # MM:SS
h = 0
m, s = map(int, parts)
else: # SS
h = 0
m = 0
s = int(parts[0])
# 处理进位
if s >= 60:
m += s // 60
s = s % 60
if m >= 60:
h += m // 60
m = m % 60
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
except Exception as e:
logger.error(f"时间戳格式转换错误 {time_str}: {str(e)}")
return "00:00:00,000"
first_timestamp = format_timestamp(first_time)
last_timestamp = format_timestamp(last_time)
timestamp_range = f"{first_timestamp}-{last_timestamp}"
return first_timestamp, last_timestamp, timestamp_range

View File

@ -0,0 +1,403 @@
import tempfile
import unittest
from pathlib import Path
try:
import tomllib
except ModuleNotFoundError: # Python < 3.11
import tomli as tomllib
from app.config import config as cfg
from app.services import fun_asr_subtitle as fasr
class FakeResponse:
def __init__(self, payload=None, status_code=200):
self.payload = payload or {}
self.status_code = status_code
def json(self):
return self.payload
def raise_for_status(self):
if self.status_code >= 400:
raise RuntimeError(f"HTTP {self.status_code}")
class InvalidJsonResponse(FakeResponse):
def json(self):
raise ValueError("invalid json")
class FakeSession:
def __init__(self, local_result):
self.calls = []
self.local_result = local_result
def get(self, url, **kwargs):
self.calls.append(("GET", url, kwargs))
if url == fasr.UPLOAD_POLICY_URL:
return FakeResponse(
{
"data": {
"policy": "policy-token",
"signature": "signature-token",
"upload_dir": "dashscope-instant/test-dir",
"upload_host": "https://dashscope-file-test.oss-cn-beijing.aliyuncs.com",
"oss_access_key_id": "oss-ak",
"x_oss_object_acl": "private",
"x_oss_forbid_overwrite": "true",
"max_file_size_mb": 1,
}
}
)
if url == "https://result.example/transcription.json":
return FakeResponse(self.local_result)
return FakeResponse({}, 404)
def post(self, url, **kwargs):
self.calls.append(("POST", url, kwargs))
if url == "https://dashscope-file-test.oss-cn-beijing.aliyuncs.com":
return FakeResponse({})
if url == fasr.TRANSCRIPTION_URL:
return FakeResponse({"output": {"task_status": "PENDING", "task_id": "task-123"}})
if url == fasr.TASK_URL_TEMPLATE.format(task_id="task-123"):
return FakeResponse(
{
"output": {
"task_status": "SUCCEEDED",
"results": [
{
"file_url": "oss://dashscope-instant/test-dir/audio.wav",
"transcription_url": "https://result.example/transcription.json",
"subtask_status": "SUCCEEDED",
}
],
}
}
)
return FakeResponse({}, 404)
OFFICIAL_SHAPE_RESULT = {
"transcripts": [
{
"sentences": [
{
"begin_time": 0,
"end_time": 3600,
"text": "你好欢迎观看今天的内容",
"speaker_id": 0,
"words": [
{"begin_time": 0, "end_time": 400, "text": "你好", "punctuation": ""},
{"begin_time": 400, "end_time": 900, "text": "欢迎", "punctuation": ""},
{"begin_time": 900, "end_time": 1300, "text": "观看", "punctuation": ""},
{"begin_time": 1300, "end_time": 1800, "text": "今天", "punctuation": ""},
{"begin_time": 1800, "end_time": 2400, "text": "的内容", "punctuation": ""},
],
}
]
}
]
}
class FunAsrSrtConversionTests(unittest.TestCase):
def test_official_shape_words_convert_ms_and_speaker_label(self):
srt = fasr.fun_asr_result_to_srt(OFFICIAL_SHAPE_RESULT, max_chars=20, max_duration=3.5)
self.assertIn("1\n00:00:00,000 --> 00:00:00,400\n说话人1: 你好,", srt)
self.assertIn("2\n00:00:00,400 --> 00:00:02,400\n说话人1: 欢迎观看今天的内容。", srt)
self.assertNotIn("00:06:40,000", srt, "milliseconds must not be treated as seconds")
def test_long_word_sequence_splits_into_fine_blocks(self):
result = {
"transcripts": [
{
"sentences": [
{
"begin_time": 0,
"end_time": 6000,
"speaker_id": 1,
"words": [
{"begin_time": i * 500, "end_time": (i + 1) * 500, "text": token, "punctuation": ""}
for i, token in enumerate(["这是", "一个", "很长", "字幕", "需要", "拆分"])
],
}
]
}
]
}
srt = fasr.fun_asr_result_to_srt(result, max_chars=4, max_duration=10)
self.assertGreaterEqual(srt.count("\n说话人2:"), 3)
self.assertIn("1\n00:00:00,000", srt)
def test_sentence_fallback_uses_ms_without_zero_duration(self):
result = {
"transcripts": [
{
"sentences": [
{
"begin_time": 1000,
"end_time": 3000,
"text": "没有词级时间戳也可以拆分。",
"speaker_id": 0,
"words": [],
}
]
}
]
}
srt = fasr.fun_asr_result_to_srt(result, max_chars=5)
self.assertIn("00:00:01,000", srt)
self.assertIn("说话人1:", srt)
self.assertNotIn("--> 00:00:01,000\n", srt)
def test_empty_result_raises_clear_error(self):
with self.assertRaises(fasr.FunAsrError):
fasr.fun_asr_result_to_srt({"transcripts": []})
def test_malformed_word_timestamp_raises_fun_asr_error(self):
result = {
"transcripts": [
{
"sentences": [
{
"begin_time": 0,
"end_time": 1000,
"speaker_id": 0,
"words": [
{"begin_time": "bad", "end_time": 500, "text": "坏时间", "punctuation": ""}
],
}
]
}
]
}
with self.assertRaises(fasr.FunAsrError):
fasr.fun_asr_result_to_srt(result)
def test_malformed_sentence_timestamp_raises_fun_asr_error(self):
result = {
"transcripts": [
{
"sentences": [
{
"begin_time": "bad",
"end_time": 1000,
"text": "坏时间",
"speaker_id": 0,
"words": [],
}
]
}
]
}
with self.assertRaises(fasr.FunAsrError):
fasr.fun_asr_result_to_srt(result)
class FunAsrServiceTests(unittest.TestCase):
def test_create_with_fun_asr_uses_expected_rest_flow(self):
with tempfile.TemporaryDirectory() as tmp_dir:
local_file = Path(tmp_dir) / "audio.wav"
local_file.write_bytes(b"audio")
subtitle_file = Path(tmp_dir) / "out.srt"
session = FakeSession(OFFICIAL_SHAPE_RESULT)
result_path = fasr.create_with_fun_asr(
str(local_file),
subtitle_file=str(subtitle_file),
api_key="sk-test",
speaker_count=2,
poll_interval=0,
session=session,
)
self.assertEqual(str(subtitle_file), result_path)
self.assertTrue(subtitle_file.exists())
self.assertIn("说话人1:", subtitle_file.read_text(encoding="utf-8"))
policy_call = session.calls[0]
self.assertEqual("GET", policy_call[0])
self.assertEqual(fasr.UPLOAD_POLICY_URL, policy_call[1])
self.assertEqual({"action": "getPolicy", "model": "fun-asr"}, policy_call[2]["params"])
self.assertEqual("Bearer sk-test", policy_call[2]["headers"]["Authorization"])
upload_call = session.calls[1]
self.assertEqual("POST", upload_call[0])
self.assertEqual("https://dashscope-file-test.oss-cn-beijing.aliyuncs.com", upload_call[1])
upload_data = upload_call[2]["data"]
self.assertEqual("oss-ak", upload_data["OSSAccessKeyId"])
self.assertEqual("policy-token", upload_data["policy"])
self.assertEqual("signature-token", upload_data["Signature"])
self.assertEqual("dashscope-instant/test-dir/audio.wav", upload_data["key"])
self.assertEqual("200", upload_data["success_action_status"])
submit_call = session.calls[2]
self.assertEqual(fasr.TRANSCRIPTION_URL, submit_call[1])
headers = submit_call[2]["headers"]
self.assertEqual("enable", headers["X-DashScope-Async"])
self.assertEqual("enable", headers["X-DashScope-OssResourceResolve"])
payload = submit_call[2]["json"]
self.assertEqual("fun-asr", payload["model"])
self.assertEqual(["oss://dashscope-instant/test-dir/audio.wav"], payload["input"]["file_urls"])
self.assertTrue(payload["parameters"]["diarization_enabled"])
self.assertEqual(2, payload["parameters"]["speaker_count"])
poll_call = session.calls[3]
self.assertEqual("POST", poll_call[0])
self.assertTrue(poll_call[1].endswith("/api/v1/tasks/task-123"))
download_call = session.calls[4]
self.assertEqual(("GET", "https://result.example/transcription.json"), download_call[:2])
def test_upload_policy_size_validation_fails_before_upload(self):
policy = fasr.UploadPolicy(
upload_host="https://upload.example",
upload_dir="dashscope-instant/test",
policy="p",
signature="s",
oss_access_key_id="ak",
max_file_size_mb=0.000001,
)
with tempfile.NamedTemporaryFile() as f:
f.write(b"too-large")
f.flush()
with self.assertRaises(fasr.FunAsrError):
fasr.upload_to_temporary_oss(f.name, policy, session=FakeSession({}))
def test_failed_subtask_raises(self):
class FailedSession(FakeSession):
def post(self, url, **kwargs):
if url == fasr.TASK_URL_TEMPLATE.format(task_id="task-123"):
return FakeResponse(
{
"output": {
"task_status": "SUCCEEDED",
"results": [{"subtask_status": "FAILED"}],
}
}
)
return super().post(url, **kwargs)
with self.assertRaises(fasr.FunAsrError):
fasr.poll_transcription_task("sk-test", "task-123", poll_interval=0, session=FailedSession({}))
def test_missing_api_key_raises_before_request(self):
session = FakeSession(OFFICIAL_SHAPE_RESULT)
with self.assertRaises(fasr.FunAsrError):
fasr.request_upload_policy("", session=session)
self.assertEqual([], session.calls)
def test_upload_policy_http_error_raises(self):
class PolicyHttpErrorSession(FakeSession):
def get(self, url, **kwargs):
self.calls.append(("GET", url, kwargs))
return FakeResponse({}, status_code=403)
with self.assertRaises(fasr.FunAsrError):
fasr.request_upload_policy("sk-test", session=PolicyHttpErrorSession({}))
def test_malformed_upload_policy_raises(self):
class MalformedPolicySession(FakeSession):
def get(self, url, **kwargs):
self.calls.append(("GET", url, kwargs))
return FakeResponse({"data": {"policy": "missing-required-fields"}})
with self.assertRaises(fasr.FunAsrError):
fasr.request_upload_policy("sk-test", session=MalformedPolicySession({}))
def test_upload_http_failure_raises(self):
class UploadFailureSession(FakeSession):
def post(self, url, **kwargs):
self.calls.append(("POST", url, kwargs))
return FakeResponse({}, status_code=500)
policy = fasr.UploadPolicy(
upload_host="https://upload.example",
upload_dir="dashscope-instant/test",
policy="p",
signature="s",
oss_access_key_id="ak",
max_file_size_mb=1,
)
with tempfile.NamedTemporaryFile() as f:
f.write(b"audio")
f.flush()
with self.assertRaises(fasr.FunAsrError):
fasr.upload_to_temporary_oss(f.name, policy, session=UploadFailureSession({}))
def test_submit_failure_raises(self):
class SubmitFailureSession(FakeSession):
def post(self, url, **kwargs):
self.calls.append(("POST", url, kwargs))
return FakeResponse({}, status_code=500)
with self.assertRaises(fasr.FunAsrError):
fasr.submit_transcription_task("sk-test", "oss://file", session=SubmitFailureSession({}))
def test_poll_timeout_raises(self):
class PendingSession(FakeSession):
def post(self, url, **kwargs):
self.calls.append(("POST", url, kwargs))
return FakeResponse({"output": {"task_status": "RUNNING"}})
with self.assertRaises(fasr.FunAsrError):
fasr.poll_transcription_task("sk-test", "task-123", poll_interval=0, timeout=-1, session=PendingSession({}))
def test_task_failed_status_raises(self):
class FailedTaskSession(FakeSession):
def post(self, url, **kwargs):
self.calls.append(("POST", url, kwargs))
return FakeResponse({"output": {"task_status": "FAILED"}})
with self.assertRaises(fasr.FunAsrError):
fasr.poll_transcription_task("sk-test", "task-123", poll_interval=0, session=FailedTaskSession({}))
def test_missing_transcription_url_raises(self):
with self.assertRaises(fasr.FunAsrError):
fasr.download_transcription_result("", session=FakeSession({}))
def test_malformed_downloaded_json_raises(self):
class MalformedDownloadSession(FakeSession):
def get(self, url, **kwargs):
self.calls.append(("GET", url, kwargs))
return InvalidJsonResponse()
with self.assertRaises(fasr.FunAsrError):
fasr.download_transcription_result("https://result.example/bad.json", session=MalformedDownloadSession({}))
class FunAsrConfigTests(unittest.TestCase):
def test_save_config_persists_fun_asr_section(self):
original_config_file = cfg.config_file
original_fun_asr = cfg.fun_asr
try:
with tempfile.TemporaryDirectory() as tmp_dir:
config_path = Path(tmp_dir) / "config.toml"
cfg.config_file = str(config_path)
cfg.fun_asr = {"api_key": "sk-local", "model": "fun-asr"}
cfg.save_config()
saved = tomllib.loads(config_path.read_text(encoding="utf-8"))
finally:
cfg.config_file = original_config_file
cfg.fun_asr = original_fun_asr
self.assertEqual("sk-local", saved["fun_asr"]["api_key"])
self.assertEqual("fun-asr", saved["fun_asr"]["model"])
def test_config_example_fun_asr_section_parses(self):
config_data = tomllib.loads(Path("config.example.toml").read_text(encoding="utf-8"))
self.assertEqual("fun-asr", config_data["fun_asr"]["model"])
self.assertIn("api_key", config_data["fun_asr"])
if __name__ == "__main__":
unittest.main()

View File

@ -1116,6 +1116,125 @@ def should_use_azure_speech_services(voice_name: str) -> bool:
return False
def doubaotts_tts(text: str, voice_name: str, voice_file: str, speed: float = 1.0) -> Union[SubMaker, None]:
"""
使用豆包语音 TTS 生成语音
"""
# 读取配置
doubaotts_cfg = getattr(config, "doubaotts", {}) or {}
appid = doubaotts_cfg.get("appid", "")
token = doubaotts_cfg.get("token", "")
ak = doubaotts_cfg.get("ak", "")
sk = doubaotts_cfg.get("sk", "")
cluster = doubaotts_cfg.get("cluster", "volcano_tts")
if not appid or not token:
logger.error("豆包语音 TTS 配置未完成")
return None
# 准备参数
voice_type = voice_name
safe_speed = float(max(0.2, min(3.0, speed)))
text = text.strip()
# 构建请求参数
import uuid
reqid = str(uuid.uuid4())
# 获取高级参数
volume = doubaotts_cfg.get("volume", 1.0)
pitch = doubaotts_cfg.get("pitch", 1.0)
silence_duration = doubaotts_cfg.get("silence_duration", 0.125)
payload = {
"app": {
"appid": appid,
"token": token,
"cluster": cluster
},
"user": {
"uid": "NarratoAI"
},
"audio": {
"voice_type": voice_type,
"encoding": "mp3",
"rate": 24000,
"speed_ratio": safe_speed,
"volume_ratio": float(volume),
"pitch_ratio": float(pitch)
},
"request": {
"reqid": reqid,
"text": text,
"text_type": "plain",
"operation": "query"
}
}
# 如果设置了句尾静音时长,添加到请求参数中
if silence_duration > 0:
payload["audio"]["silence_duration"] = float(silence_duration)
# API 地址
url = "https://openspeech.bytedance.com/api/v1/tts"
# 构建请求头使用Bearer Token认证
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer;{token}"
}
for i in range(3):
try:
logger.info(f"=== 豆包语音 TTS 请求参数 (第 {i+1} 次调用) ===")
# 发送请求
import requests
# 处理代理设置
proxies = None
proxy_enabled = config.proxy.get("enabled", False)
if proxy_enabled:
proxy_url = config.proxy.get("https", config.proxy.get("http", ""))
if proxy_url:
proxies = {"https": proxy_url, "http": proxy_url}
response = requests.post(url, json=payload, headers=headers, proxies=proxies, timeout=60)
if response.status_code == 200:
result = response.json()
if result.get("code") == 3000:
# 成功
audio_data = result.get("data", "")
if audio_data:
# 解码 base64 音频数据
import base64
audio_bytes = base64.b64decode(audio_data)
# 写入文件
with open(voice_file, "wb") as f:
f.write(audio_bytes)
logger.success(f"豆包语音 TTS 合成成功: {voice_file}")
# 创建 SubMaker 对象(简化版,不包含时间戳)
sub_maker = new_sub_maker()
return sub_maker
else:
logger.error("豆包语音 TTS 响应中无音频数据")
else:
logger.error(f"豆包语音 TTS 失败: {result.get('message', '未知错误')}")
else:
logger.error(f"豆包语音 TTS API 请求失败: {response.status_code}, {response.text}")
if i < 2:
time.sleep(1)
except Exception as e:
logger.error(f"豆包语音 TTS 错误: {str(e)}")
if i < 2:
time.sleep(3)
return None
def tts(
text: str, voice_name: str, voice_rate: float, voice_pitch: float, voice_file: str, tts_engine: str
) -> Union[SubMaker, None]:
@ -1147,6 +1266,10 @@ def tts(
if tts_engine == "indextts2":
logger.info("分发到 IndexTTS2")
return indextts2_tts(text, voice_name, voice_file, speed=voice_rate)
if tts_engine == "doubaotts":
logger.info("分发到豆包语音 TTS")
return doubaotts_tts(text, voice_name, voice_file, speed=voice_rate)
# Fallback for unknown engine - default to azure v1
logger.warning(f"未知的 TTS 引擎: '{tts_engine}', 将默认使用 Edge TTS (Azure V1)。")
@ -1606,8 +1729,8 @@ def tts_multiple(task_id: str, list_script: list, voice_name: str, voice_rate: f
f"或者使用其他 tts 引擎")
continue
else:
# SoulVoice、Qwen3、IndexTTS2 引擎不生成字幕文件
if is_soulvoice_voice(voice_name) or is_qwen_engine(tts_engine) or tts_engine == "indextts2":
# SoulVoice、Qwen3、IndexTTS2、豆包语音 引擎不生成字幕文件
if is_soulvoice_voice(voice_name) or is_qwen_engine(tts_engine) or tts_engine == "indextts2" or tts_engine == "doubaotts":
# 获取实际音频文件的时长
duration = get_audio_duration_from_file(audio_file)
if duration <= 0:
@ -1615,8 +1738,27 @@ def tts_multiple(task_id: str, list_script: list, voice_name: str, voice_rate: f
duration = get_audio_duration(sub_maker)
if duration <= 0:
# 最后的 fallback基于文本长度估算
duration = max(1.0, len(text) / 3.0)
logger.warning(f"无法获取音频时长,使用文本估算: {duration:.2f}")
# 对于英文文本,使用更准确的估算方法
# 英文平均语速约为每分钟150-180个单词即每秒2.5-3个单词
# 对于中文文本约为每秒3-4字
import re
# 计算英文单词数
english_words = len(re.findall(r'\b\w+\b', text))
# 计算中文字符数
chinese_chars = len(re.findall(r'[\u4e00-\u9fa5]', text))
if english_words > chinese_chars:
# 主要是英文文本
# 假设平均每个单词需要0.35秒
estimated_duration = max(1.0, english_words * 0.35)
else:
# 主要是中文文本
# 假设平均每个汉字需要0.3秒
estimated_duration = max(1.0, chinese_chars * 0.3)
# 确保估算时长合理
duration = max(1.0, estimated_duration)
logger.warning(f"无法获取音频时长,使用文本估算: {duration:.2f}秒 (英文单词: {english_words}, 中文字符: {chinese_chars})")
# 不创建字幕文件
subtitle_file = ""
else:
@ -1658,8 +1800,6 @@ def get_audio_duration_from_file(audio_file: str) -> float:
# 但实际文件还包含头部信息,所以调整系数
estimated_duration = max(1.0, file_size / 20000) # 调整为更保守的估算
# 对于中文语音,根据文本长度进行二次校正
# 一般中文语音速度约为 3-4 字/秒
logger.warning(f"使用文件大小估算音频时长: {estimated_duration:.2f}")
return estimated_duration
except Exception as e:

View File

@ -570,29 +570,39 @@ def temp_dir(sub_dir: str = ""):
return d
def clear_keyframes_cache(video_path: str = None):
def clear_keyframes_cache(video_path: str = None, cache_scope: str = "keyframes"):
"""
清理关键帧缓存
Args:
video_path: 视频文件路径如果指定则只清理该视频的缓存
cache_scope: 缓存作用域目录默认 keyframes
"""
try:
keyframes_dir = os.path.join(temp_dir(), "keyframes")
if not os.path.exists(keyframes_dir):
cache_dir = os.path.join(temp_dir(), cache_scope)
if not os.path.exists(cache_dir):
return
import shutil
if video_path:
# 理指定视频的缓存
video_hash = md5(video_path + str(os.path.getmtime(video_path)))
video_keyframes_dir = os.path.join(keyframes_dir, video_hash)
if os.path.exists(video_keyframes_dir):
import shutil
shutil.rmtree(video_keyframes_dir)
logger.info(f"已清理视频关键帧缓存: {video_path}")
# 清理指定视频的缓存(兼容前缀扩展键)
try:
video_mtime = os.path.getmtime(video_path)
except OSError:
video_mtime = 0
video_hash = md5(video_path + str(video_mtime))
for entry in os.listdir(cache_dir):
if not entry.startswith(video_hash):
continue
target_path = os.path.join(cache_dir, entry)
if os.path.isdir(target_path):
shutil.rmtree(target_path)
else:
os.remove(target_path)
logger.info(f"已清理视频关键帧缓存: {video_path}")
else:
# 清理所有缓存
import shutil
shutil.rmtree(keyframes_dir)
shutil.rmtree(cache_dir)
logger.info("已清理所有关键帧缓存")
except Exception as e:

View File

@ -185,6 +185,95 @@ class VideoProcessor:
return frame_numbers
def extract_frames_by_interval_with_fallback(self, output_dir: str, interval_seconds: float = 5.0) -> List[str]:
"""
先尝试单次 ffmpeg 快路径抽帧失败时回退到高兼容方案
"""
if interval_seconds <= 0:
raise ValueError("interval_seconds must be > 0")
os.makedirs(output_dir, exist_ok=True)
try:
return self._extract_frames_fast_path(output_dir, interval_seconds=interval_seconds)
except Exception as exc:
logger.warning(f"快路径抽帧失败,回退到兼容模式: {exc}")
self._cleanup_fast_path_artifacts(output_dir)
self.extract_frames_by_interval_ultra_compatible(output_dir, interval_seconds=interval_seconds)
return self._collect_extracted_frame_paths(output_dir)
def _extract_frames_fast_path(self, output_dir: str, interval_seconds: float = 5.0) -> List[str]:
"""
使用单次 ffmpeg 命令按固定间隔抽帧随后重命名为既有 keyframe 约定格式
"""
if interval_seconds <= 0:
raise ValueError("interval_seconds must be > 0")
os.makedirs(output_dir, exist_ok=True)
raw_pattern = os.path.join(output_dir, "fastframe_%06d.jpg")
cmd = [
"ffmpeg",
"-hide_banner",
"-loglevel",
"error",
"-i",
self.video_path,
"-vf",
f"fps=1/{interval_seconds}",
"-q:v",
"2",
"-start_number",
"0",
"-y",
raw_pattern,
]
subprocess.run(cmd, capture_output=True, text=True, check=True, timeout=120)
raw_files = sorted(
filename
for filename in os.listdir(output_dir)
if re.fullmatch(r"fastframe_\d{6}\.jpg", filename)
)
if not raw_files:
raise RuntimeError("Fast-path extraction produced no frames")
renamed_files: List[str] = []
for index, filename in enumerate(raw_files):
timestamp = index * interval_seconds
frame_number = int(timestamp * self.fps)
token = self._format_timestamp_token(timestamp)
source_path = os.path.join(output_dir, filename)
target_path = os.path.join(output_dir, f"keyframe_{frame_number:06d}_{token}.jpg")
os.replace(source_path, target_path)
renamed_files.append(target_path)
return renamed_files
@staticmethod
def _format_timestamp_token(timestamp: float) -> str:
hours = int(timestamp // 3600)
minutes = int((timestamp % 3600) // 60)
seconds = int(timestamp % 60)
milliseconds = int((timestamp % 1) * 1000)
return f"{hours:02d}{minutes:02d}{seconds:02d}{milliseconds:03d}"
@staticmethod
def _collect_extracted_frame_paths(output_dir: str) -> List[str]:
return sorted(
os.path.join(output_dir, name)
for name in os.listdir(output_dir)
if re.fullmatch(r"keyframe_\d{6}_\d{9}\.jpg", name)
)
@staticmethod
def _cleanup_fast_path_artifacts(output_dir: str) -> None:
for name in os.listdir(output_dir):
if not re.fullmatch(r"fastframe_\d{6}\.jpg", name):
continue
artifact_path = os.path.join(output_dir, name)
if os.path.isfile(artifact_path):
os.remove(artifact_path)
def _extract_single_frame_optimized(self, timestamp: float, output_path: str,
use_hw_accel: bool, hwaccel_type: str) -> bool:
"""

View File

@ -1,5 +1,5 @@
[app]
project_version="0.7.6"
project_version="0.7.8"
# LLM API 超时配置(秒)
llm_vision_timeout = 120 # 视觉模型基础超时时间
@ -93,6 +93,12 @@
# 访问 https://bailian.console.aliyun.com/?tab=model#/api-key 获取你的 API 密钥
api_key = ""
model_name = "qwen3-tts-flash"
[fun_asr]
# 阿里百炼 Fun-ASR 字幕转录配置
# 访问 https://bailian.console.aliyun.com/?tab=model#/api-key 获取你的 API 密钥
api_key = ""
model = "fun-asr"
[indextts2]
# IndexTTS2 语音克隆配置
@ -114,9 +120,25 @@
do_sample = true
num_beams = 3
repetition_penalty = 10.0
[doubaotts]
# 豆包语音 TTS 配置
# 申请流程:
# 1. 打开 https://console.volcengine.com/iam/keymanage 新建 Access Key 和 Secret Key
# 2. 打开 https://www.volcengine.com/product/voice-tech 点击立即使用
# 3. 在 API 服务中心找到音频生成下面的语音合成,获取 APPID 和 Token
ak = ""
sk = ""
appid = ""
token = ""
cluster = "volcano_tts"
# 高级参数
volume = 1.0
pitch = 1.0
silence_duration = 0.125
[ui]
# TTS引擎选择 (edge_tts, azure_speech, soulvoice, tencent_tts, tts_qwen)
# TTS引擎选择 (edge_tts, azure_speech, soulvoice, tencent_tts, tts_qwen, doubaotts)
tts_engine = "edge_tts"
# Edge TTS 配置
@ -130,6 +152,10 @@
azure_volume = 80
azure_rate = 1.0
azure_pitch = 0
# 豆包语音 TTS 配置
doubaotts_voice_type = "BV700_V2_streaming"
doubaotts_rate = 1.0
##########################################
# 代理和网络配置
@ -152,3 +178,6 @@
# 大模型单次处理的关键帧数量
vision_batch_size = 10
# 视觉批处理最大并发批次数OpenAI 兼容 provider
vision_max_concurrency = 2

11
conftest.py Normal file
View File

@ -0,0 +1,11 @@
"""Pytest collection rules for the repository.
These files are executable smoke-check scripts that live next to the LLM
implementation for convenience. They require live credentials or manual
execution semantics, so keep them out of the default automated test suite.
"""
collect_ignore = [
"app/services/llm/test_llm_service.py",
"app/services/llm/test_openai_compatible_integration.py",
]

View File

@ -1 +1 @@
0.7.7
0.7.9

View File

@ -35,3 +35,6 @@ tenacity>=9.0.0
# torch>=2.0.0
# torchvision>=0.15.0
# torchaudio>=2.0.0
# 剪映草稿导出依赖
pyJianYingDraft>=0.1.0

View File

@ -0,0 +1,275 @@
import unittest
import os
from tempfile import TemporaryDirectory
from unittest.mock import patch
from app.services.documentary.frame_analysis_models import DocumentaryAnalysisConfig
from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService
from app.utils import utils
class DocumentaryFrameAnalysisServiceTests(unittest.TestCase):
def test_build_analysis_prompt_formats_real_frame_count(self):
service = DocumentaryFrameAnalysisService()
prompt = service._build_analysis_prompt(frame_count=3)
self.assertIn("我提供了 3 张视频帧", prompt)
self.assertNotIn("%s", prompt)
self.assertIn("frame_observations", prompt)
self.assertIn("overall_activity_summary", prompt)
def test_parse_failed_batch_keeps_raw_response_and_time_range(self):
service = DocumentaryFrameAnalysisService()
batch = service._build_failed_batch_result(
batch_index=2,
raw_response="not-json",
error_message="JSON decode failed",
frame_paths=["/tmp/keyframe_000000_000000000.jpg"],
time_range="00:00:00,000-00:00:03,000",
)
self.assertEqual("failed", batch.status)
self.assertEqual("not-json", batch.raw_response)
self.assertEqual("00:00:00,000-00:00:03,000", batch.time_range)
self.assertTrue(batch.fallback_summary)
def test_parse_failed_batch_uses_non_empty_fallback_when_raw_response_is_empty(self):
service = DocumentaryFrameAnalysisService()
batch = service._build_failed_batch_result(
batch_index=3,
raw_response="",
error_message="Empty model response",
frame_paths=["/tmp/keyframe_000001_000001000.jpg"],
time_range="00:00:03,000-00:00:06,000",
)
self.assertEqual("failed", batch.status)
self.assertEqual("", batch.raw_response)
self.assertTrue(batch.fallback_summary)
def test_failed_batch_result_uses_prompt_contract_field_names(self):
service = DocumentaryFrameAnalysisService()
batch = service._build_failed_batch_result(
batch_index=4,
raw_response="not-json",
error_message="JSON decode failed",
frame_paths=["/tmp/keyframe_000002_000002000.jpg"],
time_range="00:00:06,000-00:00:09,000",
)
self.assertEqual([], batch.frame_observations)
self.assertEqual("", batch.overall_activity_summary)
self.assertFalse(hasattr(batch, "observations"))
self.assertFalse(hasattr(batch, "summary"))
def test_parse_batch_returns_failed_result_when_json_is_invalid(self):
service = DocumentaryFrameAnalysisService()
batch = service._parse_batch_response(
batch_index=0,
raw_response="plain text",
frame_paths=["/tmp/keyframe_000000_000000000.jpg"],
time_range="00:00:00,000-00:00:03,000",
)
self.assertEqual("failed", batch.status)
self.assertEqual("plain text", batch.raw_response)
self.assertEqual(["/tmp/keyframe_000000_000000000.jpg"], batch.frame_paths)
self.assertEqual([], batch.frame_observations)
self.assertEqual("", batch.overall_activity_summary)
def test_parse_batch_returns_failed_result_for_empty_json_object(self):
service = DocumentaryFrameAnalysisService()
batch = service._parse_batch_response(
batch_index=0,
raw_response="{}",
frame_paths=["/tmp/keyframe_000000_000000000.jpg"],
time_range="00:00:00,000-00:00:03,000",
)
self.assertEqual("failed", batch.status)
self.assertEqual("{}", batch.raw_response)
self.assertIn("frame_observations", batch.error_message)
def test_parse_batch_returns_failed_result_when_observations_are_too_short(self):
service = DocumentaryFrameAnalysisService()
raw_response = """
{
"frame_observations": [
{"observation": "第一帧画面"}
],
"overall_activity_summary": "只有一条帧观察"
}
""".strip()
batch = service._parse_batch_response(
batch_index=1,
raw_response=raw_response,
frame_paths=[
"/tmp/keyframe_000000_000000000.jpg",
"/tmp/keyframe_000075_000003000.jpg",
],
time_range="00:00:00,000-00:00:06,000",
)
self.assertEqual("failed", batch.status)
self.assertEqual(raw_response, batch.raw_response)
self.assertIn("frame_observations", batch.error_message)
def test_parse_batch_parses_code_fenced_json_into_structured_result(self):
service = DocumentaryFrameAnalysisService()
raw_response = """```json
{
"frame_observations": [
{"observation": "第一帧画面"},
{"observation": "第二帧画面"}
],
"overall_activity_summary": "人物从房间走到街道"
}
```"""
batch = service._parse_batch_response(
batch_index=1,
raw_response=raw_response,
frame_paths=[
"/tmp/keyframe_000000_000000000.jpg",
"/tmp/keyframe_000075_000003000.jpg",
],
time_range="00:00:00,000-00:00:06,000",
)
self.assertEqual("success", batch.status)
self.assertEqual(
[
{
"frame_path": "/tmp/keyframe_000000_000000000.jpg",
"timestamp": "",
"observation": "第一帧画面",
},
{
"frame_path": "/tmp/keyframe_000075_000003000.jpg",
"timestamp": "",
"observation": "第二帧画面",
},
],
batch.frame_observations,
)
self.assertEqual("人物从房间走到街道", batch.overall_activity_summary)
self.assertEqual("", batch.fallback_summary)
def test_parse_batch_preserves_frames_when_summary_is_missing(self):
service = DocumentaryFrameAnalysisService()
raw_response = """
{
"frame_observations": [
{"observation": "第一帧画面"},
{"observation": "第二帧画面"}
]
}
""".strip()
batch = service._parse_batch_response(
batch_index=2,
raw_response=raw_response,
frame_paths=[
"/tmp/keyframe_000000_000000000.jpg",
"/tmp/keyframe_000075_000003000.jpg",
],
time_range="00:00:00,000-00:00:06,000",
)
self.assertEqual("success", batch.status)
self.assertEqual(2, len(batch.frame_observations))
self.assertEqual("", batch.overall_activity_summary)
def test_cache_key_changes_when_interval_changes(self):
service = DocumentaryFrameAnalysisService()
with patch("app.services.documentary.frame_analysis_service.os.path.getmtime", return_value=100.0):
key_a = service._build_cache_key("video.mp4", 3.0, "prompt-v1", "model-a", 10, 2)
key_b = service._build_cache_key("video.mp4", 5.0, "prompt-v1", "model-a", 10, 2)
self.assertNotEqual(key_a, key_b)
def test_cache_key_changes_when_model_changes(self):
service = DocumentaryFrameAnalysisService()
with patch("app.services.documentary.frame_analysis_service.os.path.getmtime", return_value=100.0):
key_a = service._build_cache_key("video.mp4", 3.0, "prompt-v1", "model-a", 10, 2)
key_b = service._build_cache_key("video.mp4", 3.0, "prompt-v1", "model-b", 10, 2)
self.assertNotEqual(key_a, key_b)
def test_cache_key_starts_with_legacy_video_hash_prefix(self):
service = DocumentaryFrameAnalysisService()
with patch("app.services.documentary.frame_analysis_service.os.path.getmtime", return_value=123.0):
key = service._build_cache_key("video.mp4", 3.0, "prompt-v1", "model-a", 10, 2)
expected_prefix = utils.md5("video.mp4" + "123.0")
self.assertTrue(key.startswith(expected_prefix))
def test_clear_keyframes_cache_respects_scope_and_prefix_match(self):
with TemporaryDirectory() as temp_root:
service = DocumentaryFrameAnalysisService()
analysis_dir = os.path.join(temp_root, "analysis")
os.makedirs(analysis_dir, exist_ok=True)
with patch("app.services.documentary.frame_analysis_service.os.path.getmtime", return_value=123.0):
target_key_a = service._build_cache_key("video.mp4", 3.0, "prompt-v1", "model-a", 10, 2)
target_key_b = service._build_cache_key("video.mp4", 5.0, "prompt-v1", "model-a", 10, 2)
keep_key = service._build_cache_key("other.mp4", 3.0, "prompt-v1", "model-a", 10, 2)
target_dir_a = os.path.join(analysis_dir, target_key_a)
target_dir_b = os.path.join(analysis_dir, target_key_b)
keep_dir = os.path.join(analysis_dir, keep_key)
os.makedirs(target_dir_a, exist_ok=True)
os.makedirs(target_dir_b, exist_ok=True)
os.makedirs(keep_dir, exist_ok=True)
with patch("app.utils.utils.temp_dir", return_value=temp_root), patch(
"app.utils.utils.os.path.getmtime", return_value=123.0
):
utils.clear_keyframes_cache(video_path="video.mp4", cache_scope="analysis")
self.assertFalse(os.path.exists(target_dir_a))
self.assertFalse(os.path.exists(target_dir_b))
self.assertTrue(os.path.exists(keep_dir))
class DocumentaryAnalysisConfigTests(unittest.TestCase):
def test_config_rejects_non_positive_frame_interval(self):
with self.assertRaises(ValueError):
DocumentaryAnalysisConfig(
video_path="/tmp/demo.mp4",
frame_interval_seconds=0,
vision_batch_size=5,
vision_llm_provider="openai",
vision_model_name="gpt-4o-mini",
)
def test_config_rejects_non_positive_batch_size(self):
with self.assertRaises(ValueError):
DocumentaryAnalysisConfig(
video_path="/tmp/demo.mp4",
frame_interval_seconds=5,
vision_batch_size=0,
vision_llm_provider="openai",
vision_model_name="gpt-4o-mini",
)
def test_config_rejects_non_positive_max_concurrency(self):
with self.assertRaises(ValueError):
DocumentaryAnalysisConfig(
video_path="/tmp/demo.mp4",
frame_interval_seconds=5,
vision_batch_size=5,
vision_llm_provider="openai",
vision_model_name="gpt-4o-mini",
max_concurrency=0,
)

View File

@ -0,0 +1,58 @@
import json
import unittest
from pathlib import Path
from tempfile import TemporaryDirectory
from app.services.generate_narration_script import parse_frame_analysis_to_markdown
class GenerateNarrationMarkdownTests(unittest.TestCase):
def test_markdown_keeps_batches_without_summary_and_sorts_by_time(self):
artifact = {
"batches": [
{
"batch_index": 1,
"time_range": "00:00:03,000-00:00:06,000",
"overall_activity_summary": "人物转身跑向远处",
"fallback_summary": "",
"frame_observations": [
{
"timestamp": "00:00:03,000",
"observation": "人物突然回头",
}
],
},
{
"batch_index": 0,
"time_range": "00:00:00,000-00:00:03,000",
"overall_activity_summary": "",
"fallback_summary": "原始响应回退摘要",
"frame_observations": [
{
"timestamp": "00:00:00,000",
"observation": "镜头里有一只猫",
}
],
},
]
}
with TemporaryDirectory() as temp_dir:
analysis_path = Path(temp_dir) / "frame-analysis.json"
analysis_path.write_text(json.dumps(artifact, ensure_ascii=False), encoding="utf-8")
markdown = parse_frame_analysis_to_markdown(str(analysis_path))
first_range_index = markdown.find("00:00:00,000-00:00:03,000")
second_range_index = markdown.find("00:00:03,000-00:00:06,000")
self.assertIn("原始响应回退摘要", markdown)
self.assertIn("镜头里有一只猫", markdown)
self.assertIn("人物转身跑向远处", markdown)
self.assertIn("人物突然回头", markdown)
self.assertNotEqual(-1, first_range_index)
self.assertNotEqual(-1, second_range_index)
self.assertLess(first_range_index, second_range_index)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,19 @@
import unittest
from webui.tools.generate_script_docu import _normalize_progress_value
class GenerateScriptDocuProgressTests(unittest.TestCase):
def test_normalize_progress_rounds_percentage_float_to_valid_streamlit_int(self):
self.assertEqual(43, _normalize_progress_value(43.125))
def test_normalize_progress_converts_ratio_float_to_percentage_int(self):
self.assertEqual(43, _normalize_progress_value(0.43125))
def test_normalize_progress_clamps_out_of_range_values(self):
self.assertEqual(0, _normalize_progress_value(-5))
self.assertEqual(100, _normalize_progress_value(101))
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,316 @@
import json
import unittest
from pathlib import Path
from tempfile import TemporaryDirectory
from unittest.mock import AsyncMock, patch
from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService
from app.services.script_service import ScriptGenerator
class ScriptGeneratorDocumentaryTests(unittest.IsolatedAsyncioTestCase):
async def test_generate_script_forwards_explicit_values_to_shared_service(self):
expected_script = [
{
"timestamp": "00:00:00,000-00:00:03,000",
"picture": "批次描述",
"narration": "这里是解说词",
"OST": 2,
}
]
callback = lambda _percent, _message: None
with patch("app.services.script_service.DocumentaryFrameAnalysisService") as service_cls:
service = service_cls.return_value
service.generate_documentary_script = AsyncMock(return_value=expected_script)
generator = ScriptGenerator()
result = await generator.generate_script(
video_path="demo.mp4",
video_theme="荒野生存",
custom_prompt="请聚焦生存动作",
frame_interval_input=3,
vision_batch_size=6,
vision_llm_provider="openai",
progress_callback=callback,
)
self.assertEqual(expected_script, result)
self.assertTrue(result[0]["narration"])
service.generate_documentary_script.assert_awaited_once()
called_kwargs = service.generate_documentary_script.await_args.kwargs
self.assertEqual("demo.mp4", called_kwargs["video_path"])
self.assertEqual(3, called_kwargs["frame_interval_input"])
self.assertEqual(6, called_kwargs["vision_batch_size"])
self.assertEqual("openai", called_kwargs["vision_llm_provider"])
self.assertEqual("荒野生存", called_kwargs["video_theme"])
self.assertEqual("请聚焦生存动作", called_kwargs["custom_prompt"])
self.assertIs(called_kwargs["progress_callback"], callback)
async def test_generate_script_forwards_unset_values_as_none(self):
expected_script = [
{
"timestamp": "00:00:00,000-00:00:03,000",
"picture": "批次描述",
"narration": "这里是解说词",
"OST": 2,
}
]
with patch("app.services.script_service.DocumentaryFrameAnalysisService") as service_cls:
service = service_cls.return_value
service.generate_documentary_script = AsyncMock(return_value=expected_script)
generator = ScriptGenerator()
await generator.generate_script(video_path="demo.mp4")
called_kwargs = service.generate_documentary_script.await_args.kwargs
self.assertIsNone(called_kwargs["frame_interval_input"])
self.assertIsNone(called_kwargs["vision_batch_size"])
self.assertIsNone(called_kwargs["vision_llm_provider"])
async def test_generate_script_warns_when_skip_seconds_or_threshold_are_non_default(self):
expected_script = [
{
"timestamp": "00:00:00,000-00:00:03,000",
"picture": "批次描述",
"narration": "这里是解说词",
"OST": 2,
}
]
with patch("app.services.script_service.DocumentaryFrameAnalysisService") as service_cls, patch(
"app.services.script_service.logger.warning"
) as warning:
service = service_cls.return_value
service.generate_documentary_script = AsyncMock(return_value=expected_script)
generator = ScriptGenerator()
await generator.generate_script(
video_path="demo.mp4",
skip_seconds=2,
threshold=20,
)
warning.assert_called_once()
warning_message = warning.call_args.args[0]
self.assertIn("skip_seconds", warning_message)
self.assertIn("threshold", warning_message)
self.assertIn("does not currently apply", warning_message)
class DocumentaryFrameAnalysisServiceScriptGenerationTests(unittest.IsolatedAsyncioTestCase):
async def test_generate_documentary_script_returns_final_narrated_items(self):
service = DocumentaryFrameAnalysisService()
analysis_payload = {
"batches": [
{
"batch_index": 0,
"time_range": "00:00:00,000-00:00:03,000",
"overall_activity_summary": "",
"fallback_summary": "回退摘要",
"frame_observations": [
{"timestamp": "00:00:00,000", "observation": "镜头里有一只猫"},
],
}
]
}
with TemporaryDirectory() as temp_dir:
analysis_path = Path(temp_dir) / "frame_analysis_test.json"
analysis_path.write_text(json.dumps(analysis_payload, ensure_ascii=False), encoding="utf-8")
with patch.object(
DocumentaryFrameAnalysisService,
"analyze_video",
AsyncMock(return_value={"analysis_json_path": str(analysis_path)}),
), patch.dict(
"app.services.documentary.frame_analysis_service.config.app",
{
"text_llm_provider": "openai",
"text_openai_api_key": "test-key",
"text_openai_model_name": "test-model",
"text_openai_base_url": "https://example.com/v1",
},
), patch(
"app.services.documentary.frame_analysis_service.generate_narration",
return_value='{"items":[{"timestamp":"00:00:00,000-00:00:03,000","picture":"镜头里有一只猫","narration":"一只猫警觉地望向镜头。"}]}',
):
result = await service.generate_documentary_script(video_path="demo.mp4")
self.assertEqual(1, len(result))
self.assertEqual("00:00:00,000-00:00:03,000", result[0]["timestamp"])
self.assertEqual("镜头里有一只猫", result[0]["picture"])
self.assertEqual("一只猫警觉地望向镜头。", result[0]["narration"])
self.assertEqual(2, result[0]["OST"])
async def test_generate_documentary_script_raises_when_narration_json_is_malformed(self):
service = DocumentaryFrameAnalysisService()
analysis_payload = {
"batches": [
{
"batch_index": 0,
"time_range": "00:00:00,000-00:00:03,000",
"overall_activity_summary": "测试摘要",
"fallback_summary": "",
"frame_observations": [
{"timestamp": "00:00:00,000", "observation": "镜头里有一只猫"},
],
}
]
}
with TemporaryDirectory() as temp_dir:
analysis_path = Path(temp_dir) / "frame_analysis_test.json"
analysis_path.write_text(json.dumps(analysis_payload, ensure_ascii=False), encoding="utf-8")
with patch.object(
DocumentaryFrameAnalysisService,
"analyze_video",
AsyncMock(return_value={"analysis_json_path": str(analysis_path)}),
), patch.dict(
"app.services.documentary.frame_analysis_service.config.app",
{
"text_llm_provider": "openai",
"text_openai_api_key": "test-key",
"text_openai_model_name": "test-model",
"text_openai_base_url": "https://example.com/v1",
},
), patch(
"app.services.documentary.frame_analysis_service.generate_narration",
return_value="malformed narration payload",
):
with self.assertRaises(Exception) as ctx:
await service.generate_documentary_script(video_path="demo.mp4")
self.assertIn("解说文案格式错误", str(ctx.exception))
self.assertIn("items", str(ctx.exception))
def test_parse_narration_items_recovers_from_common_json_damage(self):
service = DocumentaryFrameAnalysisService()
damaged_payload = """
解释文字
```json
{{
"items": [
{{
"timestamp": "00:00:00,000-00:00:03,000",
"picture": "镜头里有一只猫",
"narration": "一只猫警觉地望向镜头。",
}},
],
}}
```
补充文字
""".strip()
parsed_items = service._parse_narration_items(damaged_payload)
self.assertEqual(1, len(parsed_items))
self.assertEqual("00:00:00,000-00:00:03,000", parsed_items[0]["timestamp"])
self.assertEqual("镜头里有一只猫", parsed_items[0]["picture"])
self.assertEqual("一只猫警觉地望向镜头。", parsed_items[0]["narration"])
def test_parse_narration_items_raises_for_unrecoverable_payload(self):
service = DocumentaryFrameAnalysisService()
with self.assertRaises(ValueError) as ctx:
service._parse_narration_items("not-json-at-all ::: ???")
self.assertIn("解说文案格式错误", str(ctx.exception))
self.assertIn("items", str(ctx.exception))
async def test_generate_documentary_script_includes_theme_and_custom_prompt_for_narration(self):
service = DocumentaryFrameAnalysisService()
analysis_payload = {
"batches": [
{
"batch_index": 0,
"time_range": "00:00:00,000-00:00:03,000",
"overall_activity_summary": "测试摘要",
"fallback_summary": "",
"frame_observations": [
{"timestamp": "00:00:00,000", "observation": "镜头里有一只猫"},
],
}
]
}
with TemporaryDirectory() as temp_dir:
analysis_path = Path(temp_dir) / "frame_analysis_test.json"
analysis_path.write_text(json.dumps(analysis_payload, ensure_ascii=False), encoding="utf-8")
with patch.object(
DocumentaryFrameAnalysisService,
"analyze_video",
AsyncMock(return_value={"analysis_json_path": str(analysis_path)}),
), patch.dict(
"app.services.documentary.frame_analysis_service.config.app",
{
"text_llm_provider": "openai",
"text_openai_api_key": "test-key",
"text_openai_model_name": "test-model",
"text_openai_base_url": "https://example.com/v1",
},
), patch(
"app.services.documentary.frame_analysis_service.generate_narration",
return_value='{"items":[{"timestamp":"00:00:00,000-00:00:03,000","picture":"镜头里有一只猫","narration":"一只猫警觉地望向镜头。"}]}',
) as mocked_generate:
await service.generate_documentary_script(
video_path="demo.mp4",
video_theme="野生动物纪录片",
custom_prompt="重点描述危险信号",
)
narration_input = mocked_generate.call_args.args[0]
self.assertIn("## 创作上下文", narration_input)
self.assertIn("视频主题:野生动物纪录片", narration_input)
self.assertIn("补充创作要求:重点描述危险信号", narration_input)
async def test_analyze_video_forwards_explicit_empty_base_url_without_config_fallback(self):
service = DocumentaryFrameAnalysisService()
with patch.dict(
"app.services.documentary.frame_analysis_service.config.app",
{
"vision_llm_provider": "openai",
"vision_openai_api_key": "config-key",
"vision_openai_model_name": "config-model",
"vision_openai_base_url": "https://config.example/v1",
},
), patch(
"app.services.documentary.frame_analysis_service.os.path.exists",
return_value=True,
), patch.object(
service,
"_load_or_extract_keyframes",
return_value=["/tmp/keyframe_000001_000000100.jpg"],
), patch.object(
service,
"_analyze_batches",
AsyncMock(return_value=[]),
), patch.object(
service,
"_save_analysis_artifact",
return_value="/tmp/frame_analysis_test.json",
), patch.object(
service,
"_build_video_clip_json",
return_value=[],
), patch(
"app.services.documentary.frame_analysis_service.create_vision_analyzer",
return_value=object(),
) as mocked_create_analyzer:
await service.analyze_video(
video_path="/tmp/demo.mp4",
vision_api_key="explicit-key",
vision_model_name="explicit-model",
vision_base_url="",
)
called_kwargs = mocked_create_analyzer.call_args.kwargs
self.assertEqual("openai", called_kwargs["provider"])
self.assertEqual("explicit-key", called_kwargs["api_key"])
self.assertEqual("explicit-model", called_kwargs["model"])
self.assertEqual("", called_kwargs["base_url"])
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,91 @@
import os
import unittest
from tempfile import TemporaryDirectory
from unittest.mock import patch
from app.utils.video_processor import VideoProcessor
class VideoProcessorDocumentaryTests(unittest.TestCase):
@patch.object(VideoProcessor, "_extract_frames_fast_path", return_value=["a.jpg"])
def test_extract_frames_by_interval_prefers_fast_path(self, fast_path):
processor = VideoProcessor.__new__(VideoProcessor)
processor.video_path = "demo.mp4"
processor.duration = 6.0
processor.fps = 25.0
result = processor.extract_frames_by_interval_with_fallback("/tmp/out", interval_seconds=3.0)
self.assertEqual(["a.jpg"], result)
fast_path.assert_called_once_with("/tmp/out", interval_seconds=3.0)
def test_extract_frames_by_interval_falls_back_to_ultra_compatible(self):
processor = VideoProcessor.__new__(VideoProcessor)
processor.video_path = "demo.mp4"
processor.duration = 6.0
processor.fps = 25.0
with TemporaryDirectory() as output_dir:
expected_frame_path = os.path.join(output_dir, "keyframe_000000_000000000.jpg")
def ultra_compatible_fallback(self, output_dir_arg, interval_seconds=5.0):
with open(expected_frame_path, "wb") as frame_file:
frame_file.write(b"frame")
return [0]
with patch.object(VideoProcessor, "_extract_frames_fast_path", side_effect=RuntimeError("fast path failed")) as fast_path, patch.object(
VideoProcessor,
"extract_frames_by_interval_ultra_compatible",
side_effect=ultra_compatible_fallback,
autospec=True,
) as fallback:
result = processor.extract_frames_by_interval_with_fallback(output_dir, interval_seconds=3.0)
self.assertEqual([expected_frame_path], result)
fast_path.assert_called_once_with(output_dir, interval_seconds=3.0)
fallback.assert_called_once_with(processor, output_dir, interval_seconds=3.0)
def test_extract_frames_by_interval_rejects_non_positive_interval(self):
processor = VideoProcessor.__new__(VideoProcessor)
processor.video_path = "demo.mp4"
processor.duration = 6.0
processor.fps = 25.0
with patch.object(VideoProcessor, "extract_frames_by_interval_ultra_compatible", autospec=True) as fallback:
with self.assertRaises(ValueError):
processor.extract_frames_by_interval_with_fallback("/tmp/out", interval_seconds=0)
fallback.assert_not_called()
def test_extract_frames_by_interval_fallback_cleans_partial_fast_path_artifacts(self):
processor = VideoProcessor.__new__(VideoProcessor)
processor.video_path = "demo.mp4"
processor.duration = 6.0
processor.fps = 25.0
with TemporaryDirectory() as output_dir:
stale_fastframe = os.path.join(output_dir, "fastframe_000000.jpg")
expected_keyframe = os.path.join(output_dir, "keyframe_000000_000000000.jpg")
def fast_path_with_partial_output(_output_dir, interval_seconds=5.0):
with open(stale_fastframe, "wb") as frame_file:
frame_file.write(b"stale")
raise RuntimeError("simulated fast-path failure")
def ultra_compatible_fallback(self, output_dir_arg, interval_seconds=5.0):
with open(expected_keyframe, "wb") as frame_file:
frame_file.write(b"frame")
return [0]
with patch.object(VideoProcessor, "_extract_frames_fast_path", side_effect=fast_path_with_partial_output) as fast_path, patch.object(
VideoProcessor,
"extract_frames_by_interval_ultra_compatible",
side_effect=ultra_compatible_fallback,
autospec=True,
) as fallback:
result = processor.extract_frames_by_interval_with_fallback(output_dir, interval_seconds=3.0)
self.assertEqual([expected_keyframe], result)
self.assertFalse(os.path.exists(stale_fastframe))
fast_path.assert_called_once_with(output_dir, interval_seconds=3.0)
fallback.assert_called_once_with(processor, output_dir, interval_seconds=3.0)

141
webui.py
View File

@ -1,6 +1,7 @@
import streamlit as st
import os
import sys
import time
from loguru import logger
from app.config import config
from webui.components import basic_settings, video_settings, audio_settings, subtitle_settings, script_settings, \
@ -221,6 +222,145 @@ def render_generate_button():
time.sleep(0.5)
def get_voice_name_for_tts_engine(tts_engine: str) -> str:
"""根据TTS引擎获取用户选择的音色"""
if tts_engine == 'doubaotts':
return st.session_state.get('voice_name', config.ui.get('doubaotts_voice_type', 'BV700_streaming'))
elif tts_engine == 'azure_speech':
return st.session_state.get('voice_name', config.ui.get('azure_voice_name', 'zh-CN-XiaoxiaoMultilingualNeural'))
else:
return st.session_state.get('voice_name', config.ui.get('edge_voice_name', 'zh-CN-XiaoxiaoNeural-Female'))
def get_jianying_export_params() -> VideoClipParams:
"""获取导出到剪映草稿的参数"""
tts_engine = st.session_state.get('tts_engine', 'azure')
voice_name = get_voice_name_for_tts_engine(tts_engine)
voice_rate = st.session_state.get('voice_rate', 1.0)
voice_pitch = st.session_state.get('voice_pitch', 1.0)
return VideoClipParams(
video_clip_json_path=st.session_state['video_clip_json_path'],
video_origin_path=st.session_state['video_origin_path'],
tts_engine=tts_engine,
voice_name=voice_name,
voice_rate=voice_rate,
voice_pitch=voice_pitch,
n_threads=config.app.get('n_threads', 4),
video_aspect=VideoAspect.landscape,
subtitle_enabled=st.session_state.get('subtitle_enabled', False),
font_name=st.session_state.get('font_name', 'Microsoft YaHei'),
font_size=st.session_state.get('font_size', 24),
text_fore_color=st.session_state.get('text_fore_color', '#FFFFFF'),
subtitle_position=st.session_state.get('subtitle_position', 'bottom'),
custom_position=st.session_state.get('custom_position', 70.0),
tts_volume=st.session_state.get('tts_volume', 1.0),
original_volume=st.session_state.get('original_volume', 0.7),
bgm_volume=st.session_state.get('bgm_volume', 0.3),
draft_name=st.session_state.get('draft_name_input', f"NarratoAI_{int(time.time())}")
)
def render_export_jianying_button():
"""渲染导出到剪映草稿按钮和处理逻辑"""
import os
import time
import uuid
from loguru import logger
# 初始化session state
if 'show_jianying_export_form' not in st.session_state:
st.session_state['show_jianying_export_form'] = False
if 'jianying_export_result' not in st.session_state:
st.session_state['jianying_export_result'] = None
if 'jianying_export_error' not in st.session_state:
st.session_state['jianying_export_error'] = None
if st.button("📤 导出到剪映草稿", use_container_width=True, type="secondary"):
config.save_config()
if not st.session_state.get('video_clip_json_path'):
st.error("脚本文件不能为空")
return
if not st.session_state.get('video_origin_path'):
st.error("视频文件不能为空")
return
jianying_draft_path = config.ui.get("jianying_draft_path", "")
if not jianying_draft_path:
st.error("请在基础设置中配置剪映草稿地址")
return
if not os.path.exists(jianying_draft_path):
st.error(f"剪映草稿文件夹不存在: {jianying_draft_path}")
return
# 显示导出表单
st.session_state['show_jianying_export_form'] = True
st.session_state['jianying_export_result'] = None
st.session_state['jianying_export_error'] = None
# 显示导出表单
if st.session_state['show_jianying_export_form']:
st.markdown("---")
st.subheader("导出到剪映草稿")
draft_name = st.text_input(
"请输入剪映草稿名称",
value=f"NarratoAI_{int(time.time())}",
key="draft_name_input"
)
if st.button("确认导出", key="confirm_export"):
if not draft_name:
st.error("请输入草稿名称")
return
# 创建任务ID
task_id = str(uuid.uuid4())
st.session_state['task_id'] = task_id
# 构建参数
try:
params = get_jianying_export_params()
except Exception as e:
logger.error(f"构建参数失败: {e}")
st.error(f"参数构建失败: {e}")
return
with st.spinner("正在导出到剪映草稿,请稍候..."):
try:
from app.services import jianying_task
# 调用导出到剪映草稿的任务
result = jianying_task.start_export_jianying_draft(task_id, params)
# 记录日志
logger.info(f"成功导出到剪映草稿: {result['draft_name']}")
logger.info(f"草稿已保存到: {result['draft_path']}")
# 保存结果到session state
st.session_state['jianying_export_result'] = result
st.session_state['jianying_export_error'] = None
st.session_state['show_jianying_export_form'] = False
st.success(f"✅ 成功导出到剪映草稿: {result['draft_name']}")
st.info(f"📁 草稿已保存到: {result['draft_path']}")
except Exception as e:
logger.error(f"导出到剪映草稿失败: {e}")
import traceback
logger.error(f"错误详情: {traceback.format_exc()}")
st.session_state['jianying_export_error'] = str(e)
st.session_state['jianying_export_result'] = None
st.error(f"❌ 导出到剪映草稿失败: {e}")
if st.button("取消", key="cancel_export"):
st.session_state['show_jianying_export_form'] = False
st.session_state['jianying_export_result'] = None
st.session_state['jianying_export_error'] = None
st.rerun()
def main():
"""主函数"""
@ -285,6 +425,7 @@ def main():
# 放到最后渲染生成按钮和处理逻辑
render_generate_button()
render_export_jianying_button()
if __name__ == "__main__":

View File

@ -26,7 +26,8 @@ def get_tts_engine_options():
"azure_speech": "Azure Speech Services",
"tencent_tts": "腾讯云 TTS",
"qwen3_tts": "通义千问 Qwen3 TTS",
"indextts2": "IndexTTS2 语音克隆"
"indextts2": "IndexTTS2 语音克隆",
"doubaotts": "豆包语音 TTS"
}
@ -62,6 +63,12 @@ def get_tts_engine_descriptions():
"features": "零样本语音克隆,上传参考音频即可合成相同音色的语音,需要本地或私有部署",
"use_case": "下载地址https://pan.quark.cn/s/0767c9bcefd5",
"registration": None
},
"doubaotts": {
"title": "豆包语音 TTS",
"features": "火山引擎豆包语音合成,支持多种音色和情感,国内访问速度快",
"use_case": "需要高质量中文语音合成的用户",
"registration": "https://www.volcengine.com/product/voice-tech"
}
}
@ -147,6 +154,8 @@ def render_tts_settings(tr):
render_qwen3_tts_settings(tr)
elif selected_engine == "indextts2":
render_indextts2_tts_settings(tr)
elif selected_engine == "doubaotts":
render_doubaotts_settings(tr)
# 4. 试听功能
render_voice_preview_new(tr, selected_engine)
@ -703,6 +712,250 @@ def render_indextts2_tts_settings(tr):
config.ui["voice_name"] = f"indextts2:{reference_audio}"
def render_doubaotts_settings(tr):
"""渲染豆包语音 TTS 设置"""
# AK 输入
ak = st.text_input(
"Access Key",
value=config.doubaotts.get("ak", ""),
help="火山引擎 Access Key"
)
# SK 输入
sk = st.text_input(
"Secret Key",
value=config.doubaotts.get("sk", ""),
type="password",
help="火山引擎 Secret Key"
)
# AppID 输入
appid = st.text_input(
"AppID",
value=config.doubaotts.get("appid", ""),
help="豆包语音应用 AppID"
)
# Token 输入
token = st.text_input(
"Token",
value=config.doubaotts.get("token", ""),
type="password",
help="豆包语音应用 Token"
)
# 集群配置
cluster = st.text_input(
"集群",
value=config.doubaotts.get("cluster", "volcano_tts"),
help="业务集群,标准音色使用 volcano_tts"
)
# 音色选择
# 在线音色列表(从文档中提取)
voice_options = {
"BV700_V2_streaming": "灿灿 2.0",
"BV705_streaming": "炀炀",
"BV701_V2_streaming": "擎苍 2.0",
"BV001_V2_streaming": "通用女声 2.0",
"BV700_streaming": "灿灿",
"BV406_V2_streaming": "超自然音色-梓梓2.0",
"BV406_streaming": "超自然音色-梓梓",
"BV407_V2_streaming": "超自然音色-燃燃2.0",
"BV407_streaming": "超自然音色-燃燃",
"BV001_streaming": "通用女声",
"BV002_streaming": "通用男声",
"BV701_streaming": "擎苍",
"BV123_streaming": "阳光青年",
"BV120_streaming": "反卷青年",
"BV119_streaming": "通用赘婿",
"BV115_streaming": "古风少御",
"BV107_streaming": "霸气青叔",
"BV100_streaming": "质朴青年",
"BV104_streaming": "温柔淑女",
"BV004_streaming": "开朗青年",
"BV113_streaming": "甜宠少御",
"BV102_streaming": "儒雅青年",
"BV405_streaming": "甜美小源",
"BV007_streaming": "亲切女声",
"BV009_streaming": "知性女声",
"BV419_streaming": "诚诚",
"BV415_streaming": "童童",
"BV008_streaming": "亲切男声",
"BV408_streaming": "译制片男声",
"BV426_streaming": "懒小羊",
"BV428_streaming": "清新文艺女声",
"BV403_streaming": "鸡汤女声",
"BV158_streaming": "智慧老者",
"BV157_streaming": "慈爱姥姥",
"BR001_streaming": "说唱小哥",
"BV410_streaming": "活力解说男",
"BV411_streaming": "影视解说小帅",
"BV437_streaming": "解说小帅-多情感",
"BV412_streaming": "影视解说小美",
"BV159_streaming": "纨绔青年",
"BV418_streaming": "直播一姐",
"BV142_streaming": "沉稳解说男",
"BV143_streaming": "潇洒青年",
"BV056_streaming": "阳光男声",
"BV005_streaming": "活泼女声",
"BV064_streaming": "小萝莉",
"BV051_streaming": "奶气萌娃",
"BV063_streaming": "动漫海绵",
"BV417_streaming": "动漫海星",
"BV050_streaming": "动漫小新",
"BV061_streaming": "天才童声",
"BV401_streaming": "促销男声",
"BV402_streaming": "促销女声",
"BV006_streaming": "磁性男声",
"BV011_streaming": "新闻女声",
"BV012_streaming": "新闻男声",
"BV034_streaming": "知性姐姐-双语",
"BV033_streaming": "温柔小哥",
"BV511_streaming": "慵懒女声-Ava",
"BV505_streaming": "议论女声-Alicia",
"BV138_streaming": "情感女声-Lawrence",
"BV027_streaming": "美式女声-Amelia",
"BV502_streaming": "讲述女声-Amanda",
"BV503_streaming": "活力女声-Ariana",
"BV504_streaming": "活力男声-Jackson",
"BV421_streaming": "天才少女",
"BV702_streaming": "Stefan",
"BV506_streaming": "天真萌娃-Lily",
"BV040_streaming": "亲切女声-Anna",
"BV516_streaming": "澳洲男声-Henry",
"BV520_streaming": "元气少女",
"BV521_streaming": "萌系少女",
"BV522_streaming": "气质女声",
"BV524_streaming": "日语男声",
"BV531_streaming": "活力男声Carlos巴西地区",
"BV530_streaming": "活力女声(巴西地区)",
"BV065_streaming": "气质御姐(墨西哥地区)",
"BV021_streaming": "东北老铁",
"BV020_streaming": "东北丫头",
"BV704_streaming": "方言灿灿",
"BV210_streaming": "西安佟掌柜",
"BV217_streaming": "沪上阿姐",
"BV213_streaming": "广西表哥",
"BV025_streaming": "甜美台妹",
"BV227_streaming": "台普男声",
"BV026_streaming": "港剧男神",
"BV424_streaming": "广东女仔",
"BV212_streaming": "相声演员",
"BV019_streaming": "重庆小伙",
"BV221_streaming": "四川甜妹儿",
"BV423_streaming": "重庆幺妹儿",
"BV214_streaming": "乡村企业家",
"BV226_streaming": "湖南妹坨",
"BV216_streaming": "长沙靓女"
}
saved_voice_type = config.ui.get("doubaotts_voice_type", "BV700_streaming")
if saved_voice_type not in voice_options:
voice_options[saved_voice_type] = f"自定义音色 ({saved_voice_type})"
selected_voice_display = st.selectbox(
"音色选择",
options=list(voice_options.values()),
index=list(voice_options.keys()).index(saved_voice_type) if saved_voice_type in voice_options else 0,
help="选择豆包语音 TTS 音色"
)
# 获取实际的音色ID
voice_type = list(voice_options.keys())[
list(voice_options.values()).index(selected_voice_display)
]
# 高级参数折叠面板
with st.expander("🔧 高级参数", expanded=False):
col1, col2 = st.columns(2)
with col1:
# 语速调节
voice_rate = st.slider(
"语速调节",
min_value=0.2,
max_value=3.0,
value=config.ui.get("doubaotts_rate", 1.0),
step=0.1,
help="调节语音速度 (0.2-3.0)"
)
# 音量调节
voice_volume = st.slider(
"音量调节",
min_value=0.1,
max_value=2.0,
value=config.doubaotts.get("volume", 1.0),
step=0.1,
help="调节语音音量 (0.1-2.0)"
)
with col2:
# 音高调节
voice_pitch = st.slider(
"音高调节",
min_value=0.5,
max_value=1.5,
value=config.doubaotts.get("pitch", 1.0),
step=0.1,
help="调节语音音高 (0.5-1.5)"
)
# 句尾静音时长
silence_duration = st.slider(
"句尾静音时长 (秒)",
min_value=0.0,
max_value=2.0,
value=config.doubaotts.get("silence_duration", 0.125),
step=0.05,
help="调节句尾静音时长 (0.0-2.0秒)"
)
# 显示API Key申请流程
with st.expander("💡 豆包语音 TTS API Key申请流程", expanded=False):
st.write("**申请步骤:**")
st.write("1. 打开 [https://console.volcengine.com/iam/keymanage](https://console.volcengine.com/iam/keymanage)")
st.write("2. 新建 Access Key 和 Secret Key")
st.write("3. 打开 [https://www.volcengine.com/product/voice-tech](https://www.volcengine.com/product/voice-tech)")
st.write("4. 点击立即使用")
st.write("5. 在最左边的API服务中心找到音频生成下面的语音合成注意是语音合成不是语音合成大模型")
st.write("6. 翻到最下面获取 APPID 和 Access Token")
st.write("")
st.info("💡 请将获取到的 Access Key、Secret Key、AppID 和 Token 填写到上方的配置中")
# 保存配置
config.doubaotts["ak"] = ak
config.doubaotts["sk"] = sk
config.doubaotts["appid"] = appid
config.doubaotts["token"] = token
config.doubaotts["cluster"] = cluster
config.doubaotts["volume"] = voice_volume
config.doubaotts["pitch"] = voice_pitch
config.doubaotts["silence_duration"] = silence_duration
config.ui["doubaotts_voice_type"] = voice_type
config.ui["doubaotts_rate"] = voice_rate
config.ui["voice_name"] = voice_type # 兼容性
st.session_state['voice_rate'] = voice_rate # 确保语速参数被保存到session state
# 显示配置状态
if ak and sk and appid and token:
st.success("✅ 豆包语音 TTS 配置已设置")
else:
missing = []
if not ak:
missing.append("Access Key")
if not sk:
missing.append("Secret Key")
if not appid:
missing.append("AppID")
if not token:
missing.append("Token")
if missing:
st.warning(f"⚠️ 请配置: {', '.join(missing)}")
def render_voice_preview_new(tr, selected_engine):
"""渲染新的语音试听功能"""
if st.button("🎵 试听语音合成", use_container_width=True):
@ -746,6 +999,11 @@ def render_voice_preview_new(tr, selected_engine):
voice_name = f"indextts2:{reference_audio}"
voice_rate = 1.0 # IndexTTS2 不支持速度调节
voice_pitch = 1.0 # IndexTTS2 不支持音调调节
elif selected_engine == "doubaotts":
voice_type = config.ui.get("doubaotts_voice_type", "BV700_streaming")
voice_name = voice_type
voice_rate = config.ui.get("doubaotts_rate", 1.0)
voice_pitch = 1.0 # 豆包语音 TTS 不支持音调调节
if not voice_name:
st.error("请先配置语音设置")

View File

@ -217,6 +217,15 @@ def render_proxy_settings(tr):
config.proxy["http"] = ""
config.proxy["https"] = ""
# 剪映草稿地址设置
st.subheader("剪映草稿设置")
jianying_draft_path = st.text_input(
"剪映草稿文件夹路径",
value=config.ui.get("jianying_draft_path", ""),
help="剪映草稿文件夹路径例如C:\\Users\\用户名\\Documents\\JianyingPro Drafts"
)
config.ui["jianying_draft_path"] = jianying_draft_path
def test_vision_model_connection(api_key, base_url, model_name, provider, tr):
"""测试视觉模型连接

View File

@ -327,6 +327,8 @@ def short_drama_summary(tr):
# 检查是否已经处理过字幕文件
if 'subtitle_file_processed' not in st.session_state:
st.session_state['subtitle_file_processed'] = False
render_fun_asr_transcription(tr)
subtitle_file = st.file_uploader(
tr("上传字幕文件"),
@ -401,6 +403,95 @@ def short_drama_summary(tr):
return video_theme
def render_fun_asr_transcription(tr):
"""使用阿里百炼 Fun-ASR 从本地音视频转写生成字幕。"""
def clear_fun_asr_subtitle_state():
st.session_state['subtitle_path'] = None
st.session_state['subtitle_content'] = None
st.session_state['subtitle_file_processed'] = False
with st.expander("阿里百炼 Fun-ASR 字幕转录", expanded=False):
st.caption("上传本地音频/视频后,将自动上传到阿里百炼临时存储并通过 fun-asr 生成 SRT 字幕。")
st.markdown(
"API Key 获取地址:"
"[https://bailian.console.aliyun.com/?tab=model#/api-key]"
"(https://bailian.console.aliyun.com/?tab=model#/api-key)"
)
api_key = st.text_input(
"阿里百炼 API Key",
value=config.fun_asr.get("api_key", ""),
type="password",
help="请输入你自己的阿里百炼 API Key保存配置后会写入本地 config.toml",
key="fun_asr_api_key",
)
uploaded_media = st.file_uploader(
"上传需要转录的音频/视频",
type=[
"aac", "amr", "avi", "flac", "flv", "m4a", "mkv", "mov",
"mp3", "mp4", "mpeg", "ogg", "opus", "wav", "webm", "wma", "wmv",
],
accept_multiple_files=False,
key="fun_asr_media_uploader",
)
if st.button("转写生成字幕", key="fun_asr_transcribe"):
if not api_key.strip():
clear_fun_asr_subtitle_state()
st.error("请先输入阿里百炼 API Key")
return
if uploaded_media is None:
clear_fun_asr_subtitle_state()
st.error("请先上传需要转录的音频或视频文件")
return
try:
clear_fun_asr_subtitle_state()
from app.services import fun_asr_subtitle
config.fun_asr["api_key"] = api_key.strip()
config.fun_asr["model"] = "fun-asr"
config.save_config()
temp_dir = utils.temp_dir("fun_asr")
safe_filename = os.path.basename(uploaded_media.name)
media_path = os.path.join(temp_dir, safe_filename)
file_name, file_extension = os.path.splitext(safe_filename)
if os.path.exists(media_path):
timestamp = time.strftime("%Y%m%d%H%M%S")
media_path = os.path.join(temp_dir, f"{file_name}_{timestamp}{file_extension}")
with open(media_path, "wb") as f:
f.write(uploaded_media.getbuffer())
subtitle_name = f"{os.path.splitext(os.path.basename(media_path))[0]}_fun_asr.srt"
subtitle_path = os.path.join(utils.subtitle_dir(), subtitle_name)
with st.spinner("正在使用阿里百炼 Fun-ASR 转写字幕,请稍候..."):
generated_path = fun_asr_subtitle.create_with_fun_asr(
local_file=media_path,
subtitle_file=subtitle_path,
api_key=api_key.strip(),
)
if not generated_path or not os.path.exists(generated_path):
clear_fun_asr_subtitle_state()
st.error("Fun-ASR 转写失败:未生成字幕文件")
return
with open(generated_path, "r", encoding="utf-8") as f:
subtitle_content = f.read()
st.session_state['subtitle_path'] = generated_path
st.session_state['subtitle_content'] = subtitle_content
st.session_state['subtitle_file_processed'] = True
st.success(f"字幕转写成功: {os.path.basename(generated_path)}")
except Exception as e:
clear_fun_asr_subtitle_state()
logger.error(f"Fun-ASR 字幕转写失败: {traceback.format_exc()}")
st.error(f"Fun-ASR 字幕转写失败: {str(e)}")
def render_script_buttons(tr, params):
"""渲染脚本操作按钮"""
# 获取当前选择的脚本类型

View File

@ -1,21 +1,32 @@
# 纪录片脚本生成
import os
import asyncio
import json
import time
import asyncio
import traceback
import streamlit as st
from loguru import logger
from datetime import datetime
from app.config import config
from app.utils import utils, video_processor
from webui.tools.base import create_vision_analyzer, get_batch_files, get_batch_timestamps
from app.services.documentary.frame_analysis_service import DocumentaryFrameAnalysisService
def _normalize_progress_value(progress: float | int) -> int:
"""Normalize mixed progress inputs to Streamlit's 0-100 integer range."""
try:
value = float(progress)
except (TypeError, ValueError):
return 0
if 0.0 <= value <= 1.0:
value *= 100
return max(0, min(100, int(round(value))))
def generate_script_docu(params):
"""
生成 纪录片 视频脚本
生成纪录片视频脚本
要求: 原视频无字幕无配音
适合场景: 纪录片动物搞笑解说荒野建造等
"""
@ -23,419 +34,72 @@ def generate_script_docu(params):
status_text = st.empty()
def update_progress(progress: float, message: str = ""):
progress_bar.progress(progress)
normalized_progress = _normalize_progress_value(progress)
progress_bar.progress(normalized_progress)
if message:
status_text.text(f"🎬 {message}")
else:
status_text.text(f"📊 进度: {progress}%")
status_text.text(f"📊 进度: {normalized_progress}%")
try:
with st.spinner("正在生成脚本..."):
if not params.video_origin_path:
st.error("请先选择视频文件")
return
"""
1. 提取键帧
"""
update_progress(10, "正在提取关键帧...")
# 创建临时目录用于存储关键帧
keyframes_dir = os.path.join(utils.temp_dir(), "keyframes")
video_hash = utils.md5(params.video_origin_path + str(os.path.getmtime(params.video_origin_path)))
video_keyframes_dir = os.path.join(keyframes_dir, video_hash)
# 检查是否已经提取过关键帧
keyframe_files = []
if os.path.exists(video_keyframes_dir):
# 取已有的关键帧文件
for filename in sorted(os.listdir(video_keyframes_dir)):
if filename.endswith('.jpg'):
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
if keyframe_files:
logger.info(f"使用已缓存的关键帧: {video_keyframes_dir}")
st.info(f"✅ 使用已缓存关键帧,共 {len(keyframe_files)}")
update_progress(20, f"使用已缓存关键帧,共 {len(keyframe_files)}")
# 如果没有缓存的关键帧,则进行提取
if not keyframe_files:
try:
# 确保目录存在
os.makedirs(video_keyframes_dir, exist_ok=True)
# 初始化视频处理器
processor = video_processor.VideoProcessor(params.video_origin_path)
# 显示视频信息
st.info(f"📹 视频信息: {processor.width}x{processor.height}, {processor.fps:.1f}fps, {processor.duration:.1f}")
# 处理视频并提取关键帧 - 直接使用超级兼容性方案
update_progress(15, "正在提取关键帧(使用超级兼容性方案)...")
try:
# 使用优化的关键帧提取方法
processor.extract_frames_by_interval_ultra_compatible(
output_dir=video_keyframes_dir,
interval_seconds=st.session_state.get('frame_interval_input'),
)
except Exception as extract_error:
logger.error(f"关键帧提取失败: {extract_error}")
# 提供详细的错误信息和解决建议
error_msg = str(extract_error)
if "权限" in error_msg or "permission" in error_msg.lower():
suggestion = "建议:检查输出目录权限,或更换输出位置"
elif "空间" in error_msg or "space" in error_msg.lower():
suggestion = "建议:检查磁盘空间是否足够"
else:
suggestion = "建议:检查视频文件是否损坏,或尝试转换为标准格式"
raise Exception(f"关键帧提取失败: {error_msg}\n{suggestion}")
# 获取所有关键文件路径
for filename in sorted(os.listdir(video_keyframes_dir)):
if filename.endswith('.jpg'):
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
if not keyframe_files:
# 检查目录中是否有其他文件
all_files = os.listdir(video_keyframes_dir)
logger.error(f"关键帧目录内容: {all_files}")
raise Exception("未提取到任何关键帧文件,请检查视频文件格式")
update_progress(20, f"关键帧提取完成,共 {len(keyframe_files)}")
st.success(f"✅ 成功提取 {len(keyframe_files)} 个关键帧")
except Exception as e:
# 如果提取失败,清理创建的目录
try:
if os.path.exists(video_keyframes_dir):
import shutil
shutil.rmtree(video_keyframes_dir)
except Exception as cleanup_err:
logger.error(f"清理失败的关键帧目录时出错: {cleanup_err}")
raise Exception(f"关键帧提取失败: {str(e)}")
"""
2. 视觉分析(批量分析每一帧)
"""
# 最佳实践:使用 get() 的默认值参数 + 从 config 获取备用值
vision_llm_provider = (
st.session_state.get('vision_llm_provider') or
config.app.get('vision_llm_provider', 'openai')
st.session_state.get("vision_llm_provider") or config.app.get("vision_llm_provider", "openai")
).lower()
logger.info(f"使用 {vision_llm_provider.upper()} 进行视觉分析")
try:
# ===================初始化视觉分析器===================
update_progress(30, "正在初始化视觉分析器...")
# 使用统一的配置键格式获取配置(支持所有 provider
vision_api_key = (
st.session_state.get(f'vision_{vision_llm_provider}_api_key') or
config.app.get(f'vision_{vision_llm_provider}_api_key')
)
vision_model = (
st.session_state.get(f'vision_{vision_llm_provider}_model_name') or
config.app.get(f'vision_{vision_llm_provider}_model_name')
)
vision_base_url = (
st.session_state.get(f'vision_{vision_llm_provider}_base_url') or
config.app.get(f'vision_{vision_llm_provider}_base_url', '')
vision_api_key = (
st.session_state.get(f"vision_{vision_llm_provider}_api_key")
or config.app.get(f"vision_{vision_llm_provider}_api_key")
)
vision_model = (
st.session_state.get(f"vision_{vision_llm_provider}_model_name")
or config.app.get(f"vision_{vision_llm_provider}_model_name")
)
vision_base_url = (
st.session_state.get(f"vision_{vision_llm_provider}_base_url")
or config.app.get(f"vision_{vision_llm_provider}_base_url", "")
)
if not vision_api_key or not vision_model:
raise ValueError(
f"未配置 {vision_llm_provider} 的 API Key 或模型名称。"
f"请在设置页面配置 vision_{vision_llm_provider}_api_key 和 vision_{vision_llm_provider}_model_name"
)
# 验证必需配置
if not vision_api_key or not vision_model:
raise ValueError(
f"未配置 {vision_llm_provider} 的 API Key 或模型名称。"
f"请在设置页面配置 vision_{vision_llm_provider}_api_key 和 vision_{vision_llm_provider}_model_name"
)
frame_interval_input = st.session_state.get("frame_interval_input") or config.frames.get(
"frame_interval_input", 3
)
vision_batch_size = st.session_state.get("vision_batch_size") or config.frames.get("vision_batch_size", 10)
vision_max_concurrency = st.session_state.get("vision_max_concurrency") or config.frames.get(
"vision_max_concurrency", 2
)
# 创建视觉分析器实例(使用统一接口)
llm_params = {
"vision_provider": vision_llm_provider,
"vision_api_key": vision_api_key,
"vision_model_name": vision_model,
"vision_base_url": vision_base_url,
}
logger.debug(f"视觉分析器配置: provider={vision_llm_provider}, model={vision_model}")
analyzer = create_vision_analyzer(
provider=vision_llm_provider,
api_key=vision_api_key,
model=vision_model,
base_url=vision_base_url
update_progress(10, "正在提取关键帧...")
service = DocumentaryFrameAnalysisService()
script_items = asyncio.run(
service.generate_documentary_script(
video_path=params.video_origin_path,
video_theme=st.session_state.get("video_theme", ""),
custom_prompt=st.session_state.get("custom_prompt", ""),
frame_interval_input=frame_interval_input,
vision_batch_size=vision_batch_size,
vision_llm_provider=vision_llm_provider,
progress_callback=update_progress,
vision_api_key=vision_api_key,
vision_model_name=vision_model,
vision_base_url=vision_base_url,
max_concurrency=vision_max_concurrency,
)
)
update_progress(40, "正在分析关键帧...")
# ===================创建异步事件循环===================
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 执行异步分析
vision_batch_size = st.session_state.get('vision_batch_size') or config.frames.get("vision_batch_size")
vision_analysis_prompt = """
我提供了 %s 张视频帧它们按时间顺序排列代表一个连续的视频片段请仔细分析每一帧的内容并关注帧与帧之间的变化以理解整个片段的活动
首先请详细描述每一帧的关键视觉信息包含主要内容人物动作和场景
然后基于所有帧的分析请用**简洁的语言**总结整个视频片段中发生的主要活动或事件流程
请务必使用 JSON 格式输出你的结果JSON 结构应如下
{
"frame_observations": [
{
"frame_number": 1, // 或其他标识帧的方式
"observation": "描述每张视频帧中的主要内容、人物、动作和场景。"
},
// ... 更多帧的观察 ...
],
"overall_activity_summary": "在这里填写你总结的整个片段的主要活动,保持简洁。"
}
请务必不要遗漏视频帧我提供了 %s 张视频帧frame_observations 必须包含 %s 个元素
请只返回 JSON 字符串不要包含任何其他解释性文字
"""
results = loop.run_until_complete(
analyzer.analyze_images(
images=keyframe_files,
prompt=vision_analysis_prompt,
batch_size=vision_batch_size
)
)
loop.close()
"""
3. 处理分析结果格式化为 json 数据
"""
# ===================处理分析结果===================
update_progress(60, "正在整理分析结果...")
# 合并所有批次的分析结果
frame_analysis = ""
merged_frame_observations = [] # 合并所有批次的帧观察
overall_activity_summaries = [] # 合并所有批次的整体总结
prev_batch_files = None
frame_counter = 1 # 初始化帧计数器,用于给所有帧分配连续的序号
# 确保分析目录存在
analysis_dir = os.path.join(utils.storage_dir(), "temp", "analysis")
os.makedirs(analysis_dir, exist_ok=True)
origin_res = os.path.join(analysis_dir, "frame_analysis.json")
with open(origin_res, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
# 开始处理
for result in results:
if 'error' in result:
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
continue
# 获取当前批次的文件列表
batch_files = get_batch_files(keyframe_files, result, vision_batch_size)
# 获取批次的时间戳范围
first_timestamp, last_timestamp, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files)
# 解析响应中的JSON数据
response_text = result['response']
try:
# 处理可能包含```json```格式的响应
if "```json" in response_text:
json_content = response_text.split("```json")[1].split("```")[0].strip()
elif "```" in response_text:
json_content = response_text.split("```")[1].split("```")[0].strip()
else:
json_content = response_text.strip()
response_data = json.loads(json_content)
# 提取frame_observations和overall_activity_summary
if "frame_observations" in response_data:
frame_obs = response_data["frame_observations"]
overall_summary = response_data.get("overall_activity_summary", "")
# 添加时间戳信息到每个帧观察
for i, obs in enumerate(frame_obs):
if i < len(batch_files):
# 从文件名中提取时间戳
file_path = batch_files[i]
file_name = os.path.basename(file_path)
# 提取时间戳字符串 (格式如: keyframe_000675_000027000.jpg)
# 格式解析: keyframe_帧序号_毫秒时间戳.jpg
timestamp_parts = file_name.split('_')
if len(timestamp_parts) >= 3:
timestamp_str = timestamp_parts[-1].split('.')[0]
try:
# 修正时间戳解析逻辑
# 格式为000100000表示00:01:00,000即1分钟
# 需要按照对应位数进行解析:
# 前两位是小时,中间两位是分钟,后面是秒和毫秒
if len(timestamp_str) >= 9: # 确保格式正确
hours = int(timestamp_str[0:2])
minutes = int(timestamp_str[2:4])
seconds = int(timestamp_str[4:6])
milliseconds = int(timestamp_str[6:9])
# 计算总秒数
timestamp_seconds = hours * 3600 + minutes * 60 + seconds + milliseconds / 1000
formatted_time = utils.format_time(timestamp_seconds) # 格式化时间戳
else:
# 兼容旧的解析方式
timestamp_seconds = int(timestamp_str) / 1000 # 转换为秒
formatted_time = utils.format_time(timestamp_seconds) # 格式化时间戳
except ValueError:
logger.warning(f"无法解析时间戳: {timestamp_str}")
timestamp_seconds = 0
formatted_time = "00:00:00,000"
else:
logger.warning(f"文件名格式不符合预期: {file_name}")
timestamp_seconds = 0
formatted_time = "00:00:00,000"
# 添加额外信息到帧观察
obs["frame_path"] = file_path
obs["timestamp"] = formatted_time
obs["timestamp_seconds"] = timestamp_seconds
obs["batch_index"] = result['batch_index']
# 使用全局递增的帧计数器替换原始的frame_number
if "frame_number" in obs:
obs["original_frame_number"] = obs["frame_number"] # 保留原始编号作为参考
obs["frame_number"] = frame_counter # 赋值连续的帧编号
frame_counter += 1 # 增加帧计数器
# 添加到合并列表
merged_frame_observations.append(obs)
# 添加批次整体总结信息
if overall_summary:
# 从文件名中提取时间戳数值
first_time_str = first_timestamp.split('_')[-1].split('.')[0]
last_time_str = last_timestamp.split('_')[-1].split('.')[0]
# 转换为毫秒并计算持续时间(秒)
try:
# 修正解析逻辑,与上面相同的方式解析时间戳
if len(first_time_str) >= 9 and len(last_time_str) >= 9:
# 解析第一个时间戳
first_hours = int(first_time_str[0:2])
first_minutes = int(first_time_str[2:4])
first_seconds = int(first_time_str[4:6])
first_ms = int(first_time_str[6:9])
first_time_seconds = first_hours * 3600 + first_minutes * 60 + first_seconds + first_ms / 1000
# 解析第二个时间戳
last_hours = int(last_time_str[0:2])
last_minutes = int(last_time_str[2:4])
last_seconds = int(last_time_str[4:6])
last_ms = int(last_time_str[6:9])
last_time_seconds = last_hours * 3600 + last_minutes * 60 + last_seconds + last_ms / 1000
batch_duration = last_time_seconds - first_time_seconds
else:
# 兼容旧的解析方式
first_time_ms = int(first_time_str)
last_time_ms = int(last_time_str)
batch_duration = (last_time_ms - first_time_ms) / 1000
except ValueError:
# 使用 utils.time_to_seconds 函数处理格式化的时间戳
first_time_seconds = utils.time_to_seconds(first_time_str.replace('_', ':').replace('-', ','))
last_time_seconds = utils.time_to_seconds(last_time_str.replace('_', ':').replace('-', ','))
batch_duration = last_time_seconds - first_time_seconds
overall_activity_summaries.append({
"batch_index": result['batch_index'],
"time_range": f"{first_timestamp}-{last_timestamp}",
"duration_seconds": batch_duration,
"summary": overall_summary
})
except Exception as e:
logger.error(f"解析批次 {result['batch_index']} 的响应数据失败: {str(e)}")
# 添加原始响应作为回退
frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n"
frame_analysis += response_text
frame_analysis += "\n"
# 更新上一个批次的文件
prev_batch_files = batch_files
# 将合并后的结果转为JSON字符串
merged_results = {
"frame_observations": merged_frame_observations,
"overall_activity_summaries": overall_activity_summaries
}
# 使用当前时间创建文件名
now = datetime.now()
timestamp_str = now.strftime("%Y%m%d_%H%M")
# 保存完整的分析结果为JSON
analysis_filename = f"frame_analysis_{timestamp_str}.json"
analysis_json_path = os.path.join(analysis_dir, analysis_filename)
with open(analysis_json_path, 'w', encoding='utf-8') as f:
json.dump(merged_results, f, ensure_ascii=False, indent=2)
logger.info(f"分析结果已保存到: {analysis_json_path}")
"""
4. 生成文案
"""
logger.info("开始生成解说文案")
update_progress(80, "正在生成解说文案...")
from app.services.generate_narration_script import parse_frame_analysis_to_markdown, generate_narration
# 从配置中获取文本生成相关配置
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
text_api_key = config.app.get(f'text_{text_provider}_api_key')
text_model = config.app.get(f'text_{text_provider}_model_name')
text_base_url = config.app.get(f'text_{text_provider}_base_url')
llm_params.update({
"text_provider": text_provider,
"text_api_key": text_api_key,
"text_model_name": text_model,
"text_base_url": text_base_url
})
# 整理帧分析数据
markdown_output = parse_frame_analysis_to_markdown(analysis_json_path)
# 生成解说文案
narration = generate_narration(
markdown_output,
text_api_key,
base_url=text_base_url,
model=text_model
)
# 使用增强的JSON解析器
from webui.tools.generate_short_summary import parse_and_fix_json
narration_data = parse_and_fix_json(narration)
if not narration_data or 'items' not in narration_data:
logger.error(f"解说文案JSON解析失败原始内容: {narration[:200]}...")
raise Exception("解说文案格式错误无法解析JSON或缺少items字段")
narration_dict = narration_data['items']
# 为 narration_dict 中每个 item 新增一个 OST: 2 的字段, 代表保留原声和配音
narration_dict = [{**item, "OST": 2} for item in narration_dict]
logger.info(f"解说文案生成完成,共 {len(narration_dict)} 个片段")
# 结果转换为JSON字符串
script = json.dumps(narration_dict, ensure_ascii=False, indent=2)
except Exception as e:
logger.exception(f"大模型处理过程中发生错误\n{traceback.format_exc()}")
raise Exception(f"分析失败: {str(e)}")
if script is None:
st.error("生成脚本失败,请检查日志")
st.stop()
logger.info(f"纪录片解说脚本生成完成")
logger.info(f"纪录片解说脚本生成完成,共 {len(script_items)} 个片段")
script = json.dumps(script_items, ensure_ascii=False, indent=2)
if isinstance(script, list):
st.session_state['video_clip_json'] = script
st.session_state["video_clip_json"] = script
elif isinstance(script, str):
st.session_state['video_clip_json'] = json.loads(script)
st.session_state["video_clip_json"] = json.loads(script)
update_progress(100, "脚本生成完成")
time.sleep(0.1)