feat(vision): 添加 QwenVL 视觉分析支持

- 新增 QwenVL 视觉分析器类,实现对阿里云 Qwen 模型的支持
- 更新基础设置界面,增加代理配置和 QwenVL 模型可用性检测
- 修改脚本生成逻辑,支持 QwenVL 模型的图像分析
- 重构视觉分析器初始化和调用接口,提高代码复用性和可维护性
This commit is contained in:
linyqh 2024-12-05 21:43:26 +08:00
parent 0caa15e762
commit f44d56110e
5 changed files with 582 additions and 335 deletions

View File

@ -0,0 +1,255 @@
import json
from typing import List, Union, Dict
import os
from pathlib import Path
from loguru import logger
from tqdm import tqdm
import asyncio
from tenacity import retry, stop_after_attempt, RetryError, wait_exponential
from openai import OpenAI
import PIL.Image
import base64
import io
import traceback
class QwenAnalyzer:
"""千问视觉分析器类"""
def __init__(self, model_name: str = "qwen-vl-max-latest", api_key: str = None):
"""
初始化千问视觉分析器
Args:
model_name: 模型名称默认使用 qwen-vl-max-latest
api_key: 阿里云API密钥
"""
if not api_key:
raise ValueError("必须提供API密钥")
self.model_name = model_name
self.api_key = api_key
# 配置API客户端
self._configure_client()
def _configure_client(self):
"""配置API客户端"""
self.client = OpenAI(
api_key=self.api_key,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
)
def _image_to_base64(self, image: PIL.Image.Image) -> str:
"""
将PIL图片对象转换为base64字符串
"""
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10)
)
async def _generate_content_with_retry(self, prompt: str, batch: List[PIL.Image.Image]):
"""使用重试机制的内部方法来调用千问API"""
try:
# 构建消息内容
content = []
# 添加图片
for img in batch:
base64_image = self._image_to_base64(img)
content.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
})
# 添加文本提示
content.append({
"type": "text",
"text": prompt
})
# 调用API
response = await asyncio.to_thread(
self.client.chat.completions.create,
model=self.model_name,
messages=[{
"role": "user",
"content": content
}]
)
return response.choices[0].message.content
except Exception as e:
logger.error(f"API调用错误: {str(e)}")
raise RetryError("API调用失败")
async def analyze_images(self,
images: Union[List[str], List[PIL.Image.Image]],
prompt: str,
batch_size: int = 5) -> List[Dict]:
"""
批量分析多张图片
Args:
images: 图片路径列表或PIL图片对象列表
prompt: 分析提示词
batch_size: 批处理大小
Returns:
分析结果列表
"""
try:
# 保存原始图片路径(如果是路径列表的话)
original_paths = images if isinstance(images[0], str) else None
# 加载图片
if isinstance(images[0], str):
logger.info("正在加载图片...")
images = self.load_images(images)
# 验证图片列表
if not images:
raise ValueError("图片列表为空")
# 验证每个图片对象
valid_images = []
valid_paths = []
for i, img in enumerate(images):
if not isinstance(img, PIL.Image.Image):
logger.error(f"无效的图片对象,索引 {i}: {type(img)}")
continue
valid_images.append(img)
if original_paths:
valid_paths.append(original_paths[i])
if not valid_images:
raise ValueError("没有有效的图片对象")
images = valid_images
results = []
total_batches = (len(images) + batch_size - 1) // batch_size
with tqdm(total=total_batches, desc="分析进度") as pbar:
for i in range(0, len(images), batch_size):
batch = images[i:i + batch_size]
batch_paths = valid_paths[i:i + batch_size] if valid_paths else None
retry_count = 0
while retry_count < 3:
try:
# 在每个批次处理前添加小延迟
if i > 0:
await asyncio.sleep(2)
# 确保每个批次的图片都是有效的
valid_batch = [img for img in batch if isinstance(img, PIL.Image.Image)]
if not valid_batch:
raise ValueError(f"批次 {i // batch_size} 中没有有效的图片")
response = await self._generate_content_with_retry(prompt, valid_batch)
result_dict = {
'batch_index': i // batch_size,
'images_processed': len(valid_batch),
'response': response,
'model_used': self.model_name
}
# 添加图片路径信息(如果有的话)
if batch_paths:
result_dict['image_paths'] = batch_paths
results.append(result_dict)
break
except Exception as e:
retry_count += 1
error_msg = f"批次 {i // batch_size} 处理出错: {str(e)}"
logger.error(error_msg)
if retry_count >= 3:
results.append({
'batch_index': i // batch_size,
'images_processed': len(batch),
'error': error_msg,
'model_used': self.model_name,
'image_paths': batch_paths if batch_paths else []
})
else:
logger.info(f"批次 {i // batch_size} 处理失败等待60秒后重试当前批次...")
await asyncio.sleep(60)
pbar.update(1)
return results
except Exception as e:
error_msg = f"图片分析过程中发生错误: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
raise Exception(error_msg)
def save_results_to_txt(self, results: List[Dict], output_dir: str):
"""将分析结果保存到txt文件"""
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
for i, result in enumerate(results):
response_text = result['response']
# 如果有图片路径信息,使用它来生成文件名
if result.get('image_paths'):
image_paths = result['image_paths']
img_name_start = Path(image_paths[0]).stem.split('_')[-1]
img_name_end = Path(image_paths[-1]).stem.split('_')[-1]
file_name = f"frame_{img_name_start}_{img_name_end}.txt"
else:
# 如果没有路径信息,使用批次索引
file_name = f"batch_{result['batch_index']}.txt"
txt_path = os.path.join(output_dir, file_name)
# 保存结果到txt文件
with open(txt_path, 'w', encoding='utf-8') as f:
f.write(response_text.strip())
logger.info(f"已保存分析结果到: {txt_path}")
def load_images(self, image_paths: List[str]) -> List[PIL.Image.Image]:
"""
加载多张图片
Args:
image_paths: 图片路径列表
Returns:
加载后的PIL Image对象列表
"""
images = []
failed_images = []
for img_path in image_paths:
try:
if not os.path.exists(img_path):
logger.error(f"图片文件不存在: {img_path}")
failed_images.append(img_path)
continue
img = PIL.Image.open(img_path)
# 确保图片被完全加载
img.load()
# 转换为RGB模式
if img.mode != 'RGB':
img = img.convert('RGB')
images.append(img)
except Exception as e:
logger.error(f"无法加载图片 {img_path}: {str(e)}")
failed_images.append(img_path)
if failed_images:
logger.warning(f"以下图片加载失败:\n{json.dumps(failed_images, indent=2, ensure_ascii=False)}")
if not images:
raise ValueError("没有成功加载任何图片")
return images

