mirror of
https://github.com/linyqh/NarratoAI.git
synced 2025-12-10 18:02:51 +00:00
feat(vision): 添加 QwenVL 视觉分析支持
- 新增 QwenVL 视觉分析器类,实现对阿里云 Qwen 模型的支持 - 更新基础设置界面,增加代理配置和 QwenVL 模型可用性检测 - 修改脚本生成逻辑,支持 QwenVL 模型的图像分析 - 重构视觉分析器初始化和调用接口,提高代码复用性和可维护性
This commit is contained in:
parent
0caa15e762
commit
f44d56110e
255
app/utils/qwenvl_analyzer.py
Normal file
255
app/utils/qwenvl_analyzer.py
Normal file
@ -0,0 +1,255 @@
|
||||
import json
|
||||
from typing import List, Union, Dict
|
||||
import os
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from tqdm import tqdm
|
||||
import asyncio
|
||||
from tenacity import retry, stop_after_attempt, RetryError, wait_exponential
|
||||
from openai import OpenAI
|
||||
import PIL.Image
|
||||
import base64
|
||||
import io
|
||||
import traceback
|
||||
|
||||
|
||||
class QwenAnalyzer:
|
||||
"""千问视觉分析器类"""
|
||||
|
||||
def __init__(self, model_name: str = "qwen-vl-max-latest", api_key: str = None):
|
||||
"""
|
||||
初始化千问视觉分析器
|
||||
Args:
|
||||
model_name: 模型名称,默认使用 qwen-vl-max-latest
|
||||
api_key: 阿里云API密钥
|
||||
"""
|
||||
if not api_key:
|
||||
raise ValueError("必须提供API密钥")
|
||||
|
||||
self.model_name = model_name
|
||||
self.api_key = api_key
|
||||
|
||||
# 配置API客户端
|
||||
self._configure_client()
|
||||
|
||||
def _configure_client(self):
|
||||
"""配置API客户端"""
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
)
|
||||
|
||||
def _image_to_base64(self, image: PIL.Image.Image) -> str:
|
||||
"""
|
||||
将PIL图片对象转换为base64字符串
|
||||
"""
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="JPEG")
|
||||
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10)
|
||||
)
|
||||
async def _generate_content_with_retry(self, prompt: str, batch: List[PIL.Image.Image]):
|
||||
"""使用重试机制的内部方法来调用千问API"""
|
||||
try:
|
||||
# 构建消息内容
|
||||
content = []
|
||||
|
||||
# 添加图片
|
||||
for img in batch:
|
||||
base64_image = self._image_to_base64(img)
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}"
|
||||
}
|
||||
})
|
||||
|
||||
# 添加文本提示
|
||||
content.append({
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
})
|
||||
|
||||
# 调用API
|
||||
response = await asyncio.to_thread(
|
||||
self.client.chat.completions.create,
|
||||
model=self.model_name,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": content
|
||||
}]
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"API调用错误: {str(e)}")
|
||||
raise RetryError("API调用失败")
|
||||
|
||||
async def analyze_images(self,
|
||||
images: Union[List[str], List[PIL.Image.Image]],
|
||||
prompt: str,
|
||||
batch_size: int = 5) -> List[Dict]:
|
||||
"""
|
||||
批量分析多张图片
|
||||
Args:
|
||||
images: 图片路径列表或PIL图片对象列表
|
||||
prompt: 分析提示词
|
||||
batch_size: 批处理大小
|
||||
Returns:
|
||||
分析结果列表
|
||||
"""
|
||||
try:
|
||||
# 保存原始图片路径(如果是路径列表的话)
|
||||
original_paths = images if isinstance(images[0], str) else None
|
||||
|
||||
# 加载图片
|
||||
if isinstance(images[0], str):
|
||||
logger.info("正在加载图片...")
|
||||
images = self.load_images(images)
|
||||
|
||||
# 验证图片列表
|
||||
if not images:
|
||||
raise ValueError("图片列表为空")
|
||||
|
||||
# 验证每个图片对象
|
||||
valid_images = []
|
||||
valid_paths = []
|
||||
for i, img in enumerate(images):
|
||||
if not isinstance(img, PIL.Image.Image):
|
||||
logger.error(f"无效的图片对象,索引 {i}: {type(img)}")
|
||||
continue
|
||||
valid_images.append(img)
|
||||
if original_paths:
|
||||
valid_paths.append(original_paths[i])
|
||||
|
||||
if not valid_images:
|
||||
raise ValueError("没有有效的图片对象")
|
||||
|
||||
images = valid_images
|
||||
results = []
|
||||
total_batches = (len(images) + batch_size - 1) // batch_size
|
||||
|
||||
with tqdm(total=total_batches, desc="分析进度") as pbar:
|
||||
for i in range(0, len(images), batch_size):
|
||||
batch = images[i:i + batch_size]
|
||||
batch_paths = valid_paths[i:i + batch_size] if valid_paths else None
|
||||
retry_count = 0
|
||||
|
||||
while retry_count < 3:
|
||||
try:
|
||||
# 在每个批次处理前添加小延迟
|
||||
if i > 0:
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# 确保每个批次的图片都是有效的
|
||||
valid_batch = [img for img in batch if isinstance(img, PIL.Image.Image)]
|
||||
if not valid_batch:
|
||||
raise ValueError(f"批次 {i // batch_size} 中没有有效的图片")
|
||||
|
||||
response = await self._generate_content_with_retry(prompt, valid_batch)
|
||||
result_dict = {
|
||||
'batch_index': i // batch_size,
|
||||
'images_processed': len(valid_batch),
|
||||
'response': response,
|
||||
'model_used': self.model_name
|
||||
}
|
||||
|
||||
# 添加图片路径信息(如果有的话)
|
||||
if batch_paths:
|
||||
result_dict['image_paths'] = batch_paths
|
||||
|
||||
results.append(result_dict)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
error_msg = f"批次 {i // batch_size} 处理出错: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
|
||||
if retry_count >= 3:
|
||||
results.append({
|
||||
'batch_index': i // batch_size,
|
||||
'images_processed': len(batch),
|
||||
'error': error_msg,
|
||||
'model_used': self.model_name,
|
||||
'image_paths': batch_paths if batch_paths else []
|
||||
})
|
||||
else:
|
||||
logger.info(f"批次 {i // batch_size} 处理失败,等待60秒后重试当前批次...")
|
||||
await asyncio.sleep(60)
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"图片分析过程中发生错误: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
def save_results_to_txt(self, results: List[Dict], output_dir: str):
|
||||
"""将分析结果保存到txt文件"""
|
||||
# 确保输出目录存在
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
for i, result in enumerate(results):
|
||||
response_text = result['response']
|
||||
|
||||
# 如果有图片路径信息,使用它来生成文件名
|
||||
if result.get('image_paths'):
|
||||
image_paths = result['image_paths']
|
||||
img_name_start = Path(image_paths[0]).stem.split('_')[-1]
|
||||
img_name_end = Path(image_paths[-1]).stem.split('_')[-1]
|
||||
file_name = f"frame_{img_name_start}_{img_name_end}.txt"
|
||||
else:
|
||||
# 如果没有路径信息,使用批次索引
|
||||
file_name = f"batch_{result['batch_index']}.txt"
|
||||
|
||||
txt_path = os.path.join(output_dir, file_name)
|
||||
|
||||
# 保存结果到txt文件
|
||||
with open(txt_path, 'w', encoding='utf-8') as f:
|
||||
f.write(response_text.strip())
|
||||
logger.info(f"已保存分析结果到: {txt_path}")
|
||||
|
||||
def load_images(self, image_paths: List[str]) -> List[PIL.Image.Image]:
|
||||
"""
|
||||
加载多张图片
|
||||
Args:
|
||||
image_paths: 图片路径列表
|
||||
Returns:
|
||||
加载后的PIL Image对象列表
|
||||
"""
|
||||
images = []
|
||||
failed_images = []
|
||||
|
||||
for img_path in image_paths:
|
||||
try:
|
||||
if not os.path.exists(img_path):
|
||||
logger.error(f"图片文件不存在: {img_path}")
|
||||
failed_images.append(img_path)
|
||||
continue
|
||||
|
||||
img = PIL.Image.open(img_path)
|
||||
# 确保图片被完全加载
|
||||
img.load()
|
||||
# 转换为RGB模式
|
||||
if img.mode != 'RGB':
|
||||
img = img.convert('RGB')
|
||||
images.append(img)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"无法加载图片 {img_path}: {str(e)}")
|
||||
failed_images.append(img_path)
|
||||
|
||||
if failed_images:
|
||||
logger.warning(f"以下图片加载失败:\n{json.dumps(failed_images, indent=2, ensure_ascii=False)}")
|
||||
|
||||
if not images:
|
||||
raise ValueError("没有成功加载任何图片")
|
||||
|
||||
return images
|
||||
@ -52,18 +52,34 @@ def render_language_settings(tr):
|
||||
|
||||
def render_proxy_settings(tr):
|
||||
"""渲染代理设置"""
|
||||
proxy_url_http = config.proxy.get("http", "") or os.getenv("VPN_PROXY_URL", "")
|
||||
proxy_url_https = config.proxy.get("https", "") or os.getenv("VPN_PROXY_URL", "")
|
||||
# 获取当前代理状态
|
||||
proxy_enabled = config.proxy.get("enabled", True)
|
||||
proxy_url_http = config.proxy.get("http")
|
||||
proxy_url_https = config.proxy.get("https")
|
||||
|
||||
HTTP_PROXY = st.text_input(tr("HTTP_PROXY"), value=proxy_url_http)
|
||||
HTTPS_PROXY = st.text_input(tr("HTTPs_PROXY"), value=proxy_url_https)
|
||||
# 添加代理开关
|
||||
proxy_enabled = st.checkbox(tr("Enable Proxy"), value=proxy_enabled)
|
||||
|
||||
# 保存代理开关状态
|
||||
config.proxy["enabled"] = proxy_enabled
|
||||
|
||||
if HTTP_PROXY:
|
||||
config.proxy["http"] = HTTP_PROXY
|
||||
os.environ["HTTP_PROXY"] = HTTP_PROXY
|
||||
if HTTPS_PROXY:
|
||||
config.proxy["https"] = HTTPS_PROXY
|
||||
os.environ["HTTPS_PROXY"] = HTTPS_PROXY
|
||||
# 只有在代理启用时才显示代理设置输入框
|
||||
if proxy_enabled:
|
||||
HTTP_PROXY = st.text_input(tr("HTTP_PROXY"), value=proxy_url_http)
|
||||
HTTPS_PROXY = st.text_input(tr("HTTPs_PROXY"), value=proxy_url_https)
|
||||
|
||||
if HTTP_PROXY:
|
||||
config.proxy["http"] = HTTP_PROXY
|
||||
os.environ["HTTP_PROXY"] = HTTP_PROXY
|
||||
if HTTPS_PROXY:
|
||||
config.proxy["https"] = HTTPS_PROXY
|
||||
os.environ["HTTPS_PROXY"] = HTTPS_PROXY
|
||||
else:
|
||||
# 当代理被禁用时,清除环境变量和配置
|
||||
os.environ.pop("HTTP_PROXY", None)
|
||||
os.environ.pop("HTTPS_PROXY", None)
|
||||
config.proxy["http"] = ""
|
||||
config.proxy["https"] = ""
|
||||
|
||||
|
||||
def test_vision_model_connection(api_key, base_url, model_name, provider, tr):
|
||||
@ -90,6 +106,28 @@ def test_vision_model_connection(api_key, base_url, model_name, provider, tr):
|
||||
except Exception as e:
|
||||
return False, f"{tr('gemini model is not available')}: {str(e)}"
|
||||
|
||||
elif provider.lower() == 'qwenvl':
|
||||
from openai import OpenAI
|
||||
try:
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
)
|
||||
|
||||
# 发送一个简单的测试请求
|
||||
response = client.chat.completions.create(
|
||||
model=model_name or "qwen-vl-max-latest",
|
||||
messages=[{"role": "user", "content": "直接回复我文本'当前网络可用'"}]
|
||||
)
|
||||
|
||||
if response and response.choices:
|
||||
return True, tr("QwenVL model is available")
|
||||
else:
|
||||
return False, tr("QwenVL model returned invalid response")
|
||||
|
||||
except Exception as e:
|
||||
return False, f"{tr('QwenVL model is not available')}: {str(e)}"
|
||||
|
||||
elif provider.lower() == 'narratoapi':
|
||||
import requests
|
||||
try:
|
||||
@ -116,7 +154,7 @@ def render_vision_llm_settings(tr):
|
||||
st.subheader(tr("Vision Model Settings"))
|
||||
|
||||
# 视频分析模型提供商选择
|
||||
vision_providers = ['Gemini', 'NarratoAPI(待发布)', 'QwenVL(待发布)']
|
||||
vision_providers = ['Gemini', 'QwenVL', 'NarratoAPI(待发布)']
|
||||
saved_vision_provider = config.app.get("vision_llm_provider", "Gemini").lower()
|
||||
saved_provider_index = 0
|
||||
|
||||
@ -142,18 +180,33 @@ def render_vision_llm_settings(tr):
|
||||
# 渲染视觉模型配置输入框
|
||||
st_vision_api_key = st.text_input(tr("Vision API Key"), value=vision_api_key, type="password")
|
||||
|
||||
# 当选择 Gemini 时禁用 base_url 输入
|
||||
if vision_provider.lower() == 'gemini':
|
||||
# 根据不同提供商设置默认值和帮助信息
|
||||
if vision_provider == 'gemini':
|
||||
st_vision_base_url = st.text_input(
|
||||
tr("Vision Base URL"),
|
||||
value=vision_base_url,
|
||||
disabled=True,
|
||||
help=tr("Gemini API does not require a base URL")
|
||||
)
|
||||
st_vision_model_name = st.text_input(
|
||||
tr("Vision Model Name"),
|
||||
value=vision_model_name or "gemini-1.5-flash",
|
||||
help=tr("Default: gemini-1.5-flash")
|
||||
)
|
||||
elif vision_provider == 'qwenvl':
|
||||
st_vision_base_url = st.text_input(
|
||||
tr("Vision Base URL"),
|
||||
value=vision_base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
help=tr("Default: https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||||
)
|
||||
st_vision_model_name = st.text_input(
|
||||
tr("Vision Model Name"),
|
||||
value=vision_model_name or "qwen-vl-max-latest",
|
||||
help=tr("Default: qwen-vl-max-latest")
|
||||
)
|
||||
else:
|
||||
st_vision_base_url = st.text_input(tr("Vision Base URL"), value=vision_base_url)
|
||||
|
||||
st_vision_model_name = st.text_input(tr("Vision Model Name"), value=vision_model_name)
|
||||
st_vision_model_name = st.text_input(tr("Vision Model Name"), value=vision_model_name)
|
||||
|
||||
# 在配置输入框后添加测试按钮
|
||||
if st.button(tr("Test Connection"), key="test_vision_connection"):
|
||||
@ -174,7 +227,7 @@ def render_vision_llm_settings(tr):
|
||||
# 保存视觉模型配置
|
||||
if st_vision_api_key:
|
||||
config.app[f"vision_{vision_provider}_api_key"] = st_vision_api_key
|
||||
st.session_state[f"vision_{vision_provider}_api_key"] = st_vision_api_key # 用于script_settings.py
|
||||
st.session_state[f"vision_{vision_provider}_api_key"] = st_vision_api_key
|
||||
if st_vision_base_url:
|
||||
config.app[f"vision_{vision_provider}_base_url"] = st_vision_base_url
|
||||
st.session_state[f"vision_{vision_provider}_base_url"] = st_vision_base_url
|
||||
@ -182,81 +235,6 @@ def render_vision_llm_settings(tr):
|
||||
config.app[f"vision_{vision_provider}_model_name"] = st_vision_model_name
|
||||
st.session_state[f"vision_{vision_provider}_model_name"] = st_vision_model_name
|
||||
|
||||
# # NarratoAPI 特殊配置
|
||||
# if vision_provider == 'narratoapi':
|
||||
# st.subheader(tr("Narrato Additional Settings"))
|
||||
#
|
||||
# # Narrato API 基础配置
|
||||
# narrato_api_key = st.text_input(
|
||||
# tr("Narrato API Key"),
|
||||
# value=config.app.get("narrato_api_key", ""),
|
||||
# type="password",
|
||||
# help="用于访问 Narrato API 的密钥"
|
||||
# )
|
||||
# if narrato_api_key:
|
||||
# config.app["narrato_api_key"] = narrato_api_key
|
||||
# st.session_state['narrato_api_key'] = narrato_api_key
|
||||
#
|
||||
# narrato_api_url = st.text_input(
|
||||
# tr("Narrato API URL"),
|
||||
# value=config.app.get("narrato_api_url", "http://127.0.0.1:8000/api/v1/video/analyze")
|
||||
# )
|
||||
# if narrato_api_url:
|
||||
# config.app["narrato_api_url"] = narrato_api_url
|
||||
# st.session_state['narrato_api_url'] = narrato_api_url
|
||||
#
|
||||
# # 视频分析模型配置
|
||||
# st.markdown("##### " + tr("Vision Model Settings"))
|
||||
# narrato_vision_model = st.text_input(
|
||||
# tr("Vision Model Name"),
|
||||
# value=config.app.get("narrato_vision_model", "gemini-1.5-flash")
|
||||
# )
|
||||
# narrato_vision_key = st.text_input(
|
||||
# tr("Vision Model API Key"),
|
||||
# value=config.app.get("narrato_vision_key", ""),
|
||||
# type="password",
|
||||
# help="用于视频分析的模 API Key"
|
||||
# )
|
||||
#
|
||||
# if narrato_vision_model:
|
||||
# config.app["narrato_vision_model"] = narrato_vision_model
|
||||
# st.session_state['narrato_vision_model'] = narrato_vision_model
|
||||
# if narrato_vision_key:
|
||||
# config.app["narrato_vision_key"] = narrato_vision_key
|
||||
# st.session_state['narrato_vision_key'] = narrato_vision_key
|
||||
#
|
||||
# # 文案生成模型配置
|
||||
# st.markdown("##### " + tr("Text Generation Model Settings"))
|
||||
# narrato_llm_model = st.text_input(
|
||||
# tr("LLM Model Name"),
|
||||
# value=config.app.get("narrato_llm_model", "qwen-plus")
|
||||
# )
|
||||
# narrato_llm_key = st.text_input(
|
||||
# tr("LLM Model API Key"),
|
||||
# value=config.app.get("narrato_llm_key", ""),
|
||||
# type="password",
|
||||
# help="用于文案生成的模型 API Key"
|
||||
# )
|
||||
#
|
||||
# if narrato_llm_model:
|
||||
# config.app["narrato_llm_model"] = narrato_llm_model
|
||||
# st.session_state['narrato_llm_model'] = narrato_llm_model
|
||||
# if narrato_llm_key:
|
||||
# config.app["narrato_llm_key"] = narrato_llm_key
|
||||
# st.session_state['narrato_llm_key'] = narrato_llm_key
|
||||
#
|
||||
# # 批处理配置
|
||||
# narrato_batch_size = st.number_input(
|
||||
# tr("Batch Size"),
|
||||
# min_value=1,
|
||||
# max_value=50,
|
||||
# value=config.app.get("narrato_batch_size", 10),
|
||||
# help="每批处理的图片数量"
|
||||
# )
|
||||
# if narrato_batch_size:
|
||||
# config.app["narrato_batch_size"] = narrato_batch_size
|
||||
# st.session_state['narrato_batch_size'] = narrato_batch_size
|
||||
|
||||
|
||||
def test_text_model_connection(api_key, base_url, model_name, provider, tr):
|
||||
"""测试文本模型连接
|
||||
@ -328,6 +306,7 @@ def test_text_model_connection(api_key, base_url, model_name, provider, tr):
|
||||
except Exception as e:
|
||||
return False, f"{tr('Connection failed')}: {str(e)}"
|
||||
|
||||
|
||||
def render_text_llm_settings(tr):
|
||||
"""渲染文案生成模型设置"""
|
||||
st.subheader(tr("Text Generation Model Settings"))
|
||||
|
||||
@ -14,7 +14,7 @@ from loguru import logger
|
||||
from app.config import config
|
||||
from app.models.schema import VideoClipParams
|
||||
from app.utils.script_generator import ScriptProcessor
|
||||
from app.utils import utils, check_script, gemini_analyzer, video_processor, video_processor_v2
|
||||
from app.utils import utils, check_script, gemini_analyzer, video_processor, video_processor_v2, qwenvl_analyzer
|
||||
from webui.utils import file_utils
|
||||
|
||||
|
||||
@ -472,267 +472,168 @@ def generate_script(tr, params):
|
||||
vision_llm_provider = st.session_state.get('vision_llm_providers').lower()
|
||||
logger.debug(f"Vision LLM 提供商: {vision_llm_provider}")
|
||||
|
||||
if vision_llm_provider == 'gemini':
|
||||
try:
|
||||
# ===================初始化视觉分析器===================
|
||||
update_progress(30, "正在初始化视觉分析器...")
|
||||
|
||||
# 从配置中获取 Gemini 相关配置
|
||||
try:
|
||||
# ===================初始化视觉分析器===================
|
||||
update_progress(30, "正在初始化视觉分析器...")
|
||||
|
||||
# 从配置中获取相关配置
|
||||
if vision_llm_provider == 'gemini':
|
||||
vision_api_key = st.session_state.get('vision_gemini_api_key')
|
||||
vision_model = st.session_state.get('vision_gemini_model_name')
|
||||
vision_base_url = st.session_state.get('vision_gemini_base_url')
|
||||
|
||||
if not vision_api_key or not vision_model:
|
||||
raise ValueError("未配置 Gemini API Key 或者 型,请在基础设置配置")
|
||||
elif vision_llm_provider == 'qwenvl':
|
||||
vision_api_key = st.session_state.get('vision_qwenvl_api_key')
|
||||
vision_model = st.session_state.get('vision_qwenvl_model_name', 'qwen-vl-max-latest')
|
||||
vision_base_url = st.session_state.get('vision_qwenvl_base_url', 'https://dashscope.aliyuncs.com/compatible-mode/v1')
|
||||
else:
|
||||
raise ValueError(f"不支持的视觉分析提供商: {vision_llm_provider}")
|
||||
|
||||
analyzer = gemini_analyzer.VisionAnalyzer(
|
||||
model_name=vision_model,
|
||||
api_key=vision_api_key,
|
||||
# 创建视觉分析器实例
|
||||
analyzer = create_vision_analyzer(
|
||||
provider=vision_llm_provider,
|
||||
api_key=vision_api_key,
|
||||
model=vision_model,
|
||||
base_url=vision_base_url
|
||||
)
|
||||
|
||||
update_progress(40, "正在分析关键帧...")
|
||||
|
||||
# ===================创建异步事件循环===================
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# 执行异步分析
|
||||
vision_batch_size = st.session_state.get('vision_batch_size') or config.frames.get("vision_batch_size")
|
||||
results = loop.run_until_complete(
|
||||
analyzer.analyze_images(
|
||||
images=keyframe_files,
|
||||
prompt=config.app.get('vision_analysis_prompt'),
|
||||
batch_size=vision_batch_size
|
||||
)
|
||||
)
|
||||
loop.close()
|
||||
|
||||
update_progress(40, "正在分析关键帧...")
|
||||
# ===================处理分析结果===================
|
||||
update_progress(60, "正在整理分析结果...")
|
||||
|
||||
# 合并所有批次的析结果
|
||||
frame_analysis = ""
|
||||
prev_batch_files = None
|
||||
|
||||
# ===================创建异步事件循环===================
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# 执行异步分析
|
||||
vision_batch_size = st.session_state.get('vision_batch_size') or config.frames.get("vision_batch_size")
|
||||
results = loop.run_until_complete(
|
||||
analyzer.analyze_images(
|
||||
images=keyframe_files,
|
||||
prompt=config.app.get('vision_analysis_prompt'),
|
||||
batch_size=vision_batch_size
|
||||
)
|
||||
)
|
||||
loop.close()
|
||||
for result in results:
|
||||
if 'error' in result:
|
||||
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
|
||||
|
||||
# ===================处理分析结果===================
|
||||
update_progress(60, "正在整理分析结果...")
|
||||
# 获取当前批次的文件列表 keyframe_001136_000045.jpg 将 000045 精度提升到 毫秒
|
||||
batch_files = get_batch_files(keyframe_files, result, vision_batch_size)
|
||||
logger.debug(f"批次 {result['batch_index']} 处理完成,共 {len(batch_files)} 张图片")
|
||||
# logger.debug(batch_files)
|
||||
|
||||
# 合并所有批次的析结果
|
||||
frame_analysis = ""
|
||||
prev_batch_files = None
|
||||
first_timestamp, last_timestamp, _ = get_batch_timestamps(batch_files, prev_batch_files)
|
||||
logger.debug(f"处理时间戳: {first_timestamp}-{last_timestamp}")
|
||||
|
||||
# 添加带时间戳的分析结果
|
||||
frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n"
|
||||
frame_analysis += result['response']
|
||||
frame_analysis += "\n"
|
||||
|
||||
# 更新上一个批次的文件
|
||||
prev_batch_files = batch_files
|
||||
|
||||
if not frame_analysis.strip():
|
||||
raise Exception("未能生成有效的帧分析结果")
|
||||
|
||||
# 保存分析结果
|
||||
analysis_path = os.path.join(utils.temp_dir(), "frame_analysis.txt")
|
||||
with open(analysis_path, 'w', encoding='utf-8') as f:
|
||||
f.write(frame_analysis)
|
||||
|
||||
update_progress(70, "正在生成脚本...")
|
||||
|
||||
for result in results:
|
||||
if 'error' in result:
|
||||
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
|
||||
continue
|
||||
# 获取当前批次的文件列表 keyframe_001136_000045.jpg 将 000045 精度提升到 毫秒
|
||||
batch_files = get_batch_files(keyframe_files, result, vision_batch_size)
|
||||
logger.debug(f"批次 {result['batch_index']} 处理完成,共 {len(batch_files)} 张图片")
|
||||
# logger.debug(batch_files)
|
||||
|
||||
first_timestamp, last_timestamp, _ = get_batch_timestamps(batch_files, prev_batch_files)
|
||||
logger.debug(f"处理时间戳: {first_timestamp}-{last_timestamp}")
|
||||
|
||||
# 添加带时间戳的分析结果
|
||||
frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n"
|
||||
frame_analysis += result['response']
|
||||
frame_analysis += "\n"
|
||||
|
||||
# 更新上一个批次的文件
|
||||
prev_batch_files = batch_files
|
||||
|
||||
if not frame_analysis.strip():
|
||||
raise Exception("未能生成有效的帧分析结果")
|
||||
|
||||
# 保存分析结果
|
||||
analysis_path = os.path.join(utils.temp_dir(), "frame_analysis.txt")
|
||||
with open(analysis_path, 'w', encoding='utf-8') as f:
|
||||
f.write(frame_analysis)
|
||||
|
||||
update_progress(70, "正在生成脚本...")
|
||||
# 从配置中获取文本生成相关配置
|
||||
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
|
||||
text_api_key = config.app.get(f'text_{text_provider}_api_key')
|
||||
text_model = config.app.get(f'text_{text_provider}_model_name')
|
||||
text_base_url = config.app.get(f'text_{text_provider}_base_url')
|
||||
|
||||
# 构建帧内容列表
|
||||
frame_content_list = []
|
||||
prev_batch_files = None
|
||||
|
||||
# 从配置中获取文本生成相关配置
|
||||
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
|
||||
text_api_key = config.app.get(f'text_{text_provider}_api_key')
|
||||
text_model = config.app.get(f'text_{text_provider}_model_name')
|
||||
text_base_url = config.app.get(f'text_{text_provider}_base_url')
|
||||
for i, result in enumerate(results):
|
||||
if 'error' in result:
|
||||
continue
|
||||
|
||||
# 构建帧内容列表
|
||||
frame_content_list = []
|
||||
prev_batch_files = None
|
||||
|
||||
for i, result in enumerate(results):
|
||||
if 'error' in result:
|
||||
continue
|
||||
|
||||
batch_files = get_batch_files(keyframe_files, result, vision_batch_size)
|
||||
_, _, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files)
|
||||
|
||||
frame_content = {
|
||||
"timestamp": timestamp_range,
|
||||
"picture": result['response'],
|
||||
"narration": "",
|
||||
"OST": 2
|
||||
}
|
||||
frame_content_list.append(frame_content)
|
||||
|
||||
logger.debug(f"添加帧内容: 时间范围={timestamp_range}, 分析结果长度={len(result['response'])}")
|
||||
|
||||
# 更新上一个批次的文件
|
||||
prev_batch_files = batch_files
|
||||
batch_files = get_batch_files(keyframe_files, result, vision_batch_size)
|
||||
_, _, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files)
|
||||
|
||||
if not frame_content_list:
|
||||
raise Exception("没有有效的帧内容可以处理")
|
||||
|
||||
# ===================开始生成文案===================
|
||||
update_progress(80, "正在生成文案...")
|
||||
# 校验配置
|
||||
api_params = {
|
||||
"vision_api_key": vision_api_key,
|
||||
"vision_model_name": vision_model,
|
||||
"vision_base_url": vision_base_url or "",
|
||||
"text_api_key": text_api_key,
|
||||
"text_model_name": text_model,
|
||||
"text_base_url": text_base_url or ""
|
||||
frame_content = {
|
||||
"timestamp": timestamp_range,
|
||||
"picture": result['response'],
|
||||
"narration": "",
|
||||
"OST": 2
|
||||
}
|
||||
headers = {
|
||||
'accept': 'application/json',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
session = requests.Session()
|
||||
retry_strategy = Retry(
|
||||
total=3,
|
||||
backoff_factor=1,
|
||||
status_forcelist=[500, 502, 503, 504]
|
||||
)
|
||||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||
session.mount("https://", adapter)
|
||||
try:
|
||||
response = session.post(
|
||||
f"{config.app.get('narrato_api_url')}/video/config",
|
||||
headers=headers,
|
||||
json=api_params,
|
||||
timeout=30,
|
||||
verify=True
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
custom_prompt = st.session_state.get('custom_prompt', '')
|
||||
processor = ScriptProcessor(
|
||||
model_name=text_model,
|
||||
api_key=text_api_key,
|
||||
prompt=custom_prompt,
|
||||
base_url=text_base_url or "",
|
||||
video_theme=st.session_state.get('video_theme', '')
|
||||
)
|
||||
|
||||
# 处理帧内容生成脚本
|
||||
script_result = processor.process_frames(frame_content_list)
|
||||
|
||||
# 结果转换为JSON字符串
|
||||
script = json.dumps(script_result, ensure_ascii=False, indent=2)
|
||||
frame_content_list.append(frame_content)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"大模型处理过程中发生错误\n{traceback.format_exc()}")
|
||||
raise Exception(f"分析失败: {str(e)}")
|
||||
logger.debug(f"添加帧内容: 时间范围={timestamp_range}, 分析结果长度={len(result['response'])}")
|
||||
|
||||
# 更新上一个批次的文件
|
||||
prev_batch_files = batch_files
|
||||
|
||||
if not frame_content_list:
|
||||
raise Exception("没有有效的帧内容可以处理")
|
||||
|
||||
elif vision_llm_provider == 'narratoapi': # NarratoAPI
|
||||
# ===================开始生成文案===================
|
||||
update_progress(80, "正在生成文案...")
|
||||
# 校验配置
|
||||
api_params = {
|
||||
"vision_api_key": vision_api_key,
|
||||
"vision_model_name": vision_model,
|
||||
"vision_base_url": vision_base_url or "",
|
||||
"text_api_key": text_api_key,
|
||||
"text_model_name": text_model,
|
||||
"text_base_url": text_base_url or ""
|
||||
}
|
||||
headers = {
|
||||
'accept': 'application/json',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
session = requests.Session()
|
||||
retry_strategy = Retry(
|
||||
total=3,
|
||||
backoff_factor=1,
|
||||
status_forcelist=[500, 502, 503, 504]
|
||||
)
|
||||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||
session.mount("https://", adapter)
|
||||
try:
|
||||
# 创建临时目录
|
||||
temp_dir = utils.temp_dir("narrato")
|
||||
|
||||
# 打包关键帧
|
||||
update_progress(30, "正在打包关键帧...")
|
||||
zip_path = os.path.join(temp_dir, f"keyframes_{int(time.time())}.zip")
|
||||
if not file_utils.create_zip(keyframe_files, zip_path):
|
||||
raise Exception("打包关键帧失败")
|
||||
|
||||
# 获取API配置
|
||||
api_url = st.session_state.get('narrato_api_url')
|
||||
api_key = st.session_state.get('narrato_api_key')
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("未配置 Narrato API Key,请在基础设置中配置")
|
||||
|
||||
# 准备API请求
|
||||
headers = {
|
||||
'X-API-Key': api_key,
|
||||
'accept': 'application/json'
|
||||
}
|
||||
|
||||
api_params = {
|
||||
'batch_size': st.session_state.get('narrato_batch_size', 10),
|
||||
'use_ai': False,
|
||||
'start_offset': 0,
|
||||
'vision_model': st.session_state.get('narrato_vision_model', 'gemini-1.5-flash'),
|
||||
'vision_api_key': st.session_state.get('narrato_vision_key'),
|
||||
'llm_model': st.session_state.get('narrato_llm_model', 'qwen-plus'),
|
||||
'llm_api_key': st.session_state.get('narrato_llm_key'),
|
||||
'custom_prompt': st.session_state.get('custom_prompt', '')
|
||||
}
|
||||
|
||||
# 发送API请求
|
||||
logger.info(f"请求NarratoAPI: {api_url}")
|
||||
update_progress(40, "正在上传文件...")
|
||||
with open(zip_path, 'rb') as f:
|
||||
files = {'file': (os.path.basename(zip_path), f, 'application/x-zip-compressed')}
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{api_url}/video/analyze",
|
||||
headers=headers,
|
||||
params=api_params,
|
||||
files=files,
|
||||
timeout=30 # 设置超时时间
|
||||
)
|
||||
response.raise_for_status()
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Narrato API 请求失败:\n{traceback.format_exc()}")
|
||||
raise Exception(f"API请求失败: {str(e)}")
|
||||
|
||||
task_data = response.json()
|
||||
task_id = task_data["data"].get('task_id')
|
||||
if not task_id:
|
||||
raise Exception(f"无效的API响应: {response.text}")
|
||||
|
||||
# 轮询任务状态
|
||||
update_progress(50, "正在等待分析结果...")
|
||||
retry_count = 0
|
||||
max_retries = 60 # 最多等待2分钟
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
status_response = requests.get(
|
||||
f"{api_url}/video/tasks/{task_id}",
|
||||
headers=headers,
|
||||
timeout=10
|
||||
)
|
||||
status_response.raise_for_status()
|
||||
task_status = status_response.json()['data']
|
||||
|
||||
if task_status['status'] == 'SUCCESS':
|
||||
script = task_status['result']['data']
|
||||
break
|
||||
elif task_status['status'] in ['FAILURE', 'RETRY']:
|
||||
raise Exception(f"任务失败: {task_status.get('error')}")
|
||||
|
||||
retry_count += 1
|
||||
time.sleep(2)
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.warning(f"获取任务状态失败,重试中: {str(e)}")
|
||||
retry_count += 1
|
||||
time.sleep(2)
|
||||
continue
|
||||
|
||||
if retry_count >= max_retries:
|
||||
raise Exception("任务执行超时")
|
||||
|
||||
response = session.post(
|
||||
f"{config.app.get('narrato_api_url')}/video/config",
|
||||
headers=headers,
|
||||
json=api_params,
|
||||
timeout=30,
|
||||
verify=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"NarratoAPI 处理过程中发生错误\n{traceback.format_exc()}")
|
||||
raise Exception(f"NarratoAPI 处理失败: {str(e)}")
|
||||
finally:
|
||||
# 清理临时文件
|
||||
try:
|
||||
if os.path.exists(zip_path):
|
||||
os.remove(zip_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"清理临时文件失败: {str(e)}")
|
||||
pass
|
||||
custom_prompt = st.session_state.get('custom_prompt', '')
|
||||
processor = ScriptProcessor(
|
||||
model_name=text_model,
|
||||
api_key=text_api_key,
|
||||
prompt=custom_prompt,
|
||||
base_url=text_base_url or "",
|
||||
video_theme=st.session_state.get('video_theme', '')
|
||||
)
|
||||
|
||||
else:
|
||||
logger.exception("Vision Model 未启用,请检查配置")
|
||||
# 处理帧内容生成脚本
|
||||
script_result = processor.process_frames(frame_content_list)
|
||||
|
||||
# 结果转换为JSON字符串
|
||||
script = json.dumps(script_result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"大模型处理过程中发生错误\n{traceback.format_exc()}")
|
||||
raise Exception(f"分析失败: {str(e)}")
|
||||
|
||||
if script is None:
|
||||
st.error("生成脚本失败,请检查日志")
|
||||
@ -823,3 +724,12 @@ def get_script_params():
|
||||
'video_name': st.session_state.get('video_name', ''),
|
||||
'video_plot': st.session_state.get('video_plot', '')
|
||||
}
|
||||
|
||||
|
||||
def create_vision_analyzer(provider, api_key, model, base_url):
|
||||
if provider == 'gemini':
|
||||
return gemini_analyzer.VisionAnalyzer(model_name=model, api_key=api_key)
|
||||
elif provider == 'qwenvl':
|
||||
return qwenvl_analyzer.QwenAnalyzer(model_name=model, api_key=api_key)
|
||||
else:
|
||||
raise ValueError(f"不支持的视觉分析提供商: {provider}")
|
||||
|
||||
@ -163,6 +163,9 @@
|
||||
"Error processing subtitle files. Please check if the subtitles are valid SRT files.": "处理字幕文件时出错。请检查字幕是否为有效的SRT文件。",
|
||||
"Preview Merged Video": "预览合并后的视频",
|
||||
"Video Path": "视频路径",
|
||||
"Subtitle Path": "字幕路径"
|
||||
"Subtitle Path": "字幕路径",
|
||||
"Enable Proxy": "启用代理",
|
||||
"QwenVL model is available": "QwenVL 模型可用",
|
||||
"QwenVL model is not available": "QwenVL 模型不可用"
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,100 @@
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from app.utils import gemini_analyzer, qwenvl_analyzer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class VisionAnalyzer:
|
||||
def __init__(self):
|
||||
self.provider = None
|
||||
self.api_key = None
|
||||
self.model = None
|
||||
self.base_url = None
|
||||
self.analyzer = None
|
||||
|
||||
def initialize_gemini(self, api_key: str, model: str, base_url: str) -> None:
|
||||
"""
|
||||
初始化Gemini视觉分析器
|
||||
|
||||
Args:
|
||||
api_key: Gemini API密钥
|
||||
model: 模型名称
|
||||
base_url: API基础URL
|
||||
"""
|
||||
self.provider = 'gemini'
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
self.analyzer = gemini_analyzer.VisionAnalyzer(
|
||||
model_name=model,
|
||||
api_key=api_key
|
||||
)
|
||||
|
||||
def initialize_qwenvl(self, api_key: str, model: str, base_url: str) -> None:
|
||||
"""
|
||||
初始化QwenVL视觉分析器
|
||||
|
||||
Args:
|
||||
api_key: 阿里云API密钥
|
||||
model: 模型名称
|
||||
base_url: API基础URL
|
||||
"""
|
||||
self.provider = 'qwenvl'
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
self.analyzer = qwenvl_analyzer.QwenAnalyzer(
|
||||
model_name=model,
|
||||
api_key=api_key
|
||||
)
|
||||
|
||||
async def analyze_images(self, images: List[str], prompt: str, batch_size: int = 5) -> Dict[str, Any]:
|
||||
"""
|
||||
分析图片内容
|
||||
|
||||
Args:
|
||||
images: 图片路径列表
|
||||
prompt: 分析提示词
|
||||
batch_size: 每批处理的图片数量,默认为5
|
||||
|
||||
Returns:
|
||||
Dict: 分析结果
|
||||
"""
|
||||
if not self.analyzer:
|
||||
raise ValueError("未初始化视觉分析器")
|
||||
|
||||
return await self.analyzer.analyze_images(
|
||||
images=images,
|
||||
prompt=prompt,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
def create_vision_analyzer(provider: str, **kwargs) -> VisionAnalyzer:
|
||||
"""
|
||||
创建视觉分析器实例
|
||||
|
||||
Args:
|
||||
provider: 提供商名称 ('gemini' 或 'qwenvl')
|
||||
**kwargs: 提供商特定的配置参数
|
||||
|
||||
Returns:
|
||||
VisionAnalyzer: 配置好的视觉分析器实例
|
||||
"""
|
||||
analyzer = VisionAnalyzer()
|
||||
|
||||
if provider.lower() == 'gemini':
|
||||
analyzer.initialize_gemini(
|
||||
api_key=kwargs.get('api_key'),
|
||||
model=kwargs.get('model'),
|
||||
base_url=kwargs.get('base_url')
|
||||
)
|
||||
elif provider.lower() == 'qwenvl':
|
||||
analyzer.initialize_qwenvl(
|
||||
api_key=kwargs.get('api_key'),
|
||||
model=kwargs.get('model'),
|
||||
base_url=kwargs.get('base_url')
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"不支持的视觉分析提供商: {provider}")
|
||||
|
||||
return analyzer
|
||||
Loading…
x
Reference in New Issue
Block a user