mirror of
https://github.com/linyqh/NarratoAI.git
synced 2026-06-28 18:22:04 +00:00
Merge pull request #258 from mohit-twelvelabs/feat/twelvelabs-integration
feat(llm): 新增 TwelveLabs Pegasus 视频理解视觉提供商(可选) / add optional TwelveLabs Pegasus vision provider
This commit is contained in:
commit
18c9ff81d2
@ -72,6 +72,7 @@ Below is a screenshot of this person's x (Twitter) homepage
|
||||
- [x] Optimized the story generation process and improved the generation effect
|
||||
- [x] Released version 0.3.5 integration package
|
||||
- [x] Support Alibaba Qwen2-VL large model for video understanding
|
||||
- [x] Support TwelveLabs Pegasus as an optional video-understanding backend (analyzes footage natively to drive highlight selection and commentary; opt-in, set `vision_llm_provider = "twelvelabs"`)
|
||||
- [x] Support short drama commentary
|
||||
- [x] One-click merge materials
|
||||
- [x] One-click transcription
|
||||
|
||||
@ -93,6 +93,7 @@ _**1. NarratoAI 是一款完全免费的软件,近期在社交媒体(抖音,B
|
||||
- [x] 优化剧情生成流程,提升生成效果
|
||||
- [x] 发布 0.3.5 整合包
|
||||
- [x] 支持阿里 Qwen2-VL 大模型理解视频
|
||||
- [x] 支持 TwelveLabs Pegasus 作为可选的视频理解后端(原生理解整段画面以挑选高光、生成解说;可选启用,设置 `vision_llm_provider = "twelvelabs"`)
|
||||
- [x] 支持短剧混剪
|
||||
- [x] 一键合并素材
|
||||
- [x] 一键转录
|
||||
|
||||
@ -12,7 +12,7 @@ def register_all_providers():
|
||||
"""
|
||||
注册所有提供商
|
||||
|
||||
当前实现:只注册 OpenAI 兼容统一接口
|
||||
当前实现:注册 OpenAI 兼容统一接口,并可选注册 TwelveLabs Pegasus 视频理解。
|
||||
"""
|
||||
# 在函数内部导入,避免循环依赖
|
||||
from ..manager import LLMServiceManager
|
||||
@ -32,6 +32,14 @@ def register_all_providers():
|
||||
|
||||
logger.info("✅ OpenAI 兼容提供商注册完成")
|
||||
|
||||
# ===== 注册 TwelveLabs Pegasus 视频理解(可选视觉提供商)=====
|
||||
# 仅当用户将 vision_llm_provider 设为 "twelvelabs" 时启用;默认行为保持不变。
|
||||
from ..twelvelabs_provider import TwelveLabsVisionProvider
|
||||
|
||||
LLMServiceManager.register_vision_provider('twelvelabs', TwelveLabsVisionProvider)
|
||||
|
||||
logger.info("✅ TwelveLabs Pegasus 视觉提供商注册完成")
|
||||
|
||||
|
||||
# 导出注册函数
|
||||
__all__ = [
|
||||
|
||||
@ -45,11 +45,12 @@ class OpenAICompatManagerTests(unittest.TestCase):
|
||||
config.app.clear()
|
||||
config.app.update(self._original_app)
|
||||
|
||||
def test_register_all_providers_only_registers_openai_provider(self):
|
||||
def test_register_all_providers_registers_expected_providers(self):
|
||||
register_all_providers()
|
||||
|
||||
# 文本仅 OpenAI 兼容;视觉额外提供可选的 TwelveLabs Pegasus。
|
||||
self.assertEqual({"openai"}, set(LLMServiceManager.list_text_providers()))
|
||||
self.assertEqual({"openai"}, set(LLMServiceManager.list_vision_providers()))
|
||||
self.assertEqual({"openai", "twelvelabs"}, set(LLMServiceManager.list_vision_providers()))
|
||||
|
||||
def test_get_text_provider_uses_openai_keys(self):
|
||||
LLMServiceManager.register_text_provider("openai", DummyOpenAITextProvider)
|
||||
|
||||
91
app/services/llm/test_twelvelabs_provider_unittest.py
Normal file
91
app/services/llm/test_twelvelabs_provider_unittest.py
Normal file
@ -0,0 +1,91 @@
|
||||
"""TwelveLabs Pegasus 视觉 provider 的最小回归测试。
|
||||
|
||||
- 无网络单元测试:mock SDK 与 ffmpeg,校验 provider 把关键帧批次转成 Pegasus 文本,
|
||||
并正确执行 max_tokens 下限、批次降级与 Asset 清理。
|
||||
- 可选 live 测试:仅在设置 TWELVELABS_API_KEY 时运行,验证真实 SDK 契约。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import PIL.Image
|
||||
|
||||
from app.config import config
|
||||
from app.services.llm.manager import LLMServiceManager
|
||||
from app.services.llm.providers import register_all_providers
|
||||
from app.services.llm.twelvelabs_provider import TwelveLabsVisionProvider
|
||||
|
||||
|
||||
def _make_provider() -> TwelveLabsVisionProvider:
|
||||
# _resolve_ffmpeg 在 _initialize 中执行,patch shutil.which 让其在无 ffmpeg 环境也可构建。
|
||||
with patch("app.services.llm.twelvelabs_provider.shutil.which", return_value="/usr/bin/ffmpeg"):
|
||||
return TwelveLabsVisionProvider(api_key="test-key", model_name="pegasus1.5")
|
||||
|
||||
|
||||
class TwelveLabsProviderUnitTests(unittest.TestCase):
|
||||
def test_registered_as_vision_provider(self):
|
||||
LLMServiceManager._vision_providers.clear()
|
||||
LLMServiceManager._text_providers.clear()
|
||||
register_all_providers()
|
||||
self.assertIn("twelvelabs", LLMServiceManager.list_vision_providers())
|
||||
|
||||
def test_resolve_max_tokens_enforces_floor(self):
|
||||
provider = _make_provider()
|
||||
# 低于 Pegasus 下限 512 时被抬到 512。
|
||||
self.assertEqual(512, provider._resolve_max_tokens(10))
|
||||
self.assertEqual(2048, provider._resolve_max_tokens(2048))
|
||||
|
||||
def test_analyze_images_returns_pegasus_text(self):
|
||||
provider = _make_provider()
|
||||
|
||||
# 伪造 SDK client:上传 -> ready -> analyze 返回文本。
|
||||
fake_client = MagicMock()
|
||||
fake_client.assets.create.return_value = MagicMock(id="asset-1")
|
||||
fake_client.assets.retrieve.return_value = MagicMock(status="ready")
|
||||
fake_client.analyze.return_value = MagicMock(data="A red frame fades to blue.", finish_reason="stop")
|
||||
|
||||
img = PIL.Image.new("RGB", (64, 64), (200, 30, 30))
|
||||
|
||||
with patch.object(provider, "_build_client", return_value=fake_client), \
|
||||
patch.object(provider, "_frames_to_clip", return_value="/tmp/clip.mp4"), \
|
||||
patch("app.services.llm.twelvelabs_provider.os.path.getsize", return_value=1234), \
|
||||
patch("builtins.open", MagicMock()):
|
||||
results = asyncio.run(provider.analyze_images(images=[img, img], prompt="describe", batch_size=10))
|
||||
|
||||
self.assertEqual(["A red frame fades to blue."], results)
|
||||
fake_client.analyze.assert_called_once()
|
||||
# 调用使用配置的模型与 >=512 的 max_tokens。
|
||||
_, kwargs = fake_client.analyze.call_args
|
||||
self.assertEqual("pegasus1.5", kwargs["model_name"])
|
||||
self.assertGreaterEqual(kwargs["max_tokens"], 512)
|
||||
# 远端 Asset 被清理。
|
||||
fake_client.assets.delete.assert_called_once_with(asset_id="asset-1")
|
||||
|
||||
def test_analyze_images_degrades_on_batch_error(self):
|
||||
provider = _make_provider()
|
||||
with patch.object(provider, "_analyze_batch_sync", side_effect=RuntimeError("boom")):
|
||||
img = PIL.Image.new("RGB", (64, 64), (0, 0, 0))
|
||||
results = asyncio.run(provider.analyze_images(images=[img], prompt="p", batch_size=10))
|
||||
self.assertEqual(1, len(results))
|
||||
self.assertIn("批次处理失败", results[0])
|
||||
|
||||
|
||||
class TwelveLabsProviderLiveTests(unittest.TestCase):
|
||||
"""需要真实 API Key 与 ffmpeg;未配置时跳过。"""
|
||||
|
||||
@unittest.skipUnless(os.getenv("TWELVELABS_API_KEY"), "TWELVELABS_API_KEY 未设置,跳过 live 测试")
|
||||
def test_live_keyframe_analysis_returns_text(self):
|
||||
provider = TwelveLabsVisionProvider(
|
||||
api_key=os.environ["TWELVELABS_API_KEY"], model_name="pegasus1.5"
|
||||
)
|
||||
frames = [PIL.Image.new("RGB", (640, 360), c) for c in [(220, 40, 40), (40, 180, 60), (40, 80, 220)]]
|
||||
results = asyncio.run(provider.analyze_images(images=frames, prompt="Describe what is shown.", batch_size=10))
|
||||
self.assertEqual(1, len(results))
|
||||
self.assertTrue(results[0].strip())
|
||||
self.assertNotIn("批次处理失败", results[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
253
app/services/llm/twelvelabs_provider.py
Normal file
253
app/services/llm/twelvelabs_provider.py
Normal file
@ -0,0 +1,253 @@
|
||||
"""
|
||||
TwelveLabs 视觉模型提供商实现
|
||||
|
||||
使用 TwelveLabs 官方 Python SDK 调用 Pegasus 视频理解模型。
|
||||
|
||||
与其它视觉提供商(OpenAI 兼容接口)不同,Pegasus 是一个原生的*视频*理解模型,
|
||||
而非逐帧图像模型。为了在不改动现有调用方(关键帧批次 -> 文本描述)的前提下接入,
|
||||
本提供商把每个关键帧批次用 ffmpeg 组装成一段短视频片段,上传为 TwelveLabs Asset,
|
||||
再调用 Pegasus 进行分析,返回与其它视觉提供商一致的文本结果。
|
||||
|
||||
这是一个**可选**的视觉提供商:仅当 `vision_llm_provider = "twelvelabs"` 时才会启用,
|
||||
默认行为保持不变。未配置 TwelveLabs API Key 时,整套流程与之前完全一致。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import PIL.Image
|
||||
from loguru import logger
|
||||
|
||||
from app.config import config
|
||||
from .base import VisionModelProvider
|
||||
from .exceptions import APICallError, AuthenticationError, ConfigurationError, RateLimitError
|
||||
|
||||
# Pegasus 对分析窗口的硬性要求:最短 4 秒。
|
||||
_MIN_CLIP_SECONDS = 4
|
||||
# Pegasus 1.5 同步分析对 max_tokens 的有效区间为 [512, 98304]。
|
||||
_MIN_MAX_TOKENS = 512
|
||||
# 本地直传 Asset 的体积上限(method="direct",约 200MB)。关键帧拼接的短片远小于此值。
|
||||
_DIRECT_UPLOAD_LIMIT_BYTES = 200 * 1024 * 1024
|
||||
|
||||
|
||||
class TwelveLabsVisionProvider(VisionModelProvider):
|
||||
"""TwelveLabs Pegasus 视频理解提供商。"""
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "twelvelabs"
|
||||
|
||||
@property
|
||||
def supported_models(self) -> List[str]:
|
||||
return ["pegasus1.5", "pegasus1.2"]
|
||||
|
||||
def _validate_model_support(self):
|
||||
# Pegasus 模型列表稳定,保持宽松校验(与基类一致,仅记录警告)。
|
||||
if self.model_name not in self.supported_models:
|
||||
logger.warning(
|
||||
f"模型 {self.model_name} 不在 TwelveLabs 预定义列表中,"
|
||||
f"将按原样传递给 API。支持的模型: {self.supported_models}"
|
||||
)
|
||||
|
||||
def _initialize(self):
|
||||
# SDK client 按请求构建,这里仅校验 ffmpeg 可用性(拼接关键帧片段需要)。
|
||||
self._ffmpeg_bin = self._resolve_ffmpeg()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_ffmpeg() -> str:
|
||||
configured = (config.app.get("ffmpeg_path") or "").strip()
|
||||
if configured:
|
||||
return configured
|
||||
found = shutil.which("ffmpeg")
|
||||
if not found:
|
||||
raise ConfigurationError(
|
||||
"TwelveLabs 提供商需要 ffmpeg 将关键帧拼接为视频片段,但未找到 ffmpeg。"
|
||||
"请安装 ffmpeg 或在配置中设置 ffmpeg_path。",
|
||||
"ffmpeg_path",
|
||||
)
|
||||
return found
|
||||
|
||||
def _build_client(self):
|
||||
try:
|
||||
from twelvelabs import TwelveLabs
|
||||
except ImportError as exc: # pragma: no cover - 仅在缺少可选依赖时触发
|
||||
raise ConfigurationError(
|
||||
"未安装 twelvelabs SDK。请运行 `pip install twelvelabs>=1.2.8` 后重试。",
|
||||
"twelvelabs",
|
||||
) from exc
|
||||
return TwelveLabs(api_key=self.api_key)
|
||||
|
||||
async def analyze_images(
|
||||
self,
|
||||
images: List[Union[str, Path, PIL.Image.Image]],
|
||||
prompt: str,
|
||||
batch_size: int = 10,
|
||||
max_concurrency: int = 1,
|
||||
**kwargs,
|
||||
) -> List[str]:
|
||||
logger.info(
|
||||
f"开始使用 TwelveLabs Pegasus ({self.model_name}) 分析 {len(images)} 张关键帧"
|
||||
)
|
||||
|
||||
processed_images = self._prepare_images(images)
|
||||
if not processed_images:
|
||||
return []
|
||||
|
||||
bounded_concurrency = max(1, int(max_concurrency))
|
||||
semaphore = asyncio.Semaphore(bounded_concurrency)
|
||||
batches = [
|
||||
(index // batch_size, processed_images[index : index + batch_size])
|
||||
for index in range(0, len(processed_images), batch_size)
|
||||
]
|
||||
|
||||
max_tokens = self._resolve_max_tokens(kwargs.get("max_tokens"))
|
||||
|
||||
async def run_batch(batch_index: int, batch: List[PIL.Image.Image]) -> tuple[int, str]:
|
||||
logger.info(f"处理第 {batch_index + 1} 批,共 {len(batch)} 张关键帧")
|
||||
async with semaphore:
|
||||
try:
|
||||
# SDK 为同步实现,放到线程池中执行以免阻塞事件循环。
|
||||
result = await asyncio.to_thread(
|
||||
self._analyze_batch_sync, batch, prompt, max_tokens
|
||||
)
|
||||
return batch_index, result
|
||||
except Exception as exc: # 与其它 provider 保持一致:批次级降级,不整体失败。
|
||||
logger.error(f"批次 {batch_index + 1} 处理失败: {exc}")
|
||||
return batch_index, f"批次处理失败: {exc}"
|
||||
|
||||
completed = await asyncio.gather(
|
||||
*(run_batch(index, batch) for index, batch in batches)
|
||||
)
|
||||
completed.sort(key=lambda item: item[0])
|
||||
return [result for _, result in completed]
|
||||
|
||||
def _resolve_max_tokens(self, override: Any) -> int:
|
||||
configured = override if override is not None else config.app.get(
|
||||
"vision_twelvelabs_max_tokens", 1024
|
||||
)
|
||||
try:
|
||||
value = int(configured)
|
||||
except (TypeError, ValueError):
|
||||
value = 1024
|
||||
return max(_MIN_MAX_TOKENS, value)
|
||||
|
||||
def _analyze_batch_sync(
|
||||
self, batch: List[PIL.Image.Image], prompt: str, max_tokens: int
|
||||
) -> str:
|
||||
"""把一批关键帧拼成短视频,上传为 Asset,调用 Pegasus 分析后返回文本。"""
|
||||
from twelvelabs.types.video_context import VideoContext_AssetId
|
||||
from twelvelabs.errors import (
|
||||
BadRequestError,
|
||||
ForbiddenError,
|
||||
TooManyRequestsError,
|
||||
)
|
||||
|
||||
client = self._build_client()
|
||||
asset_id: Optional[str] = None
|
||||
|
||||
with tempfile.TemporaryDirectory(prefix="tl_pegasus_") as tmp_dir:
|
||||
clip_path = self._frames_to_clip(batch, tmp_dir)
|
||||
size = os.path.getsize(clip_path)
|
||||
if size > _DIRECT_UPLOAD_LIMIT_BYTES:
|
||||
raise APICallError(
|
||||
f"拼接片段过大({size} 字节),超过直传上限 {_DIRECT_UPLOAD_LIMIT_BYTES} 字节。"
|
||||
"请减小 vision_batch_size 或降低关键帧分辨率。"
|
||||
)
|
||||
|
||||
try:
|
||||
with open(clip_path, "rb") as fh:
|
||||
asset = client.assets.create(
|
||||
method="direct", file=fh, filename="keyframes.mp4"
|
||||
)
|
||||
asset_id = asset.id
|
||||
self._wait_for_asset_ready(client, asset_id)
|
||||
|
||||
response = client.analyze(
|
||||
model_name=self.model_name,
|
||||
video=VideoContext_AssetId(asset_id=asset_id),
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
text = (response.data or "").strip()
|
||||
if not text:
|
||||
raise APICallError("TwelveLabs Pegasus 返回空响应")
|
||||
return text
|
||||
except ForbiddenError as exc:
|
||||
raise AuthenticationError(str(exc))
|
||||
except TooManyRequestsError as exc:
|
||||
raise RateLimitError(str(exc))
|
||||
except BadRequestError as exc:
|
||||
raise APICallError(f"请求错误: {getattr(exc, 'body', exc)}")
|
||||
finally:
|
||||
# 尽力清理远端 Asset,避免占用配额。
|
||||
if asset_id:
|
||||
try:
|
||||
client.assets.delete(asset_id=asset_id)
|
||||
except Exception as exc: # pragma: no cover - 清理失败不影响结果
|
||||
logger.debug(f"清理 TwelveLabs Asset 失败 {asset_id}: {exc}")
|
||||
|
||||
def _frames_to_clip(self, batch: List[PIL.Image.Image], tmp_dir: str) -> str:
|
||||
"""用 ffmpeg 把关键帧序列拼成 >= 4s 的视频片段(满足 Pegasus 最短窗口要求)。"""
|
||||
frame_paths: List[str] = []
|
||||
for idx, img in enumerate(batch):
|
||||
frame_path = os.path.join(tmp_dir, f"frame_{idx:04d}.jpg")
|
||||
img.convert("RGB").save(frame_path, format="JPEG", quality=85)
|
||||
frame_paths.append(frame_path)
|
||||
|
||||
# 每帧停留时长,保证总时长不少于 _MIN_CLIP_SECONDS。
|
||||
per_frame_seconds = max(1.0, _MIN_CLIP_SECONDS / max(1, len(frame_paths)))
|
||||
list_file = os.path.join(tmp_dir, "frames.txt")
|
||||
with open(list_file, "w", encoding="utf-8") as fh:
|
||||
for frame_path in frame_paths:
|
||||
fh.write(f"file '{frame_path}'\n")
|
||||
fh.write(f"duration {per_frame_seconds}\n")
|
||||
# concat demuxer 需要重复最后一帧才能让其显示完整时长。
|
||||
fh.write(f"file '{frame_paths[-1]}'\n")
|
||||
|
||||
clip_path = os.path.join(tmp_dir, "clip.mp4")
|
||||
cmd = [
|
||||
self._ffmpeg_bin,
|
||||
"-y",
|
||||
"-loglevel", "error",
|
||||
"-f", "concat",
|
||||
"-safe", "0",
|
||||
"-i", list_file,
|
||||
# 强制偶数尺寸 + yuv420p,保证 H.264 兼容。
|
||||
"-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2",
|
||||
"-pix_fmt", "yuv420p",
|
||||
"-r", "24",
|
||||
"-c:v", "libx264",
|
||||
clip_path,
|
||||
]
|
||||
try:
|
||||
subprocess.run(cmd, capture_output=True, text=True, check=True, timeout=120)
|
||||
except subprocess.CalledProcessError as exc:
|
||||
raise APICallError(f"ffmpeg 拼接关键帧失败: {exc.stderr or exc}")
|
||||
except subprocess.TimeoutExpired:
|
||||
raise APICallError("ffmpeg 拼接关键帧超时")
|
||||
return clip_path
|
||||
|
||||
@staticmethod
|
||||
def _wait_for_asset_ready(client, asset_id: str, timeout_seconds: int = 180) -> None:
|
||||
"""轮询等待上传的 Asset 进入 ready 状态。"""
|
||||
import time
|
||||
|
||||
deadline = time.monotonic() + timeout_seconds
|
||||
while time.monotonic() < deadline:
|
||||
asset = client.assets.retrieve(asset_id=asset_id)
|
||||
status = (asset.status or "").lower()
|
||||
if status == "ready":
|
||||
return
|
||||
if status == "failed":
|
||||
raise APICallError(f"TwelveLabs Asset 处理失败: {asset_id}")
|
||||
time.sleep(3)
|
||||
raise APICallError(f"等待 TwelveLabs Asset 就绪超时: {asset_id}")
|
||||
|
||||
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# 本提供商直接使用官方 SDK,不走通用 payload 通道。
|
||||
return payload
|
||||
@ -32,6 +32,16 @@
|
||||
vision_openai_max_tokens = 65536
|
||||
vision_openai_thinking_level = "auto" # auto/off/low/medium/high
|
||||
|
||||
# ===== 可选:TwelveLabs Pegasus 视频理解 =====
|
||||
# 将 vision_llm_provider 改为 "twelvelabs" 即可启用(默认仍为 openai,不影响现有行为)。
|
||||
# Pegasus 是原生的视频理解模型,会把每批关键帧拼成短片后整体理解,更擅长把握镜头内的
|
||||
# 动作、时序与高光片段,从而生成更贴合画面的解说。需要本地可用的 ffmpeg。
|
||||
# 免费 API Key 获取地址:https://twelvelabs.io (有较充裕的免费额度)
|
||||
# vision_llm_provider = "twelvelabs"
|
||||
vision_twelvelabs_api_key = "" # 填入 TwelveLabs API Key
|
||||
vision_twelvelabs_model_name = "pegasus1.5" # pegasus1.5 / pegasus1.2
|
||||
vision_twelvelabs_max_tokens = 1024 # Pegasus 1.5 有效区间 512-98304
|
||||
|
||||
# ===== 文本模型配置 =====
|
||||
text_llm_provider = "openai"
|
||||
|
||||
|
||||
@ -25,6 +25,9 @@ tqdm>=4.66.6
|
||||
tenacity>=9.0.0
|
||||
|
||||
# 可选依赖(根据功能需要)
|
||||
# 如果使用 TwelveLabs Pegasus 视频理解作为视觉提供商,取消注释下面的行
|
||||
# twelvelabs>=1.2.8
|
||||
|
||||
# 如果需要本地语音识别,取消注释下面的行
|
||||
# faster-whisper>=1.0.1
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user