View File

@ -52,18 +52,34 @@ def render_language_settings(tr):
def render_proxy_settings(tr): def render_proxy_settings(tr):
"""渲染代理设置""" """渲染代理设置"""
proxy_url_http = config.proxy.get("http", "") or os.getenv("VPN_PROXY_URL", "") # 获取当前代理状态
proxy_url_https = config.proxy.get("https", "") or os.getenv("VPN_PROXY_URL", "") proxy_enabled = config.proxy.get("enabled", True)
proxy_url_http = config.proxy.get("http")
proxy_url_https = config.proxy.get("https")
HTTP_PROXY = st.text_input(tr("HTTP_PROXY"), value=proxy_url_http) # 添加代理开关
HTTPS_PROXY = st.text_input(tr("HTTPs_PROXY"), value=proxy_url_https) proxy_enabled = st.checkbox(tr("Enable Proxy"), value=proxy_enabled)
# 保存代理开关状态
config.proxy["enabled"] = proxy_enabled
if HTTP_PROXY: # 只有在代理启用时才显示代理设置输入框
config.proxy["http"] = HTTP_PROXY if proxy_enabled:
os.environ["HTTP_PROXY"] = HTTP_PROXY HTTP_PROXY = st.text_input(tr("HTTP_PROXY"), value=proxy_url_http)
if HTTPS_PROXY: HTTPS_PROXY = st.text_input(tr("HTTPs_PROXY"), value=proxy_url_https)
config.proxy["https"] = HTTPS_PROXY
os.environ["HTTPS_PROXY"] = HTTPS_PROXY if HTTP_PROXY:
config.proxy["http"] = HTTP_PROXY
os.environ["HTTP_PROXY"] = HTTP_PROXY
if HTTPS_PROXY:
config.proxy["https"] = HTTPS_PROXY
os.environ["HTTPS_PROXY"] = HTTPS_PROXY
else:
# 当代理被禁用时,清除环境变量和配置
os.environ.pop("HTTP_PROXY", None)
os.environ.pop("HTTPS_PROXY", None)
config.proxy["http"] = ""
config.proxy["https"] = ""
def test_vision_model_connection(api_key, base_url, model_name, provider, tr): def test_vision_model_connection(api_key, base_url, model_name, provider, tr):
@ -90,6 +106,28 @@ def test_vision_model_connection(api_key, base_url, model_name, provider, tr):
except Exception as e: except Exception as e:
return False, f"{tr('gemini model is not available')}: {str(e)}" return False, f"{tr('gemini model is not available')}: {str(e)}"
elif provider.lower() == 'qwenvl':
from openai import OpenAI
try:
client = OpenAI(
api_key=api_key,
base_url=base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1"
)
# 发送一个简单的测试请求
response = client.chat.completions.create(
model=model_name or "qwen-vl-max-latest",
messages=[{"role": "user", "content": "直接回复我文本'当前网络可用'"}]
)
if response and response.choices:
return True, tr("QwenVL model is available")
else:
return False, tr("QwenVL model returned invalid response")
except Exception as e:
return False, f"{tr('QwenVL model is not available')}: {str(e)}"
elif provider.lower() == 'narratoapi': elif provider.lower() == 'narratoapi':
import requests import requests
try: try:
@ -116,7 +154,7 @@ def render_vision_llm_settings(tr):
st.subheader(tr("Vision Model Settings")) st.subheader(tr("Vision Model Settings"))
# 视频分析模型提供商选择 # 视频分析模型提供商选择
vision_providers = ['Gemini', 'NarratoAPI(待发布)', 'QwenVL(待发布)'] vision_providers = ['Gemini', 'QwenVL', 'NarratoAPI(待发布)']
saved_vision_provider = config.app.get("vision_llm_provider", "Gemini").lower() saved_vision_provider = config.app.get("vision_llm_provider", "Gemini").lower()
saved_provider_index = 0 saved_provider_index = 0
@ -142,18 +180,33 @@ def render_vision_llm_settings(tr):
# 渲染视觉模型配置输入框 # 渲染视觉模型配置输入框
st_vision_api_key = st.text_input(tr("Vision API Key"), value=vision_api_key, type="password") st_vision_api_key = st.text_input(tr("Vision API Key"), value=vision_api_key, type="password")
# 当选择 Gemini 时禁用 base_url 输入 # 根据不同提供商设置默认值和帮助信息
if vision_provider.lower() == 'gemini': if vision_provider == 'gemini':
st_vision_base_url = st.text_input( st_vision_base_url = st.text_input(
tr("Vision Base URL"), tr("Vision Base URL"),
value=vision_base_url, value=vision_base_url,
disabled=True, disabled=True,
help=tr("Gemini API does not require a base URL") help=tr("Gemini API does not require a base URL")
) )
st_vision_model_name = st.text_input(
tr("Vision Model Name"),
value=vision_model_name or "gemini-1.5-flash",
help=tr("Default: gemini-1.5-flash")
)
elif vision_provider == 'qwenvl':
st_vision_base_url = st.text_input(
tr("Vision Base URL"),
value=vision_base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1",
help=tr("Default: https://dashscope.aliyuncs.com/compatible-mode/v1")
)
st_vision_model_name = st.text_input(
tr("Vision Model Name"),
value=vision_model_name or "qwen-vl-max-latest",
help=tr("Default: qwen-vl-max-latest")
)
else: else:
st_vision_base_url = st.text_input(tr("Vision Base URL"), value=vision_base_url) st_vision_base_url = st.text_input(tr("Vision Base URL"), value=vision_base_url)
st_vision_model_name = st.text_input(tr("Vision Model Name"), value=vision_model_name)
st_vision_model_name = st.text_input(tr("Vision Model Name"), value=vision_model_name)
# 在配置输入框后添加测试按钮 # 在配置输入框后添加测试按钮
if st.button(tr("Test Connection"), key="test_vision_connection"): if st.button(tr("Test Connection"), key="test_vision_connection"):
@ -174,7 +227,7 @@ def render_vision_llm_settings(tr):
# 保存视觉模型配置 # 保存视觉模型配置
if st_vision_api_key: if st_vision_api_key:
config.app[f"vision_{vision_provider}_api_key"] = st_vision_api_key config.app[f"vision_{vision_provider}_api_key"] = st_vision_api_key
st.session_state[f"vision_{vision_provider}_api_key"] = st_vision_api_key # 用于script_settings.py st.session_state[f"vision_{vision_provider}_api_key"] = st_vision_api_key
if st_vision_base_url: if st_vision_base_url:
config.app[f"vision_{vision_provider}_base_url"] = st_vision_base_url config.app[f"vision_{vision_provider}_base_url"] = st_vision_base_url
st.session_state[f"vision_{vision_provider}_base_url"] = st_vision_base_url st.session_state[f"vision_{vision_provider}_base_url"] = st_vision_base_url
@ -182,81 +235,6 @@ def render_vision_llm_settings(tr):
config.app[f"vision_{vision_provider}_model_name"] = st_vision_model_name config.app[f"vision_{vision_provider}_model_name"] = st_vision_model_name
st.session_state[f"vision_{vision_provider}_model_name"] = st_vision_model_name st.session_state[f"vision_{vision_provider}_model_name"] = st_vision_model_name
# # NarratoAPI 特殊配置
# if vision_provider == 'narratoapi':
# st.subheader(tr("Narrato Additional Settings"))
#
# # Narrato API 基础配置
# narrato_api_key = st.text_input(
# tr("Narrato API Key"),
# value=config.app.get("narrato_api_key", ""),
# type="password",
# help="用于访问 Narrato API 的密钥"
# )
# if narrato_api_key:
# config.app["narrato_api_key"] = narrato_api_key
# st.session_state['narrato_api_key'] = narrato_api_key
#
# narrato_api_url = st.text_input(
# tr("Narrato API URL"),
# value=config.app.get("narrato_api_url", "http://127.0.0.1:8000/api/v1/video/analyze")
# )
# if narrato_api_url:
# config.app["narrato_api_url"] = narrato_api_url
# st.session_state['narrato_api_url'] = narrato_api_url
#
# # 视频分析模型配置
# st.markdown("##### " + tr("Vision Model Settings"))
# narrato_vision_model = st.text_input(
# tr("Vision Model Name"),
# value=config.app.get("narrato_vision_model", "gemini-1.5-flash")
# )
# narrato_vision_key = st.text_input(
# tr("Vision Model API Key"),
# value=config.app.get("narrato_vision_key", ""),
# type="password",
# help="用于视频分析的模 API Key"
# )
#
# if narrato_vision_model:
# config.app["narrato_vision_model"] = narrato_vision_model
# st.session_state['narrato_vision_model'] = narrato_vision_model
# if narrato_vision_key:
# config.app["narrato_vision_key"] = narrato_vision_key
# st.session_state['narrato_vision_key'] = narrato_vision_key
#
# # 文案生成模型配置
# st.markdown("##### " + tr("Text Generation Model Settings"))
# narrato_llm_model = st.text_input(
# tr("LLM Model Name"),
# value=config.app.get("narrato_llm_model", "qwen-plus")
# )
# narrato_llm_key = st.text_input(
# tr("LLM Model API Key"),
# value=config.app.get("narrato_llm_key", ""),
# type="password",
# help="用于文案生成的模型 API Key"
# )
#
# if narrato_llm_model:
# config.app["narrato_llm_model"] = narrato_llm_model
# st.session_state['narrato_llm_model'] = narrato_llm_model
# if narrato_llm_key:
# config.app["narrato_llm_key"] = narrato_llm_key
# st.session_state['narrato_llm_key'] = narrato_llm_key
#
# # 批处理配置
# narrato_batch_size = st.number_input(
# tr("Batch Size"),
# min_value=1,
# max_value=50,
# value=config.app.get("narrato_batch_size", 10),
# help="每批处理的图片数量"
# )
# if narrato_batch_size:
# config.app["narrato_batch_size"] = narrato_batch_size
# st.session_state['narrato_batch_size'] = narrato_batch_size
def test_text_model_connection(api_key, base_url, model_name, provider, tr): def test_text_model_connection(api_key, base_url, model_name, provider, tr):
"""测试文本模型连接 """测试文本模型连接
@ -328,6 +306,7 @@ def test_text_model_connection(api_key, base_url, model_name, provider, tr):
except Exception as e: except Exception as e:
return False, f"{tr('Connection failed')}: {str(e)}" return False, f"{tr('Connection failed')}: {str(e)}"
def render_text_llm_settings(tr): def render_text_llm_settings(tr):
"""渲染文案生成模型设置""" """渲染文案生成模型设置"""
st.subheader(tr("Text Generation Model Settings")) st.subheader(tr("Text Generation Model Settings"))

