diff --git a/README-en.md b/README-en.md index 5c57ef6..9e9d481 100644 --- a/README-en.md +++ b/README-en.md @@ -72,6 +72,7 @@ Below is a screenshot of this person's x (Twitter) homepage - [x] Optimized the story generation process and improved the generation effect - [x] Released version 0.3.5 integration package - [x] Support Alibaba Qwen2-VL large model for video understanding +- [x] Support TwelveLabs Pegasus as an optional video-understanding backend (analyzes footage natively to drive highlight selection and commentary; opt-in, set `vision_llm_provider = "twelvelabs"`) - [x] Support short drama commentary - [x] One-click merge materials - [x] One-click transcription diff --git a/README.md b/README.md index 03e4dac..515ab26 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,7 @@ _**1. NarratoAI 是一款完全免费的软件,近期在社交媒体(抖音,B - [x] 优化剧情生成流程,提升生成效果 - [x] 发布 0.3.5 整合包 - [x] 支持阿里 Qwen2-VL 大模型理解视频 +- [x] 支持 TwelveLabs Pegasus 作为可选的视频理解后端(原生理解整段画面以挑选高光、生成解说;可选启用,设置 `vision_llm_provider = "twelvelabs"`) - [x] 支持短剧混剪 - [x] 一键合并素材 - [x] 一键转录 diff --git a/app/services/llm/providers/__init__.py b/app/services/llm/providers/__init__.py index d8ecc65..cf78ac5 100644 --- a/app/services/llm/providers/__init__.py +++ b/app/services/llm/providers/__init__.py @@ -12,7 +12,7 @@ def register_all_providers(): """ 注册所有提供商 - 当前实现:只注册 OpenAI 兼容统一接口 + 当前实现:注册 OpenAI 兼容统一接口,并可选注册 TwelveLabs Pegasus 视频理解。 """ # 在函数内部导入,避免循环依赖 from ..manager import LLMServiceManager @@ -32,6 +32,14 @@ def register_all_providers(): logger.info("✅ OpenAI 兼容提供商注册完成") + # ===== 注册 TwelveLabs Pegasus 视频理解(可选视觉提供商)===== + # 仅当用户将 vision_llm_provider 设为 "twelvelabs" 时启用;默认行为保持不变。 + from ..twelvelabs_provider import TwelveLabsVisionProvider + + LLMServiceManager.register_vision_provider('twelvelabs', TwelveLabsVisionProvider) + + logger.info("✅ TwelveLabs Pegasus 视觉提供商注册完成") + # 导出注册函数 __all__ = [ diff --git a/app/services/llm/test_openai_compat_unittest.py b/app/services/llm/test_openai_compat_unittest.py index 14b3ab1..f6f12c5 100644 --- a/app/services/llm/test_openai_compat_unittest.py +++ b/app/services/llm/test_openai_compat_unittest.py @@ -45,11 +45,12 @@ class OpenAICompatManagerTests(unittest.TestCase): config.app.clear() config.app.update(self._original_app) - def test_register_all_providers_only_registers_openai_provider(self): + def test_register_all_providers_registers_expected_providers(self): register_all_providers() + # 文本仅 OpenAI 兼容;视觉额外提供可选的 TwelveLabs Pegasus。 self.assertEqual({"openai"}, set(LLMServiceManager.list_text_providers())) - self.assertEqual({"openai"}, set(LLMServiceManager.list_vision_providers())) + self.assertEqual({"openai", "twelvelabs"}, set(LLMServiceManager.list_vision_providers())) def test_get_text_provider_uses_openai_keys(self): LLMServiceManager.register_text_provider("openai", DummyOpenAITextProvider) diff --git a/app/services/llm/test_twelvelabs_provider_unittest.py b/app/services/llm/test_twelvelabs_provider_unittest.py new file mode 100644 index 0000000..5b2245e --- /dev/null +++ b/app/services/llm/test_twelvelabs_provider_unittest.py @@ -0,0 +1,91 @@ +"""TwelveLabs Pegasus 视觉 provider 的最小回归测试。 + +- 无网络单元测试:mock SDK 与 ffmpeg,校验 provider 把关键帧批次转成 Pegasus 文本, + 并正确执行 max_tokens 下限、批次降级与 Asset 清理。 +- 可选 live 测试:仅在设置 TWELVELABS_API_KEY 时运行,验证真实 SDK 契约。 +""" + +import asyncio +import os +import unittest +from unittest.mock import MagicMock, patch + +import PIL.Image + +from app.config import config +from app.services.llm.manager import LLMServiceManager +from app.services.llm.providers import register_all_providers +from app.services.llm.twelvelabs_provider import TwelveLabsVisionProvider + + +def _make_provider() -> TwelveLabsVisionProvider: + # _resolve_ffmpeg 在 _initialize 中执行,patch shutil.which 让其在无 ffmpeg 环境也可构建。 + with patch("app.services.llm.twelvelabs_provider.shutil.which", return_value="/usr/bin/ffmpeg"): + return TwelveLabsVisionProvider(api_key="test-key", model_name="pegasus1.5") + + +class TwelveLabsProviderUnitTests(unittest.TestCase): + def test_registered_as_vision_provider(self): + LLMServiceManager._vision_providers.clear() + LLMServiceManager._text_providers.clear() + register_all_providers() + self.assertIn("twelvelabs", LLMServiceManager.list_vision_providers()) + + def test_resolve_max_tokens_enforces_floor(self): + provider = _make_provider() + # 低于 Pegasus 下限 512 时被抬到 512。 + self.assertEqual(512, provider._resolve_max_tokens(10)) + self.assertEqual(2048, provider._resolve_max_tokens(2048)) + + def test_analyze_images_returns_pegasus_text(self): + provider = _make_provider() + + # 伪造 SDK client:上传 -> ready -> analyze 返回文本。 + fake_client = MagicMock() + fake_client.assets.create.return_value = MagicMock(id="asset-1") + fake_client.assets.retrieve.return_value = MagicMock(status="ready") + fake_client.analyze.return_value = MagicMock(data="A red frame fades to blue.", finish_reason="stop") + + img = PIL.Image.new("RGB", (64, 64), (200, 30, 30)) + + with patch.object(provider, "_build_client", return_value=fake_client), \ + patch.object(provider, "_frames_to_clip", return_value="/tmp/clip.mp4"), \ + patch("app.services.llm.twelvelabs_provider.os.path.getsize", return_value=1234), \ + patch("builtins.open", MagicMock()): + results = asyncio.run(provider.analyze_images(images=[img, img], prompt="describe", batch_size=10)) + + self.assertEqual(["A red frame fades to blue."], results) + fake_client.analyze.assert_called_once() + # 调用使用配置的模型与 >=512 的 max_tokens。 + _, kwargs = fake_client.analyze.call_args + self.assertEqual("pegasus1.5", kwargs["model_name"]) + self.assertGreaterEqual(kwargs["max_tokens"], 512) + # 远端 Asset 被清理。 + fake_client.assets.delete.assert_called_once_with(asset_id="asset-1") + + def test_analyze_images_degrades_on_batch_error(self): + provider = _make_provider() + with patch.object(provider, "_analyze_batch_sync", side_effect=RuntimeError("boom")): + img = PIL.Image.new("RGB", (64, 64), (0, 0, 0)) + results = asyncio.run(provider.analyze_images(images=[img], prompt="p", batch_size=10)) + self.assertEqual(1, len(results)) + self.assertIn("批次处理失败", results[0]) + + +class TwelveLabsProviderLiveTests(unittest.TestCase): + """需要真实 API Key 与 ffmpeg;未配置时跳过。""" + + @unittest.skipUnless(os.getenv("TWELVELABS_API_KEY"), "TWELVELABS_API_KEY 未设置,跳过 live 测试") + def test_live_keyframe_analysis_returns_text(self): + provider = TwelveLabsVisionProvider( + api_key=os.environ["TWELVELABS_API_KEY"], model_name="pegasus1.5" + ) + frames = [PIL.Image.new("RGB", (640, 360), c) for c in [(220, 40, 40), (40, 180, 60), (40, 80, 220)]] + results = asyncio.run(provider.analyze_images(images=frames, prompt="Describe what is shown.", batch_size=10)) + self.assertEqual(1, len(results)) + self.assertTrue(results[0].strip()) + self.assertNotIn("批次处理失败", results[0]) + + +if __name__ == "__main__": + unittest.main() diff --git a/app/services/llm/twelvelabs_provider.py b/app/services/llm/twelvelabs_provider.py new file mode 100644 index 0000000..8ea6d17 --- /dev/null +++ b/app/services/llm/twelvelabs_provider.py @@ -0,0 +1,253 @@ +""" +TwelveLabs 视觉模型提供商实现 + +使用 TwelveLabs 官方 Python SDK 调用 Pegasus 视频理解模型。 + +与其它视觉提供商(OpenAI 兼容接口)不同,Pegasus 是一个原生的*视频*理解模型, +而非逐帧图像模型。为了在不改动现有调用方(关键帧批次 -> 文本描述)的前提下接入, +本提供商把每个关键帧批次用 ffmpeg 组装成一段短视频片段,上传为 TwelveLabs Asset, +再调用 Pegasus 进行分析,返回与其它视觉提供商一致的文本结果。 + +这是一个**可选**的视觉提供商:仅当 `vision_llm_provider = "twelvelabs"` 时才会启用, +默认行为保持不变。未配置 TwelveLabs API Key 时,整套流程与之前完全一致。 +""" + +import asyncio +import os +import shutil +import subprocess +import tempfile +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import PIL.Image +from loguru import logger + +from app.config import config +from .base import VisionModelProvider +from .exceptions import APICallError, AuthenticationError, ConfigurationError, RateLimitError + +# Pegasus 对分析窗口的硬性要求:最短 4 秒。 +_MIN_CLIP_SECONDS = 4 +# Pegasus 1.5 同步分析对 max_tokens 的有效区间为 [512, 98304]。 +_MIN_MAX_TOKENS = 512 +# 本地直传 Asset 的体积上限(method="direct",约 200MB)。关键帧拼接的短片远小于此值。 +_DIRECT_UPLOAD_LIMIT_BYTES = 200 * 1024 * 1024 + + +class TwelveLabsVisionProvider(VisionModelProvider): + """TwelveLabs Pegasus 视频理解提供商。""" + + @property + def provider_name(self) -> str: + return "twelvelabs" + + @property + def supported_models(self) -> List[str]: + return ["pegasus1.5", "pegasus1.2"] + + def _validate_model_support(self): + # Pegasus 模型列表稳定,保持宽松校验(与基类一致,仅记录警告)。 + if self.model_name not in self.supported_models: + logger.warning( + f"模型 {self.model_name} 不在 TwelveLabs 预定义列表中," + f"将按原样传递给 API。支持的模型: {self.supported_models}" + ) + + def _initialize(self): + # SDK client 按请求构建,这里仅校验 ffmpeg 可用性(拼接关键帧片段需要)。 + self._ffmpeg_bin = self._resolve_ffmpeg() + + @staticmethod + def _resolve_ffmpeg() -> str: + configured = (config.app.get("ffmpeg_path") or "").strip() + if configured: + return configured + found = shutil.which("ffmpeg") + if not found: + raise ConfigurationError( + "TwelveLabs 提供商需要 ffmpeg 将关键帧拼接为视频片段,但未找到 ffmpeg。" + "请安装 ffmpeg 或在配置中设置 ffmpeg_path。", + "ffmpeg_path", + ) + return found + + def _build_client(self): + try: + from twelvelabs import TwelveLabs + except ImportError as exc: # pragma: no cover - 仅在缺少可选依赖时触发 + raise ConfigurationError( + "未安装 twelvelabs SDK。请运行 `pip install twelvelabs>=1.2.8` 后重试。", + "twelvelabs", + ) from exc + return TwelveLabs(api_key=self.api_key) + + async def analyze_images( + self, + images: List[Union[str, Path, PIL.Image.Image]], + prompt: str, + batch_size: int = 10, + max_concurrency: int = 1, + **kwargs, + ) -> List[str]: + logger.info( + f"开始使用 TwelveLabs Pegasus ({self.model_name}) 分析 {len(images)} 张关键帧" + ) + + processed_images = self._prepare_images(images) + if not processed_images: + return [] + + bounded_concurrency = max(1, int(max_concurrency)) + semaphore = asyncio.Semaphore(bounded_concurrency) + batches = [ + (index // batch_size, processed_images[index : index + batch_size]) + for index in range(0, len(processed_images), batch_size) + ] + + max_tokens = self._resolve_max_tokens(kwargs.get("max_tokens")) + + async def run_batch(batch_index: int, batch: List[PIL.Image.Image]) -> tuple[int, str]: + logger.info(f"处理第 {batch_index + 1} 批,共 {len(batch)} 张关键帧") + async with semaphore: + try: + # SDK 为同步实现,放到线程池中执行以免阻塞事件循环。 + result = await asyncio.to_thread( + self._analyze_batch_sync, batch, prompt, max_tokens + ) + return batch_index, result + except Exception as exc: # 与其它 provider 保持一致:批次级降级,不整体失败。 + logger.error(f"批次 {batch_index + 1} 处理失败: {exc}") + return batch_index, f"批次处理失败: {exc}" + + completed = await asyncio.gather( + *(run_batch(index, batch) for index, batch in batches) + ) + completed.sort(key=lambda item: item[0]) + return [result for _, result in completed] + + def _resolve_max_tokens(self, override: Any) -> int: + configured = override if override is not None else config.app.get( + "vision_twelvelabs_max_tokens", 1024 + ) + try: + value = int(configured) + except (TypeError, ValueError): + value = 1024 + return max(_MIN_MAX_TOKENS, value) + + def _analyze_batch_sync( + self, batch: List[PIL.Image.Image], prompt: str, max_tokens: int + ) -> str: + """把一批关键帧拼成短视频,上传为 Asset,调用 Pegasus 分析后返回文本。""" + from twelvelabs.types.video_context import VideoContext_AssetId + from twelvelabs.errors import ( + BadRequestError, + ForbiddenError, + TooManyRequestsError, + ) + + client = self._build_client() + asset_id: Optional[str] = None + + with tempfile.TemporaryDirectory(prefix="tl_pegasus_") as tmp_dir: + clip_path = self._frames_to_clip(batch, tmp_dir) + size = os.path.getsize(clip_path) + if size > _DIRECT_UPLOAD_LIMIT_BYTES: + raise APICallError( + f"拼接片段过大({size} 字节),超过直传上限 {_DIRECT_UPLOAD_LIMIT_BYTES} 字节。" + "请减小 vision_batch_size 或降低关键帧分辨率。" + ) + + try: + with open(clip_path, "rb") as fh: + asset = client.assets.create( + method="direct", file=fh, filename="keyframes.mp4" + ) + asset_id = asset.id + self._wait_for_asset_ready(client, asset_id) + + response = client.analyze( + model_name=self.model_name, + video=VideoContext_AssetId(asset_id=asset_id), + prompt=prompt, + max_tokens=max_tokens, + ) + text = (response.data or "").strip() + if not text: + raise APICallError("TwelveLabs Pegasus 返回空响应") + return text + except ForbiddenError as exc: + raise AuthenticationError(str(exc)) + except TooManyRequestsError as exc: + raise RateLimitError(str(exc)) + except BadRequestError as exc: + raise APICallError(f"请求错误: {getattr(exc, 'body', exc)}") + finally: + # 尽力清理远端 Asset,避免占用配额。 + if asset_id: + try: + client.assets.delete(asset_id=asset_id) + except Exception as exc: # pragma: no cover - 清理失败不影响结果 + logger.debug(f"清理 TwelveLabs Asset 失败 {asset_id}: {exc}") + + def _frames_to_clip(self, batch: List[PIL.Image.Image], tmp_dir: str) -> str: + """用 ffmpeg 把关键帧序列拼成 >= 4s 的视频片段(满足 Pegasus 最短窗口要求)。""" + frame_paths: List[str] = [] + for idx, img in enumerate(batch): + frame_path = os.path.join(tmp_dir, f"frame_{idx:04d}.jpg") + img.convert("RGB").save(frame_path, format="JPEG", quality=85) + frame_paths.append(frame_path) + + # 每帧停留时长,保证总时长不少于 _MIN_CLIP_SECONDS。 + per_frame_seconds = max(1.0, _MIN_CLIP_SECONDS / max(1, len(frame_paths))) + list_file = os.path.join(tmp_dir, "frames.txt") + with open(list_file, "w", encoding="utf-8") as fh: + for frame_path in frame_paths: + fh.write(f"file '{frame_path}'\n") + fh.write(f"duration {per_frame_seconds}\n") + # concat demuxer 需要重复最后一帧才能让其显示完整时长。 + fh.write(f"file '{frame_paths[-1]}'\n") + + clip_path = os.path.join(tmp_dir, "clip.mp4") + cmd = [ + self._ffmpeg_bin, + "-y", + "-loglevel", "error", + "-f", "concat", + "-safe", "0", + "-i", list_file, + # 强制偶数尺寸 + yuv420p,保证 H.264 兼容。 + "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", + "-pix_fmt", "yuv420p", + "-r", "24", + "-c:v", "libx264", + clip_path, + ] + try: + subprocess.run(cmd, capture_output=True, text=True, check=True, timeout=120) + except subprocess.CalledProcessError as exc: + raise APICallError(f"ffmpeg 拼接关键帧失败: {exc.stderr or exc}") + except subprocess.TimeoutExpired: + raise APICallError("ffmpeg 拼接关键帧超时") + return clip_path + + @staticmethod + def _wait_for_asset_ready(client, asset_id: str, timeout_seconds: int = 180) -> None: + """轮询等待上传的 Asset 进入 ready 状态。""" + import time + + deadline = time.monotonic() + timeout_seconds + while time.monotonic() < deadline: + asset = client.assets.retrieve(asset_id=asset_id) + status = (asset.status or "").lower() + if status == "ready": + return + if status == "failed": + raise APICallError(f"TwelveLabs Asset 处理失败: {asset_id}") + time.sleep(3) + raise APICallError(f"等待 TwelveLabs Asset 就绪超时: {asset_id}") + + async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]: + # 本提供商直接使用官方 SDK,不走通用 payload 通道。 + return payload diff --git a/config.example.toml b/config.example.toml index 46735fe..5774c1f 100644 --- a/config.example.toml +++ b/config.example.toml @@ -32,6 +32,16 @@ vision_openai_max_tokens = 65536 vision_openai_thinking_level = "auto" # auto/off/low/medium/high + # ===== 可选:TwelveLabs Pegasus 视频理解 ===== + # 将 vision_llm_provider 改为 "twelvelabs" 即可启用(默认仍为 openai,不影响现有行为)。 + # Pegasus 是原生的视频理解模型,会把每批关键帧拼成短片后整体理解,更擅长把握镜头内的 + # 动作、时序与高光片段,从而生成更贴合画面的解说。需要本地可用的 ffmpeg。 + # 免费 API Key 获取地址:https://twelvelabs.io (有较充裕的免费额度) + # vision_llm_provider = "twelvelabs" + vision_twelvelabs_api_key = "" # 填入 TwelveLabs API Key + vision_twelvelabs_model_name = "pegasus1.5" # pegasus1.5 / pegasus1.2 + vision_twelvelabs_max_tokens = 1024 # Pegasus 1.5 有效区间 512-98304 + # ===== 文本模型配置 ===== text_llm_provider = "openai" diff --git a/requirements.txt b/requirements.txt index be125ac..e39d339 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,6 +25,9 @@ tqdm>=4.66.6 tenacity>=9.0.0 # 可选依赖(根据功能需要) +# 如果使用 TwelveLabs Pegasus 视频理解作为视觉提供商,取消注释下面的行 +# twelvelabs>=1.2.8 + # 如果需要本地语音识别,取消注释下面的行 # faster-whisper>=1.0.1