diff --git a/webui/components/script_settings.py b/webui/components/script_settings.py index 4025f68..23a0c53 100644 --- a/webui/components/script_settings.py +++ b/webui/components/script_settings.py @@ -1,120 +1,13 @@ import os -import ssl import glob import json import time -import asyncio -import traceback -from requests.adapters import HTTPAdapter -from urllib3.util.retry import Retry -import requests import streamlit as st -from loguru import logger from app.config import config from app.models.schema import VideoClipParams -from app.utils.script_generator import ScriptProcessor -from app.utils import utils, check_script, gemini_analyzer, video_processor, video_processor_v2, qwenvl_analyzer -from webui.utils import file_utils - - -def get_batch_timestamps(batch_files, prev_batch_files=None): - """ - 解析一批文件的时间戳范围,支持毫秒级精度 - - Args: - batch_files: 当前批次的文件列表 - prev_batch_files: 上一个批次的文件列表,用于处理单张图片的情况 - - Returns: - tuple: (first_timestamp, last_timestamp, timestamp_range) - 时间戳格式: HH:MM:SS,mmm (时:分:秒,毫秒) - 例如: 00:00:50,100 表示50秒100毫秒 - - 示例文件名格式: - keyframe_001253_000050100.jpg - 其中 000050100 表示 00:00:50,100 (50秒100毫秒) - """ - if not batch_files: - logger.warning("Empty batch files") - return "00:00:00,000", "00:00:00,000", "00:00:00,000-00:00:00,000" - - def get_frame_files(): - """获取首帧和尾帧文件名""" - if len(batch_files) == 1 and prev_batch_files and prev_batch_files: - # 单张图片情况:使用上一批次最后一帧作为首帧 - first = os.path.basename(prev_batch_files[-1]) - last = os.path.basename(batch_files[0]) - logger.debug(f"单张图片批次,使用上一批次最后一帧作为首帧: {first}") - else: - first = os.path.basename(batch_files[0]) - last = os.path.basename(batch_files[-1]) - return first, last - - def extract_time(filename): - """从文件名提取时间信息""" - try: - # 提取类似 000050100 的时间戳部分 - time_str = filename.split('_')[2].replace('.jpg', '') - if len(time_str) < 9: # 处理旧格式 - time_str = time_str.ljust(9, '0') - return time_str - except (IndexError, AttributeError) as e: - logger.warning(f"Invalid filename format: {filename}, error: {e}") - return "000000000" - - def format_timestamp(time_str): - """ - 将时间字符串转换为 HH:MM:SS,mmm 格式 - - Args: - time_str: 9位数字字符串,格式为 HHMMSSMMM - 例如: 000010000 表示 00时00分10秒000毫秒 - 000043039 表示 00时00分43秒039毫秒 - - Returns: - str: HH:MM:SS,mmm 格式的时间戳 - """ - try: - if len(time_str) < 9: - logger.warning(f"Invalid timestamp format: {time_str}") - return "00:00:00,000" - - # 从时间戳中提取时、分、秒和毫秒 - hours = int(time_str[0:2]) # 前2位作为小时 - minutes = int(time_str[2:4]) # 第3-4位作为分钟 - seconds = int(time_str[4:6]) # 第5-6位作为秒数 - milliseconds = int(time_str[6:]) # 最后3位作为毫秒 - - return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}" - - except ValueError as e: - logger.warning(f"时间戳格式转换失败: {time_str}, error: {e}") - return "00:00:00,000" - - # 获取首帧和尾帧文件名 - first_frame, last_frame = get_frame_files() - - # 从文件名中提取时间信息 - first_time = extract_time(first_frame) - last_time = extract_time(last_frame) - - # 转换为标准时间戳格式 - first_timestamp = format_timestamp(first_time) - last_timestamp = format_timestamp(last_time) - timestamp_range = f"{first_timestamp}-{last_timestamp}" - - # logger.debug(f"解析时间戳: {first_frame} -> {first_timestamp}, {last_frame} -> {last_timestamp}") - return first_timestamp, last_timestamp, timestamp_range - - -def get_batch_files(keyframe_files, result, batch_size=5): - """ - 获取当前批次的图片文件 - """ - batch_start = result['batch_index'] * batch_size - batch_end = min(batch_start + batch_size, len(keyframe_files)) - return keyframe_files[batch_start:batch_end] +from app.utils import utils, check_script +from webui.tools.generate_script_docu import generate_script_docu def render_script_panel(tr): @@ -330,7 +223,7 @@ def render_script_buttons(tr, params): if st.button(button_name, key="script_action", disabled=not script_path): if script_path == "auto": - generate_script(tr, params) + generate_script_docu(tr, params) else: load_script(tr, script_path) @@ -385,280 +278,6 @@ def load_script(tr, script_path): st.error(f"{tr('Failed to load script')}: {str(e)}") -def generate_script(tr, params): - """生成视频脚本""" - progress_bar = st.progress(0) - status_text = st.empty() - - def update_progress(progress: float, message: str = ""): - progress_bar.progress(progress) - if message: - status_text.text(f"{progress}% - {message}") - else: - status_text.text(f"进度: {progress}%") - - try: - with st.spinner("正在生成脚本..."): - if not params.video_origin_path: - st.error("请先选择视频文件") - return - - # ===================提取键帧=================== - update_progress(10, "正在提取关键帧...") - - # 创建临时目录用于存储关键帧 - keyframes_dir = os.path.join(utils.temp_dir(), "keyframes") - video_hash = utils.md5(params.video_origin_path + str(os.path.getmtime(params.video_origin_path))) - video_keyframes_dir = os.path.join(keyframes_dir, video_hash) - - # 检查是否已经提取过关键帧 - keyframe_files = [] - if os.path.exists(video_keyframes_dir): - # 取已有的关键帧文件 - for filename in sorted(os.listdir(video_keyframes_dir)): - if filename.endswith('.jpg'): - keyframe_files.append(os.path.join(video_keyframes_dir, filename)) - - if keyframe_files: - logger.info(f"使用已缓存的关键帧: {video_keyframes_dir}") - st.info(f"使用已缓存的关键帧,如需重新提取请删除目录: {video_keyframes_dir}") - update_progress(20, f"使用已缓存关键帧,共 {len(keyframe_files)} 帧") - - # 如果没有缓存的关键帧,则进行提取 - if not keyframe_files: - try: - # 确保目录存在 - os.makedirs(video_keyframes_dir, exist_ok=True) - - # 初始化视频处理器 - if config.frames.get("version") == "v2": - processor = video_processor_v2.VideoProcessor(params.video_origin_path) - # 处理视频并提取关键帧 - processor.process_video_pipeline( - output_dir=video_keyframes_dir, - skip_seconds=st.session_state.get('skip_seconds'), - threshold=st.session_state.get('threshold') - ) - else: - processor = video_processor.VideoProcessor(params.video_origin_path) - # 处理视频并提取关键帧 - processor.process_video( - output_dir=video_keyframes_dir, - skip_seconds=0 - ) - - # 获取所有关键文件路径 - for filename in sorted(os.listdir(video_keyframes_dir)): - if filename.endswith('.jpg'): - keyframe_files.append(os.path.join(video_keyframes_dir, filename)) - - if not keyframe_files: - raise Exception("未提取到任何关键帧") - - update_progress(20, f"关键帧提取完成,共 {len(keyframe_files)} 帧") - - except Exception as e: - # 如果提取失败,清理创建的目录 - try: - if os.path.exists(video_keyframes_dir): - import shutil - shutil.rmtree(video_keyframes_dir) - except Exception as cleanup_err: - logger.error(f"清理失败的关键帧目录时出错: {cleanup_err}") - - raise Exception(f"关键帧提取失败: {str(e)}") - - # 根据不同的 LLM 提供商处理 - vision_llm_provider = st.session_state.get('vision_llm_providers').lower() - logger.debug(f"Vision LLM 提供商: {vision_llm_provider}") - - try: - # ===================初始化视觉分析器=================== - update_progress(30, "正在初始化视觉分析器...") - - # 从配置中获取相关配置 - if vision_llm_provider == 'gemini': - vision_api_key = st.session_state.get('vision_gemini_api_key') - vision_model = st.session_state.get('vision_gemini_model_name') - vision_base_url = st.session_state.get('vision_gemini_base_url') - elif vision_llm_provider == 'qwenvl': - vision_api_key = st.session_state.get('vision_qwenvl_api_key') - vision_model = st.session_state.get('vision_qwenvl_model_name', 'qwen-vl-max-latest') - vision_base_url = st.session_state.get('vision_qwenvl_base_url', 'https://dashscope.aliyuncs.com/compatible-mode/v1') - else: - raise ValueError(f"不支持的视觉分析提供商: {vision_llm_provider}") - - # 创建视觉分析器实例 - analyzer = create_vision_analyzer( - provider=vision_llm_provider, - api_key=vision_api_key, - model=vision_model, - base_url=vision_base_url - ) - - update_progress(40, "正在分析关键帧...") - - # ===================创建异步事件循环=================== - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - # 执行异步分析 - vision_batch_size = st.session_state.get('vision_batch_size') or config.frames.get("vision_batch_size") - results = loop.run_until_complete( - analyzer.analyze_images( - images=keyframe_files, - prompt=config.app.get('vision_analysis_prompt'), - batch_size=vision_batch_size - ) - ) - loop.close() - - # ===================处理分析结果=================== - update_progress(60, "正在整理分析结果...") - - # 合并所有批次的析结果 - frame_analysis = "" - prev_batch_files = None - - for result in results: - if 'error' in result: - logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}") - - # 获取当前批次的文件列表 keyframe_001136_000045.jpg 将 000045 精度提升到 毫秒 - batch_files = get_batch_files(keyframe_files, result, vision_batch_size) - logger.debug(f"批次 {result['batch_index']} 处理完成,共 {len(batch_files)} 张图片") - # logger.debug(batch_files) - - first_timestamp, last_timestamp, _ = get_batch_timestamps(batch_files, prev_batch_files) - logger.debug(f"处理时间戳: {first_timestamp}-{last_timestamp}") - - # 添加带时间戳的分析结果 - frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n" - frame_analysis += result['response'] - frame_analysis += "\n" - - # 更新上一个批次的文件 - prev_batch_files = batch_files - - if not frame_analysis.strip(): - raise Exception("未能生成有效的帧分析结果") - - # 保存分析结果 - analysis_path = os.path.join(utils.temp_dir(), "frame_analysis.txt") - with open(analysis_path, 'w', encoding='utf-8') as f: - f.write(frame_analysis) - - update_progress(70, "正在生成脚本...") - - # 从配置中获取文本生成相关配置 - text_provider = config.app.get('text_llm_provider', 'gemini').lower() - text_api_key = config.app.get(f'text_{text_provider}_api_key') - text_model = config.app.get(f'text_{text_provider}_model_name') - text_base_url = config.app.get(f'text_{text_provider}_base_url') - - # 构建帧内容列表 - frame_content_list = [] - prev_batch_files = None - - for i, result in enumerate(results): - if 'error' in result: - continue - - batch_files = get_batch_files(keyframe_files, result, vision_batch_size) - _, _, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files) - - frame_content = { - "timestamp": timestamp_range, - "picture": result['response'], - "narration": "", - "OST": 2 - } - frame_content_list.append(frame_content) - - logger.debug(f"添加帧内容: 时间范围={timestamp_range}, 分析结果长度={len(result['response'])}") - - # 更新上一个批次的文件 - prev_batch_files = batch_files - - if not frame_content_list: - raise Exception("没有有效的帧内容可以处理") - - # ===================开始生成文案=================== - update_progress(80, "正在生成文案...") - # 校验配置 - api_params = { - "vision_api_key": vision_api_key, - "vision_model_name": vision_model, - "vision_base_url": vision_base_url or "", - "text_api_key": text_api_key, - "text_model_name": text_model, - "text_base_url": text_base_url or "" - } - headers = { - 'accept': 'application/json', - 'Content-Type': 'application/json' - } - session = requests.Session() - retry_strategy = Retry( - total=3, - backoff_factor=1, - status_forcelist=[500, 502, 503, 504] - ) - adapter = HTTPAdapter(max_retries=retry_strategy) - session.mount("https://", adapter) - try: - response = session.post( - f"{config.app.get('narrato_api_url')}/video/config", - headers=headers, - json=api_params, - timeout=30, - verify=True - ) - except Exception as e: - pass - custom_prompt = st.session_state.get('custom_prompt', '') - processor = ScriptProcessor( - model_name=text_model, - api_key=text_api_key, - prompt=custom_prompt, - base_url=text_base_url or "", - video_theme=st.session_state.get('video_theme', '') - ) - - # 处理帧内容生成脚本 - script_result = processor.process_frames(frame_content_list) - - # 结果转换为JSON字符串 - script = json.dumps(script_result, ensure_ascii=False, indent=2) - - except Exception as e: - logger.exception(f"大模型处理过程中发生错误\n{traceback.format_exc()}") - raise Exception(f"分析失败: {str(e)}") - - if script is None: - st.error("生成脚本失败,请检查日志") - st.stop() - logger.info(f"脚本生成完成") - if isinstance(script, list): - st.session_state['video_clip_json'] = script - elif isinstance(script, str): - st.session_state['video_clip_json'] = json.loads(script) - update_progress(80, "脚本生成完成") - - time.sleep(0.1) - progress_bar.progress(100) - status_text.text("脚本生成完成!") - st.success("视频脚本生成成功!") - - except Exception as err: - st.error(f"生成过程中发生错误: {str(err)}") - logger.exception(f"生成脚本时发生错误\n{traceback.format_exc()}") - finally: - time.sleep(2) - progress_bar.empty() - status_text.empty() - - def save_script(tr, video_clip_json_details): """保存视频脚本""" if not video_clip_json_details: @@ -713,23 +332,3 @@ def crop_video(tr, params): time.sleep(2) progress_bar.empty() status_text.empty() - - -def get_script_params(): - """获取脚本参数""" - return { - 'video_language': st.session_state.get('video_language', ''), - 'video_clip_json_path': st.session_state.get('video_clip_json_path', ''), - 'video_origin_path': st.session_state.get('video_origin_path', ''), - 'video_name': st.session_state.get('video_name', ''), - 'video_plot': st.session_state.get('video_plot', '') - } - - -def create_vision_analyzer(provider, api_key, model, base_url): - if provider == 'gemini': - return gemini_analyzer.VisionAnalyzer(model_name=model, api_key=api_key) - elif provider == 'qwenvl': - return qwenvl_analyzer.QwenAnalyzer(model_name=model, api_key=api_key) - else: - raise ValueError(f"不支持的视觉分析提供商: {provider}") diff --git a/webui/tools/base.py b/webui/tools/base.py new file mode 100644 index 0000000..4148d34 --- /dev/null +++ b/webui/tools/base.py @@ -0,0 +1,124 @@ +import os +import streamlit as st +from loguru import logger + +from app.utils import gemini_analyzer, qwenvl_analyzer + + +def create_vision_analyzer(provider, api_key, model, base_url): + if provider == 'gemini': + return gemini_analyzer.VisionAnalyzer(model_name=model, api_key=api_key) + elif provider == 'qwenvl': + return qwenvl_analyzer.QwenAnalyzer(model_name=model, api_key=api_key) + else: + raise ValueError(f"不支持的视觉分析提供商: {provider}") + + +def get_script_params(): + """获取脚本参数""" + return { + 'video_language': st.session_state.get('video_language', ''), + 'video_clip_json_path': st.session_state.get('video_clip_json_path', ''), + 'video_origin_path': st.session_state.get('video_origin_path', ''), + 'video_name': st.session_state.get('video_name', ''), + 'video_plot': st.session_state.get('video_plot', '') + } + + +def get_batch_timestamps(batch_files, prev_batch_files=None): + """ + 解析一批文件的时间戳范围,支持毫秒级精度 + + Args: + batch_files: 当前批次的文件列表 + prev_batch_files: 上一个批次的文件列表,用于处理单张图片的情况 + + Returns: + tuple: (first_timestamp, last_timestamp, timestamp_range) + 时间戳格式: HH:MM:SS,mmm (时:分:秒,毫秒) + 例如: 00:00:50,100 表示50秒100毫秒 + + 示例文件名格式: + keyframe_001253_000050100.jpg + 其中 000050100 表示 00:00:50,100 (50秒100毫秒) + """ + if not batch_files: + logger.warning("Empty batch files") + return "00:00:00,000", "00:00:00,000", "00:00:00,000-00:00:00,000" + + def get_frame_files(): + """获取首帧和尾帧文件名""" + if len(batch_files) == 1 and prev_batch_files and prev_batch_files: + # 单张图片情况:使用上一批次最后一帧作为首帧 + first = os.path.basename(prev_batch_files[-1]) + last = os.path.basename(batch_files[0]) + logger.debug(f"单张图片批次,使用上一批次最后一帧作为首帧: {first}") + else: + first = os.path.basename(batch_files[0]) + last = os.path.basename(batch_files[-1]) + return first, last + + def extract_time(filename): + """从文件名提取时间信息""" + try: + # 提取类似 000050100 的时间戳部分 + time_str = filename.split('_')[2].replace('.jpg', '') + if len(time_str) < 9: # 处理旧格式 + time_str = time_str.ljust(9, '0') + return time_str + except (IndexError, AttributeError) as e: + logger.warning(f"Invalid filename format: {filename}, error: {e}") + return "000000000" + + def format_timestamp(time_str): + """ + 将时间字符串转换为 HH:MM:SS,mmm 格式 + + Args: + time_str: 9位数字字符串,格式为 HHMMSSMMM + 例如: 000010000 表示 00时00分10秒000毫秒 + 000043039 表示 00时00分43秒039毫秒 + + Returns: + str: HH:MM:SS,mmm 格式的时间戳 + """ + try: + if len(time_str) < 9: + logger.warning(f"Invalid timestamp format: {time_str}") + return "00:00:00,000" + + # 从时间戳中提取时、分、秒和毫秒 + hours = int(time_str[0:2]) # 前2位作为小时 + minutes = int(time_str[2:4]) # 第3-4位作为分钟 + seconds = int(time_str[4:6]) # 第5-6位作为秒数 + milliseconds = int(time_str[6:]) # 最后3位作为毫秒 + + return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}" + + except ValueError as e: + logger.warning(f"时间戳格式转换失败: {time_str}, error: {e}") + return "00:00:00,000" + + # 获取首帧和尾帧文件名 + first_frame, last_frame = get_frame_files() + + # 从文件名中提取时间信息 + first_time = extract_time(first_frame) + last_time = extract_time(last_frame) + + # 转换为标准时间戳格式 + first_timestamp = format_timestamp(first_time) + last_timestamp = format_timestamp(last_time) + timestamp_range = f"{first_timestamp}-{last_timestamp}" + + # logger.debug(f"解析时间戳: {first_frame} -> {first_timestamp}, {last_frame} -> {last_timestamp}") + return first_timestamp, last_timestamp, timestamp_range + + +def get_batch_files(keyframe_files, result, batch_size=5): + """ + 获取当前批次的图片文件 + """ + batch_start = result['batch_index'] * batch_size + batch_end = min(batch_start + batch_size, len(keyframe_files)) + return keyframe_files[batch_start:batch_end] diff --git a/webui/tools/generate_script_docu.py b/webui/tools/generate_script_docu.py new file mode 100644 index 0000000..2c72500 --- /dev/null +++ b/webui/tools/generate_script_docu.py @@ -0,0 +1,293 @@ +# 纪录片脚本生成 +import os +import json +import time +import asyncio +import traceback +import requests +import streamlit as st +from loguru import logger +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry + +from app.config import config +from app.utils.script_generator import ScriptProcessor +from app.utils import utils, video_processor, video_processor_v2, qwenvl_analyzer +from webui.tools.base import create_vision_analyzer, get_batch_files, get_batch_timestamps + + +def generate_script_docu(tr, params): + """ + 生成 纪录片 视频脚本 + """ + progress_bar = st.progress(0) + status_text = st.empty() + + def update_progress(progress: float, message: str = ""): + progress_bar.progress(progress) + if message: + status_text.text(f"{progress}% - {message}") + else: + status_text.text(f"进度: {progress}%") + + try: + with st.spinner("正在生成脚本..."): + if not params.video_origin_path: + st.error("请先选择视频文件") + return + + # ===================提取键帧=================== + update_progress(10, "正在提取关键帧...") + + # 创建临时目录用于存储关键帧 + keyframes_dir = os.path.join(utils.temp_dir(), "keyframes") + video_hash = utils.md5(params.video_origin_path + str(os.path.getmtime(params.video_origin_path))) + video_keyframes_dir = os.path.join(keyframes_dir, video_hash) + + # 检查是否已经提取过关键帧 + keyframe_files = [] + if os.path.exists(video_keyframes_dir): + # 取已有的关键帧文件 + for filename in sorted(os.listdir(video_keyframes_dir)): + if filename.endswith('.jpg'): + keyframe_files.append(os.path.join(video_keyframes_dir, filename)) + + if keyframe_files: + logger.info(f"使用已缓存的关键帧: {video_keyframes_dir}") + st.info(f"使用已缓存的关键帧,如需重新提取请删除目录: {video_keyframes_dir}") + update_progress(20, f"使用已缓存关键帧,共 {len(keyframe_files)} 帧") + + # 如果没有缓存的关键帧,则进行提取 + if not keyframe_files: + try: + # 确保目录存在 + os.makedirs(video_keyframes_dir, exist_ok=True) + + # 初始化视频处理器 + if config.frames.get("version") == "v2": + processor = video_processor_v2.VideoProcessor(params.video_origin_path) + # 处理视频并提取关键帧 + processor.process_video_pipeline( + output_dir=video_keyframes_dir, + skip_seconds=st.session_state.get('skip_seconds'), + threshold=st.session_state.get('threshold') + ) + else: + processor = video_processor.VideoProcessor(params.video_origin_path) + # 处理视频并提取关键帧 + processor.process_video( + output_dir=video_keyframes_dir, + skip_seconds=0 + ) + + # 获取所有关键文件路径 + for filename in sorted(os.listdir(video_keyframes_dir)): + if filename.endswith('.jpg'): + keyframe_files.append(os.path.join(video_keyframes_dir, filename)) + + if not keyframe_files: + raise Exception("未提取到任何关键帧") + + update_progress(20, f"关键帧提取完成,共 {len(keyframe_files)} 帧") + + except Exception as e: + # 如果提取失败,清理创建的目录 + try: + if os.path.exists(video_keyframes_dir): + import shutil + shutil.rmtree(video_keyframes_dir) + except Exception as cleanup_err: + logger.error(f"清理失败的关键帧目录时出错: {cleanup_err}") + + raise Exception(f"关键帧提取失败: {str(e)}") + + # 根据不同的 LLM 提供商处理 + vision_llm_provider = st.session_state.get('vision_llm_providers').lower() + logger.debug(f"Vision LLM 提供商: {vision_llm_provider}") + + try: + # ===================初始化视觉分析器=================== + update_progress(30, "正在初始化视觉分析器...") + + # 从配置中获取相关配置 + if vision_llm_provider == 'gemini': + vision_api_key = st.session_state.get('vision_gemini_api_key') + vision_model = st.session_state.get('vision_gemini_model_name') + vision_base_url = st.session_state.get('vision_gemini_base_url') + elif vision_llm_provider == 'qwenvl': + vision_api_key = st.session_state.get('vision_qwenvl_api_key') + vision_model = st.session_state.get('vision_qwenvl_model_name', 'qwen-vl-max-latest') + vision_base_url = st.session_state.get('vision_qwenvl_base_url', + 'https://dashscope.aliyuncs.com/compatible-mode/v1') + else: + raise ValueError(f"不支持的视觉分析提供商: {vision_llm_provider}") + + # 创建视觉分析器实例 + analyzer = create_vision_analyzer( + provider=vision_llm_provider, + api_key=vision_api_key, + model=vision_model, + base_url=vision_base_url + ) + + update_progress(40, "正在分析关键帧...") + + # ===================创建异步事件循环=================== + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # 执行异步分析 + vision_batch_size = st.session_state.get('vision_batch_size') or config.frames.get("vision_batch_size") + results = loop.run_until_complete( + analyzer.analyze_images( + images=keyframe_files, + prompt=config.app.get('vision_analysis_prompt'), + batch_size=vision_batch_size + ) + ) + loop.close() + + # ===================处理分析结果=================== + update_progress(60, "正在整理分析结果...") + + # 合并所有批次的析结果 + frame_analysis = "" + prev_batch_files = None + + for result in results: + if 'error' in result: + logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}") + + # 获取当前批次的文件列表 keyframe_001136_000045.jpg 将 000045 精度提升到 毫秒 + batch_files = get_batch_files(keyframe_files, result, vision_batch_size) + logger.debug(f"批次 {result['batch_index']} 处理完成,共 {len(batch_files)} 张图片") + # logger.debug(batch_files) + + first_timestamp, last_timestamp, _ = get_batch_timestamps(batch_files, prev_batch_files) + logger.debug(f"处理时间戳: {first_timestamp}-{last_timestamp}") + + # 添加带时间戳的分析结果 + frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n" + frame_analysis += result['response'] + frame_analysis += "\n" + + # 更新上一个批次的文件 + prev_batch_files = batch_files + + if not frame_analysis.strip(): + raise Exception("未能生成有效的帧分析结果") + + # 保存分析结果 + analysis_path = os.path.join(utils.temp_dir(), "frame_analysis.txt") + with open(analysis_path, 'w', encoding='utf-8') as f: + f.write(frame_analysis) + + update_progress(70, "正在生成脚本...") + + # 从配置中获取文本生成相关配置 + text_provider = config.app.get('text_llm_provider', 'gemini').lower() + text_api_key = config.app.get(f'text_{text_provider}_api_key') + text_model = config.app.get(f'text_{text_provider}_model_name') + text_base_url = config.app.get(f'text_{text_provider}_base_url') + + # 构建帧内容列表 + frame_content_list = [] + prev_batch_files = None + + for i, result in enumerate(results): + if 'error' in result: + continue + + batch_files = get_batch_files(keyframe_files, result, vision_batch_size) + _, _, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files) + + frame_content = { + "timestamp": timestamp_range, + "picture": result['response'], + "narration": "", + "OST": 2 + } + frame_content_list.append(frame_content) + + logger.debug(f"添加帧内容: 时间范围={timestamp_range}, 分析结果长度={len(result['response'])}") + + # 更新上一个批次的文件 + prev_batch_files = batch_files + + if not frame_content_list: + raise Exception("没有有效的帧内容可以处理") + + # ===================开始生成文案=================== + update_progress(80, "正在生成文案...") + # 校验配置 + api_params = { + "vision_api_key": vision_api_key, + "vision_model_name": vision_model, + "vision_base_url": vision_base_url or "", + "text_api_key": text_api_key, + "text_model_name": text_model, + "text_base_url": text_base_url or "" + } + headers = { + 'accept': 'application/json', + 'Content-Type': 'application/json' + } + session = requests.Session() + retry_strategy = Retry( + total=3, + backoff_factor=1, + status_forcelist=[500, 502, 503, 504] + ) + adapter = HTTPAdapter(max_retries=retry_strategy) + session.mount("https://", adapter) + try: + response = session.post( + f"{config.app.get('narrato_api_url')}/video/config", + headers=headers, + json=api_params, + timeout=30, + verify=True + ) + except Exception as e: + pass + custom_prompt = st.session_state.get('custom_prompt', '') + processor = ScriptProcessor( + model_name=text_model, + api_key=text_api_key, + prompt=custom_prompt, + base_url=text_base_url or "", + video_theme=st.session_state.get('video_theme', '') + ) + + # 处理帧内容生成脚本 + script_result = processor.process_frames(frame_content_list) + + # 结果转换为JSON字符串 + script = json.dumps(script_result, ensure_ascii=False, indent=2) + + except Exception as e: + logger.exception(f"大模型处理过程中发生错误\n{traceback.format_exc()}") + raise Exception(f"分析失败: {str(e)}") + + if script is None: + st.error("生成脚本失败,请检查日志") + st.stop() + logger.info(f"脚本生成完成") + if isinstance(script, list): + st.session_state['video_clip_json'] = script + elif isinstance(script, str): + st.session_state['video_clip_json'] = json.loads(script) + update_progress(80, "脚本生成完成") + + time.sleep(0.1) + progress_bar.progress(100) + status_text.text("脚本生成完成!") + st.success("视频脚本生成成功!") + + except Exception as err: + st.error(f"生成过程中发生错误: {str(err)}") + logger.exception(f"生成脚本时发生错误\n{traceback.format_exc()}") + finally: + time.sleep(2) + progress_bar.empty() + status_text.empty()