refactor(webui): 重构视频脚本生成-目录结果

- 将视频脚本生成相关代码从 script_settings.py 移动到新的 generate_script_docu.py 文件
- 新增 base.py 文件,提取公共工具函数
- 优化代码结构,提高可维护性和可读性- 重构函数名称,更清晰地反映功能
This commit is contained in:
linyq 2024-12-06 18:18:23 +08:00
parent 65d5a681ac
commit d2f724217c
3 changed files with 420 additions and 404 deletions

View File

@ -1,120 +1,13 @@
import os
import ssl
import glob
import json
import time
import asyncio
import traceback
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
import requests
import streamlit as st
from loguru import logger
from app.config import config
from app.models.schema import VideoClipParams
from app.utils.script_generator import ScriptProcessor
from app.utils import utils, check_script, gemini_analyzer, video_processor, video_processor_v2, qwenvl_analyzer
from webui.utils import file_utils
def get_batch_timestamps(batch_files, prev_batch_files=None):
"""
解析一批文件的时间戳范围,支持毫秒级精度
Args:
batch_files: 当前批次的文件列表
prev_batch_files: 上一个批次的文件列表,用于处理单张图片的情况
Returns:
tuple: (first_timestamp, last_timestamp, timestamp_range)
时间戳格式: HH:MM:SS,mmm (::,毫秒)
例如: 00:00:50,100 表示50秒100毫秒
示例文件名格式:
keyframe_001253_000050100.jpg
其中 000050100 表示 00:00:50,100 (50秒100毫秒)
"""
if not batch_files:
logger.warning("Empty batch files")
return "00:00:00,000", "00:00:00,000", "00:00:00,000-00:00:00,000"
def get_frame_files():
"""获取首帧和尾帧文件名"""
if len(batch_files) == 1 and prev_batch_files and prev_batch_files:
# 单张图片情况:使用上一批次最后一帧作为首帧
first = os.path.basename(prev_batch_files[-1])
last = os.path.basename(batch_files[0])
logger.debug(f"单张图片批次,使用上一批次最后一帧作为首帧: {first}")
else:
first = os.path.basename(batch_files[0])
last = os.path.basename(batch_files[-1])
return first, last
def extract_time(filename):
"""从文件名提取时间信息"""
try:
# 提取类似 000050100 的时间戳部分
time_str = filename.split('_')[2].replace('.jpg', '')
if len(time_str) < 9: # 处理旧格式
time_str = time_str.ljust(9, '0')
return time_str
except (IndexError, AttributeError) as e:
logger.warning(f"Invalid filename format: {filename}, error: {e}")
return "000000000"
def format_timestamp(time_str):
"""
将时间字符串转换为 HH:MM:SS,mmm 格式
Args:
time_str: 9位数字字符串,格式为 HHMMSSMMM
例如: 000010000 表示 00时00分10秒000毫秒
000043039 表示 00时00分43秒039毫秒
Returns:
str: HH:MM:SS,mmm 格式的时间戳
"""
try:
if len(time_str) < 9:
logger.warning(f"Invalid timestamp format: {time_str}")
return "00:00:00,000"
# 从时间戳中提取时、分、秒和毫秒
hours = int(time_str[0:2]) # 前2位作为小时
minutes = int(time_str[2:4]) # 第3-4位作为分钟
seconds = int(time_str[4:6]) # 第5-6位作为秒数
milliseconds = int(time_str[6:]) # 最后3位作为毫秒
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
except ValueError as e:
logger.warning(f"时间戳格式转换失败: {time_str}, error: {e}")
return "00:00:00,000"
# 获取首帧和尾帧文件名
first_frame, last_frame = get_frame_files()
# 从文件名中提取时间信息
first_time = extract_time(first_frame)
last_time = extract_time(last_frame)
# 转换为标准时间戳格式
first_timestamp = format_timestamp(first_time)
last_timestamp = format_timestamp(last_time)
timestamp_range = f"{first_timestamp}-{last_timestamp}"
# logger.debug(f"解析时间戳: {first_frame} -> {first_timestamp}, {last_frame} -> {last_timestamp}")
return first_timestamp, last_timestamp, timestamp_range
def get_batch_files(keyframe_files, result, batch_size=5):
"""
获取当前批次的图片文件
"""
batch_start = result['batch_index'] * batch_size
batch_end = min(batch_start + batch_size, len(keyframe_files))
return keyframe_files[batch_start:batch_end]
from app.utils import utils, check_script
from webui.tools.generate_script_docu import generate_script_docu
def render_script_panel(tr):
@ -330,7 +223,7 @@ def render_script_buttons(tr, params):
if st.button(button_name, key="script_action", disabled=not script_path):
if script_path == "auto":
generate_script(tr, params)
generate_script_docu(tr, params)
else:
load_script(tr, script_path)
@ -385,280 +278,6 @@ def load_script(tr, script_path):
st.error(f"{tr('Failed to load script')}: {str(e)}")
def generate_script(tr, params):
"""生成视频脚本"""
progress_bar = st.progress(0)
status_text = st.empty()
def update_progress(progress: float, message: str = ""):
progress_bar.progress(progress)
if message:
status_text.text(f"{progress}% - {message}")
else:
status_text.text(f"进度: {progress}%")
try:
with st.spinner("正在生成脚本..."):
if not params.video_origin_path:
st.error("请先选择视频文件")
return
# ===================提取键帧===================
update_progress(10, "正在提取关键帧...")
# 创建临时目录用于存储关键帧
keyframes_dir = os.path.join(utils.temp_dir(), "keyframes")
video_hash = utils.md5(params.video_origin_path + str(os.path.getmtime(params.video_origin_path)))
video_keyframes_dir = os.path.join(keyframes_dir, video_hash)
# 检查是否已经提取过关键帧
keyframe_files = []
if os.path.exists(video_keyframes_dir):
# 取已有的关键帧文件
for filename in sorted(os.listdir(video_keyframes_dir)):
if filename.endswith('.jpg'):
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
if keyframe_files:
logger.info(f"使用已缓存的关键帧: {video_keyframes_dir}")
st.info(f"使用已缓存的关键帧,如需重新提取请删除目录: {video_keyframes_dir}")
update_progress(20, f"使用已缓存关键帧,共 {len(keyframe_files)}")
# 如果没有缓存的关键帧,则进行提取
if not keyframe_files:
try:
# 确保目录存在
os.makedirs(video_keyframes_dir, exist_ok=True)
# 初始化视频处理器
if config.frames.get("version") == "v2":
processor = video_processor_v2.VideoProcessor(params.video_origin_path)
# 处理视频并提取关键帧
processor.process_video_pipeline(
output_dir=video_keyframes_dir,
skip_seconds=st.session_state.get('skip_seconds'),
threshold=st.session_state.get('threshold')
)
else:
processor = video_processor.VideoProcessor(params.video_origin_path)
# 处理视频并提取关键帧
processor.process_video(
output_dir=video_keyframes_dir,
skip_seconds=0
)
# 获取所有关键文件路径
for filename in sorted(os.listdir(video_keyframes_dir)):
if filename.endswith('.jpg'):
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
if not keyframe_files:
raise Exception("未提取到任何关键帧")
update_progress(20, f"关键帧提取完成,共 {len(keyframe_files)}")
except Exception as e:
# 如果提取失败,清理创建的目录
try:
if os.path.exists(video_keyframes_dir):
import shutil
shutil.rmtree(video_keyframes_dir)
except Exception as cleanup_err:
logger.error(f"清理失败的关键帧目录时出错: {cleanup_err}")
raise Exception(f"关键帧提取失败: {str(e)}")
# 根据不同的 LLM 提供商处理
vision_llm_provider = st.session_state.get('vision_llm_providers').lower()
logger.debug(f"Vision LLM 提供商: {vision_llm_provider}")
try:
# ===================初始化视觉分析器===================
update_progress(30, "正在初始化视觉分析器...")
# 从配置中获取相关配置
if vision_llm_provider == 'gemini':
vision_api_key = st.session_state.get('vision_gemini_api_key')
vision_model = st.session_state.get('vision_gemini_model_name')
vision_base_url = st.session_state.get('vision_gemini_base_url')
elif vision_llm_provider == 'qwenvl':
vision_api_key = st.session_state.get('vision_qwenvl_api_key')
vision_model = st.session_state.get('vision_qwenvl_model_name', 'qwen-vl-max-latest')
vision_base_url = st.session_state.get('vision_qwenvl_base_url', 'https://dashscope.aliyuncs.com/compatible-mode/v1')
else:
raise ValueError(f"不支持的视觉分析提供商: {vision_llm_provider}")
# 创建视觉分析器实例
analyzer = create_vision_analyzer(
provider=vision_llm_provider,
api_key=vision_api_key,
model=vision_model,
base_url=vision_base_url
)
update_progress(40, "正在分析关键帧...")
# ===================创建异步事件循环===================
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 执行异步分析
vision_batch_size = st.session_state.get('vision_batch_size') or config.frames.get("vision_batch_size")
results = loop.run_until_complete(
analyzer.analyze_images(
images=keyframe_files,
prompt=config.app.get('vision_analysis_prompt'),
batch_size=vision_batch_size
)
)
loop.close()
# ===================处理分析结果===================
update_progress(60, "正在整理分析结果...")
# 合并所有批次的析结果
frame_analysis = ""
prev_batch_files = None
for result in results:
if 'error' in result:
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
# 获取当前批次的文件列表 keyframe_001136_000045.jpg 将 000045 精度提升到 毫秒
batch_files = get_batch_files(keyframe_files, result, vision_batch_size)
logger.debug(f"批次 {result['batch_index']} 处理完成,共 {len(batch_files)} 张图片")
# logger.debug(batch_files)
first_timestamp, last_timestamp, _ = get_batch_timestamps(batch_files, prev_batch_files)
logger.debug(f"处理时间戳: {first_timestamp}-{last_timestamp}")
# 添加带时间戳的分析结果
frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n"
frame_analysis += result['response']
frame_analysis += "\n"
# 更新上一个批次的文件
prev_batch_files = batch_files
if not frame_analysis.strip():
raise Exception("未能生成有效的帧分析结果")
# 保存分析结果
analysis_path = os.path.join(utils.temp_dir(), "frame_analysis.txt")
with open(analysis_path, 'w', encoding='utf-8') as f:
f.write(frame_analysis)
update_progress(70, "正在生成脚本...")
# 从配置中获取文本生成相关配置
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
text_api_key = config.app.get(f'text_{text_provider}_api_key')
text_model = config.app.get(f'text_{text_provider}_model_name')
text_base_url = config.app.get(f'text_{text_provider}_base_url')
# 构建帧内容列表
frame_content_list = []
prev_batch_files = None
for i, result in enumerate(results):
if 'error' in result:
continue
batch_files = get_batch_files(keyframe_files, result, vision_batch_size)
_, _, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files)
frame_content = {
"timestamp": timestamp_range,
"picture": result['response'],
"narration": "",
"OST": 2
}
frame_content_list.append(frame_content)
logger.debug(f"添加帧内容: 时间范围={timestamp_range}, 分析结果长度={len(result['response'])}")
# 更新上一个批次的文件
prev_batch_files = batch_files
if not frame_content_list:
raise Exception("没有有效的帧内容可以处理")
# ===================开始生成文案===================
update_progress(80, "正在生成文案...")
# 校验配置
api_params = {
"vision_api_key": vision_api_key,
"vision_model_name": vision_model,
"vision_base_url": vision_base_url or "",
"text_api_key": text_api_key,
"text_model_name": text_model,
"text_base_url": text_base_url or ""
}
headers = {
'accept': 'application/json',
'Content-Type': 'application/json'
}
session = requests.Session()
retry_strategy = Retry(
total=3,
backoff_factor=1,
status_forcelist=[500, 502, 503, 504]
)
adapter = HTTPAdapter(max_retries=retry_strategy)
session.mount("https://", adapter)
try:
response = session.post(
f"{config.app.get('narrato_api_url')}/video/config",
headers=headers,
json=api_params,
timeout=30,
verify=True
)
except Exception as e:
pass
custom_prompt = st.session_state.get('custom_prompt', '')
processor = ScriptProcessor(
model_name=text_model,
api_key=text_api_key,
prompt=custom_prompt,
base_url=text_base_url or "",
video_theme=st.session_state.get('video_theme', '')
)
# 处理帧内容生成脚本
script_result = processor.process_frames(frame_content_list)
# 结果转换为JSON字符串
script = json.dumps(script_result, ensure_ascii=False, indent=2)
except Exception as e:
logger.exception(f"大模型处理过程中发生错误\n{traceback.format_exc()}")
raise Exception(f"分析失败: {str(e)}")
if script is None:
st.error("生成脚本失败,请检查日志")
st.stop()
logger.info(f"脚本生成完成")
if isinstance(script, list):
st.session_state['video_clip_json'] = script
elif isinstance(script, str):
st.session_state['video_clip_json'] = json.loads(script)
update_progress(80, "脚本生成完成")
time.sleep(0.1)
progress_bar.progress(100)
status_text.text("脚本生成完成!")
st.success("视频脚本生成成功!")
except Exception as err:
st.error(f"生成过程中发生错误: {str(err)}")
logger.exception(f"生成脚本时发生错误\n{traceback.format_exc()}")
finally:
time.sleep(2)
progress_bar.empty()
status_text.empty()
def save_script(tr, video_clip_json_details):
"""保存视频脚本"""
if not video_clip_json_details:
@ -713,23 +332,3 @@ def crop_video(tr, params):
time.sleep(2)
progress_bar.empty()
status_text.empty()
def get_script_params():
"""获取脚本参数"""
return {
'video_language': st.session_state.get('video_language', ''),
'video_clip_json_path': st.session_state.get('video_clip_json_path', ''),
'video_origin_path': st.session_state.get('video_origin_path', ''),
'video_name': st.session_state.get('video_name', ''),
'video_plot': st.session_state.get('video_plot', '')
}
def create_vision_analyzer(provider, api_key, model, base_url):
if provider == 'gemini':
return gemini_analyzer.VisionAnalyzer(model_name=model, api_key=api_key)
elif provider == 'qwenvl':
return qwenvl_analyzer.QwenAnalyzer(model_name=model, api_key=api_key)
else:
raise ValueError(f"不支持的视觉分析提供商: {provider}")

