diff --git a/app/utils/qwenvl_analyzer.py b/app/utils/qwenvl_analyzer.py new file mode 100644 index 0000000..254cee1 --- /dev/null +++ b/app/utils/qwenvl_analyzer.py @@ -0,0 +1,255 @@ +import json +from typing import List, Union, Dict +import os +from pathlib import Path +from loguru import logger +from tqdm import tqdm +import asyncio +from tenacity import retry, stop_after_attempt, RetryError, wait_exponential +from openai import OpenAI +import PIL.Image +import base64 +import io +import traceback + + +class QwenAnalyzer: + """千问视觉分析器类""" + + def __init__(self, model_name: str = "qwen-vl-max-latest", api_key: str = None): + """ + 初始化千问视觉分析器 + Args: + model_name: 模型名称,默认使用 qwen-vl-max-latest + api_key: 阿里云API密钥 + """ + if not api_key: + raise ValueError("必须提供API密钥") + + self.model_name = model_name + self.api_key = api_key + + # 配置API客户端 + self._configure_client() + + def _configure_client(self): + """配置API客户端""" + self.client = OpenAI( + api_key=self.api_key, + base_url="https://dashscope.aliyuncs.com/compatible-mode/v1" + ) + + def _image_to_base64(self, image: PIL.Image.Image) -> str: + """ + 将PIL图片对象转换为base64字符串 + """ + buffered = io.BytesIO() + image.save(buffered, format="JPEG") + return base64.b64encode(buffered.getvalue()).decode("utf-8") + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10) + ) + async def _generate_content_with_retry(self, prompt: str, batch: List[PIL.Image.Image]): + """使用重试机制的内部方法来调用千问API""" + try: + # 构建消息内容 + content = [] + + # 添加图片 + for img in batch: + base64_image = self._image_to_base64(img) + content.append({ + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}" + } + }) + + # 添加文本提示 + content.append({ + "type": "text", + "text": prompt + }) + + # 调用API + response = await asyncio.to_thread( + self.client.chat.completions.create, + model=self.model_name, + messages=[{ + "role": "user", + "content": content + }] + ) + + return response.choices[0].message.content + + except Exception as e: + logger.error(f"API调用错误: {str(e)}") + raise RetryError("API调用失败") + + async def analyze_images(self, + images: Union[List[str], List[PIL.Image.Image]], + prompt: str, + batch_size: int = 5) -> List[Dict]: + """ + 批量分析多张图片 + Args: + images: 图片路径列表或PIL图片对象列表 + prompt: 分析提示词 + batch_size: 批处理大小 + Returns: + 分析结果列表 + """ + try: + # 保存原始图片路径(如果是路径列表的话) + original_paths = images if isinstance(images[0], str) else None + + # 加载图片 + if isinstance(images[0], str): + logger.info("正在加载图片...") + images = self.load_images(images) + + # 验证图片列表 + if not images: + raise ValueError("图片列表为空") + + # 验证每个图片对象 + valid_images = [] + valid_paths = [] + for i, img in enumerate(images): + if not isinstance(img, PIL.Image.Image): + logger.error(f"无效的图片对象,索引 {i}: {type(img)}") + continue + valid_images.append(img) + if original_paths: + valid_paths.append(original_paths[i]) + + if not valid_images: + raise ValueError("没有有效的图片对象") + + images = valid_images + results = [] + total_batches = (len(images) + batch_size - 1) // batch_size + + with tqdm(total=total_batches, desc="分析进度") as pbar: + for i in range(0, len(images), batch_size): + batch = images[i:i + batch_size] + batch_paths = valid_paths[i:i + batch_size] if valid_paths else None + retry_count = 0 + + while retry_count < 3: + try: + # 在每个批次处理前添加小延迟 + if i > 0: + await asyncio.sleep(2) + + # 确保每个批次的图片都是有效的 + valid_batch = [img for img in batch if isinstance(img, PIL.Image.Image)] + if not valid_batch: + raise ValueError(f"批次 {i // batch_size} 中没有有效的图片") + + response = await self._generate_content_with_retry(prompt, valid_batch) + result_dict = { + 'batch_index': i // batch_size, + 'images_processed': len(valid_batch), + 'response': response, + 'model_used': self.model_name + } + + # 添加图片路径信息(如果有的话) + if batch_paths: + result_dict['image_paths'] = batch_paths + + results.append(result_dict) + break + + except Exception as e: + retry_count += 1 + error_msg = f"批次 {i // batch_size} 处理出错: {str(e)}" + logger.error(error_msg) + + if retry_count >= 3: + results.append({ + 'batch_index': i // batch_size, + 'images_processed': len(batch), + 'error': error_msg, + 'model_used': self.model_name, + 'image_paths': batch_paths if batch_paths else [] + }) + else: + logger.info(f"批次 {i // batch_size} 处理失败,等待60秒后重试当前批次...") + await asyncio.sleep(60) + + pbar.update(1) + + return results + + except Exception as e: + error_msg = f"图片分析过程中发生错误: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + raise Exception(error_msg) + + def save_results_to_txt(self, results: List[Dict], output_dir: str): + """将分析结果保存到txt文件""" + # 确保输出目录存在 + os.makedirs(output_dir, exist_ok=True) + + for i, result in enumerate(results): + response_text = result['response'] + + # 如果有图片路径信息,使用它来生成文件名 + if result.get('image_paths'): + image_paths = result['image_paths'] + img_name_start = Path(image_paths[0]).stem.split('_')[-1] + img_name_end = Path(image_paths[-1]).stem.split('_')[-1] + file_name = f"frame_{img_name_start}_{img_name_end}.txt" + else: + # 如果没有路径信息,使用批次索引 + file_name = f"batch_{result['batch_index']}.txt" + + txt_path = os.path.join(output_dir, file_name) + + # 保存结果到txt文件 + with open(txt_path, 'w', encoding='utf-8') as f: + f.write(response_text.strip()) + logger.info(f"已保存分析结果到: {txt_path}") + + def load_images(self, image_paths: List[str]) -> List[PIL.Image.Image]: + """ + 加载多张图片 + Args: + image_paths: 图片路径列表 + Returns: + 加载后的PIL Image对象列表 + """ + images = [] + failed_images = [] + + for img_path in image_paths: + try: + if not os.path.exists(img_path): + logger.error(f"图片文件不存在: {img_path}") + failed_images.append(img_path) + continue + + img = PIL.Image.open(img_path) + # 确保图片被完全加载 + img.load() + # 转换为RGB模式 + if img.mode != 'RGB': + img = img.convert('RGB') + images.append(img) + + except Exception as e: + logger.error(f"无法加载图片 {img_path}: {str(e)}") + failed_images.append(img_path) + + if failed_images: + logger.warning(f"以下图片加载失败:\n{json.dumps(failed_images, indent=2, ensure_ascii=False)}") + + if not images: + raise ValueError("没有成功加载任何图片") + + return images diff --git a/webui/components/basic_settings.py b/webui/components/basic_settings.py index adeca9e..d7b5144 100644 --- a/webui/components/basic_settings.py +++ b/webui/components/basic_settings.py @@ -52,18 +52,34 @@ def render_language_settings(tr): def render_proxy_settings(tr): """渲染代理设置""" - proxy_url_http = config.proxy.get("http", "") or os.getenv("VPN_PROXY_URL", "") - proxy_url_https = config.proxy.get("https", "") or os.getenv("VPN_PROXY_URL", "") + # 获取当前代理状态 + proxy_enabled = config.proxy.get("enabled", True) + proxy_url_http = config.proxy.get("http") + proxy_url_https = config.proxy.get("https") - HTTP_PROXY = st.text_input(tr("HTTP_PROXY"), value=proxy_url_http) - HTTPS_PROXY = st.text_input(tr("HTTPs_PROXY"), value=proxy_url_https) + # 添加代理开关 + proxy_enabled = st.checkbox(tr("Enable Proxy"), value=proxy_enabled) + + # 保存代理开关状态 + config.proxy["enabled"] = proxy_enabled - if HTTP_PROXY: - config.proxy["http"] = HTTP_PROXY - os.environ["HTTP_PROXY"] = HTTP_PROXY - if HTTPS_PROXY: - config.proxy["https"] = HTTPS_PROXY - os.environ["HTTPS_PROXY"] = HTTPS_PROXY + # 只有在代理启用时才显示代理设置输入框 + if proxy_enabled: + HTTP_PROXY = st.text_input(tr("HTTP_PROXY"), value=proxy_url_http) + HTTPS_PROXY = st.text_input(tr("HTTPs_PROXY"), value=proxy_url_https) + + if HTTP_PROXY: + config.proxy["http"] = HTTP_PROXY + os.environ["HTTP_PROXY"] = HTTP_PROXY + if HTTPS_PROXY: + config.proxy["https"] = HTTPS_PROXY + os.environ["HTTPS_PROXY"] = HTTPS_PROXY + else: + # 当代理被禁用时,清除环境变量和配置 + os.environ.pop("HTTP_PROXY", None) + os.environ.pop("HTTPS_PROXY", None) + config.proxy["http"] = "" + config.proxy["https"] = "" def test_vision_model_connection(api_key, base_url, model_name, provider, tr): @@ -90,6 +106,28 @@ def test_vision_model_connection(api_key, base_url, model_name, provider, tr): except Exception as e: return False, f"{tr('gemini model is not available')}: {str(e)}" + elif provider.lower() == 'qwenvl': + from openai import OpenAI + try: + client = OpenAI( + api_key=api_key, + base_url=base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1" + ) + + # 发送一个简单的测试请求 + response = client.chat.completions.create( + model=model_name or "qwen-vl-max-latest", + messages=[{"role": "user", "content": "直接回复我文本'当前网络可用'"}] + ) + + if response and response.choices: + return True, tr("QwenVL model is available") + else: + return False, tr("QwenVL model returned invalid response") + + except Exception as e: + return False, f"{tr('QwenVL model is not available')}: {str(e)}" + elif provider.lower() == 'narratoapi': import requests try: @@ -116,7 +154,7 @@ def render_vision_llm_settings(tr): st.subheader(tr("Vision Model Settings")) # 视频分析模型提供商选择 - vision_providers = ['Gemini', 'NarratoAPI(待发布)', 'QwenVL(待发布)'] + vision_providers = ['Gemini', 'QwenVL', 'NarratoAPI(待发布)'] saved_vision_provider = config.app.get("vision_llm_provider", "Gemini").lower() saved_provider_index = 0 @@ -142,18 +180,33 @@ def render_vision_llm_settings(tr): # 渲染视觉模型配置输入框 st_vision_api_key = st.text_input(tr("Vision API Key"), value=vision_api_key, type="password") - # 当选择 Gemini 时禁用 base_url 输入 - if vision_provider.lower() == 'gemini': + # 根据不同提供商设置默认值和帮助信息 + if vision_provider == 'gemini': st_vision_base_url = st.text_input( tr("Vision Base URL"), value=vision_base_url, disabled=True, help=tr("Gemini API does not require a base URL") ) + st_vision_model_name = st.text_input( + tr("Vision Model Name"), + value=vision_model_name or "gemini-1.5-flash", + help=tr("Default: gemini-1.5-flash") + ) + elif vision_provider == 'qwenvl': + st_vision_base_url = st.text_input( + tr("Vision Base URL"), + value=vision_base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1", + help=tr("Default: https://dashscope.aliyuncs.com/compatible-mode/v1") + ) + st_vision_model_name = st.text_input( + tr("Vision Model Name"), + value=vision_model_name or "qwen-vl-max-latest", + help=tr("Default: qwen-vl-max-latest") + ) else: st_vision_base_url = st.text_input(tr("Vision Base URL"), value=vision_base_url) - - st_vision_model_name = st.text_input(tr("Vision Model Name"), value=vision_model_name) + st_vision_model_name = st.text_input(tr("Vision Model Name"), value=vision_model_name) # 在配置输入框后添加测试按钮 if st.button(tr("Test Connection"), key="test_vision_connection"): @@ -174,7 +227,7 @@ def render_vision_llm_settings(tr): # 保存视觉模型配置 if st_vision_api_key: config.app[f"vision_{vision_provider}_api_key"] = st_vision_api_key - st.session_state[f"vision_{vision_provider}_api_key"] = st_vision_api_key # 用于script_settings.py + st.session_state[f"vision_{vision_provider}_api_key"] = st_vision_api_key if st_vision_base_url: config.app[f"vision_{vision_provider}_base_url"] = st_vision_base_url st.session_state[f"vision_{vision_provider}_base_url"] = st_vision_base_url @@ -182,81 +235,6 @@ def render_vision_llm_settings(tr): config.app[f"vision_{vision_provider}_model_name"] = st_vision_model_name st.session_state[f"vision_{vision_provider}_model_name"] = st_vision_model_name - # # NarratoAPI 特殊配置 - # if vision_provider == 'narratoapi': - # st.subheader(tr("Narrato Additional Settings")) - # - # # Narrato API 基础配置 - # narrato_api_key = st.text_input( - # tr("Narrato API Key"), - # value=config.app.get("narrato_api_key", ""), - # type="password", - # help="用于访问 Narrato API 的密钥" - # ) - # if narrato_api_key: - # config.app["narrato_api_key"] = narrato_api_key - # st.session_state['narrato_api_key'] = narrato_api_key - # - # narrato_api_url = st.text_input( - # tr("Narrato API URL"), - # value=config.app.get("narrato_api_url", "http://127.0.0.1:8000/api/v1/video/analyze") - # ) - # if narrato_api_url: - # config.app["narrato_api_url"] = narrato_api_url - # st.session_state['narrato_api_url'] = narrato_api_url - # - # # 视频分析模型配置 - # st.markdown("##### " + tr("Vision Model Settings")) - # narrato_vision_model = st.text_input( - # tr("Vision Model Name"), - # value=config.app.get("narrato_vision_model", "gemini-1.5-flash") - # ) - # narrato_vision_key = st.text_input( - # tr("Vision Model API Key"), - # value=config.app.get("narrato_vision_key", ""), - # type="password", - # help="用于视频分析的模 API Key" - # ) - # - # if narrato_vision_model: - # config.app["narrato_vision_model"] = narrato_vision_model - # st.session_state['narrato_vision_model'] = narrato_vision_model - # if narrato_vision_key: - # config.app["narrato_vision_key"] = narrato_vision_key - # st.session_state['narrato_vision_key'] = narrato_vision_key - # - # # 文案生成模型配置 - # st.markdown("##### " + tr("Text Generation Model Settings")) - # narrato_llm_model = st.text_input( - # tr("LLM Model Name"), - # value=config.app.get("narrato_llm_model", "qwen-plus") - # ) - # narrato_llm_key = st.text_input( - # tr("LLM Model API Key"), - # value=config.app.get("narrato_llm_key", ""), - # type="password", - # help="用于文案生成的模型 API Key" - # ) - # - # if narrato_llm_model: - # config.app["narrato_llm_model"] = narrato_llm_model - # st.session_state['narrato_llm_model'] = narrato_llm_model - # if narrato_llm_key: - # config.app["narrato_llm_key"] = narrato_llm_key - # st.session_state['narrato_llm_key'] = narrato_llm_key - # - # # 批处理配置 - # narrato_batch_size = st.number_input( - # tr("Batch Size"), - # min_value=1, - # max_value=50, - # value=config.app.get("narrato_batch_size", 10), - # help="每批处理的图片数量" - # ) - # if narrato_batch_size: - # config.app["narrato_batch_size"] = narrato_batch_size - # st.session_state['narrato_batch_size'] = narrato_batch_size - def test_text_model_connection(api_key, base_url, model_name, provider, tr): """测试文本模型连接 @@ -328,6 +306,7 @@ def test_text_model_connection(api_key, base_url, model_name, provider, tr): except Exception as e: return False, f"{tr('Connection failed')}: {str(e)}" + def render_text_llm_settings(tr): """渲染文案生成模型设置""" st.subheader(tr("Text Generation Model Settings")) diff --git a/webui/components/script_settings.py b/webui/components/script_settings.py index b64979d..4025f68 100644 --- a/webui/components/script_settings.py +++ b/webui/components/script_settings.py @@ -14,7 +14,7 @@ from loguru import logger from app.config import config from app.models.schema import VideoClipParams from app.utils.script_generator import ScriptProcessor -from app.utils import utils, check_script, gemini_analyzer, video_processor, video_processor_v2 +from app.utils import utils, check_script, gemini_analyzer, video_processor, video_processor_v2, qwenvl_analyzer from webui.utils import file_utils @@ -472,267 +472,168 @@ def generate_script(tr, params): vision_llm_provider = st.session_state.get('vision_llm_providers').lower() logger.debug(f"Vision LLM 提供商: {vision_llm_provider}") - if vision_llm_provider == 'gemini': - try: - # ===================初始化视觉分析器=================== - update_progress(30, "正在初始化视觉分析器...") - - # 从配置中获取 Gemini 相关配置 + try: + # ===================初始化视觉分析器=================== + update_progress(30, "正在初始化视觉分析器...") + + # 从配置中获取相关配置 + if vision_llm_provider == 'gemini': vision_api_key = st.session_state.get('vision_gemini_api_key') vision_model = st.session_state.get('vision_gemini_model_name') vision_base_url = st.session_state.get('vision_gemini_base_url') - - if not vision_api_key or not vision_model: - raise ValueError("未配置 Gemini API Key 或者 型,请在基础设置配置") + elif vision_llm_provider == 'qwenvl': + vision_api_key = st.session_state.get('vision_qwenvl_api_key') + vision_model = st.session_state.get('vision_qwenvl_model_name', 'qwen-vl-max-latest') + vision_base_url = st.session_state.get('vision_qwenvl_base_url', 'https://dashscope.aliyuncs.com/compatible-mode/v1') + else: + raise ValueError(f"不支持的视觉分析提供商: {vision_llm_provider}") - analyzer = gemini_analyzer.VisionAnalyzer( - model_name=vision_model, - api_key=vision_api_key, + # 创建视觉分析器实例 + analyzer = create_vision_analyzer( + provider=vision_llm_provider, + api_key=vision_api_key, + model=vision_model, + base_url=vision_base_url + ) + + update_progress(40, "正在分析关键帧...") + + # ===================创建异步事件循环=================== + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # 执行异步分析 + vision_batch_size = st.session_state.get('vision_batch_size') or config.frames.get("vision_batch_size") + results = loop.run_until_complete( + analyzer.analyze_images( + images=keyframe_files, + prompt=config.app.get('vision_analysis_prompt'), + batch_size=vision_batch_size ) + ) + loop.close() - update_progress(40, "正在分析关键帧...") + # ===================处理分析结果=================== + update_progress(60, "正在整理分析结果...") + + # 合并所有批次的析结果 + frame_analysis = "" + prev_batch_files = None - # ===================创建异步事件循环=================== - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - # 执行异步分析 - vision_batch_size = st.session_state.get('vision_batch_size') or config.frames.get("vision_batch_size") - results = loop.run_until_complete( - analyzer.analyze_images( - images=keyframe_files, - prompt=config.app.get('vision_analysis_prompt'), - batch_size=vision_batch_size - ) - ) - loop.close() + for result in results: + if 'error' in result: + logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}") - # ===================处理分析结果=================== - update_progress(60, "正在整理分析结果...") + # 获取当前批次的文件列表 keyframe_001136_000045.jpg 将 000045 精度提升到 毫秒 + batch_files = get_batch_files(keyframe_files, result, vision_batch_size) + logger.debug(f"批次 {result['batch_index']} 处理完成,共 {len(batch_files)} 张图片") + # logger.debug(batch_files) - # 合并所有批次的析结果 - frame_analysis = "" - prev_batch_files = None + first_timestamp, last_timestamp, _ = get_batch_timestamps(batch_files, prev_batch_files) + logger.debug(f"处理时间戳: {first_timestamp}-{last_timestamp}") + + # 添加带时间戳的分析结果 + frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n" + frame_analysis += result['response'] + frame_analysis += "\n" + + # 更新上一个批次的文件 + prev_batch_files = batch_files + + if not frame_analysis.strip(): + raise Exception("未能生成有效的帧分析结果") + + # 保存分析结果 + analysis_path = os.path.join(utils.temp_dir(), "frame_analysis.txt") + with open(analysis_path, 'w', encoding='utf-8') as f: + f.write(frame_analysis) + + update_progress(70, "正在生成脚本...") - for result in results: - if 'error' in result: - logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}") - continue - # 获取当前批次的文件列表 keyframe_001136_000045.jpg 将 000045 精度提升到 毫秒 - batch_files = get_batch_files(keyframe_files, result, vision_batch_size) - logger.debug(f"批次 {result['batch_index']} 处理完成,共 {len(batch_files)} 张图片") - # logger.debug(batch_files) - - first_timestamp, last_timestamp, _ = get_batch_timestamps(batch_files, prev_batch_files) - logger.debug(f"处理时间戳: {first_timestamp}-{last_timestamp}") - - # 添加带时间戳的分析结果 - frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n" - frame_analysis += result['response'] - frame_analysis += "\n" - - # 更新上一个批次的文件 - prev_batch_files = batch_files - - if not frame_analysis.strip(): - raise Exception("未能生成有效的帧分析结果") - - # 保存分析结果 - analysis_path = os.path.join(utils.temp_dir(), "frame_analysis.txt") - with open(analysis_path, 'w', encoding='utf-8') as f: - f.write(frame_analysis) - - update_progress(70, "正在生成脚本...") + # 从配置中获取文本生成相关配置 + text_provider = config.app.get('text_llm_provider', 'gemini').lower() + text_api_key = config.app.get(f'text_{text_provider}_api_key') + text_model = config.app.get(f'text_{text_provider}_model_name') + text_base_url = config.app.get(f'text_{text_provider}_base_url') + + # 构建帧内容列表 + frame_content_list = [] + prev_batch_files = None - # 从配置中获取文本生成相关配置 - text_provider = config.app.get('text_llm_provider', 'gemini').lower() - text_api_key = config.app.get(f'text_{text_provider}_api_key') - text_model = config.app.get(f'text_{text_provider}_model_name') - text_base_url = config.app.get(f'text_{text_provider}_base_url') + for i, result in enumerate(results): + if 'error' in result: + continue - # 构建帧内容列表 - frame_content_list = [] - prev_batch_files = None - - for i, result in enumerate(results): - if 'error' in result: - continue - - batch_files = get_batch_files(keyframe_files, result, vision_batch_size) - _, _, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files) - - frame_content = { - "timestamp": timestamp_range, - "picture": result['response'], - "narration": "", - "OST": 2 - } - frame_content_list.append(frame_content) - - logger.debug(f"添加帧内容: 时间范围={timestamp_range}, 分析结果长度={len(result['response'])}") - - # 更新上一个批次的文件 - prev_batch_files = batch_files + batch_files = get_batch_files(keyframe_files, result, vision_batch_size) + _, _, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files) - if not frame_content_list: - raise Exception("没有有效的帧内容可以处理") - - # ===================开始生成文案=================== - update_progress(80, "正在生成文案...") - # 校验配置 - api_params = { - "vision_api_key": vision_api_key, - "vision_model_name": vision_model, - "vision_base_url": vision_base_url or "", - "text_api_key": text_api_key, - "text_model_name": text_model, - "text_base_url": text_base_url or "" + frame_content = { + "timestamp": timestamp_range, + "picture": result['response'], + "narration": "", + "OST": 2 } - headers = { - 'accept': 'application/json', - 'Content-Type': 'application/json' - } - session = requests.Session() - retry_strategy = Retry( - total=3, - backoff_factor=1, - status_forcelist=[500, 502, 503, 504] - ) - adapter = HTTPAdapter(max_retries=retry_strategy) - session.mount("https://", adapter) - try: - response = session.post( - f"{config.app.get('narrato_api_url')}/video/config", - headers=headers, - json=api_params, - timeout=30, - verify=True - ) - except Exception as e: - pass - custom_prompt = st.session_state.get('custom_prompt', '') - processor = ScriptProcessor( - model_name=text_model, - api_key=text_api_key, - prompt=custom_prompt, - base_url=text_base_url or "", - video_theme=st.session_state.get('video_theme', '') - ) - - # 处理帧内容生成脚本 - script_result = processor.process_frames(frame_content_list) - - # 结果转换为JSON字符串 - script = json.dumps(script_result, ensure_ascii=False, indent=2) + frame_content_list.append(frame_content) - except Exception as e: - logger.exception(f"大模型处理过程中发生错误\n{traceback.format_exc()}") - raise Exception(f"分析失败: {str(e)}") + logger.debug(f"添加帧内容: 时间范围={timestamp_range}, 分析结果长度={len(result['response'])}") + + # 更新上一个批次的文件 + prev_batch_files = batch_files + + if not frame_content_list: + raise Exception("没有有效的帧内容可以处理") - elif vision_llm_provider == 'narratoapi': # NarratoAPI + # ===================开始生成文案=================== + update_progress(80, "正在生成文案...") + # 校验配置 + api_params = { + "vision_api_key": vision_api_key, + "vision_model_name": vision_model, + "vision_base_url": vision_base_url or "", + "text_api_key": text_api_key, + "text_model_name": text_model, + "text_base_url": text_base_url or "" + } + headers = { + 'accept': 'application/json', + 'Content-Type': 'application/json' + } + session = requests.Session() + retry_strategy = Retry( + total=3, + backoff_factor=1, + status_forcelist=[500, 502, 503, 504] + ) + adapter = HTTPAdapter(max_retries=retry_strategy) + session.mount("https://", adapter) try: - # 创建临时目录 - temp_dir = utils.temp_dir("narrato") - - # 打包关键帧 - update_progress(30, "正在打包关键帧...") - zip_path = os.path.join(temp_dir, f"keyframes_{int(time.time())}.zip") - if not file_utils.create_zip(keyframe_files, zip_path): - raise Exception("打包关键帧失败") - - # 获取API配置 - api_url = st.session_state.get('narrato_api_url') - api_key = st.session_state.get('narrato_api_key') - - if not api_key: - raise ValueError("未配置 Narrato API Key,请在基础设置中配置") - - # 准备API请求 - headers = { - 'X-API-Key': api_key, - 'accept': 'application/json' - } - - api_params = { - 'batch_size': st.session_state.get('narrato_batch_size', 10), - 'use_ai': False, - 'start_offset': 0, - 'vision_model': st.session_state.get('narrato_vision_model', 'gemini-1.5-flash'), - 'vision_api_key': st.session_state.get('narrato_vision_key'), - 'llm_model': st.session_state.get('narrato_llm_model', 'qwen-plus'), - 'llm_api_key': st.session_state.get('narrato_llm_key'), - 'custom_prompt': st.session_state.get('custom_prompt', '') - } - - # 发送API请求 - logger.info(f"请求NarratoAPI: {api_url}") - update_progress(40, "正在上传文件...") - with open(zip_path, 'rb') as f: - files = {'file': (os.path.basename(zip_path), f, 'application/x-zip-compressed')} - try: - response = requests.post( - f"{api_url}/video/analyze", - headers=headers, - params=api_params, - files=files, - timeout=30 # 设置超时时间 - ) - response.raise_for_status() - except requests.RequestException as e: - logger.error(f"Narrato API 请求失败:\n{traceback.format_exc()}") - raise Exception(f"API请求失败: {str(e)}") - - task_data = response.json() - task_id = task_data["data"].get('task_id') - if not task_id: - raise Exception(f"无效的API响应: {response.text}") - - # 轮询任务状态 - update_progress(50, "正在等待分析结果...") - retry_count = 0 - max_retries = 60 # 最多等待2分钟 - - while retry_count < max_retries: - try: - status_response = requests.get( - f"{api_url}/video/tasks/{task_id}", - headers=headers, - timeout=10 - ) - status_response.raise_for_status() - task_status = status_response.json()['data'] - - if task_status['status'] == 'SUCCESS': - script = task_status['result']['data'] - break - elif task_status['status'] in ['FAILURE', 'RETRY']: - raise Exception(f"任务失败: {task_status.get('error')}") - - retry_count += 1 - time.sleep(2) - - except requests.RequestException as e: - logger.warning(f"获取任务状态失败,重试中: {str(e)}") - retry_count += 1 - time.sleep(2) - continue - - if retry_count >= max_retries: - raise Exception("任务执行超时") - + response = session.post( + f"{config.app.get('narrato_api_url')}/video/config", + headers=headers, + json=api_params, + timeout=30, + verify=True + ) except Exception as e: - logger.exception(f"NarratoAPI 处理过程中发生错误\n{traceback.format_exc()}") - raise Exception(f"NarratoAPI 处理失败: {str(e)}") - finally: - # 清理临时文件 - try: - if os.path.exists(zip_path): - os.remove(zip_path) - except Exception as e: - logger.warning(f"清理临时文件失败: {str(e)}") + pass + custom_prompt = st.session_state.get('custom_prompt', '') + processor = ScriptProcessor( + model_name=text_model, + api_key=text_api_key, + prompt=custom_prompt, + base_url=text_base_url or "", + video_theme=st.session_state.get('video_theme', '') + ) - else: - logger.exception("Vision Model 未启用,请检查配置") + # 处理帧内容生成脚本 + script_result = processor.process_frames(frame_content_list) + + # 结果转换为JSON字符串 + script = json.dumps(script_result, ensure_ascii=False, indent=2) + + except Exception as e: + logger.exception(f"大模型处理过程中发生错误\n{traceback.format_exc()}") + raise Exception(f"分析失败: {str(e)}") if script is None: st.error("生成脚本失败,请检查日志") @@ -823,3 +724,12 @@ def get_script_params(): 'video_name': st.session_state.get('video_name', ''), 'video_plot': st.session_state.get('video_plot', '') } + + +def create_vision_analyzer(provider, api_key, model, base_url): + if provider == 'gemini': + return gemini_analyzer.VisionAnalyzer(model_name=model, api_key=api_key) + elif provider == 'qwenvl': + return qwenvl_analyzer.QwenAnalyzer(model_name=model, api_key=api_key) + else: + raise ValueError(f"不支持的视觉分析提供商: {provider}") diff --git a/webui/i18n/zh.json b/webui/i18n/zh.json index 0e62ce0..3e4398e 100644 --- a/webui/i18n/zh.json +++ b/webui/i18n/zh.json @@ -163,6 +163,9 @@ "Error processing subtitle files. Please check if the subtitles are valid SRT files.": "处理字幕文件时出错。请检查字幕是否为有效的SRT文件。", "Preview Merged Video": "预览合并后的视频", "Video Path": "视频路径", - "Subtitle Path": "字幕路径" + "Subtitle Path": "字幕路径", + "Enable Proxy": "启用代理", + "QwenVL model is available": "QwenVL 模型可用", + "QwenVL model is not available": "QwenVL 模型不可用" } } diff --git a/webui/utils/vision_analyzer.py b/webui/utils/vision_analyzer.py index e69de29..3e0fecd 100644 --- a/webui/utils/vision_analyzer.py +++ b/webui/utils/vision_analyzer.py @@ -0,0 +1,100 @@ +import logging +from typing import List, Dict, Any, Optional +from app.utils import gemini_analyzer, qwenvl_analyzer + +logger = logging.getLogger(__name__) + +class VisionAnalyzer: + def __init__(self): + self.provider = None + self.api_key = None + self.model = None + self.base_url = None + self.analyzer = None + + def initialize_gemini(self, api_key: str, model: str, base_url: str) -> None: + """ + 初始化Gemini视觉分析器 + + Args: + api_key: Gemini API密钥 + model: 模型名称 + base_url: API基础URL + """ + self.provider = 'gemini' + self.api_key = api_key + self.model = model + self.base_url = base_url + self.analyzer = gemini_analyzer.VisionAnalyzer( + model_name=model, + api_key=api_key + ) + + def initialize_qwenvl(self, api_key: str, model: str, base_url: str) -> None: + """ + 初始化QwenVL视觉分析器 + + Args: + api_key: 阿里云API密钥 + model: 模型名称 + base_url: API基础URL + """ + self.provider = 'qwenvl' + self.api_key = api_key + self.model = model + self.base_url = base_url + self.analyzer = qwenvl_analyzer.QwenAnalyzer( + model_name=model, + api_key=api_key + ) + + async def analyze_images(self, images: List[str], prompt: str, batch_size: int = 5) -> Dict[str, Any]: + """ + 分析图片内容 + + Args: + images: 图片路径列表 + prompt: 分析提示词 + batch_size: 每批处理的图片数量,默认为5 + + Returns: + Dict: 分析结果 + """ + if not self.analyzer: + raise ValueError("未初始化视觉分析器") + + return await self.analyzer.analyze_images( + images=images, + prompt=prompt, + batch_size=batch_size + ) + +def create_vision_analyzer(provider: str, **kwargs) -> VisionAnalyzer: + """ + 创建视觉分析器实例 + + Args: + provider: 提供商名称 ('gemini' 或 'qwenvl') + **kwargs: 提供商特定的配置参数 + + Returns: + VisionAnalyzer: 配置好的视觉分析器实例 + """ + analyzer = VisionAnalyzer() + + if provider.lower() == 'gemini': + analyzer.initialize_gemini( + api_key=kwargs.get('api_key'), + model=kwargs.get('model'), + base_url=kwargs.get('base_url') + ) + elif provider.lower() == 'qwenvl': + analyzer.initialize_qwenvl( + api_key=kwargs.get('api_key'), + model=kwargs.get('model'), + base_url=kwargs.get('base_url') + ) + else: + raise ValueError(f"不支持的视觉分析提供商: {provider}") + + return analyzer \ No newline at end of file