View File

@ -14,7 +14,7 @@ from loguru import logger
from app.config import config from app.config import config
from app.models.schema import VideoClipParams from app.models.schema import VideoClipParams
from app.utils.script_generator import ScriptProcessor from app.utils.script_generator import ScriptProcessor
from app.utils import utils, check_script, gemini_analyzer, video_processor, video_processor_v2 from app.utils import utils, check_script, gemini_analyzer, video_processor, video_processor_v2, qwenvl_analyzer
from webui.utils import file_utils from webui.utils import file_utils
@ -472,267 +472,168 @@ def generate_script(tr, params):
vision_llm_provider = st.session_state.get('vision_llm_providers').lower() vision_llm_provider = st.session_state.get('vision_llm_providers').lower()
logger.debug(f"Vision LLM 提供商: {vision_llm_provider}") logger.debug(f"Vision LLM 提供商: {vision_llm_provider}")
if vision_llm_provider == 'gemini': try:
try: # ===================初始化视觉分析器===================
# ===================初始化视觉分析器=================== update_progress(30, "正在初始化视觉分析器...")
update_progress(30, "正在初始化视觉分析器...")
# 从配置中获取相关配置
# 从配置中获取 Gemini 相关配置 if vision_llm_provider == 'gemini':
vision_api_key = st.session_state.get('vision_gemini_api_key') vision_api_key = st.session_state.get('vision_gemini_api_key')
vision_model = st.session_state.get('vision_gemini_model_name') vision_model = st.session_state.get('vision_gemini_model_name')
vision_base_url = st.session_state.get('vision_gemini_base_url') vision_base_url = st.session_state.get('vision_gemini_base_url')
elif vision_llm_provider == 'qwenvl':
if not vision_api_key or not vision_model: vision_api_key = st.session_state.get('vision_qwenvl_api_key')
raise ValueError("未配置 Gemini API Key 或者 型,请在基础设置配置") vision_model = st.session_state.get('vision_qwenvl_model_name', 'qwen-vl-max-latest')
vision_base_url = st.session_state.get('vision_qwenvl_base_url', 'https://dashscope.aliyuncs.com/compatible-mode/v1')
else:
raise ValueError(f"不支持的视觉分析提供商: {vision_llm_provider}")
analyzer = gemini_analyzer.VisionAnalyzer( # 创建视觉分析器实例
model_name=vision_model, analyzer = create_vision_analyzer(
api_key=vision_api_key, provider=vision_llm_provider,
api_key=vision_api_key,
model=vision_model,
base_url=vision_base_url
)
update_progress(40, "正在分析关键帧...")
# ===================创建异步事件循环===================
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 执行异步分析
vision_batch_size = st.session_state.get('vision_batch_size') or config.frames.get("vision_batch_size")
results = loop.run_until_complete(
analyzer.analyze_images(
images=keyframe_files,
prompt=config.app.get('vision_analysis_prompt'),
batch_size=vision_batch_size
) )
)
loop.close()
update_progress(40, "正在分析关键帧...") # ===================处理分析结果===================
update_progress(60, "正在整理分析结果...")
# 合并所有批次的析结果
frame_analysis = ""
prev_batch_files = None
# ===================创建异步事件循环=================== for result in results:
loop = asyncio.new_event_loop() if 'error' in result:
asyncio.set_event_loop(loop) logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
# 执行异步分析
vision_batch_size = st.session_state.get('vision_batch_size') or config.frames.get("vision_batch_size")
results = loop.run_until_complete(
analyzer.analyze_images(
images=keyframe_files,
prompt=config.app.get('vision_analysis_prompt'),
batch_size=vision_batch_size
)
)
loop.close()
# ===================处理分析结果=================== # 获取当前批次的文件列表 keyframe_001136_000045.jpg 将 000045 精度提升到 毫秒
update_progress(60, "正在整理分析结果...") batch_files = get_batch_files(keyframe_files, result, vision_batch_size)
logger.debug(f"批次 {result['batch_index']} 处理完成,共 {len(batch_files)} 张图片")
# logger.debug(batch_files)
# 合并所有批次的析结果 first_timestamp, last_timestamp, _ = get_batch_timestamps(batch_files, prev_batch_files)
frame_analysis = "" logger.debug(f"处理时间戳: {first_timestamp}-{last_timestamp}")
prev_batch_files = None
# 添加带时间戳的分析结果
frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n"
frame_analysis += result['response']
frame_analysis += "\n"
# 更新上一个批次的文件
prev_batch_files = batch_files
if not frame_analysis.strip():
raise Exception("未能生成有效的帧分析结果")
# 保存分析结果
analysis_path = os.path.join(utils.temp_dir(), "frame_analysis.txt")
with open(analysis_path, 'w', encoding='utf-8') as f:
f.write(frame_analysis)
update_progress(70, "正在生成脚本...")
for result in results: # 从配置中获取文本生成相关配置
if 'error' in result: text_provider = config.app.get('text_llm_provider', 'gemini').lower()
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}") text_api_key = config.app.get(f'text_{text_provider}_api_key')
continue text_model = config.app.get(f'text_{text_provider}_model_name')
# 获取当前批次的文件列表 keyframe_001136_000045.jpg 将 000045 精度提升到 毫秒 text_base_url = config.app.get(f'text_{text_provider}_base_url')
batch_files = get_batch_files(keyframe_files, result, vision_batch_size)
logger.debug(f"批次 {result['batch_index']} 处理完成,共 {len(batch_files)} 张图片") # 构建帧内容列表
# logger.debug(batch_files) frame_content_list = []
prev_batch_files = None
first_timestamp, last_timestamp, _ = get_batch_timestamps(batch_files, prev_batch_files)
logger.debug(f"处理时间戳: {first_timestamp}-{last_timestamp}")
# 添加带时间戳的分析结果
frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n"
frame_analysis += result['response']
frame_analysis += "\n"
# 更新上一个批次的文件
prev_batch_files = batch_files
if not frame_analysis.strip():
raise Exception("未能生成有效的帧分析结果")
# 保存分析结果
analysis_path = os.path.join(utils.temp_dir(), "frame_analysis.txt")
with open(analysis_path, 'w', encoding='utf-8') as f:
f.write(frame_analysis)
update_progress(70, "正在生成脚本...")
# 从配置中获取文本生成相关配置 for i, result in enumerate(results):
text_provider = config.app.get('text_llm_provider', 'gemini').lower() if 'error' in result:
text_api_key = config.app.get(f'text_{text_provider}_api_key') continue
text_model = config.app.get(f'text_{text_provider}_model_name')
text_base_url = config.app.get(f'text_{text_provider}_base_url')
# 构建帧内容列表 batch_files = get_batch_files(keyframe_files, result, vision_batch_size)
frame_content_list = [] _, _, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files)
prev_batch_files = None
for i, result in enumerate(results):
if 'error' in result:
continue
batch_files = get_batch_files(keyframe_files, result, vision_batch_size)
_, _, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files)
frame_content = {
"timestamp": timestamp_range,
"picture": result['response'],
"narration": "",
"OST": 2
}
frame_content_list.append(frame_content)
logger.debug(f"添加帧内容: 时间范围={timestamp_range}, 分析结果长度={len(result['response'])}")
# 更新上一个批次的文件
prev_batch_files = batch_files
if not frame_content_list: frame_content = {
raise Exception("没有有效的帧内容可以处理") "timestamp": timestamp_range,
"picture": result['response'],
# ===================开始生成文案=================== "narration": "",
update_progress(80, "正在生成文案...") "OST": 2
# 校验配置
api_params = {
"vision_api_key": vision_api_key,
"vision_model_name": vision_model,
"vision_base_url": vision_base_url or "",
"text_api_key": text_api_key,
"text_model_name": text_model,
"text_base_url": text_base_url or ""
} }
headers = { frame_content_list.append(frame_content)
'accept': 'application/json',
'Content-Type': 'application/json'
}
session = requests.Session()
retry_strategy = Retry(
total=3,
backoff_factor=1,
status_forcelist=[500, 502, 503, 504]
)
adapter = HTTPAdapter(max_retries=retry_strategy)
session.mount("https://", adapter)
try:
response = session.post(
f"{config.app.get('narrato_api_url')}/video/config",
headers=headers,
json=api_params,
timeout=30,
verify=True
)
except Exception as e:
pass
custom_prompt = st.session_state.get('custom_prompt', '')
processor = ScriptProcessor(
model_name=text_model,
api_key=text_api_key,
prompt=custom_prompt,
base_url=text_base_url or "",
video_theme=st.session_state.get('video_theme', '')
)
# 处理帧内容生成脚本
script_result = processor.process_frames(frame_content_list)
# 结果转换为JSON字符串
script = json.dumps(script_result, ensure_ascii=False, indent=2)
except Exception as e: logger.debug(f"添加帧内容: 时间范围={timestamp_range}, 分析结果长度={len(result['response'])}")
logger.exception(f"大模型处理过程中发生错误\n{traceback.format_exc()}")
raise Exception(f"分析失败: {str(e)}") # 更新上一个批次的文件
prev_batch_files = batch_files
if not frame_content_list:
raise Exception("没有有效的帧内容可以处理")
elif vision_llm_provider == 'narratoapi': # NarratoAPI # ===================开始生成文案===================
update_progress(80, "正在生成文案...")
# 校验配置
api_params = {
"vision_api_key": vision_api_key,
"vision_model_name": vision_model,
"vision_base_url": vision_base_url or "",
"text_api_key": text_api_key,
"text_model_name": text_model,
"text_base_url": text_base_url or ""
}
headers = {
'accept': 'application/json',
'Content-Type': 'application/json'
}
session = requests.Session()
retry_strategy = Retry(
total=3,
backoff_factor=1,
status_forcelist=[500, 502, 503, 504]
)
adapter = HTTPAdapter(max_retries=retry_strategy)
session.mount("https://", adapter)
try: try:
# 创建临时目录 response = session.post(
temp_dir = utils.temp_dir("narrato") f"{config.app.get('narrato_api_url')}/video/config",
headers=headers,
# 打包关键帧 json=api_params,
update_progress(30, "正在打包关键帧...") timeout=30,
zip_path = os.path.join(temp_dir, f"keyframes_{int(time.time())}.zip") verify=True
if not file_utils.create_zip(keyframe_files, zip_path): )
raise Exception("打包关键帧失败")
# 获取API配置
api_url = st.session_state.get('narrato_api_url')
api_key = st.session_state.get('narrato_api_key')
if not api_key:
raise ValueError("未配置 Narrato API Key请在基础设置中配置")
# 准备API请求
headers = {
'X-API-Key': api_key,
'accept': 'application/json'
}
api_params = {
'batch_size': st.session_state.get('narrato_batch_size', 10),
'use_ai': False,
'start_offset': 0,
'vision_model': st.session_state.get('narrato_vision_model', 'gemini-1.5-flash'),
'vision_api_key': st.session_state.get('narrato_vision_key'),
'llm_model': st.session_state.get('narrato_llm_model', 'qwen-plus'),
'llm_api_key': st.session_state.get('narrato_llm_key'),
'custom_prompt': st.session_state.get('custom_prompt', '')
}
# 发送API请求
logger.info(f"请求NarratoAPI: {api_url}")
update_progress(40, "正在上传文件...")
with open(zip_path, 'rb') as f:
files = {'file': (os.path.basename(zip_path), f, 'application/x-zip-compressed')}
try:
response = requests.post(
f"{api_url}/video/analyze",
headers=headers,
params=api_params,
files=files,
timeout=30 # 设置超时时间
)
response.raise_for_status()
except requests.RequestException as e:
logger.error(f"Narrato API 请求失败:\n{traceback.format_exc()}")
raise Exception(f"API请求失败: {str(e)}")
task_data = response.json()
task_id = task_data["data"].get('task_id')
if not task_id:
raise Exception(f"无效的API响应: {response.text}")
# 轮询任务状态
update_progress(50, "正在等待分析结果...")
retry_count = 0
max_retries = 60 # 最多等待2分钟
while retry_count < max_retries:
try:
status_response = requests.get(
f"{api_url}/video/tasks/{task_id}",
headers=headers,
timeout=10
)
status_response.raise_for_status()
task_status = status_response.json()['data']
if task_status['status'] == 'SUCCESS':
script = task_status['result']['data']
break
elif task_status['status'] in ['FAILURE', 'RETRY']:
raise Exception(f"任务失败: {task_status.get('error')}")
retry_count += 1
time.sleep(2)
except requests.RequestException as e:
logger.warning(f"获取任务状态失败,重试中: {str(e)}")
retry_count += 1
time.sleep(2)
continue
if retry_count >= max_retries:
raise Exception("任务执行超时")
except Exception as e: except Exception as e:
logger.exception(f"NarratoAPI 处理过程中发生错误\n{traceback.format_exc()}") pass
raise Exception(f"NarratoAPI 处理失败: {str(e)}") custom_prompt = st.session_state.get('custom_prompt', '')
finally: processor = ScriptProcessor(
# 清理临时文件 model_name=text_model,
try: api_key=text_api_key,
if os.path.exists(zip_path): prompt=custom_prompt,
os.remove(zip_path) base_url=text_base_url or "",
except Exception as e: video_theme=st.session_state.get('video_theme', '')
logger.warning(f"清理临时文件失败: {str(e)}") )
else: # 处理帧内容生成脚本
logger.exception("Vision Model 未启用,请检查配置") script_result = processor.process_frames(frame_content_list)
# 结果转换为JSON字符串
script = json.dumps(script_result, ensure_ascii=False, indent=2)
except Exception as e:
logger.exception(f"大模型处理过程中发生错误\n{traceback.format_exc()}")
raise Exception(f"分析失败: {str(e)}")
if script is None: if script is None:
st.error("生成脚本失败,请检查日志") st.error("生成脚本失败,请检查日志")
@ -823,3 +724,12 @@ def get_script_params():
'video_name': st.session_state.get('video_name', ''), 'video_name': st.session_state.get('video_name', ''),
'video_plot': st.session_state.get('video_plot', '') 'video_plot': st.session_state.get('video_plot', '')
} }
def create_vision_analyzer(provider, api_key, model, base_url):
if provider == 'gemini':
return gemini_analyzer.VisionAnalyzer(model_name=model, api_key=api_key)
elif provider == 'qwenvl':
return qwenvl_analyzer.QwenAnalyzer(model_name=model, api_key=api_key)
else:
raise ValueError(f"不支持的视觉分析提供商: {provider}")