124
webui/tools/base.py Normal file
View File

@ -0,0 +1,124 @@
import os
import streamlit as st
from loguru import logger
from app.utils import gemini_analyzer, qwenvl_analyzer
def create_vision_analyzer(provider, api_key, model, base_url):
if provider == 'gemini':
return gemini_analyzer.VisionAnalyzer(model_name=model, api_key=api_key)
elif provider == 'qwenvl':
return qwenvl_analyzer.QwenAnalyzer(model_name=model, api_key=api_key)
else:
raise ValueError(f"不支持的视觉分析提供商: {provider}")
def get_script_params():
"""获取脚本参数"""
return {
'video_language': st.session_state.get('video_language', ''),
'video_clip_json_path': st.session_state.get('video_clip_json_path', ''),
'video_origin_path': st.session_state.get('video_origin_path', ''),
'video_name': st.session_state.get('video_name', ''),
'video_plot': st.session_state.get('video_plot', '')
}
def get_batch_timestamps(batch_files, prev_batch_files=None):
"""
解析一批文件的时间戳范围,支持毫秒级精度
Args:
batch_files: 当前批次的文件列表
prev_batch_files: 上一个批次的文件列表,用于处理单张图片的情况
Returns:
tuple: (first_timestamp, last_timestamp, timestamp_range)
时间戳格式: HH:MM:SS,mmm (::,毫秒)
例如: 00:00:50,100 表示50秒100毫秒
示例文件名格式:
keyframe_001253_000050100.jpg
其中 000050100 表示 00:00:50,100 (50秒100毫秒)
"""
if not batch_files:
logger.warning("Empty batch files")
return "00:00:00,000", "00:00:00,000", "00:00:00,000-00:00:00,000"
def get_frame_files():
"""获取首帧和尾帧文件名"""
if len(batch_files) == 1 and prev_batch_files and prev_batch_files:
# 单张图片情况:使用上一批次最后一帧作为首帧
first = os.path.basename(prev_batch_files[-1])
last = os.path.basename(batch_files[0])
logger.debug(f"单张图片批次,使用上一批次最后一帧作为首帧: {first}")
else:
first = os.path.basename(batch_files[0])
last = os.path.basename(batch_files[-1])
return first, last
def extract_time(filename):
"""从文件名提取时间信息"""
try:
# 提取类似 000050100 的时间戳部分
time_str = filename.split('_')[2].replace('.jpg', '')
if len(time_str) < 9: # 处理旧格式
time_str = time_str.ljust(9, '0')
return time_str
except (IndexError, AttributeError) as e:
logger.warning(f"Invalid filename format: {filename}, error: {e}")
return "000000000"
def format_timestamp(time_str):
"""
将时间字符串转换为 HH:MM:SS,mmm 格式
Args:
time_str: 9位数字字符串,格式为 HHMMSSMMM
例如: 000010000 表示 00时00分10秒000毫秒
000043039 表示 00时00分43秒039毫秒
Returns:
str: HH:MM:SS,mmm 格式的时间戳
"""
try:
if len(time_str) < 9:
logger.warning(f"Invalid timestamp format: {time_str}")
return "00:00:00,000"
# 从时间戳中提取时、分、秒和毫秒
hours = int(time_str[0:2]) # 前2位作为小时
minutes = int(time_str[2:4]) # 第3-4位作为分钟
seconds = int(time_str[4:6]) # 第5-6位作为秒数
milliseconds = int(time_str[6:]) # 最后3位作为毫秒
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
except ValueError as e:
logger.warning(f"时间戳格式转换失败: {time_str}, error: {e}")
return "00:00:00,000"
# 获取首帧和尾帧文件名
first_frame, last_frame = get_frame_files()
# 从文件名中提取时间信息
first_time = extract_time(first_frame)
last_time = extract_time(last_frame)
# 转换为标准时间戳格式
first_timestamp = format_timestamp(first_time)
last_timestamp = format_timestamp(last_time)
timestamp_range = f"{first_timestamp}-{last_timestamp}"
# logger.debug(f"解析时间戳: {first_frame} -> {first_timestamp}, {last_frame} -> {last_timestamp}")
return first_timestamp, last_timestamp, timestamp_range
def get_batch_files(keyframe_files, result, batch_size=5):
"""
获取当前批次的图片文件
"""
batch_start = result['batch_index'] * batch_size
batch_end = min(batch_start + batch_size, len(keyframe_files))
return keyframe_files[batch_start:batch_end]

