mirror of
https://github.com/linyqh/NarratoAI.git
synced 2025-12-11 18:42:49 +00:00
完成了gemini 生成视频脚本的逻辑
This commit is contained in:
parent
ec282adb1b
commit
8267a0b3eb
@ -4,6 +4,7 @@ import glob
|
||||
import random
|
||||
from typing import List
|
||||
from typing import Union
|
||||
import traceback
|
||||
|
||||
from loguru import logger
|
||||
from moviepy.editor import *
|
||||
@ -145,7 +146,7 @@ def combine_videos(
|
||||
return combined_video_path
|
||||
|
||||
|
||||
def wrap_text(text, max_width, font="Arial", fontsize=60):
|
||||
def wrap_text(text, max_width, font, fontsize=60):
|
||||
# 创建字体对象
|
||||
font = ImageFont.truetype(font, fontsize)
|
||||
|
||||
@ -158,7 +159,7 @@ def wrap_text(text, max_width, font="Arial", fontsize=60):
|
||||
if width <= max_width:
|
||||
return text, height
|
||||
|
||||
# logger.warning(f"wrapping text, max_width: {max_width}, text_width: {width}, text: {text}")
|
||||
logger.debug(f"换行文本, 最大宽度: {max_width}, 文本宽度: {width}, 文本: {text}")
|
||||
|
||||
processed = True
|
||||
|
||||
@ -199,7 +200,7 @@ def wrap_text(text, max_width, font="Arial", fontsize=60):
|
||||
_wrapped_lines_.append(_txt_)
|
||||
result = "\n".join(_wrapped_lines_).strip()
|
||||
height = len(_wrapped_lines_) * height
|
||||
# logger.warning(f"wrapped text: {result}")
|
||||
logger.debug(f"换行文本: {result}")
|
||||
return result, height
|
||||
|
||||
|
||||
@ -233,7 +234,7 @@ def generate_video_v2(
|
||||
Returns:
|
||||
|
||||
"""
|
||||
total_steps = 4 # 总步<E680BB><E6ADA5><EFBFBD>
|
||||
total_steps = 4
|
||||
current_step = 0
|
||||
|
||||
def update_progress(step_name):
|
||||
@ -506,7 +507,7 @@ def combine_clip_videos(combined_video_path: str,
|
||||
temp_audiofile=os.path.join(output_dir, "temp-audio.m4a")
|
||||
)
|
||||
finally:
|
||||
# 确保资源被正确释放
|
||||
# 确保资源被正确<EFBFBD><EFBFBD><EFBFBD>放
|
||||
video_clip.close()
|
||||
for clip in clips:
|
||||
clip.close()
|
||||
|
||||
399
app/utils/script_generator.py
Normal file
399
app/utils/script_generator.py
Normal file
@ -0,0 +1,399 @@
|
||||
import os
|
||||
import json
|
||||
import traceback
|
||||
from loguru import logger
|
||||
import tiktoken
|
||||
from typing import List, Dict
|
||||
from datetime import datetime
|
||||
from openai import OpenAI
|
||||
import google.generativeai as genai
|
||||
|
||||
|
||||
class BaseGenerator:
|
||||
def __init__(self, model_name: str, api_key: str, prompt: str):
|
||||
self.model_name = model_name
|
||||
self.api_key = api_key
|
||||
self.base_prompt = prompt
|
||||
self.conversation_history = []
|
||||
self.chunk_overlap = 50
|
||||
self.last_chunk_ending = ""
|
||||
|
||||
def generate_script(self, scene_description: str, word_count: int) -> str:
|
||||
raise NotImplementedError("Subclasses must implement generate_script method")
|
||||
|
||||
|
||||
class OpenAIGenerator(BaseGenerator):
|
||||
def __init__(self, model_name: str, api_key: str, prompt: str):
|
||||
super().__init__(model_name, api_key, prompt)
|
||||
self.client = OpenAI(api_key=api_key)
|
||||
self.max_tokens = 7000
|
||||
try:
|
||||
self.encoding = tiktoken.encoding_for_model(self.model_name)
|
||||
except KeyError:
|
||||
logger.info(f"警告:未找到模型 {self.model_name} 的专用编码器,使用认编码器")
|
||||
self.encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
def _count_tokens(self, messages: list) -> int:
|
||||
num_tokens = 0
|
||||
for message in messages:
|
||||
num_tokens += 3
|
||||
for key, value in message.items():
|
||||
num_tokens += len(self.encoding.encode(str(value)))
|
||||
if key == "role":
|
||||
num_tokens += 1
|
||||
num_tokens += 3
|
||||
return num_tokens
|
||||
|
||||
def _trim_conversation_history(self, system_prompt: str, new_user_prompt: str) -> None:
|
||||
base_messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": new_user_prompt}
|
||||
]
|
||||
base_tokens = self._count_tokens(base_messages)
|
||||
|
||||
temp_history = []
|
||||
current_tokens = base_tokens
|
||||
|
||||
for message in reversed(self.conversation_history):
|
||||
message_tokens = self._count_tokens([message])
|
||||
if current_tokens + message_tokens > self.max_tokens:
|
||||
break
|
||||
temp_history.insert(0, message)
|
||||
current_tokens += message_tokens
|
||||
|
||||
self.conversation_history = temp_history
|
||||
|
||||
def generate_script(self, scene_description: str, word_count: int) -> str:
|
||||
max_attempts = 3
|
||||
tolerance = 5
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
system_prompt, user_prompt = self._create_prompt(scene_description, word_count)
|
||||
self._trim_conversation_history(system_prompt, user_prompt)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
*self.conversation_history,
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=0.7,
|
||||
max_tokens=500,
|
||||
top_p=0.9,
|
||||
frequency_penalty=0.3,
|
||||
presence_penalty=0.5
|
||||
)
|
||||
|
||||
generated_script = response.choices[0].message.content.strip().strip('"').strip("'").replace('\"',
|
||||
'').replace(
|
||||
'\n', '')
|
||||
|
||||
current_length = len(generated_script)
|
||||
if abs(current_length - word_count) <= tolerance:
|
||||
self.conversation_history.append({"role": "user", "content": user_prompt})
|
||||
self.conversation_history.append({"role": "assistant", "content": generated_script})
|
||||
self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len(
|
||||
generated_script) > self.chunk_overlap else generated_script
|
||||
return generated_script
|
||||
|
||||
return generated_script
|
||||
|
||||
def _create_prompt(self, scene_description: str, word_count: int) -> tuple:
|
||||
system_prompt = self.base_prompt.format(word_count=word_count)
|
||||
|
||||
user_prompt = f"""上一段文案的结尾:{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"}
|
||||
|
||||
当前画面描述:{scene_description}
|
||||
|
||||
请确保新生成的文案与上文自然衔接,保持叙事的连贯性和趣味性。
|
||||
严格字数要求:{word_count}字,允许误差±5字。"""
|
||||
|
||||
return system_prompt, user_prompt
|
||||
|
||||
|
||||
class GeminiGenerator(BaseGenerator):
|
||||
def __init__(self, model_name: str, api_key: str, prompt: str):
|
||||
super().__init__(model_name, api_key, prompt)
|
||||
genai.configure(api_key=api_key)
|
||||
self.model = genai.GenerativeModel(model_name)
|
||||
|
||||
def generate_script(self, scene_description: str, word_count: int) -> str:
|
||||
max_attempts = 3
|
||||
tolerance = 5
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
prompt = f"""{self.base_prompt}
|
||||
|
||||
上一段文案的结尾:{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"}
|
||||
|
||||
当前画面描述:{scene_description}
|
||||
|
||||
请确保新生成的文案与上文自然衔接,保持叙事的连贯性和趣味性。
|
||||
严格字数要求:{word_count}字,允许误差±5字。"""
|
||||
|
||||
response = self.model.generate_content(prompt)
|
||||
generated_script = response.text.strip().strip('"').strip("'").replace('\"', '').replace('\n', '')
|
||||
|
||||
current_length = len(generated_script)
|
||||
if abs(current_length - word_count) <= tolerance:
|
||||
self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len(
|
||||
generated_script) > self.chunk_overlap else generated_script
|
||||
return generated_script
|
||||
|
||||
return generated_script
|
||||
|
||||
|
||||
class QwenGenerator(BaseGenerator):
|
||||
def __init__(self, model_name: str, api_key: str, prompt: str):
|
||||
super().__init__(model_name, api_key, prompt)
|
||||
self.client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
)
|
||||
|
||||
def generate_script(self, scene_description: str, word_count: int) -> str:
|
||||
max_attempts = 3
|
||||
tolerance = 5
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
prompt = f"""{self.base_prompt}
|
||||
|
||||
上一段文案的结尾:{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"}
|
||||
|
||||
当前画面描述:{scene_description}
|
||||
|
||||
请确保新生成的文案与上文自然衔接,保持叙事的连贯性和趣味性。
|
||||
严格字数要求:{word_count}字,允许误差±5字。"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": self.base_prompt},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name, # 如 "qwen-plus"
|
||||
messages=messages,
|
||||
temperature=0.7,
|
||||
max_tokens=500,
|
||||
top_p=0.9,
|
||||
frequency_penalty=0.3,
|
||||
presence_penalty=0.5
|
||||
)
|
||||
|
||||
generated_script = response.choices[0].message.content.strip().strip('"').strip("'").replace('\"',
|
||||
'').replace(
|
||||
'\n', '')
|
||||
|
||||
current_length = len(generated_script)
|
||||
if abs(current_length - word_count) <= tolerance:
|
||||
self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len(
|
||||
generated_script) > self.chunk_overlap else generated_script
|
||||
return generated_script
|
||||
|
||||
return generated_script
|
||||
|
||||
|
||||
class MoonshotGenerator(BaseGenerator):
|
||||
def __init__(self, model_name: str, api_key: str, prompt: str):
|
||||
super().__init__(model_name, api_key, prompt)
|
||||
self.client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url="https://api.moonshot.cn/v1"
|
||||
)
|
||||
|
||||
def generate_script(self, scene_description: str, word_count: int) -> str:
|
||||
max_attempts = 3
|
||||
tolerance = 5
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
prompt = f"""{self.base_prompt}
|
||||
|
||||
上一段文案的结尾:{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"}
|
||||
|
||||
当前画面描述:{scene_description}
|
||||
|
||||
请确保新生成的文案与上文自然衔接,保持叙事的连贯性和趣味性。
|
||||
严格字数要求:{word_count}字,允许误差±5字。"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": self.base_prompt},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name, # 如 "moonshot-v1-8k"
|
||||
messages=messages,
|
||||
temperature=0.7,
|
||||
max_tokens=500,
|
||||
top_p=0.9,
|
||||
frequency_penalty=0.3,
|
||||
presence_penalty=0.5
|
||||
)
|
||||
|
||||
generated_script = response.choices[0].message.content.strip().strip('"').strip("'").replace('\"',
|
||||
'').replace(
|
||||
'\n', '')
|
||||
|
||||
current_length = len(generated_script)
|
||||
if abs(current_length - word_count) <= tolerance:
|
||||
self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len(
|
||||
generated_script) > self.chunk_overlap else generated_script
|
||||
return generated_script
|
||||
|
||||
return generated_script
|
||||
|
||||
|
||||
class ScriptProcessor:
|
||||
def __init__(self, model_name, api_key=None, prompt=None):
|
||||
self.model_name = model_name
|
||||
# 根据不同模型选择对应的环境变量
|
||||
default_api_key = {
|
||||
'gemini': 'GOOGLE_API_KEY',
|
||||
'gpt': 'OPENAI_API_KEY',
|
||||
'qwen': 'DASHSCOPE_API_KEY',
|
||||
'moonshot': 'MOONSHOT_API_KEY'
|
||||
}
|
||||
api_key_env = next((v for k, v in default_api_key.items() if k in model_name.lower()), 'OPENAI_API_KEY')
|
||||
self.api_key = api_key or os.getenv(api_key_env)
|
||||
self.prompt = prompt or self._get_default_prompt()
|
||||
|
||||
# 根据模型名称选择对应的生成器
|
||||
if 'gemini' in model_name.lower():
|
||||
self.generator = GeminiGenerator(model_name, self.api_key, self.prompt)
|
||||
elif 'qwen' in model_name.lower():
|
||||
self.generator = QwenGenerator(model_name, self.api_key, self.prompt)
|
||||
elif 'moonshot' in model_name.lower():
|
||||
self.generator = MoonshotGenerator(model_name, self.api_key, self.prompt)
|
||||
else:
|
||||
self.generator = OpenAIGenerator(model_name, self.api_key, self.prompt)
|
||||
|
||||
def _get_default_prompt(self) -> str:
|
||||
return """你是一位极具幽默感的短视频脚本创作大师,擅长用"温和的违反"制造笑点,让野外建造视频既有趣又富有传播力。你的任务是将视频画面描述转化为能在社交平台疯狂传播的爆款口播文案。
|
||||
|
||||
目标受众:热爱野外生活、追求独特体验的18-35岁年轻人
|
||||
文案风格:基于HKRR理论 + 段子手精神
|
||||
主题:野外建造
|
||||
|
||||
【创作核心理念】
|
||||
1. 敢于用"温和的违反"制造笑点,但不能过于冒犯
|
||||
2. 巧妙运用中国式幽默,让观众会心一笑
|
||||
3. 保持轻松愉快的叙事基调
|
||||
|
||||
【爆款内容四要素】
|
||||
|
||||
【快乐元素 Happy】
|
||||
1. 用调侃的语气描述建造过程中的"笨手笨脚"
|
||||
2. 巧妙植入网络流行梗,增加内容的传播性
|
||||
3. 适时自嘲,展现真实且有趣的一面
|
||||
|
||||
【知识价值 Knowledge】
|
||||
1. 用段子手的方式解释专业知识(比如:"这根木头不是一般的木头,它比我前任还难搞...")
|
||||
2. 把复杂的建造技巧转化为生动有趣的比喻
|
||||
3. 在幽默中传递实用的野外生存技能
|
||||
|
||||
【情感共鸣 Resonance】
|
||||
1. 描述"真实但夸张"的建造困境
|
||||
2. 把对自然的感悟融入俏皮话中
|
||||
3. 用接地气的表达方式拉近与观众距离
|
||||
|
||||
【节奏控制 Rhythm】
|
||||
1. 严格控制文案字数在{word_count}字左右,允许误差不超过5字
|
||||
2. 像讲段子一样,注意铺垫和包袱的节奏
|
||||
3. 确保每段都有笑点,但不强求
|
||||
4. 段落结尾干净利落,不拖泥带水
|
||||
|
||||
【连贯性要求】
|
||||
1. 新生成的内容必须自然衔接上一段文案的结尾
|
||||
2. 使用恰当的连接词和过渡语,确保叙事流畅
|
||||
3. 保持人物视角和语气的一致性
|
||||
4. 避免重复上一段已经提到的信息
|
||||
5. 确保情节和建造过程的逻辑连续性
|
||||
|
||||
【字数控制要求】
|
||||
1. 严格控制文案字数在{word_count}字左右,允许误差不超过5字
|
||||
2. 如果内容过长,优先精简修饰性词语
|
||||
3. 如果内容过短,可以适当增加细节描写
|
||||
4. 保持文案结构完整,不因字数限制而牺牲内容质量
|
||||
5. 确保每个笑点和包袱都得到完整表达
|
||||
|
||||
我会按顺序提供多段视频画面描述。请创作既搞笑又能火爆全网的口播文案。
|
||||
记住:要敢于用"温和的违反"制造笑点,但要把握好尺度,让观众在轻松愉快中感受野外建造的乐趣。"""
|
||||
|
||||
def calculate_duration_and_word_count(self, time_range: str) -> int:
|
||||
try:
|
||||
start_str, end_str = time_range.split('-')
|
||||
|
||||
def time_to_seconds(time_str):
|
||||
minutes, seconds = map(int, time_str.split(':'))
|
||||
return minutes * 60 + seconds
|
||||
|
||||
start_seconds = time_to_seconds(start_str)
|
||||
end_seconds = time_to_seconds(end_str)
|
||||
duration = end_seconds - start_seconds
|
||||
word_count = int(duration / 0.2)
|
||||
|
||||
return word_count
|
||||
except Exception as e:
|
||||
logger.info(f"时间格式转换错误: {traceback.format_exc()}")
|
||||
return 100
|
||||
|
||||
def process_frames(self, frame_content_list: List[Dict]) -> List[Dict]:
|
||||
for frame_content in frame_content_list:
|
||||
word_count = self.calculate_duration_and_word_count(frame_content["timestamp"])
|
||||
script = self.generator.generate_script(frame_content["picture"], word_count)
|
||||
frame_content["narration"] = script
|
||||
frame_content["OST"] = 2
|
||||
logger.info(f"时间范围: {frame_content['timestamp']}, 建议字数: {word_count}")
|
||||
logger.info(script)
|
||||
|
||||
self._save_results(frame_content_list)
|
||||
return frame_content_list
|
||||
|
||||
def _save_results(self, frame_content_list: List[Dict]):
|
||||
"""保存处理结果,并添加新的时间戳"""
|
||||
try:
|
||||
# 计算新的时间戳
|
||||
current_time = 0 # 当前时间点(秒)
|
||||
|
||||
for frame in frame_content_list:
|
||||
# 获取原始时间戳的持续时间
|
||||
start_str, end_str = frame['timestamp'].split('-')
|
||||
|
||||
def time_to_seconds(time_str):
|
||||
minutes, seconds = map(int, time_str.split(':'))
|
||||
return minutes * 60 + seconds
|
||||
|
||||
# 计算当前片段的持续时间
|
||||
start_seconds = time_to_seconds(start_str)
|
||||
end_seconds = time_to_seconds(end_str)
|
||||
duration = end_seconds - start_seconds
|
||||
|
||||
# 转换秒数为 MM:SS 格式
|
||||
def seconds_to_time(seconds):
|
||||
minutes = seconds // 60
|
||||
remaining_seconds = seconds % 60
|
||||
return f"{minutes:02d}:{remaining_seconds:02d}"
|
||||
|
||||
# 设置新的时间戳
|
||||
new_start = seconds_to_time(current_time)
|
||||
new_end = seconds_to_time(current_time + duration)
|
||||
frame['new_timestamp'] = f"{new_start}-{new_end}"
|
||||
|
||||
# 更新当前时间点
|
||||
current_time += duration
|
||||
|
||||
# 保存结果
|
||||
file_name = f"storage/json/step2_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||||
os.makedirs(os.path.dirname(file_name), exist_ok=True)
|
||||
|
||||
with open(file_name, 'w', encoding='utf-8') as file:
|
||||
json.dump(frame_content_list, file, ensure_ascii=False, indent=4)
|
||||
|
||||
logger.info(f"保存脚本成功,总时长: {seconds_to_time(current_time)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存结果时发生错误: {str(e)}\n{traceback.format_exc()}")
|
||||
raise
|
||||
@ -56,7 +56,7 @@ def to_json(obj):
|
||||
# 使用serialize函数处理输入对象
|
||||
serialized_obj = serialize(obj)
|
||||
|
||||
# 序列化处理后的对象为JSON字符串
|
||||
# 序列化处理后的对象为JSON<EFBFBD><EFBFBD><EFBFBD>符串
|
||||
return json.dumps(serialized_obj, ensure_ascii=False, indent=4)
|
||||
except Exception as e:
|
||||
return None
|
||||
@ -100,7 +100,7 @@ def task_dir(sub_dir: str = ""):
|
||||
|
||||
|
||||
def font_dir(sub_dir: str = ""):
|
||||
d = resource_dir(f"fonts")
|
||||
d = resource_dir("fonts")
|
||||
if sub_dir:
|
||||
d = os.path.join(d, sub_dir)
|
||||
if not os.path.exists(d):
|
||||
@ -109,7 +109,7 @@ def font_dir(sub_dir: str = ""):
|
||||
|
||||
|
||||
def song_dir(sub_dir: str = ""):
|
||||
d = resource_dir(f"songs")
|
||||
d = resource_dir("songs")
|
||||
if sub_dir:
|
||||
d = os.path.join(d, sub_dir)
|
||||
if not os.path.exists(d):
|
||||
@ -425,3 +425,102 @@ def cut_video(params, progress_callback=None):
|
||||
except Exception as e:
|
||||
logger.error(f"视频裁剪过程中发生错误: \n{traceback.format_exc()}")
|
||||
raise
|
||||
|
||||
|
||||
def temp_dir(sub_dir: str = ""):
|
||||
"""
|
||||
获取临时文件目录
|
||||
Args:
|
||||
sub_dir: 子目录名
|
||||
Returns:
|
||||
str: 临时文件目录路径
|
||||
"""
|
||||
d = os.path.join(storage_dir(), "temp")
|
||||
if sub_dir:
|
||||
d = os.path.join(d, sub_dir)
|
||||
if not os.path.exists(d):
|
||||
os.makedirs(d)
|
||||
return d
|
||||
|
||||
|
||||
def clear_keyframes_cache(video_path: str = None):
|
||||
"""
|
||||
清理关键帧缓存
|
||||
Args:
|
||||
video_path: 视频文件路径,如果指定则只清理该视频的缓存
|
||||
"""
|
||||
try:
|
||||
keyframes_dir = os.path.join(temp_dir(), "keyframes")
|
||||
if not os.path.exists(keyframes_dir):
|
||||
return
|
||||
|
||||
if video_path:
|
||||
# <20><><EFBFBD>理指定视频的缓存
|
||||
video_hash = md5(video_path + str(os.path.getmtime(video_path)))
|
||||
video_keyframes_dir = os.path.join(keyframes_dir, video_hash)
|
||||
if os.path.exists(video_keyframes_dir):
|
||||
import shutil
|
||||
shutil.rmtree(video_keyframes_dir)
|
||||
logger.info(f"已清理视频关键帧缓存: {video_path}")
|
||||
else:
|
||||
# 清理所有缓存
|
||||
import shutil
|
||||
shutil.rmtree(keyframes_dir)
|
||||
logger.info("已清理所有关键帧缓存")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理关键帧缓存失败: {e}")
|
||||
|
||||
|
||||
def init_resources():
|
||||
"""初始化资源文件"""
|
||||
try:
|
||||
# 创建字体目录
|
||||
font_dir = os.path.join(root_dir(), "resource", "fonts")
|
||||
os.makedirs(font_dir, exist_ok=True)
|
||||
|
||||
# 检查字体文件
|
||||
font_files = [
|
||||
("SourceHanSansCN-Regular.otf", "https://github.com/adobe-fonts/source-han-sans/raw/release/OTF/SimplifiedChinese/SourceHanSansSC-Regular.otf"),
|
||||
("simhei.ttf", "C:/Windows/Fonts/simhei.ttf"), # Windows 黑体
|
||||
("simkai.ttf", "C:/Windows/Fonts/simkai.ttf"), # Windows 楷体
|
||||
("simsun.ttc", "C:/Windows/Fonts/simsun.ttc"), # Windows 宋体
|
||||
]
|
||||
|
||||
# 优先使用系统字体
|
||||
system_font_found = False
|
||||
for font_name, source in font_files:
|
||||
if not source.startswith("http") and os.path.exists(source):
|
||||
target_path = os.path.join(font_dir, font_name)
|
||||
if not os.path.exists(target_path):
|
||||
import shutil
|
||||
shutil.copy2(source, target_path)
|
||||
logger.info(f"已复制系统字体: {font_name}")
|
||||
system_font_found = True
|
||||
break
|
||||
|
||||
# 如果没有找到系统字体,则下载思源黑体
|
||||
if not system_font_found:
|
||||
source_han_path = os.path.join(font_dir, "SourceHanSansCN-Regular.otf")
|
||||
if not os.path.exists(source_han_path):
|
||||
download_font(font_files[0][1], source_han_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"初始化资源文件失败: {e}")
|
||||
|
||||
def download_font(url: str, font_path: str):
|
||||
"""下载字体文件"""
|
||||
try:
|
||||
logger.info(f"正在下载字体文件: {url}")
|
||||
import requests
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(font_path, 'wb') as f:
|
||||
f.write(response.content)
|
||||
|
||||
logger.info(f"字体文件下载成功: {font_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"下载字体文件失败: {e}")
|
||||
raise
|
||||
|
||||
209
app/utils/video_processor.py
Normal file
209
app/utils/video_processor.py
Normal file
@ -0,0 +1,209 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from sklearn.cluster import KMeans
|
||||
import os
|
||||
import re
|
||||
from typing import List, Tuple, Generator
|
||||
|
||||
|
||||
class VideoProcessor:
|
||||
def __init__(self, video_path: str):
|
||||
"""
|
||||
初始化视频处理器
|
||||
|
||||
Args:
|
||||
video_path: 视频文件路径
|
||||
"""
|
||||
if not os.path.exists(video_path):
|
||||
raise FileNotFoundError(f"视频文件不存在: {video_path}")
|
||||
|
||||
self.video_path = video_path
|
||||
self.cap = cv2.VideoCapture(video_path)
|
||||
|
||||
if not self.cap.isOpened():
|
||||
raise RuntimeError(f"无法打开视频文件: {video_path}")
|
||||
|
||||
self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
self.fps = int(self.cap.get(cv2.CAP_PROP_FPS))
|
||||
|
||||
def __del__(self):
|
||||
"""析构函数,确保视频资源被释放"""
|
||||
if hasattr(self, 'cap'):
|
||||
self.cap.release()
|
||||
|
||||
def preprocess_video(self) -> Generator[np.ndarray, None, None]:
|
||||
"""
|
||||
使用生成器方式读取视频帧
|
||||
|
||||
Yields:
|
||||
np.ndarray: 视频帧
|
||||
"""
|
||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # 重置到视频开始
|
||||
while self.cap.isOpened():
|
||||
ret, frame = self.cap.read()
|
||||
if not ret:
|
||||
break
|
||||
yield frame
|
||||
|
||||
def detect_shot_boundaries(self, frames: List[np.ndarray], threshold: int = 30) -> List[int]:
|
||||
"""
|
||||
使用帧差法检测镜头边界
|
||||
|
||||
Args:
|
||||
frames: 视频帧列表
|
||||
threshold: 差异阈值
|
||||
|
||||
Returns:
|
||||
List[int]: 镜头边界帧的索引列表
|
||||
"""
|
||||
shot_boundaries = []
|
||||
for i in range(1, len(frames)):
|
||||
prev_frame = cv2.cvtColor(frames[i - 1], cv2.COLOR_BGR2GRAY)
|
||||
curr_frame = cv2.cvtColor(frames[i], cv2.COLOR_BGR2GRAY)
|
||||
diff = np.mean(np.abs(curr_frame.astype(int) - prev_frame.astype(int)))
|
||||
if diff > threshold:
|
||||
shot_boundaries.append(i)
|
||||
return shot_boundaries
|
||||
|
||||
def extract_keyframes(self, frames: List[np.ndarray], shot_boundaries: List[int]) -> Tuple[List[np.ndarray], List[int]]:
|
||||
"""
|
||||
从每个镜头中提取关键帧
|
||||
|
||||
Args:
|
||||
frames: 视频帧列表
|
||||
shot_boundaries: 镜头边界列表
|
||||
|
||||
Returns:
|
||||
Tuple[List[np.ndarray], List[int]]: 关<EFBFBD><EFBFBD><EFBFBD>帧列表和对应的帧索引
|
||||
"""
|
||||
keyframes = []
|
||||
keyframe_indices = []
|
||||
|
||||
for i in range(len(shot_boundaries)):
|
||||
start = shot_boundaries[i - 1] if i > 0 else 0
|
||||
end = shot_boundaries[i]
|
||||
shot_frames = frames[start:end]
|
||||
|
||||
frame_features = np.array([cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).flatten()
|
||||
for frame in shot_frames])
|
||||
kmeans = KMeans(n_clusters=1, random_state=0).fit(frame_features)
|
||||
center_idx = np.argmin(np.sum((frame_features - kmeans.cluster_centers_[0]) ** 2, axis=1))
|
||||
|
||||
keyframes.append(shot_frames[center_idx])
|
||||
keyframe_indices.append(start + center_idx)
|
||||
|
||||
return keyframes, keyframe_indices
|
||||
|
||||
def save_keyframes(self, keyframes: List[np.ndarray], keyframe_indices: List[int],
|
||||
output_dir: str) -> None:
|
||||
"""
|
||||
保存关键帧到指定目录,文件名格式为:keyframe_帧序号_时间戳.jpg
|
||||
|
||||
Args:
|
||||
keyframes: 关键帧列表
|
||||
keyframe_indices: 关键帧索引列表
|
||||
output_dir: 输出目录
|
||||
"""
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
for keyframe, frame_idx in zip(keyframes, keyframe_indices):
|
||||
# 计算时间戳(秒)
|
||||
timestamp = frame_idx / self.fps
|
||||
# 将时间戳转换为 HH:MM:SS 格式
|
||||
hours = int(timestamp // 3600)
|
||||
minutes = int((timestamp % 3600) // 60)
|
||||
seconds = int(timestamp % 60)
|
||||
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}"
|
||||
|
||||
# 构建新的文件名格式:keyframe_帧序号_时间戳.jpg
|
||||
output_path = os.path.join(output_dir,
|
||||
f'keyframe_{frame_idx:06d}_{time_str}.jpg')
|
||||
cv2.imwrite(output_path, keyframe)
|
||||
|
||||
print(f"已保存 {len(keyframes)} 个关键帧到 {output_dir}")
|
||||
|
||||
def extract_frames_by_numbers(self, frame_numbers: List[int], output_folder: str) -> None:
|
||||
"""
|
||||
根据指定的帧号提取帧
|
||||
|
||||
Args:
|
||||
frame_numbers: 要提取的帧号列表
|
||||
output_folder: 输出文件夹路径
|
||||
"""
|
||||
if not frame_numbers:
|
||||
raise ValueError("未提供帧号列表")
|
||||
|
||||
if any(fn >= self.total_frames or fn < 0 for fn in frame_numbers):
|
||||
raise ValueError("存在无效的帧号")
|
||||
|
||||
if not os.path.exists(output_folder):
|
||||
os.makedirs(output_folder)
|
||||
|
||||
for frame_number in frame_numbers:
|
||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
|
||||
ret, frame = self.cap.read()
|
||||
|
||||
if ret:
|
||||
# 计算时间戳
|
||||
timestamp = frame_number / self.fps
|
||||
hours = int(timestamp // 3600)
|
||||
minutes = int((timestamp % 3600) // 60)
|
||||
seconds = int(timestamp % 60)
|
||||
time_str = f"{hours:02d}{minutes:02d}{seconds:02d}"
|
||||
|
||||
# 使用与关键帧相同的命名格式
|
||||
output_path = os.path.join(output_folder,
|
||||
f"extracted_frame_{frame_number:06d}_{time_str}.jpg")
|
||||
cv2.imwrite(output_path, frame)
|
||||
print(f"已提取并保存帧 {frame_number}")
|
||||
else:
|
||||
print(f"无法读取帧 {frame_number}")
|
||||
|
||||
@staticmethod
|
||||
def extract_numbers_from_folder(folder_path: str) -> List[int]:
|
||||
"""
|
||||
从文件夹中提取帧号
|
||||
|
||||
Args:
|
||||
folder_path: 关键帧文件夹路径
|
||||
|
||||
Returns:
|
||||
List[int]: 排序后的帧号列表
|
||||
"""
|
||||
files = [f for f in os.listdir(folder_path) if f.endswith('.jpg')]
|
||||
# 更新正则表达式以匹配新的文件名格式:keyframe_000123_010534.jpg
|
||||
pattern = re.compile(r'keyframe_(\d+)_\d+\.jpg$')
|
||||
numbers = []
|
||||
for f in files:
|
||||
match = pattern.search(f)
|
||||
if match:
|
||||
numbers.append(int(match.group(1)))
|
||||
return sorted(numbers)
|
||||
|
||||
def process_video(self, output_dir: str, skip_seconds: float = 0) -> None:
|
||||
"""
|
||||
处理视频并提取关键帧
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录
|
||||
skip_seconds: 跳过视频开头的秒数
|
||||
"""
|
||||
# 计算要跳过的帧数
|
||||
skip_frames = int(skip_seconds * self.fps)
|
||||
|
||||
# 获取所有帧
|
||||
frames = list(self.preprocess_video())
|
||||
|
||||
# 跳过指定秒数的帧
|
||||
frames = frames[skip_frames:]
|
||||
|
||||
if not frames:
|
||||
raise ValueError(f"跳过 {skip_seconds} 秒后没有剩余帧可以处理")
|
||||
|
||||
shot_boundaries = self.detect_shot_boundaries(frames)
|
||||
keyframes, keyframe_indices = self.extract_keyframes(frames, shot_boundaries)
|
||||
|
||||
# 调整关键帧索引,加上跳过的帧数
|
||||
adjusted_indices = [idx + skip_frames for idx in keyframe_indices]
|
||||
self.save_keyframes(keyframes, adjusted_indices, output_dir)
|
||||
183
app/utils/vision_analyzer.py
Normal file
183
app/utils/vision_analyzer.py
Normal file
@ -0,0 +1,183 @@
|
||||
import json
|
||||
from typing import List, Union, Dict
|
||||
import os
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from tqdm import tqdm
|
||||
import asyncio
|
||||
from tenacity import retry, stop_after_attempt, RetryError, retry_if_exception_type, wait_exponential
|
||||
from google.api_core import exceptions
|
||||
import google.generativeai as genai
|
||||
import PIL.Image
|
||||
import traceback
|
||||
|
||||
class VisionAnalyzer:
|
||||
"""视觉分析器类"""
|
||||
|
||||
def __init__(self, model_name: str = "gemini-1.5-flash", api_key: str = None):
|
||||
"""初始化视觉分析器"""
|
||||
if not api_key:
|
||||
raise ValueError("必须提供API密钥")
|
||||
|
||||
self.model_name = model_name
|
||||
self.api_key = api_key
|
||||
|
||||
# 初始化配置
|
||||
self._configure_client()
|
||||
|
||||
def _configure_client(self):
|
||||
"""配置API客户端"""
|
||||
genai.configure(api_key=self.api_key)
|
||||
self.model = genai.GenerativeModel(self.model_name)
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(exceptions.ResourceExhausted)
|
||||
)
|
||||
async def _generate_content_with_retry(self, prompt, batch):
|
||||
"""使用重试机制的内部方法来调用 generate_content_async"""
|
||||
try:
|
||||
return await self.model.generate_content_async([prompt, *batch])
|
||||
except exceptions.ResourceExhausted as e:
|
||||
print(f"API配额限制: {str(e)}")
|
||||
raise RetryError("API调用失败")
|
||||
|
||||
async def analyze_images(self,
|
||||
images: Union[List[str], List[PIL.Image.Image]],
|
||||
prompt: str,
|
||||
batch_size: int = 5) -> List[Dict]:
|
||||
"""批量分析多张图片"""
|
||||
try:
|
||||
# 加载图片
|
||||
if isinstance(images[0], str):
|
||||
logger.info("正在加载图片...")
|
||||
images = self.load_images(images)
|
||||
|
||||
# 验证图片列表
|
||||
if not images:
|
||||
raise ValueError("图片列表为空")
|
||||
|
||||
# 验证每个图片对象
|
||||
valid_images = []
|
||||
for i, img in enumerate(images):
|
||||
if not isinstance(img, PIL.Image.Image):
|
||||
logger.error(f"无效的图片对象,索引 {i}: {type(img)}")
|
||||
continue
|
||||
valid_images.append(img)
|
||||
|
||||
if not valid_images:
|
||||
raise ValueError("没有有效的图片对象")
|
||||
|
||||
images = valid_images
|
||||
results = []
|
||||
total_batches = (len(images) + batch_size - 1) // batch_size
|
||||
|
||||
with tqdm(total=total_batches, desc="分析进度") as pbar:
|
||||
for i in range(0, len(images), batch_size):
|
||||
batch = images[i:i + batch_size]
|
||||
retry_count = 0
|
||||
|
||||
while retry_count < 3:
|
||||
try:
|
||||
# 在每个批次处理前添加小延迟
|
||||
if i > 0:
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# 确保每个批次的图片都是有效的
|
||||
valid_batch = [img for img in batch if isinstance(img, PIL.Image.Image)]
|
||||
if not valid_batch:
|
||||
raise ValueError(f"批次 {i // batch_size} 中没有有效的图片")
|
||||
|
||||
response = await self._generate_content_with_retry(prompt, valid_batch)
|
||||
results.append({
|
||||
'batch_index': i // batch_size,
|
||||
'images_processed': len(valid_batch),
|
||||
'response': response.text,
|
||||
'model_used': self.model_name
|
||||
})
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
error_msg = f"批次 {i // batch_size} 处理出错: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
|
||||
if retry_count >= 3:
|
||||
results.append({
|
||||
'batch_index': i // batch_size,
|
||||
'images_processed': len(batch),
|
||||
'error': error_msg,
|
||||
'model_used': self.model_name
|
||||
})
|
||||
else:
|
||||
logger.info(f"批次 {i // batch_size} 处理失败,等待60秒后重试...")
|
||||
await asyncio.sleep(60)
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"图片分析过程中发生错误: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
def save_results_to_txt(self, results: List[Dict], output_dir: str):
|
||||
"""将分析结果保存到txt文件"""
|
||||
# 确保输出目录存在
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
for result in results:
|
||||
if not result.get('image_paths'):
|
||||
continue
|
||||
|
||||
response_text = result['response']
|
||||
image_paths = result['image_paths']
|
||||
|
||||
img_name_start = Path(image_paths[0]).stem.split('_')[-1]
|
||||
img_name_end = Path(image_paths[-1]).stem.split('_')[-1]
|
||||
txt_path = os.path.join(output_dir, f"frame_{img_name_start}_{img_name_end}.txt")
|
||||
|
||||
# 保存结果到txt文件
|
||||
with open(txt_path, 'w', encoding='utf-8') as f:
|
||||
f.write(response_text.strip())
|
||||
print(f"已保存分析结果到: {txt_path}")
|
||||
|
||||
def load_images(self, image_paths: List[str]) -> List[PIL.Image.Image]:
|
||||
"""
|
||||
加载多张图片
|
||||
Args:
|
||||
image_paths: 图片路径列表
|
||||
Returns:
|
||||
加载后的PIL Image对象列表
|
||||
"""
|
||||
images = []
|
||||
failed_images = []
|
||||
|
||||
for img_path in image_paths:
|
||||
try:
|
||||
if not os.path.exists(img_path):
|
||||
logger.error(f"图片文件不存在: {img_path}")
|
||||
failed_images.append(img_path)
|
||||
continue
|
||||
|
||||
img = PIL.Image.open(img_path)
|
||||
# 确保图片被完全加载
|
||||
img.load()
|
||||
# 转换为RGB模式
|
||||
if img.mode != 'RGB':
|
||||
img = img.convert('RGB')
|
||||
images.append(img)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"无法加载图片 {img_path}: {str(e)}")
|
||||
failed_images.append(img_path)
|
||||
|
||||
if failed_images:
|
||||
logger.warning(f"以下图片加载失败:\n{json.dumps(failed_images, indent=2, ensure_ascii=False)}")
|
||||
|
||||
if not images:
|
||||
raise ValueError("没有成功加载任何图片")
|
||||
|
||||
return images
|
||||
@ -1,5 +1,5 @@
|
||||
[app]
|
||||
project_version="0.2.2"
|
||||
project_version="0.3.0"
|
||||
# 支持视频理解的大模型提供商
|
||||
# gemini
|
||||
# qwen2-vl (待增加)
|
||||
|
||||
@ -1,16 +1,14 @@
|
||||
requests~=2.31.0
|
||||
moviepy~=2.0.0.dev2
|
||||
openai~=1.13.3
|
||||
faster-whisper~=1.0.1
|
||||
edge_tts~=6.1.15
|
||||
uvicorn~=0.27.1
|
||||
fastapi~=0.115.4
|
||||
tomli~=2.0.1
|
||||
streamlit~=1.39.0
|
||||
streamlit~=1.40.0
|
||||
loguru~=0.7.2
|
||||
aiohttp~=3.10.10
|
||||
urllib3~=2.2.1
|
||||
pillow~=10.4.0
|
||||
pydantic~=2.6.3
|
||||
g4f~=0.3.0.4
|
||||
dashscope~=1.15.0
|
||||
@ -25,3 +23,12 @@ git-changelog~=2.5.2
|
||||
watchdog==5.0.2
|
||||
pydub==0.25.1
|
||||
psutil>=5.9.0
|
||||
opencv-python~=4.10.0.84
|
||||
scikit-learn~=1.5.2
|
||||
google-generativeai~=0.8.3
|
||||
Pillow>=11.0.0
|
||||
python-dotenv~=1.0.1
|
||||
openai~=1.53.0
|
||||
tqdm>=4.66.6
|
||||
tenacity>=9.0.0
|
||||
tiktoken==0.8.0
|
||||
158
webui.py
158
webui.py
@ -4,9 +4,10 @@ import sys
|
||||
from uuid import uuid4
|
||||
from app.config import config
|
||||
from webui.components import basic_settings, video_settings, audio_settings, subtitle_settings, script_settings, review_settings
|
||||
from webui.utils import cache, file_utils, performance
|
||||
from webui.utils import cache, file_utils
|
||||
from app.utils import utils
|
||||
from app.models.schema import VideoClipParams, VideoAspect
|
||||
from webui.utils.performance import PerformanceMonitor
|
||||
|
||||
# 初始化配置 - 必须是第一个 Streamlit 命令
|
||||
st.set_page_config(
|
||||
@ -34,6 +35,17 @@ def init_log():
|
||||
_lvl = "DEBUG"
|
||||
|
||||
def format_record(record):
|
||||
# 增加更多需要过滤的警告消息
|
||||
ignore_messages = [
|
||||
"Examining the path of torch.classes raised",
|
||||
"torch.cuda.is_available()",
|
||||
"CUDA initialization"
|
||||
]
|
||||
|
||||
for msg in ignore_messages:
|
||||
if msg in record["message"]:
|
||||
return ""
|
||||
|
||||
file_path = record["file"].path
|
||||
relative_path = os.path.relpath(file_path, config.root_dir)
|
||||
record["file"].path = f"./{relative_path}"
|
||||
@ -45,11 +57,21 @@ def init_log():
|
||||
'- <level>{message}</>' + "\n"
|
||||
return _format
|
||||
|
||||
# 优化日志过滤器
|
||||
def log_filter(record):
|
||||
ignore_messages = [
|
||||
"Examining the path of torch.classes raised",
|
||||
"torch.cuda.is_available()",
|
||||
"CUDA initialization"
|
||||
]
|
||||
return not any(msg in record["message"] for msg in ignore_messages)
|
||||
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
level=_lvl,
|
||||
format=format_record,
|
||||
colorize=True,
|
||||
filter=log_filter
|
||||
)
|
||||
|
||||
def init_global_state():
|
||||
@ -73,77 +95,83 @@ def tr(key):
|
||||
def render_generate_button():
|
||||
"""渲染生成按钮和处理逻辑"""
|
||||
if st.button(tr("Generate Video"), use_container_width=True, type="primary"):
|
||||
from app.services import task as tm
|
||||
|
||||
# 重置日志容器和记录
|
||||
log_container = st.empty()
|
||||
log_records = []
|
||||
|
||||
def log_received(msg):
|
||||
with log_container:
|
||||
log_records.append(msg)
|
||||
st.code("\n".join(log_records))
|
||||
|
||||
from loguru import logger
|
||||
logger.add(log_received)
|
||||
|
||||
config.save_config()
|
||||
task_id = st.session_state.get('task_id')
|
||||
|
||||
if not task_id:
|
||||
st.error(tr("请先裁剪视频"))
|
||||
return
|
||||
if not st.session_state.get('video_clip_json_path'):
|
||||
st.error(tr("脚本文件不能为空"))
|
||||
return
|
||||
if not st.session_state.get('video_origin_path'):
|
||||
st.error(tr("视频文件不能为空"))
|
||||
return
|
||||
|
||||
st.toast(tr("生成视频"))
|
||||
logger.info(tr("开始生成视频"))
|
||||
|
||||
# 获取所有参数
|
||||
script_params = script_settings.get_script_params()
|
||||
video_params = video_settings.get_video_params()
|
||||
audio_params = audio_settings.get_audio_params()
|
||||
subtitle_params = subtitle_settings.get_subtitle_params()
|
||||
|
||||
# 合并所有参数
|
||||
all_params = {
|
||||
**script_params,
|
||||
**video_params,
|
||||
**audio_params,
|
||||
**subtitle_params
|
||||
}
|
||||
|
||||
# 创建参数对象
|
||||
params = VideoClipParams(**all_params)
|
||||
|
||||
result = tm.start_subclip(
|
||||
task_id=task_id,
|
||||
params=params,
|
||||
subclip_path_videos=st.session_state['subclip_videos']
|
||||
)
|
||||
|
||||
video_files = result.get("videos", [])
|
||||
st.success(tr("视频生成完成"))
|
||||
|
||||
try:
|
||||
if video_files:
|
||||
player_cols = st.columns(len(video_files) * 2 + 1)
|
||||
for i, url in enumerate(video_files):
|
||||
player_cols[i * 2 + 1].video(url)
|
||||
except Exception as e:
|
||||
logger.error(f"播放视频失败: {e}")
|
||||
from app.services import task as tm
|
||||
import torch
|
||||
|
||||
# 重置日志容器和记录
|
||||
log_container = st.empty()
|
||||
log_records = []
|
||||
|
||||
file_utils.open_task_folder(config.root_dir, task_id)
|
||||
logger.info(tr("视频生成完成"))
|
||||
def log_received(msg):
|
||||
with log_container:
|
||||
log_records.append(msg)
|
||||
st.code("\n".join(log_records))
|
||||
|
||||
from loguru import logger
|
||||
logger.add(log_received)
|
||||
|
||||
config.save_config()
|
||||
task_id = st.session_state.get('task_id')
|
||||
|
||||
if not task_id:
|
||||
st.error(tr("请先裁剪视频"))
|
||||
return
|
||||
if not st.session_state.get('video_clip_json_path'):
|
||||
st.error(tr("脚本文件不能为空"))
|
||||
return
|
||||
if not st.session_state.get('video_origin_path'):
|
||||
st.error(tr("视频文件不能为空"))
|
||||
return
|
||||
|
||||
st.toast(tr("生成视频"))
|
||||
logger.info(tr("开始生成视频"))
|
||||
|
||||
# 获取所有参数
|
||||
script_params = script_settings.get_script_params()
|
||||
video_params = video_settings.get_video_params()
|
||||
audio_params = audio_settings.get_audio_params()
|
||||
subtitle_params = subtitle_settings.get_subtitle_params()
|
||||
|
||||
# 合并所有参数
|
||||
all_params = {
|
||||
**script_params,
|
||||
**video_params,
|
||||
**audio_params,
|
||||
**subtitle_params
|
||||
}
|
||||
|
||||
# 创建参数对象
|
||||
params = VideoClipParams(**all_params)
|
||||
|
||||
result = tm.start_subclip(
|
||||
task_id=task_id,
|
||||
params=params,
|
||||
subclip_path_videos=st.session_state['subclip_videos']
|
||||
)
|
||||
|
||||
video_files = result.get("videos", [])
|
||||
st.success(tr("视生成完成"))
|
||||
|
||||
try:
|
||||
if video_files:
|
||||
player_cols = st.columns(len(video_files) * 2 + 1)
|
||||
for i, url in enumerate(video_files):
|
||||
player_cols[i * 2 + 1].video(url)
|
||||
except Exception as e:
|
||||
logger.error(f"播放视频失败: {e}")
|
||||
|
||||
file_utils.open_task_folder(config.root_dir, task_id)
|
||||
logger.info(tr("视频生成完成"))
|
||||
|
||||
finally:
|
||||
PerformanceMonitor.cleanup_resources()
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
init_log()
|
||||
init_global_state()
|
||||
utils.init_resources()
|
||||
|
||||
st.title(f"NarratoAI :sunglasses:📽️")
|
||||
st.write(tr("Get Help"))
|
||||
|
||||
@ -3,6 +3,7 @@ import os
|
||||
from app.config import config
|
||||
from app.utils import utils
|
||||
|
||||
|
||||
def render_basic_settings(tr):
|
||||
"""渲染基础设置面板"""
|
||||
with st.expander(tr("Basic Settings"), expanded=False):
|
||||
@ -10,23 +11,24 @@ def render_basic_settings(tr):
|
||||
left_config_panel = config_panels[0]
|
||||
middle_config_panel = config_panels[1]
|
||||
right_config_panel = config_panels[2]
|
||||
|
||||
|
||||
with left_config_panel:
|
||||
render_language_settings(tr)
|
||||
render_proxy_settings(tr)
|
||||
|
||||
|
||||
with middle_config_panel:
|
||||
render_video_llm_settings(tr)
|
||||
|
||||
render_vision_llm_settings(tr) # 视频分析模型设置
|
||||
|
||||
with right_config_panel:
|
||||
render_llm_settings(tr)
|
||||
render_text_llm_settings(tr) # 文案生成模型设置
|
||||
|
||||
|
||||
def render_language_settings(tr):
|
||||
"""渲染语言设置"""
|
||||
system_locale = utils.get_system_locale()
|
||||
i18n_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "i18n")
|
||||
locales = utils.load_locales(i18n_dir)
|
||||
|
||||
|
||||
display_languages = []
|
||||
selected_index = 0
|
||||
for i, code in enumerate(locales.keys()):
|
||||
@ -35,24 +37,25 @@ def render_language_settings(tr):
|
||||
selected_index = i
|
||||
|
||||
selected_language = st.selectbox(
|
||||
tr("Language"),
|
||||
tr("Language"),
|
||||
options=display_languages,
|
||||
index=selected_index
|
||||
)
|
||||
|
||||
|
||||
if selected_language:
|
||||
code = selected_language.split(" - ")[0].strip()
|
||||
st.session_state['ui_language'] = code
|
||||
config.ui['language'] = code
|
||||
|
||||
|
||||
def render_proxy_settings(tr):
|
||||
"""渲染代理设置"""
|
||||
proxy_url_http = config.proxy.get("http", "") or os.getenv("VPN_PROXY_URL", "")
|
||||
proxy_url_https = config.proxy.get("https", "") or os.getenv("VPN_PROXY_URL", "")
|
||||
|
||||
|
||||
HTTP_PROXY = st.text_input(tr("HTTP_PROXY"), value=proxy_url_http)
|
||||
HTTPS_PROXY = st.text_input(tr("HTTPs_PROXY"), value=proxy_url_https)
|
||||
|
||||
|
||||
if HTTP_PROXY:
|
||||
config.proxy["http"] = HTTP_PROXY
|
||||
os.environ["HTTP_PROXY"] = HTTP_PROXY
|
||||
@ -60,83 +63,172 @@ def render_proxy_settings(tr):
|
||||
config.proxy["https"] = HTTPS_PROXY
|
||||
os.environ["HTTPS_PROXY"] = HTTPS_PROXY
|
||||
|
||||
def render_video_llm_settings(tr):
|
||||
"""渲染视频LLM设置"""
|
||||
video_llm_providers = ['Gemini', 'NarratoAPI']
|
||||
saved_llm_provider = config.app.get("video_llm_provider", "OpenAI").lower()
|
||||
saved_llm_provider_index = 0
|
||||
|
||||
for i, provider in enumerate(video_llm_providers):
|
||||
if provider.lower() == saved_llm_provider:
|
||||
saved_llm_provider_index = i
|
||||
|
||||
def render_vision_llm_settings(tr):
|
||||
"""渲染视频分析模型设置"""
|
||||
st.subheader(tr("Vision Model Settings"))
|
||||
|
||||
# 视频分析模型提供商选择
|
||||
vision_providers = ['Gemini', 'NarratoAPI']
|
||||
saved_vision_provider = config.app.get("vision_llm_provider", "Gemini").lower()
|
||||
saved_provider_index = 0
|
||||
|
||||
for i, provider in enumerate(vision_providers):
|
||||
if provider.lower() == saved_vision_provider:
|
||||
saved_provider_index = i
|
||||
break
|
||||
|
||||
video_llm_provider = st.selectbox(
|
||||
tr("Video LLM Provider"),
|
||||
options=video_llm_providers,
|
||||
index=saved_llm_provider_index
|
||||
vision_provider = st.selectbox(
|
||||
tr("Vision Model Provider"),
|
||||
options=vision_providers,
|
||||
index=saved_provider_index
|
||||
)
|
||||
video_llm_provider = video_llm_provider.lower()
|
||||
config.app["video_llm_provider"] = video_llm_provider
|
||||
vision_provider = vision_provider.lower()
|
||||
config.app["vision_llm_provider"] = vision_provider
|
||||
st.session_state['vision_llm_providers'] = vision_provider
|
||||
|
||||
# 获取已保存的配置
|
||||
video_llm_api_key = config.app.get(f"{video_llm_provider}_api_key", "")
|
||||
video_llm_base_url = config.app.get(f"{video_llm_provider}_base_url", "")
|
||||
video_llm_model_name = config.app.get(f"{video_llm_provider}_model_name", "")
|
||||
|
||||
# 渲染输入框
|
||||
st_llm_api_key = st.text_input(tr("Video API Key"), value=video_llm_api_key, type="password")
|
||||
st_llm_base_url = st.text_input(tr("Video Base Url"), value=video_llm_base_url)
|
||||
st_llm_model_name = st.text_input(tr("Video Model Name"), value=video_llm_model_name)
|
||||
|
||||
# 保存配置
|
||||
if st_llm_api_key:
|
||||
config.app[f"{video_llm_provider}_api_key"] = st_llm_api_key
|
||||
if st_llm_base_url:
|
||||
config.app[f"{video_llm_provider}_base_url"] = st_llm_base_url
|
||||
if st_llm_model_name:
|
||||
config.app[f"{video_llm_provider}_model_name"] = st_llm_model_name
|
||||
# 获取已保存的视觉模型配置
|
||||
vision_api_key = config.app.get(f"vision_{vision_provider}_api_key", "")
|
||||
vision_base_url = config.app.get(f"vision_{vision_provider}_base_url", "")
|
||||
vision_model_name = config.app.get(f"vision_{vision_provider}_model_name", "")
|
||||
|
||||
def render_llm_settings(tr):
|
||||
"""渲染LLM设置"""
|
||||
llm_providers = ['Gemini', 'OpenAI', 'Moonshot', 'Azure', 'Qwen', 'Ollama', 'G4f', 'OneAPI', "Cloudflare"]
|
||||
saved_llm_provider = config.app.get("llm_provider", "OpenAI").lower()
|
||||
saved_llm_provider_index = 0
|
||||
|
||||
for i, provider in enumerate(llm_providers):
|
||||
if provider.lower() == saved_llm_provider:
|
||||
saved_llm_provider_index = i
|
||||
# 渲染视觉模型配置输入框
|
||||
st_vision_api_key = st.text_input(tr("Vision API Key"), value=vision_api_key, type="password")
|
||||
st_vision_base_url = st.text_input(tr("Vision Base URL"), value=vision_base_url)
|
||||
st_vision_model_name = st.text_input(tr("Vision Model Name"), value=vision_model_name)
|
||||
|
||||
# 保存视觉模型配置
|
||||
if st_vision_api_key:
|
||||
config.app[f"vision_{vision_provider}_api_key"] = st_vision_api_key
|
||||
st.session_state[f"vision_{vision_provider}_api_key"] = st_vision_api_key # 用于script_settings.py
|
||||
if st_vision_base_url:
|
||||
config.app[f"vision_{vision_provider}_base_url"] = st_vision_base_url
|
||||
st.session_state[f"vision_{vision_provider}_base_url"] = st_vision_base_url
|
||||
if st_vision_model_name:
|
||||
config.app[f"vision_{vision_provider}_model_name"] = st_vision_model_name
|
||||
st.session_state[f"vision_{vision_provider}_model_name"] = st_vision_model_name
|
||||
|
||||
# NarratoAPI 特殊配置
|
||||
if vision_provider == 'narratoapi':
|
||||
st.subheader(tr("Narrato Additional Settings"))
|
||||
|
||||
# Narrato API 基础配置
|
||||
narrato_api_key = st.text_input(
|
||||
tr("Narrato API Key"),
|
||||
value=config.app.get("narrato_api_key", ""),
|
||||
type="password",
|
||||
help="用于访问 Narrato API 的密钥"
|
||||
)
|
||||
if narrato_api_key:
|
||||
config.app["narrato_api_key"] = narrato_api_key
|
||||
st.session_state['narrato_api_key'] = narrato_api_key
|
||||
|
||||
narrato_api_url = st.text_input(
|
||||
tr("Narrato API URL"),
|
||||
value=config.app.get("narrato_api_url", "http://127.0.0.1:8000/api/v1/video/analyze")
|
||||
)
|
||||
if narrato_api_url:
|
||||
config.app["narrato_api_url"] = narrato_api_url
|
||||
st.session_state['narrato_api_url'] = narrato_api_url
|
||||
|
||||
# 视频分析模型配置
|
||||
st.markdown("##### " + tr("Vision Model Settings"))
|
||||
narrato_vision_model = st.text_input(
|
||||
tr("Vision Model Name"),
|
||||
value=config.app.get("narrato_vision_model", "gemini-1.5-flash")
|
||||
)
|
||||
narrato_vision_key = st.text_input(
|
||||
tr("Vision Model API Key"),
|
||||
value=config.app.get("narrato_vision_key", ""),
|
||||
type="password",
|
||||
help="用于视频分析的模型 API Key"
|
||||
)
|
||||
|
||||
if narrato_vision_model:
|
||||
config.app["narrato_vision_model"] = narrato_vision_model
|
||||
st.session_state['narrato_vision_model'] = narrato_vision_model
|
||||
if narrato_vision_key:
|
||||
config.app["narrato_vision_key"] = narrato_vision_key
|
||||
st.session_state['narrato_vision_key'] = narrato_vision_key
|
||||
|
||||
# 文案生成模型配置
|
||||
st.markdown("##### " + tr("Text Generation Model Settings"))
|
||||
narrato_llm_model = st.text_input(
|
||||
tr("LLM Model Name"),
|
||||
value=config.app.get("narrato_llm_model", "qwen-plus")
|
||||
)
|
||||
narrato_llm_key = st.text_input(
|
||||
tr("LLM Model API Key"),
|
||||
value=config.app.get("narrato_llm_key", ""),
|
||||
type="password",
|
||||
help="用于文案生成的模型 API Key"
|
||||
)
|
||||
|
||||
if narrato_llm_model:
|
||||
config.app["narrato_llm_model"] = narrato_llm_model
|
||||
st.session_state['narrato_llm_model'] = narrato_llm_model
|
||||
if narrato_llm_key:
|
||||
config.app["narrato_llm_key"] = narrato_llm_key
|
||||
st.session_state['narrato_llm_key'] = narrato_llm_key
|
||||
|
||||
# 批处理配置
|
||||
narrato_batch_size = st.number_input(
|
||||
tr("Batch Size"),
|
||||
min_value=1,
|
||||
max_value=50,
|
||||
value=config.app.get("narrato_batch_size", 10),
|
||||
help="每批处理的图片数量"
|
||||
)
|
||||
if narrato_batch_size:
|
||||
config.app["narrato_batch_size"] = narrato_batch_size
|
||||
st.session_state['narrato_batch_size'] = narrato_batch_size
|
||||
|
||||
|
||||
def render_text_llm_settings(tr):
|
||||
"""渲染文案生成模型设置"""
|
||||
st.subheader(tr("Text Generation Model Settings"))
|
||||
|
||||
# 文案生成模型提供商选择
|
||||
text_providers = ['OpenAI', 'Gemini', 'Moonshot', 'Azure', 'Qwen', 'Ollama', 'G4f', 'OneAPI', 'Cloudflare']
|
||||
saved_text_provider = config.app.get("text_llm_provider", "OpenAI").lower()
|
||||
saved_provider_index = 0
|
||||
|
||||
for i, provider in enumerate(text_providers):
|
||||
if provider.lower() == saved_text_provider:
|
||||
saved_provider_index = i
|
||||
break
|
||||
|
||||
llm_provider = st.selectbox(
|
||||
tr("LLM Provider"),
|
||||
options=llm_providers,
|
||||
index=saved_llm_provider_index
|
||||
text_provider = st.selectbox(
|
||||
tr("Text Model Provider"),
|
||||
options=text_providers,
|
||||
index=saved_provider_index
|
||||
)
|
||||
llm_provider = llm_provider.lower()
|
||||
config.app["llm_provider"] = llm_provider
|
||||
text_provider = text_provider.lower()
|
||||
config.app["text_llm_provider"] = text_provider
|
||||
|
||||
# 获取已保存的配置
|
||||
llm_api_key = config.app.get(f"{llm_provider}_api_key", "")
|
||||
llm_base_url = config.app.get(f"{llm_provider}_base_url", "")
|
||||
llm_model_name = config.app.get(f"{llm_provider}_model_name", "")
|
||||
llm_account_id = config.app.get(f"{llm_provider}_account_id", "")
|
||||
|
||||
# 渲染输入框
|
||||
st_llm_api_key = st.text_input(tr("API Key"), value=llm_api_key, type="password")
|
||||
st_llm_base_url = st.text_input(tr("Base Url"), value=llm_base_url)
|
||||
st_llm_model_name = st.text_input(tr("Model Name"), value=llm_model_name)
|
||||
|
||||
# 保存配置
|
||||
if st_llm_api_key:
|
||||
config.app[f"{llm_provider}_api_key"] = st_llm_api_key
|
||||
if st_llm_base_url:
|
||||
config.app[f"{llm_provider}_base_url"] = st_llm_base_url
|
||||
if st_llm_model_name:
|
||||
config.app[f"{llm_provider}_model_name"] = st_llm_model_name
|
||||
# 获取已保存的文本模型配置
|
||||
text_api_key = config.app.get(f"text_{text_provider}_api_key", "")
|
||||
text_base_url = config.app.get(f"text_{text_provider}_base_url", "")
|
||||
text_model_name = config.app.get(f"text_{text_provider}_model_name", "")
|
||||
|
||||
# Cloudflare 特殊处理
|
||||
if llm_provider == 'cloudflare':
|
||||
st_llm_account_id = st.text_input(tr("Account ID"), value=llm_account_id)
|
||||
if st_llm_account_id:
|
||||
config.app[f"{llm_provider}_account_id"] = st_llm_account_id
|
||||
# 渲染文本模型配置输入框
|
||||
st_text_api_key = st.text_input(tr("Text API Key"), value=text_api_key, type="password")
|
||||
st_text_base_url = st.text_input(tr("Text Base URL"), value=text_base_url)
|
||||
st_text_model_name = st.text_input(tr("Text Model Name"), value=text_model_name)
|
||||
|
||||
# 保存文本模型配置
|
||||
if st_text_api_key:
|
||||
config.app[f"text_{text_provider}_api_key"] = st_text_api_key
|
||||
if st_text_base_url:
|
||||
config.app[f"text_{text_provider}_base_url"] = st_text_base_url
|
||||
if st_text_model_name:
|
||||
config.app[f"text_{text_provider}_model_name"] = st_text_model_name
|
||||
|
||||
# Cloudflare 特殊配置
|
||||
if text_provider == 'cloudflare':
|
||||
st_account_id = st.text_input(
|
||||
tr("Account ID"),
|
||||
value=config.app.get(f"text_{text_provider}_account_id", "")
|
||||
)
|
||||
if st_account_id:
|
||||
config.app[f"text_{text_provider}_account_id"] = st_account_id
|
||||
|
||||
@ -7,8 +7,10 @@ def render_review_panel(tr):
|
||||
with st.expander(tr("Video Check"), expanded=False):
|
||||
try:
|
||||
video_list = st.session_state.get('video_clip_json', [])
|
||||
subclip_videos = st.session_state.get('subclip_videos', {})
|
||||
except KeyError:
|
||||
video_list = []
|
||||
subclip_videos = {}
|
||||
|
||||
# 计算列数和行数
|
||||
num_videos = len(video_list)
|
||||
@ -22,44 +24,62 @@ def render_review_panel(tr):
|
||||
index = row * cols_per_row + col
|
||||
if index < num_videos:
|
||||
with cols[col]:
|
||||
render_video_item(tr, video_list, index)
|
||||
render_video_item(tr, video_list, subclip_videos, index)
|
||||
|
||||
def render_video_item(tr, video_list, index):
|
||||
def render_video_item(tr, video_list, subclip_videos, index):
|
||||
"""渲染单个视频项"""
|
||||
video_info = video_list[index]
|
||||
video_path = video_info.get('path')
|
||||
if video_path is not None and os.path.exists(video_path):
|
||||
initial_narration = video_info.get('narration', '')
|
||||
initial_picture = video_info.get('picture', '')
|
||||
initial_timestamp = video_info.get('timestamp', '')
|
||||
|
||||
# 显示视频
|
||||
with open(video_path, 'rb') as video_file:
|
||||
video_bytes = video_file.read()
|
||||
st.video(video_bytes)
|
||||
|
||||
# 显示信息(只读)
|
||||
text_panels = st.columns(2)
|
||||
with text_panels[0]:
|
||||
st.text_area(
|
||||
tr("timestamp"),
|
||||
value=initial_timestamp,
|
||||
height=20,
|
||||
key=f"timestamp_{index}",
|
||||
disabled=True
|
||||
)
|
||||
with text_panels[1]:
|
||||
st.text_area(
|
||||
tr("Picture description"),
|
||||
value=initial_picture,
|
||||
height=20,
|
||||
key=f"picture_{index}",
|
||||
disabled=True
|
||||
)
|
||||
st.text_area(
|
||||
tr("Narration"),
|
||||
value=initial_narration,
|
||||
height=100,
|
||||
key=f"narration_{index}",
|
||||
disabled=True
|
||||
)
|
||||
video_script = video_list[index]
|
||||
|
||||
# 显示时间戳
|
||||
timestamp = video_script.get('timestamp', '')
|
||||
st.text_area(
|
||||
tr("Timestamp"),
|
||||
value=timestamp,
|
||||
height=70,
|
||||
disabled=True,
|
||||
key=f"timestamp_{index}"
|
||||
)
|
||||
|
||||
# 显示视频播放器
|
||||
video_path = subclip_videos.get(timestamp)
|
||||
if video_path and os.path.exists(video_path):
|
||||
try:
|
||||
st.video(video_path)
|
||||
except Exception as e:
|
||||
logger.error(f"加载视频失败 {video_path}: {e}")
|
||||
st.error(f"无法加载视频: {os.path.basename(video_path)}")
|
||||
else:
|
||||
st.warning(tr("视频文件未找到"))
|
||||
|
||||
# 显示画面描述
|
||||
st.text_area(
|
||||
tr("Picture Description"),
|
||||
value=video_script.get('picture', ''),
|
||||
height=150,
|
||||
disabled=True,
|
||||
key=f"picture_{index}"
|
||||
)
|
||||
|
||||
# 显示旁白文本
|
||||
narration = st.text_area(
|
||||
tr("Narration"),
|
||||
value=video_script.get('narration', ''),
|
||||
height=150,
|
||||
key=f"narration_{index}"
|
||||
)
|
||||
# 保存修改后的旁白文本
|
||||
if narration != video_script.get('narration', ''):
|
||||
video_script['narration'] = narration
|
||||
st.session_state['video_clip_json'] = video_list
|
||||
|
||||
# 显示剪辑模式
|
||||
ost = st.selectbox(
|
||||
tr("Clip Mode"),
|
||||
options=range(1, 10),
|
||||
index=video_script.get('OST', 1) - 1,
|
||||
key=f"ost_{index}"
|
||||
)
|
||||
# 保存修改后的剪辑模式
|
||||
if ost != video_script.get('OST', 1):
|
||||
video_script['OST'] = ost
|
||||
st.session_state['video_clip_json'] = video_list
|
||||
@ -1,43 +1,50 @@
|
||||
import streamlit as st
|
||||
import os
|
||||
import glob
|
||||
import json
|
||||
import time
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
import requests
|
||||
import streamlit as st
|
||||
from loguru import logger
|
||||
|
||||
from app.config import config
|
||||
from app.models.schema import VideoClipParams
|
||||
from app.services import llm
|
||||
from app.utils import utils, check_script
|
||||
from loguru import logger
|
||||
from app.utils import utils, check_script, vision_analyzer, video_processor
|
||||
from webui.utils import file_utils
|
||||
|
||||
|
||||
def render_script_panel(tr):
|
||||
"""渲染脚本配置面板"""
|
||||
with st.container(border=True):
|
||||
st.write(tr("Video Script Configuration"))
|
||||
params = VideoClipParams()
|
||||
|
||||
|
||||
# 渲染脚本文件选择
|
||||
render_script_file(tr, params)
|
||||
|
||||
|
||||
# 渲染视频文件选择
|
||||
render_video_file(tr, params)
|
||||
|
||||
|
||||
# 渲染视频主题和提示词
|
||||
render_video_details(tr)
|
||||
|
||||
|
||||
# 渲染脚本操作按钮
|
||||
render_script_buttons(tr, params)
|
||||
|
||||
|
||||
def render_script_file(tr, params):
|
||||
"""渲染脚本文件选择"""
|
||||
script_list = [(tr("None"), ""), (tr("Auto Generate"), "auto")]
|
||||
|
||||
|
||||
# 获取已有脚本文件
|
||||
suffix = "*.json"
|
||||
script_dir = utils.script_dir()
|
||||
files = glob.glob(os.path.join(script_dir, suffix))
|
||||
file_list = []
|
||||
|
||||
|
||||
for file in files:
|
||||
file_list.append({
|
||||
"name": os.path.basename(file),
|
||||
@ -64,15 +71,16 @@ def render_script_file(tr, params):
|
||||
options=range(len(script_list)),
|
||||
format_func=lambda x: script_list[x][0]
|
||||
)
|
||||
|
||||
|
||||
script_path = script_list[selected_script_index][1]
|
||||
st.session_state['video_clip_json_path'] = script_path
|
||||
params.video_clip_json_path = script_path
|
||||
|
||||
|
||||
def render_video_file(tr, params):
|
||||
"""渲染视频文件选择"""
|
||||
video_list = [(tr("None"), ""), (tr("Upload Local Files"), "local")]
|
||||
|
||||
|
||||
# 获取已有视频文件
|
||||
for suffix in ["*.mp4", "*.mov", "*.avi", "*.mkv"]:
|
||||
video_files = glob.glob(os.path.join(utils.video_dir(), suffix))
|
||||
@ -86,7 +94,7 @@ def render_video_file(tr, params):
|
||||
options=range(len(video_list)),
|
||||
format_func=lambda x: video_list[x][0]
|
||||
)
|
||||
|
||||
|
||||
video_path = video_list[selected_video_index][1]
|
||||
st.session_state['video_origin_path'] = video_path
|
||||
params.video_origin_path = video_path
|
||||
@ -97,16 +105,16 @@ def render_video_file(tr, params):
|
||||
type=["mp4", "mov", "avi", "flv", "mkv"],
|
||||
accept_multiple_files=False,
|
||||
)
|
||||
|
||||
|
||||
if uploaded_file is not None:
|
||||
video_file_path = os.path.join(utils.video_dir(), uploaded_file.name)
|
||||
file_name, file_extension = os.path.splitext(uploaded_file.name)
|
||||
|
||||
|
||||
if os.path.exists(video_file_path):
|
||||
timestamp = time.strftime("%Y%m%d%H%M%S")
|
||||
file_name_with_timestamp = f"{file_name}_{timestamp}"
|
||||
video_file_path = os.path.join(utils.video_dir(), file_name_with_timestamp + file_extension)
|
||||
|
||||
|
||||
with open(video_file_path, "wb") as f:
|
||||
f.write(uploaded_file.read())
|
||||
st.success(tr("File Uploaded Successfully"))
|
||||
@ -115,6 +123,7 @@ def render_video_file(tr, params):
|
||||
time.sleep(1)
|
||||
st.rerun()
|
||||
|
||||
|
||||
def render_video_details(tr):
|
||||
"""渲染视频主题和提示词"""
|
||||
video_theme = st.text_input(tr("Video Theme"))
|
||||
@ -128,6 +137,7 @@ def render_video_details(tr):
|
||||
st.session_state['video_plot'] = prompt
|
||||
return video_theme, prompt
|
||||
|
||||
|
||||
def render_script_buttons(tr, params):
|
||||
"""渲染脚本操作按钮"""
|
||||
# 生成/加载按钮
|
||||
@ -157,16 +167,17 @@ def render_script_buttons(tr, params):
|
||||
with button_cols[0]:
|
||||
if st.button(tr("Check Format"), key="check_format", use_container_width=True):
|
||||
check_script_format(tr, video_clip_json_details)
|
||||
|
||||
|
||||
with button_cols[1]:
|
||||
if st.button(tr("Save Script"), key="save_script", use_container_width=True):
|
||||
save_script(tr, video_clip_json_details)
|
||||
|
||||
|
||||
with button_cols[2]:
|
||||
script_valid = st.session_state.get('script_format_valid', False)
|
||||
if st.button(tr("Crop Video"), key="crop_video", disabled=not script_valid, use_container_width=True):
|
||||
crop_video(tr, params)
|
||||
|
||||
|
||||
def check_script_format(tr, script_content):
|
||||
"""检查脚本格式"""
|
||||
try:
|
||||
@ -181,6 +192,7 @@ def check_script_format(tr, script_content):
|
||||
st.error(f"{tr('Script format check error')}: {str(e)}")
|
||||
st.session_state['script_format_valid'] = False
|
||||
|
||||
|
||||
def load_script(tr, script_path):
|
||||
"""加载脚本文件"""
|
||||
try:
|
||||
@ -193,6 +205,7 @@ def load_script(tr, script_path):
|
||||
except Exception as e:
|
||||
st.error(f"{tr('Failed to load script')}: {str(e)}")
|
||||
|
||||
|
||||
def generate_script(tr, params):
|
||||
"""生成视频脚本"""
|
||||
progress_bar = st.progress(0)
|
||||
@ -207,48 +220,354 @@ def generate_script(tr, params):
|
||||
|
||||
try:
|
||||
with st.spinner("正在生成脚本..."):
|
||||
if not st.session_state.get('video_plot'):
|
||||
st.warning("视频剧情为空; 会极大影响生成效果!")
|
||||
if not params.video_origin_path:
|
||||
st.error("请先选择视频文件")
|
||||
st.stop()
|
||||
return
|
||||
|
||||
# ===================提取键帧===================
|
||||
update_progress(10, "正在提取关键帧...")
|
||||
|
||||
# 创建临时目录用于存储关键帧
|
||||
keyframes_dir = os.path.join(utils.temp_dir(), "keyframes")
|
||||
video_hash = utils.md5(params.video_origin_path + str(os.path.getmtime(params.video_origin_path)))
|
||||
video_keyframes_dir = os.path.join(keyframes_dir, video_hash)
|
||||
|
||||
# 检查是否已经提取过关键帧
|
||||
keyframe_files = []
|
||||
if os.path.exists(video_keyframes_dir):
|
||||
# 获取已有的关键帧文件
|
||||
for filename in sorted(os.listdir(video_keyframes_dir)):
|
||||
if filename.endswith('.jpg'):
|
||||
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
|
||||
|
||||
if params.video_clip_json_path == "" and params.video_origin_path != "":
|
||||
update_progress(10, "压缩视频中...")
|
||||
script = llm.generate_script(
|
||||
video_path=params.video_origin_path,
|
||||
video_plot=st.session_state.get('video_plot', ''),
|
||||
video_name=st.session_state.get('video_name', ''),
|
||||
language=params.video_language,
|
||||
progress_callback=update_progress
|
||||
)
|
||||
if script is None:
|
||||
st.error("生成脚本失败,请检查日志")
|
||||
st.stop()
|
||||
else:
|
||||
update_progress(90)
|
||||
if keyframe_files:
|
||||
logger.info(f"使用已缓存的关键帧: {video_keyframes_dir}")
|
||||
st.info(f"使用已缓存的关键帧,如需重新提取请删除目录: {video_keyframes_dir}")
|
||||
update_progress(20, f"使用已缓存关键帧,共 {len(keyframe_files)} 帧")
|
||||
|
||||
# 如果没有缓存的关键帧,则进行提取
|
||||
if not keyframe_files:
|
||||
try:
|
||||
# 确保目录存在
|
||||
os.makedirs(video_keyframes_dir, exist_ok=True)
|
||||
|
||||
# 初始化视频处理器
|
||||
processor = video_processor.VideoProcessor(params.video_origin_path)
|
||||
|
||||
# 处理视频并提取关键帧
|
||||
processor.process_video(
|
||||
output_dir=video_keyframes_dir,
|
||||
skip_seconds=0
|
||||
)
|
||||
|
||||
# 获取所有关键帧文件路径
|
||||
for filename in sorted(os.listdir(video_keyframes_dir)):
|
||||
if filename.endswith('.jpg'):
|
||||
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
|
||||
|
||||
if not keyframe_files:
|
||||
raise Exception("未提取到任何关键帧")
|
||||
|
||||
update_progress(20, f"关键帧提取完成,共 {len(keyframe_files)} 帧")
|
||||
|
||||
except Exception as e:
|
||||
# 如果提取失败,清理创建的目录
|
||||
try:
|
||||
if os.path.exists(video_keyframes_dir):
|
||||
import shutil
|
||||
shutil.rmtree(video_keyframes_dir)
|
||||
except Exception as cleanup_err:
|
||||
logger.error(f"清理失败的关键帧目录时出错: {cleanup_err}")
|
||||
|
||||
raise Exception(f"关键帧提取失败: {str(e)}")
|
||||
|
||||
script = utils.clean_model_output(script)
|
||||
st.session_state['video_clip_json'] = json.loads(script)
|
||||
else:
|
||||
# 从本地加载
|
||||
with open(params.video_clip_json_path, 'r', encoding='utf-8') as f:
|
||||
update_progress(50)
|
||||
status_text.text("从本地加载中...")
|
||||
script = f.read()
|
||||
script = utils.clean_model_output(script)
|
||||
st.session_state['video_clip_json'] = json.loads(script)
|
||||
update_progress(100)
|
||||
status_text.text("从本地加载成功")
|
||||
# 根据不同的 LLM 提供商处理
|
||||
video_llm_provider = st.session_state.get('video_llm_providers', 'Gemini').lower()
|
||||
|
||||
if video_llm_provider == 'gemini':
|
||||
try:
|
||||
# ===================初始化视觉分析器===================
|
||||
update_progress(30, "正在初始化视觉分析器...")
|
||||
|
||||
# 从配置中获取 Gemini 相关配置
|
||||
vision_api_key = st.session_state.get('vision_gemini_api_key')
|
||||
vision_model = st.session_state.get('vision_gemini_model_name')
|
||||
vision_base_url = st.session_state.get('vision_gemini_base_url')
|
||||
|
||||
if not vision_api_key or not vision_model:
|
||||
raise ValueError("未配置 Gemini API Key 或者 模型,请在基础设置中配置")
|
||||
|
||||
analyzer = vision_analyzer.VisionAnalyzer(
|
||||
model_name=vision_model,
|
||||
api_key=vision_api_key
|
||||
)
|
||||
|
||||
update_progress(40, "正在分析关键帧...")
|
||||
|
||||
# ===================创建异步事件循环===================
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# 执行异步分析
|
||||
results = loop.run_until_complete(
|
||||
analyzer.analyze_images(
|
||||
images=keyframe_files,
|
||||
prompt=config.app.get('vision_analysis_prompt'),
|
||||
batch_size=config.app.get("vision_batch_size", 5)
|
||||
)
|
||||
)
|
||||
loop.close()
|
||||
|
||||
# ===================处理分析结果===================
|
||||
update_progress(60, "正在整理分析结果...")
|
||||
|
||||
# 合并所有批次的析结果
|
||||
frame_analysis = ""
|
||||
for result in results:
|
||||
if 'error' in result:
|
||||
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
|
||||
continue
|
||||
|
||||
# 获取当前批次的图片文件
|
||||
batch_start = result['batch_index'] * config.app.get("vision_batch_size", 5)
|
||||
batch_end = min(batch_start + config.app.get("vision_batch_size", 5), len(keyframe_files))
|
||||
batch_files = keyframe_files[batch_start:batch_end]
|
||||
|
||||
# 提取首帧和尾帧的时间戳
|
||||
first_frame = os.path.basename(batch_files[0])
|
||||
last_frame = os.path.basename(batch_files[-1])
|
||||
|
||||
# 从文件名中提取时间信息
|
||||
first_time = first_frame.split('_')[2].replace('.jpg', '') # 000002
|
||||
last_time = last_frame.split('_')[2].replace('.jpg', '') # 000005
|
||||
|
||||
# 转换为分:秒格式
|
||||
def format_timestamp(time_str):
|
||||
seconds = int(time_str)
|
||||
minutes = seconds // 60
|
||||
seconds = seconds % 60
|
||||
return f"{minutes:02d}:{seconds:02d}"
|
||||
|
||||
first_timestamp = format_timestamp(first_time)
|
||||
last_timestamp = format_timestamp(last_time)
|
||||
|
||||
# 添加带时间戳的分析结果
|
||||
frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n"
|
||||
frame_analysis += result['response']
|
||||
frame_analysis += "\n"
|
||||
|
||||
if not frame_analysis.strip():
|
||||
raise Exception("未能生成有效的帧分析结果")
|
||||
|
||||
# 保存分析结果
|
||||
analysis_path = os.path.join(utils.temp_dir(), "frame_analysis.txt")
|
||||
with open(analysis_path, 'w', encoding='utf-8') as f:
|
||||
f.write(frame_analysis)
|
||||
|
||||
update_progress(70, "正在生成脚本...")
|
||||
|
||||
# 构建完整的上下文
|
||||
context = {
|
||||
'video_name': st.session_state.get('video_name', ''),
|
||||
'video_plot': st.session_state.get('video_plot', ''),
|
||||
'frame_analysis': frame_analysis,
|
||||
'total_frames': len(keyframe_files)
|
||||
}
|
||||
# 从配置中获取文本生成相关配置
|
||||
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
|
||||
text_api_key = config.app.get(f'text_{text_provider}_api_key')
|
||||
text_model = config.app.get(f'text_{text_provider}_model_name', 'gemini-1.5-pro')
|
||||
text_base_url = config.app.get(f'text_{text_provider}_base_url')
|
||||
|
||||
# 构建帧内容列表
|
||||
frame_content_list = []
|
||||
for i, result in enumerate(results):
|
||||
if 'error' in result:
|
||||
continue
|
||||
|
||||
# 获取当前批次的图片文件
|
||||
batch_start = result['batch_index'] * config.app.get("vision_batch_size", 5)
|
||||
batch_end = min(batch_start + config.app.get("vision_batch_size", 5), len(keyframe_files))
|
||||
batch_files = keyframe_files[batch_start:batch_end]
|
||||
|
||||
# 提取首帧和尾帧的时间戳
|
||||
first_frame = os.path.basename(batch_files[0])
|
||||
last_frame = os.path.basename(batch_files[-1])
|
||||
|
||||
# 从文件名中提取时间信息
|
||||
first_time = first_frame.split('_')[2].replace('.jpg', '') # 000002
|
||||
last_time = last_frame.split('_')[2].replace('.jpg', '') # 000005
|
||||
|
||||
# 转换为分:秒格式
|
||||
def format_timestamp(time_str):
|
||||
seconds = int(time_str)
|
||||
minutes = seconds // 60
|
||||
seconds = seconds % 60
|
||||
return f"{minutes:02d}:{seconds:02d}"
|
||||
|
||||
first_timestamp = format_timestamp(first_time)
|
||||
last_timestamp = format_timestamp(last_time)
|
||||
|
||||
# 构建时间戳范围字符串 (MM:SS-MM:SS 格式)
|
||||
timestamp_range = f"{first_timestamp}-{last_timestamp}"
|
||||
|
||||
frame_content = {
|
||||
"timestamp": timestamp_range, # 使用时间范围格式 "MM:SS-MM:SS"
|
||||
"picture": result['response'], # 图片分析结果
|
||||
"narration": "", # 将由 ScriptProcessor 生成
|
||||
"OST": 2 # 默认值
|
||||
}
|
||||
frame_content_list.append(frame_content)
|
||||
|
||||
logger.debug(f"添加帧内容: 时间范围={timestamp_range}, 分析结果长度={len(result['response'])}")
|
||||
|
||||
if not frame_content_list:
|
||||
raise Exception("没有有效的帧内容可以处理")
|
||||
|
||||
# 使用 ScriptProcessor 生成脚本
|
||||
from app.utils.script_generator import ScriptProcessor
|
||||
processor = ScriptProcessor(
|
||||
model_name=text_model,
|
||||
api_key=text_api_key,
|
||||
prompt="" # 使用默认提示词
|
||||
)
|
||||
|
||||
# 处理帧内容生成脚本
|
||||
script_result = processor.process_frames(frame_content_list)
|
||||
|
||||
# 将结果转换为JSON字符串
|
||||
script = json.dumps(script_result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Gemini 处理过程中发生错误\n{traceback.format_exc()}")
|
||||
raise Exception(f"视觉分析失败: {str(e)}")
|
||||
|
||||
else: # NarratoAPI
|
||||
try:
|
||||
# 创建临时目录
|
||||
temp_dir = utils.temp_dir("narrato")
|
||||
|
||||
# 打包关键帧
|
||||
update_progress(30, "正在打包关键帧...")
|
||||
zip_path = os.path.join(temp_dir, f"keyframes_{int(time.time())}.zip")
|
||||
|
||||
if not file_utils.create_zip(keyframe_files, zip_path):
|
||||
raise Exception("打包关键帧失败")
|
||||
|
||||
# 获取API配置
|
||||
api_url = st.session_state.get('narrato_api_url', 'http://127.0.0.1:8000/api/v1/video/analyze')
|
||||
api_key = st.session_state.get('narrato_api_key')
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("未配置 Narrato API Key,请在基础设置中配置")
|
||||
|
||||
# 准备API请求
|
||||
headers = {
|
||||
'X-API-Key': api_key,
|
||||
'accept': 'application/json'
|
||||
}
|
||||
|
||||
api_params = {
|
||||
'batch_size': st.session_state.get('narrato_batch_size', 10),
|
||||
'use_ai': False,
|
||||
'start_offset': 0,
|
||||
'vision_model': st.session_state.get('narrato_vision_model', 'gemini-1.5-flash'),
|
||||
'vision_api_key': st.session_state.get('narrato_vision_key'),
|
||||
'llm_model': st.session_state.get('narrato_llm_model', 'qwen-plus'),
|
||||
'llm_api_key': st.session_state.get('narrato_llm_key'),
|
||||
'custom_prompt': st.session_state.get('video_plot', '')
|
||||
}
|
||||
|
||||
# 发送API请求
|
||||
update_progress(40, "正在上传文件...")
|
||||
with open(zip_path, 'rb') as f:
|
||||
files = {'file': (os.path.basename(zip_path), f, 'application/x-zip-compressed')}
|
||||
try:
|
||||
response = requests.post(
|
||||
api_url,
|
||||
headers=headers,
|
||||
params=api_params,
|
||||
files=files,
|
||||
timeout=30 # 设置超时时间
|
||||
)
|
||||
response.raise_for_status()
|
||||
except requests.RequestException as e:
|
||||
raise Exception(f"API请求失败: {str(e)}")
|
||||
|
||||
task_data = response.json()
|
||||
task_id = task_data.get('task_id')
|
||||
if not task_id:
|
||||
raise Exception(f"无效的API响应: {response.text}")
|
||||
|
||||
# 轮询任务状态
|
||||
update_progress(50, "正在等待分析结果...")
|
||||
retry_count = 0
|
||||
max_retries = 60 # 最多等待2分钟
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
status_response = requests.get(
|
||||
f"{api_url}/tasks/{task_id}",
|
||||
headers=headers,
|
||||
timeout=10
|
||||
)
|
||||
status_response.raise_for_status()
|
||||
task_status = status_response.json()['data']
|
||||
|
||||
if task_status['status'] == 'SUCCESS':
|
||||
script = task_status['result']
|
||||
break
|
||||
elif task_status['status'] in ['FAILURE', 'RETRY']:
|
||||
raise Exception(f"任务失败: {task_status.get('error')}")
|
||||
|
||||
retry_count += 1
|
||||
progress = min(70, 50 + (retry_count * 20 / max_retries))
|
||||
update_progress(progress, "正在分析中...")
|
||||
time.sleep(2)
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.warning(f"获取任务状态失败,重试中: {str(e)}")
|
||||
retry_count += 1
|
||||
time.sleep(2)
|
||||
continue
|
||||
|
||||
if retry_count >= max_retries:
|
||||
raise Exception("任务执行超时")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("NarratoAPI 处理过程中发生错误")
|
||||
raise Exception(f"NarratoAPI 处理失败: {str(e)}")
|
||||
finally:
|
||||
# 清理临时文件
|
||||
try:
|
||||
if os.path.exists(zip_path):
|
||||
os.remove(zip_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"清理临时文件失败: {str(e)}")
|
||||
|
||||
if script is None:
|
||||
st.error("生成脚本失败,请检查日志")
|
||||
st.stop()
|
||||
|
||||
script = utils.clean_model_output(script)
|
||||
st.session_state['video_clip_json'] = json.loads(script)
|
||||
update_progress(90, "脚本生成完成")
|
||||
|
||||
time.sleep(0.5)
|
||||
progress_bar.progress(100)
|
||||
status_text.text("脚本生成完成!")
|
||||
st.success("视频脚本生成成功!")
|
||||
|
||||
except Exception as err:
|
||||
st.error(f"生成过程中发生错误: {str(err)}")
|
||||
logger.exception("生成脚本时发生错误")
|
||||
finally:
|
||||
time.sleep(2)
|
||||
progress_bar.empty()
|
||||
status_text.empty()
|
||||
|
||||
|
||||
def save_script(tr, video_clip_json_details):
|
||||
"""保存视频脚本"""
|
||||
if not video_clip_json_details:
|
||||
@ -266,21 +585,22 @@ def save_script(tr, video_clip_json_details):
|
||||
json.dump(data, file, ensure_ascii=False, indent=4)
|
||||
st.session_state['video_clip_json'] = data
|
||||
st.session_state['video_clip_json_path'] = save_path
|
||||
|
||||
|
||||
# 更新配置
|
||||
config.app["video_clip_json_path"] = save_path
|
||||
|
||||
|
||||
# 显示成功消息
|
||||
st.success(tr("Script saved successfully"))
|
||||
|
||||
# 强制重新加载页面以更新选择框
|
||||
|
||||
# 强制重新加载页面<EFBFBD><EFBFBD>更新选择框
|
||||
time.sleep(0.5) # 给一点时间让用户看到成功消息
|
||||
st.rerun()
|
||||
|
||||
|
||||
except Exception as err:
|
||||
st.error(f"{tr('Failed to save script')}: {str(err)}")
|
||||
st.stop()
|
||||
|
||||
|
||||
def crop_video(tr, params):
|
||||
"""裁剪视频"""
|
||||
progress_bar = st.progress(0)
|
||||
@ -303,6 +623,7 @@ def crop_video(tr, params):
|
||||
progress_bar.empty()
|
||||
status_text.empty()
|
||||
|
||||
|
||||
def get_script_params():
|
||||
"""获取脚本参数"""
|
||||
return {
|
||||
@ -311,4 +632,4 @@ def get_script_params():
|
||||
'video_origin_path': st.session_state.get('video_origin_path', ''),
|
||||
'video_name': st.session_state.get('video_name', ''),
|
||||
'video_plot': st.session_state.get('video_plot', '')
|
||||
}
|
||||
}
|
||||
|
||||
@ -19,6 +19,18 @@ class WebUIConfig:
|
||||
project_version: str = "0.1.0"
|
||||
# 项目根目录
|
||||
root_dir: str = None
|
||||
# Gemini API Key
|
||||
gemini_api_key: str = ""
|
||||
# 每批处理的图片数量
|
||||
vision_batch_size: int = 5
|
||||
# 提示词
|
||||
vision_prompt: str = """..."""
|
||||
# Narrato API 配置
|
||||
narrato_api_url: str = "http://127.0.0.1:8000/api/v1/video/analyze"
|
||||
narrato_api_key: str = ""
|
||||
narrato_batch_size: int = 10
|
||||
narrato_vision_model: str = "gemini-1.5-flash"
|
||||
narrato_llm_model: str = "qwen-plus"
|
||||
|
||||
def __post_init__(self):
|
||||
"""初始化默认值"""
|
||||
|
||||
@ -1,20 +1,8 @@
|
||||
from .cache import get_fonts_cache, get_video_files_cache, get_songs_cache
|
||||
from .file_utils import (
|
||||
open_task_folder, cleanup_temp_files, get_file_list,
|
||||
save_uploaded_file, create_temp_file, get_file_size, ensure_directory
|
||||
)
|
||||
from .performance import monitor_performance
|
||||
from .performance import monitor_performance, PerformanceMonitor
|
||||
from .cache import *
|
||||
from .file_utils import *
|
||||
|
||||
__all__ = [
|
||||
'get_fonts_cache',
|
||||
'get_video_files_cache',
|
||||
'get_songs_cache',
|
||||
'open_task_folder',
|
||||
'cleanup_temp_files',
|
||||
'get_file_list',
|
||||
'save_uploaded_file',
|
||||
'create_temp_file',
|
||||
'get_file_size',
|
||||
'ensure_directory',
|
||||
'monitor_performance'
|
||||
'monitor_performance',
|
||||
'PerformanceMonitor'
|
||||
]
|
||||
@ -186,4 +186,45 @@ def ensure_directory(directory):
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"创建目录失败: {directory}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def create_zip(files: list, zip_path: str, base_dir: str = None) -> bool:
|
||||
"""
|
||||
创建zip文件
|
||||
Args:
|
||||
files: 要打包的文件列表
|
||||
zip_path: zip文件保存路径
|
||||
base_dir: 基础目录,用于保持目录结构
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
try:
|
||||
import zipfile
|
||||
|
||||
# 确保目标目录存在
|
||||
os.makedirs(os.path.dirname(zip_path), exist_ok=True)
|
||||
|
||||
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
||||
for file in files:
|
||||
if not os.path.exists(file):
|
||||
logger.warning(f"文件不存在,跳过: {file}")
|
||||
continue
|
||||
|
||||
# 计算文件在zip中的路径
|
||||
if base_dir:
|
||||
arcname = os.path.relpath(file, base_dir)
|
||||
else:
|
||||
arcname = os.path.basename(file)
|
||||
|
||||
try:
|
||||
zipf.write(file, arcname)
|
||||
except Exception as e:
|
||||
logger.error(f"添加文件到zip失败: {file}, 错误: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"创建zip文件成功: {zip_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建zip文件失败: {e}")
|
||||
return False
|
||||
@ -1,24 +1,37 @@
|
||||
import time
|
||||
import psutil
|
||||
import os
|
||||
from loguru import logger
|
||||
import torch
|
||||
|
||||
try:
|
||||
import psutil
|
||||
ENABLE_PERFORMANCE_MONITORING = True
|
||||
except ImportError:
|
||||
ENABLE_PERFORMANCE_MONITORING = False
|
||||
logger.warning("psutil not installed. Performance monitoring is disabled.")
|
||||
class PerformanceMonitor:
|
||||
@staticmethod
|
||||
def monitor_memory():
|
||||
process = psutil.Process(os.getpid())
|
||||
memory_info = process.memory_info()
|
||||
|
||||
logger.debug(f"Memory usage: {memory_info.rss / 1024 / 1024:.2f} MB")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
gpu_memory = torch.cuda.memory_allocated() / 1024 / 1024
|
||||
logger.debug(f"GPU Memory usage: {gpu_memory:.2f} MB")
|
||||
|
||||
@staticmethod
|
||||
def cleanup_resources():
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
PerformanceMonitor.monitor_memory()
|
||||
|
||||
def monitor_performance():
|
||||
if not ENABLE_PERFORMANCE_MONITORING:
|
||||
return {'execution_time': 0, 'memory_usage': 0}
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
memory_usage = psutil.Process().memory_info().rss / 1024 / 1024 # MB
|
||||
except:
|
||||
memory_usage = 0
|
||||
|
||||
return {
|
||||
'execution_time': time.time() - start_time,
|
||||
'memory_usage': memory_usage
|
||||
}
|
||||
def monitor_performance(func):
|
||||
"""性能监控装饰器"""
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
PerformanceMonitor.monitor_memory()
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
finally:
|
||||
PerformanceMonitor.cleanup_resources()
|
||||
return wrapper
|
||||
Loading…
x
Reference in New Issue
Block a user