View File

@ -163,6 +163,9 @@
"Error processing subtitle files. Please check if the subtitles are valid SRT files.": "处理字幕文件时出错。请检查字幕是否为有效的SRT文件。", "Error processing subtitle files. Please check if the subtitles are valid SRT files.": "处理字幕文件时出错。请检查字幕是否为有效的SRT文件。",
"Preview Merged Video": "预览合并后的视频", "Preview Merged Video": "预览合并后的视频",
"Video Path": "视频路径", "Video Path": "视频路径",
"Subtitle Path": "字幕路径" "Subtitle Path": "字幕路径",
"Enable Proxy": "启用代理",
"QwenVL model is available": "QwenVL 模型可用",
"QwenVL model is not available": "QwenVL 模型不可用"
} }
} }

View File

@ -0,0 +1,100 @@
import logging
from typing import List, Dict, Any, Optional
from app.utils import gemini_analyzer, qwenvl_analyzer
logger = logging.getLogger(__name__)
class VisionAnalyzer:
def __init__(self):
self.provider = None
self.api_key = None
self.model = None
self.base_url = None
self.analyzer = None
def initialize_gemini(self, api_key: str, model: str, base_url: str) -> None:
"""
初始化Gemini视觉分析器
Args:
api_key: Gemini API密钥
model: 模型名称
base_url: API基础URL
"""
self.provider = 'gemini'
self.api_key = api_key
self.model = model
self.base_url = base_url
self.analyzer = gemini_analyzer.VisionAnalyzer(
model_name=model,
api_key=api_key
)
def initialize_qwenvl(self, api_key: str, model: str, base_url: str) -> None:
"""
初始化QwenVL视觉分析器
Args:
api_key: 阿里云API密钥
model: 模型名称
base_url: API基础URL
"""
self.provider = 'qwenvl'
self.api_key = api_key
self.model = model
self.base_url = base_url
self.analyzer = qwenvl_analyzer.QwenAnalyzer(
model_name=model,
api_key=api_key
)
async def analyze_images(self, images: List[str], prompt: str, batch_size: int = 5) -> Dict[str, Any]:
"""
分析图片内容
Args:
images: 图片路径列表
prompt: 分析提示词
batch_size: 每批处理的图片数量默认为5
Returns:
Dict: 分析结果
"""
if not self.analyzer:
raise ValueError("未初始化视觉分析器")
return await self.analyzer.analyze_images(
images=images,
prompt=prompt,
batch_size=batch_size
)
def create_vision_analyzer(provider: str, **kwargs) -> VisionAnalyzer:
"""
创建视觉分析器实例
Args:
provider: 提供商名称 ('gemini' 'qwenvl')
**kwargs: 提供商特定的配置参数
Returns:
VisionAnalyzer: 配置好的视觉分析器实例
"""
analyzer = VisionAnalyzer()
if provider.lower() == 'gemini':
analyzer.initialize_gemini(
api_key=kwargs.get('api_key'),
model=kwargs.get('model'),
base_url=kwargs.get('base_url')
)
elif provider.lower() == 'qwenvl':
analyzer.initialize_qwenvl(
api_key=kwargs.get('api_key'),
model=kwargs.get('model'),
base_url=kwargs.get('base_url')
)
else:
raise ValueError(f"不支持的视觉分析提供商: {provider}")
return analyzer