View File

@ -0,0 +1,293 @@
# 纪录片脚本生成
import os
import json
import time
import asyncio
import traceback
import requests
import streamlit as st
from loguru import logger
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from app.config import config
from app.utils.script_generator import ScriptProcessor
from app.utils import utils, video_processor, video_processor_v2, qwenvl_analyzer
from webui.tools.base import create_vision_analyzer, get_batch_files, get_batch_timestamps
def generate_script_docu(tr, params):
"""
生成 纪录片 视频脚本
"""
progress_bar = st.progress(0)
status_text = st.empty()
def update_progress(progress: float, message: str = ""):
progress_bar.progress(progress)
if message:
status_text.text(f"{progress}% - {message}")
else:
status_text.text(f"进度: {progress}%")
try:
with st.spinner("正在生成脚本..."):
if not params.video_origin_path:
st.error("请先选择视频文件")
return
# ===================提取键帧===================
update_progress(10, "正在提取关键帧...")
# 创建临时目录用于存储关键帧
keyframes_dir = os.path.join(utils.temp_dir(), "keyframes")
video_hash = utils.md5(params.video_origin_path + str(os.path.getmtime(params.video_origin_path)))
video_keyframes_dir = os.path.join(keyframes_dir, video_hash)
# 检查是否已经提取过关键帧
keyframe_files = []
if os.path.exists(video_keyframes_dir):
# 取已有的关键帧文件
for filename in sorted(os.listdir(video_keyframes_dir)):
if filename.endswith('.jpg'):
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
if keyframe_files:
logger.info(f"使用已缓存的关键帧: {video_keyframes_dir}")
st.info(f"使用已缓存的关键帧,如需重新提取请删除目录: {video_keyframes_dir}")
update_progress(20, f"使用已缓存关键帧,共 {len(keyframe_files)}")
# 如果没有缓存的关键帧,则进行提取
if not keyframe_files:
try:
# 确保目录存在
os.makedirs(video_keyframes_dir, exist_ok=True)
# 初始化视频处理器
if config.frames.get("version") == "v2":
processor = video_processor_v2.VideoProcessor(params.video_origin_path)
# 处理视频并提取关键帧
processor.process_video_pipeline(
output_dir=video_keyframes_dir,
skip_seconds=st.session_state.get('skip_seconds'),
threshold=st.session_state.get('threshold')
)
else:
processor = video_processor.VideoProcessor(params.video_origin_path)
# 处理视频并提取关键帧
processor.process_video(
output_dir=video_keyframes_dir,
skip_seconds=0
)
# 获取所有关键文件路径
for filename in sorted(os.listdir(video_keyframes_dir)):
if filename.endswith('.jpg'):
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
if not keyframe_files:
raise Exception("未提取到任何关键帧")
update_progress(20, f"关键帧提取完成,共 {len(keyframe_files)}")
except Exception as e:
# 如果提取失败,清理创建的目录
try:
if os.path.exists(video_keyframes_dir):
import shutil
shutil.rmtree(video_keyframes_dir)
except Exception as cleanup_err:
logger.error(f"清理失败的关键帧目录时出错: {cleanup_err}")
raise Exception(f"关键帧提取失败: {str(e)}")
# 根据不同的 LLM 提供商处理
vision_llm_provider = st.session_state.get('vision_llm_providers').lower()
logger.debug(f"Vision LLM 提供商: {vision_llm_provider}")
try:
# ===================初始化视觉分析器===================
update_progress(30, "正在初始化视觉分析器...")
# 从配置中获取相关配置
if vision_llm_provider == 'gemini':
vision_api_key = st.session_state.get('vision_gemini_api_key')
vision_model = st.session_state.get('vision_gemini_model_name')
vision_base_url = st.session_state.get('vision_gemini_base_url')
elif vision_llm_provider == 'qwenvl':
vision_api_key = st.session_state.get('vision_qwenvl_api_key')
vision_model = st.session_state.get('vision_qwenvl_model_name', 'qwen-vl-max-latest')
vision_base_url = st.session_state.get('vision_qwenvl_base_url',
'https://dashscope.aliyuncs.com/compatible-mode/v1')
else:
raise ValueError(f"不支持的视觉分析提供商: {vision_llm_provider}")
# 创建视觉分析器实例
analyzer = create_vision_analyzer(
provider=vision_llm_provider,
api_key=vision_api_key,
model=vision_model,
base_url=vision_base_url
)
update_progress(40, "正在分析关键帧...")
# ===================创建异步事件循环===================
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 执行异步分析
vision_batch_size = st.session_state.get('vision_batch_size') or config.frames.get("vision_batch_size")
results = loop.run_until_complete(
analyzer.analyze_images(
images=keyframe_files,
prompt=config.app.get('vision_analysis_prompt'),
batch_size=vision_batch_size
)
)
loop.close()
# ===================处理分析结果===================
update_progress(60, "正在整理分析结果...")
# 合并所有批次的析结果
frame_analysis = ""
prev_batch_files = None
for result in results:
if 'error' in result:
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
# 获取当前批次的文件列表 keyframe_001136_000045.jpg 将 000045 精度提升到 毫秒
batch_files = get_batch_files(keyframe_files, result, vision_batch_size)
logger.debug(f"批次 {result['batch_index']} 处理完成,共 {len(batch_files)} 张图片")
# logger.debug(batch_files)
first_timestamp, last_timestamp, _ = get_batch_timestamps(batch_files, prev_batch_files)
logger.debug(f"处理时间戳: {first_timestamp}-{last_timestamp}")
# 添加带时间戳的分析结果
frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n"
frame_analysis += result['response']
frame_analysis += "\n"
# 更新上一个批次的文件
prev_batch_files = batch_files
if not frame_analysis.strip():
raise Exception("未能生成有效的帧分析结果")
# 保存分析结果
analysis_path = os.path.join(utils.temp_dir(), "frame_analysis.txt")
with open(analysis_path, 'w', encoding='utf-8') as f:
f.write(frame_analysis)
update_progress(70, "正在生成脚本...")
# 从配置中获取文本生成相关配置
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
text_api_key = config.app.get(f'text_{text_provider}_api_key')
text_model = config.app.get(f'text_{text_provider}_model_name')
text_base_url = config.app.get(f'text_{text_provider}_base_url')
# 构建帧内容列表
frame_content_list = []
prev_batch_files = None
for i, result in enumerate(results):
if 'error' in result:
continue
batch_files = get_batch_files(keyframe_files, result, vision_batch_size)
_, _, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files)
frame_content = {
"timestamp": timestamp_range,
"picture": result['response'],
"narration": "",
"OST": 2
}
frame_content_list.append(frame_content)
logger.debug(f"添加帧内容: 时间范围={timestamp_range}, 分析结果长度={len(result['response'])}")
# 更新上一个批次的文件
prev_batch_files = batch_files
if not frame_content_list:
raise Exception("没有有效的帧内容可以处理")
# ===================开始生成文案===================
update_progress(80, "正在生成文案...")
# 校验配置
api_params = {
"vision_api_key": vision_api_key,
"vision_model_name": vision_model,
"vision_base_url": vision_base_url or "",
"text_api_key": text_api_key,
"text_model_name": text_model,
"text_base_url": text_base_url or ""
}
headers = {
'accept': 'application/json',
'Content-Type': 'application/json'
}
session = requests.Session()
retry_strategy = Retry(
total=3,
backoff_factor=1,
status_forcelist=[500, 502, 503, 504]
)
adapter = HTTPAdapter(max_retries=retry_strategy)
session.mount("https://", adapter)
try:
response = session.post(
f"{config.app.get('narrato_api_url')}/video/config",
headers=headers,
json=api_params,
timeout=30,
verify=True
)
except Exception as e:
pass
custom_prompt = st.session_state.get('custom_prompt', '')
processor = ScriptProcessor(
model_name=text_model,
api_key=text_api_key,
prompt=custom_prompt,
base_url=text_base_url or "",
video_theme=st.session_state.get('video_theme', '')
)
# 处理帧内容生成脚本
script_result = processor.process_frames(frame_content_list)
# 结果转换为JSON字符串
script = json.dumps(script_result, ensure_ascii=False, indent=2)
except Exception as e:
logger.exception(f"大模型处理过程中发生错误\n{traceback.format_exc()}")
raise Exception(f"分析失败: {str(e)}")
if script is None:
st.error("生成脚本失败,请检查日志")
st.stop()
logger.info(f"脚本生成完成")
if isinstance(script, list):
st.session_state['video_clip_json'] = script
elif isinstance(script, str):
st.session_state['video_clip_json'] = json.loads(script)
update_progress(80, "脚本生成完成")
time.sleep(0.1)
progress_bar.progress(100)
status_text.text("脚本生成完成!")
st.success("视频脚本生成成功!")
except Exception as err:
st.error(f"生成过程中发生错误: {str(err)}")
logger.exception(f"生成脚本时发生错误\n{traceback.format_exc()}")
finally:
time.sleep(2)
progress_bar.empty()
status_text.empty()