NarratoAI/webui/utils/vision_analyzer.py
linyqh f44d56110e feat(vision): 添加 QwenVL 视觉分析支持
- 新增 QwenVL 视觉分析器类,实现对阿里云 Qwen 模型的支持
- 更新基础设置界面,增加代理配置和 QwenVL 模型可用性检测
- 修改脚本生成逻辑,支持 QwenVL 模型的图像分析
- 重构视觉分析器初始化和调用接口,提高代码复用性和可维护性
2024-12-05 21:43:26 +08:00

100 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
from typing import List, Dict, Any, Optional
from app.utils import gemini_analyzer, qwenvl_analyzer
logger = logging.getLogger(__name__)
class VisionAnalyzer:
def __init__(self):
self.provider = None
self.api_key = None
self.model = None
self.base_url = None
self.analyzer = None
def initialize_gemini(self, api_key: str, model: str, base_url: str) -> None:
"""
初始化Gemini视觉分析器
Args:
api_key: Gemini API密钥
model: 模型名称
base_url: API基础URL
"""
self.provider = 'gemini'
self.api_key = api_key
self.model = model
self.base_url = base_url
self.analyzer = gemini_analyzer.VisionAnalyzer(
model_name=model,
api_key=api_key
)
def initialize_qwenvl(self, api_key: str, model: str, base_url: str) -> None:
"""
初始化QwenVL视觉分析器
Args:
api_key: 阿里云API密钥
model: 模型名称
base_url: API基础URL
"""
self.provider = 'qwenvl'
self.api_key = api_key
self.model = model
self.base_url = base_url
self.analyzer = qwenvl_analyzer.QwenAnalyzer(
model_name=model,
api_key=api_key
)
async def analyze_images(self, images: List[str], prompt: str, batch_size: int = 5) -> Dict[str, Any]:
"""
分析图片内容
Args:
images: 图片路径列表
prompt: 分析提示词
batch_size: 每批处理的图片数量默认为5
Returns:
Dict: 分析结果
"""
if not self.analyzer:
raise ValueError("未初始化视觉分析器")
return await self.analyzer.analyze_images(
images=images,
prompt=prompt,
batch_size=batch_size
)
def create_vision_analyzer(provider: str, **kwargs) -> VisionAnalyzer:
"""
创建视觉分析器实例
Args:
provider: 提供商名称 ('gemini''qwenvl')
**kwargs: 提供商特定的配置参数
Returns:
VisionAnalyzer: 配置好的视觉分析器实例
"""
analyzer = VisionAnalyzer()
if provider.lower() == 'gemini':
analyzer.initialize_gemini(
api_key=kwargs.get('api_key'),
model=kwargs.get('model'),
base_url=kwargs.get('base_url')
)
elif provider.lower() == 'qwenvl':
analyzer.initialize_qwenvl(
api_key=kwargs.get('api_key'),
model=kwargs.get('model'),
base_url=kwargs.get('base_url')
)
else:
raise ValueError(f"不支持的视觉分析提供商: {provider}")
return analyzer