diff --git a/README.md b/README.md
index 31b6604..8845a23 100644
--- a/README.md
+++ b/README.md
@@ -130,7 +130,7 @@ docker-compose up
## 开发 💻
1. 安装依赖
```shell
-conda create -n narratoai python=3.10
+conda create -n narratoai python=3.11
conda activate narratoai
cd narratoai
pip install -r requirements.txt
diff --git a/app/controllers/v1/llm.py b/app/controllers/v1/llm.py
index e841d68..b5da6ae 100644
--- a/app/controllers/v1/llm.py
+++ b/app/controllers/v1/llm.py
@@ -1,18 +1,24 @@
-from fastapi import Request
+from fastapi import Request, File, UploadFile
+import os
from app.controllers.v1.base import new_router
from app.models.schema import (
VideoScriptResponse,
VideoScriptRequest,
VideoTermsResponse,
VideoTermsRequest,
+ VideoTranscriptionRequest,
+ VideoTranscriptionResponse,
)
from app.services import llm
from app.utils import utils
+from app.config import config
# 认证依赖项
# router = new_router(dependencies=[Depends(base.verify_token)])
router = new_router()
+# 定义上传目录
+UPLOAD_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "uploads")
@router.post(
"/scripts",
@@ -42,3 +48,46 @@ def generate_video_terms(request: Request, body: VideoTermsRequest):
)
response = {"video_terms": video_terms}
return utils.get_response(200, response)
+
+
+@router.post(
+ "/transcription",
+ response_model=VideoTranscriptionResponse,
+ summary="Transcribe video content using Gemini"
+)
+async def transcribe_video(
+ request: Request,
+ video_name: str,
+ language: str = "zh-CN",
+ video_file: UploadFile = File(...)
+):
+ """
+ 使用 Gemini 转录视频内容,包括时间戳、画面描述和语音内容
+
+ Args:
+ video_name: 视频名称
+ language: 语言代码,默认zh-CN
+ video_file: 上传的视频文件
+ """
+ # 创建临时目录用于存储上传的视频
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
+
+ # 保存上传的视频文件
+ video_path = os.path.join(UPLOAD_DIR, video_file.filename)
+ with open(video_path, "wb") as buffer:
+ content = await video_file.read()
+ buffer.write(content)
+
+ try:
+ transcription = llm.gemini_video_transcription(
+ video_name=video_name,
+ video_path=video_path,
+ language=language,
+ llm_provider_video=config.app.get("video_llm_provider", "gemini")
+ )
+ response = {"transcription": transcription}
+ return utils.get_response(200, response)
+ finally:
+ # 处理完成后删除临时文件
+ if os.path.exists(video_path):
+ os.remove(video_path)
diff --git a/app/models/schema.py b/app/models/schema.py
index 682cd94..9d0c5d4 100644
--- a/app/models/schema.py
+++ b/app/models/schema.py
@@ -347,6 +347,7 @@ class VideoClipParams(BaseModel):
voice_name: Optional[str] = Field(default="zh-CN-YunjianNeural", description="语音名称")
voice_volume: Optional[float] = Field(default=1.0, description="语音音量")
voice_rate: Optional[float] = Field(default=1.0, description="语速")
+ voice_pitch: Optional[float] = Field(default=1.0, description="语调")
bgm_name: Optional[str] = Field(default="random", description="背景音乐名称")
bgm_type: Optional[str] = Field(default="random", description="背景音乐类型")
@@ -365,3 +366,13 @@ class VideoClipParams(BaseModel):
custom_position: float = Field(default=70.0, description="自定义位置")
n_threads: Optional[int] = 8 # 线程数,有助于提升视频处理速度
+
+class VideoTranscriptionRequest(BaseModel):
+ video_name: str
+ language: str = "zh-CN"
+
+ class Config:
+ arbitrary_types_allowed = True
+
+class VideoTranscriptionResponse(BaseModel):
+ transcription: str
diff --git a/app/services/audio_merger.py b/app/services/audio_merger.py
index 80c9aff..f0face0 100644
--- a/app/services/audio_merger.py
+++ b/app/services/audio_merger.py
@@ -73,25 +73,40 @@ def merge_audio_files(task_id: str, audio_file_paths: List[str], total_duration:
def parse_timestamp(timestamp: str):
"""解析时间戳字符串为秒数"""
- # start, end = timestamp.split('-')
+ # 确保使用冒号作为分隔符
+ timestamp = timestamp.replace('_', ':')
return time_to_seconds(timestamp)
def extract_timestamp(filename):
"""从文件名中提取开始和结束时间戳"""
- time_part = filename.split('_')[1].split('.')[0]
- times = time_part.split('-')
-
+ # 从 "audio_00_06-00_24.mp3" 这样的格式中提取时间
+ time_part = filename.split('_', 1)[1].split('.')[0] # 获取 "00_06-00_24" 部分
+ start_time, end_time = time_part.split('-') # 分割成 "00_06" 和 "00_24"
+
+ # 将下划线格式转换回冒号格式
+ start_time = start_time.replace('_', ':')
+ end_time = end_time.replace('_', ':')
+
# 将时间戳转换为秒
- start_seconds = time_to_seconds(times[0])
- end_seconds = time_to_seconds(times[1])
+ start_seconds = time_to_seconds(start_time)
+ end_seconds = time_to_seconds(end_time)
return start_seconds, end_seconds
-def time_to_seconds(times):
- """将 “00:06” 转换为总秒数 """
- times = times.split(':')
- return int(times[0]) * 60 + int(times[1])
+def time_to_seconds(time_str):
+ """将 "00:06" 或 "00_06" 格式转换为总秒数"""
+ # 确保使用冒号作为分隔符
+ time_str = time_str.replace('_', ':')
+ try:
+ parts = time_str.split(':')
+ if len(parts) != 2:
+ logger.error(f"Invalid time format: {time_str}")
+ return 0
+ return int(parts[0]) * 60 + int(parts[1])
+ except (ValueError, IndexError) as e:
+ logger.error(f"Error parsing time {time_str}: {str(e)}")
+ return 0
if __name__ == "__main__":
diff --git a/app/services/llm.py b/app/services/llm.py
index 3e9ba16..d054eb1 100644
--- a/app/services/llm.py
+++ b/app/services/llm.py
@@ -14,6 +14,7 @@ from googleapiclient.errors import ResumableUploadError
from google.api_core.exceptions import *
from google.generativeai.types import *
import subprocess
+from typing import Union, TextIO
from app.config import config
from app.utils.utils import clean_model_output
@@ -353,7 +354,7 @@ def _generate_response(prompt: str, llm_provider: str = None) -> str:
return content.replace("\n", "")
-def _generate_response_video(prompt: str, llm_provider_video: str, video_file: str | File) -> str:
+def _generate_response_video(prompt: str, llm_provider_video: str, video_file: Union[str, TextIO]) -> str:
"""
多模态能力大模型
"""
@@ -780,22 +781,28 @@ def screen_matching(huamian: str, wenan: str, llm_provider: str):
if __name__ == "__main__":
# 1. 视频转录
- # video_subject = "第二十条之无罪释放"
- # video_path = "../../resource/videos/test01.mp4"
- # language = "zh-CN"
- # gemini_video_transcription(video_subject, video_path, language)
+ video_subject = "第二十条之无罪释放"
+ video_path = "/Users/apple/Desktop/home/pipedream_project/downloads/jianzao.mp4"
+ language = "zh-CN"
+ gemini_video_transcription(
+ video_name=video_subject,
+ video_path=video_path,
+ language=language,
+ progress_callback=print,
+ llm_provider_video="gemini"
+ )
- # 2. 解说文案
- video_path = "/Users/apple/Desktop/home/NarratoAI/resource/videos/1.mp4"
- # video_path = "E:\\projects\\NarratoAI\\resource\\videos\\1.mp4"
- video_plot = """
- 李自忠拿着儿子李牧名下的存折,去银行取钱给儿子救命,却被要求证明"你儿子是你儿子"。
- 走投无路时碰到银行被抢劫,劫匪给了他两沓钱救命,李自忠却因此被银行以抢劫罪起诉,并顶格判处20年有期徒刑。
- 苏醒后的李牧坚决为父亲做无罪辩护,面对银行的顶级律师团队,他一个法学院大一学生,能否力挽狂澜,创作奇迹?挥法律之利剑 ,持正义之天平!
- """
- res = generate_script(video_path, video_plot, video_name="第二十条之无罪释放")
- # res = generate_script(video_path, video_plot, video_name="海岸")
- print("脚本生成成功:\n", res)
- res = clean_model_output(res)
- aaa = json.loads(res)
- print(json.dumps(aaa, indent=2, ensure_ascii=False))
+ # # 2. 解说文案
+ # video_path = "/Users/apple/Desktop/home/NarratoAI/resource/videos/1.mp4"
+ # # video_path = "E:\\projects\\NarratoAI\\resource\\videos\\1.mp4"
+ # video_plot = """
+ # 李自忠拿着儿子李牧名下的存折,去银行取钱给儿子救命,却被要求证明"你儿子是你儿子"。
+ # 走投无路时碰到银行被抢劫,劫匪给了他两沓钱救命,李自忠却因此被银行以抢劫罪起诉,并顶格判处20年有期徒刑。
+ # 苏醒后的李牧坚决为父亲做无罪辩护,面对银行的顶级律师团队,他一个法学院大一学生,能否力挽狂澜,创作奇迹?挥法律之利剑 ,持正义之天平!
+ # """
+ # res = generate_script(video_path, video_plot, video_name="第二十条之无罪释放")
+ # # res = generate_script(video_path, video_plot, video_name="海岸")
+ # print("脚本生成成功:\n", res)
+ # res = clean_model_output(res)
+ # aaa = json.loads(res)
+ # print(json.dumps(aaa, indent=2, ensure_ascii=False))
diff --git a/app/services/material.py b/app/services/material.py
index bc4d118..ab0dab0 100644
--- a/app/services/material.py
+++ b/app/services/material.py
@@ -1,6 +1,7 @@
import os
import subprocess
import random
+import traceback
from urllib.parse import urlencode
import requests
@@ -274,25 +275,37 @@ def save_clip_video(timestamp: str, origin_video: str, save_dir: str = "") -> di
logger.info(f"video already exists: {video_path}")
return {timestamp: video_path}
- # 剪辑视频
- start, end = utils.split_timestamp(timestamp)
- video = VideoFileClip(origin_video).subclip(start, end)
- video.write_videofile(video_path, logger=None) # 禁用 MoviePy 的内置日志
+ try:
+ # 剪辑视频
+ start, end = utils.split_timestamp(timestamp)
+ video = VideoFileClip(origin_video).subclip(start, end)
+
+ # 检查视频是否有音频轨道
+ if video.audio is not None:
+ video.write_videofile(video_path, logger=None) # 有音频时正常处理
+ else:
+ # 没有音频时使用不同的写入方式
+ video.write_videofile(video_path, audio=False, logger=None)
+
+ video.close() # 确保关闭视频文件
- if os.path.getsize(video_path) > 0 and os.path.exists(video_path):
- try:
- clip = VideoFileClip(video_path)
- duration = clip.duration
- fps = clip.fps
- clip.close()
- if duration > 0 and fps > 0:
- return {timestamp: video_path}
- except Exception as e:
+ if os.path.getsize(video_path) > 0 and os.path.exists(video_path):
try:
- os.remove(video_path)
+ clip = VideoFileClip(video_path)
+ duration = clip.duration
+ fps = clip.fps
+ clip.close()
+ if duration > 0 and fps > 0:
+ return {timestamp: video_path}
except Exception as e:
- logger.warning(str(e))
- logger.warning(f"无效的视频文件: {video_path}")
+ logger.warning(f"视频文件验证失败: {video_path} => {str(e)}")
+ if os.path.exists(video_path):
+ os.remove(video_path)
+ except Exception as e:
+ logger.warning(f"视频剪辑失败: {str(e)}")
+ if os.path.exists(video_path):
+ os.remove(video_path)
+
return {}
@@ -327,7 +340,7 @@ def clip_videos(task_id: str, timestamp_terms: List[str], origin_video: str, pro
if progress_callback:
progress_callback(index + 1, total_items)
except Exception as e:
- logger.error(f"视频裁剪失败: {utils.to_json(item)} => {str(e)}")
+ logger.error(f"视频裁剪失败: {utils.to_json(item)} =>\n{str(traceback.format_exc())}")
return {}
logger.success(f"裁剪 {len(video_paths)} videos")
return video_paths
diff --git a/app/services/subtitle.py b/app/services/subtitle.py
index c792667..f37eb65 100644
--- a/app/services/subtitle.py
+++ b/app/services/subtitle.py
@@ -29,7 +29,7 @@ def create(audio_file, subtitle_file: str = ""):
返回:
无返回值,但会在指定路径生成字幕文件。
"""
- global model
+ global model, device, compute_type
if not model:
model_path = f"{utils.root_dir()}/app/models/faster-whisper-large-v2"
model_bin_file = f"{model_path}/model.bin"
@@ -43,27 +43,45 @@ def create(audio_file, subtitle_file: str = ""):
)
return None
- logger.info(
- f"加载模型: {model_path}, 设备: {device}, 计算类型: {compute_type}"
- )
+ # 尝试使用 CUDA,如果失败则回退到 CPU
try:
+ import torch
+ if torch.cuda.is_available():
+ try:
+ logger.info(f"尝试使用 CUDA 加载模型: {model_path}")
+ model = WhisperModel(
+ model_size_or_path=model_path,
+ device="cuda",
+ compute_type="float16",
+ local_files_only=True
+ )
+ device = "cuda"
+ compute_type = "float16"
+ logger.info("成功使用 CUDA 加载模型")
+ except Exception as e:
+ logger.warning(f"CUDA 加载失败,错误信息: {str(e)}")
+ logger.warning("回退到 CPU 模式")
+ device = "cpu"
+ compute_type = "int8"
+ else:
+ logger.info("未检测到 CUDA,使用 CPU 模式")
+ device = "cpu"
+ compute_type = "int8"
+ except ImportError:
+ logger.warning("未安装 torch,使用 CPU 模式")
+ device = "cpu"
+ compute_type = "int8"
+
+ if device == "cpu":
+ logger.info(f"使用 CPU 加载模型: {model_path}")
model = WhisperModel(
model_size_or_path=model_path,
device=device,
compute_type=compute_type,
local_files_only=True
)
- except Exception as e:
- logger.error(
- f"加载模型失败: {e} \n\n"
- f"********************************************\n"
- f"这可能是由网络问题引起的. \n"
- f"请手动下载模型并将其放入 'app/models' 文件夹中。 \n"
- f"see [README.md FAQ](https://github.com/linyqh/NarratoAI) for more details.\n"
- f"********************************************\n\n"
- f"{traceback.format_exc()}"
- )
- return None
+
+ logger.info(f"模型加载完成,使用设备: {device}, 计算类型: {compute_type}")
logger.info(f"start, output file: {subtitle_file}")
if not subtitle_file:
diff --git a/app/services/task.py b/app/services/task.py
index 78941f8..c903047 100644
--- a/app/services/task.py
+++ b/app/services/task.py
@@ -372,12 +372,13 @@ def start_subclip(task_id: str, params: VideoClipParams, subclip_path_videos: li
list_script=list_script,
voice_name=voice_name,
voice_rate=params.voice_rate,
+ voice_pitch=params.voice_pitch,
force_regenerate=True
)
if audio_files is None:
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
logger.error(
- "音频文件为空,可能是网络不可用。如果您在中国,请使用VPN。或者手动选择 zh-CN-Yunjian-男性 音频")
+ "TTS转换音频失败, 可能是网络不可用! 如果您在中国, 请使用VPN.")
return
logger.info(f"合并音频:\n\n {audio_files}")
audio_file = audio_merger.merge_audio_files(task_id, audio_files, total_duration, list_script)
diff --git a/app/services/video.py b/app/services/video.py
index 6bfb9bf..23bf644 100644
--- a/app/services/video.py
+++ b/app/services/video.py
@@ -4,11 +4,13 @@ import glob
import random
from typing import List
from typing import Union
+import traceback
from loguru import logger
from moviepy.editor import *
from moviepy.video.tools.subtitles import SubtitlesClip
from PIL import ImageFont
+from contextlib import contextmanager
from app.models import const
from app.models.schema import MaterialInfo, VideoAspect, VideoConcatMode, VideoParams, VideoClipParams
@@ -144,7 +146,7 @@ def combine_videos(
return combined_video_path
-def wrap_text(text, max_width, font="Arial", fontsize=60):
+def wrap_text(text, max_width, font, fontsize=60):
# 创建字体对象
font = ImageFont.truetype(font, fontsize)
@@ -157,7 +159,7 @@ def wrap_text(text, max_width, font="Arial", fontsize=60):
if width <= max_width:
return text, height
- # logger.warning(f"wrapping text, max_width: {max_width}, text_width: {width}, text: {text}")
+ logger.debug(f"换行文本, 最大宽度: {max_width}, 文本宽度: {width}, 文本: {text}")
processed = True
@@ -198,110 +200,17 @@ def wrap_text(text, max_width, font="Arial", fontsize=60):
_wrapped_lines_.append(_txt_)
result = "\n".join(_wrapped_lines_).strip()
height = len(_wrapped_lines_) * height
- # logger.warning(f"wrapped text: {result}")
+ logger.debug(f"换行文本: {result}")
return result, height
-def generate_video(
- video_path: str,
- audio_path: str,
- subtitle_path: str,
- output_file: str,
- params: Union[VideoParams, VideoClipParams],
-):
- aspect = VideoAspect(params.video_aspect)
- video_width, video_height = aspect.to_resolution()
-
- logger.info(f"start, video size: {video_width} x {video_height}")
- logger.info(f" ① video: {video_path}")
- logger.info(f" ② audio: {audio_path}")
- logger.info(f" ③ subtitle: {subtitle_path}")
- logger.info(f" ④ output: {output_file}")
-
- # 写入与输出文件相同的目录
- output_dir = os.path.dirname(output_file)
-
- font_path = ""
- if params.subtitle_enabled:
- if not params.font_name:
- params.font_name = "STHeitiMedium.ttc"
- font_path = os.path.join(utils.font_dir(), params.font_name)
- if os.name == "nt":
- font_path = font_path.replace("\\", "/")
-
- logger.info(f"using font: {font_path}")
-
- def create_text_clip(subtitle_item):
- phrase = subtitle_item[1]
- max_width = video_width * 0.9
- wrapped_txt, txt_height = wrap_text(
- phrase, max_width=max_width, font=font_path, fontsize=params.font_size
- )
- _clip = TextClip(
- wrapped_txt,
- font=font_path,
- fontsize=params.font_size,
- color=params.text_fore_color,
- bg_color=params.text_background_color,
- stroke_color=params.stroke_color,
- stroke_width=params.stroke_width,
- print_cmd=False,
- )
- duration = subtitle_item[0][1] - subtitle_item[0][0]
- _clip = _clip.set_start(subtitle_item[0][0])
- _clip = _clip.set_end(subtitle_item[0][1])
- _clip = _clip.set_duration(duration)
- if params.subtitle_position == "bottom":
- _clip = _clip.set_position(("center", video_height * 0.95 - _clip.h))
- elif params.subtitle_position == "top":
- _clip = _clip.set_position(("center", video_height * 0.05))
- elif params.subtitle_position == "custom":
- # 确保字幕完全在屏幕内
- margin = 10 # 额外的边距,单位为像素
- max_y = video_height - _clip.h - margin
- min_y = margin
- custom_y = (video_height - _clip.h) * (params.custom_position / 100)
- custom_y = max(min_y, min(custom_y, max_y)) # 限制 y 值在有效范围内
- _clip = _clip.set_position(("center", custom_y))
- else: # center
- _clip = _clip.set_position(("center", "center"))
- return _clip
-
- video_clip = VideoFileClip(video_path)
- audio_clip = AudioFileClip(audio_path).volumex(params.voice_volume)
-
- if subtitle_path and os.path.exists(subtitle_path):
- sub = SubtitlesClip(subtitles=subtitle_path, encoding="utf-8")
- text_clips = []
- for item in sub.subtitles:
- clip = create_text_clip(subtitle_item=item)
- text_clips.append(clip)
- video_clip = CompositeVideoClip([video_clip, *text_clips])
-
- bgm_file = get_bgm_file(bgm_type=params.bgm_type, bgm_file=params.bgm_file)
- if bgm_file:
- try:
- bgm_clip = (
- AudioFileClip(bgm_file).volumex(params.bgm_volume).audio_fadeout(3)
- )
- bgm_clip = afx.audio_loop(bgm_clip, duration=video_clip.duration)
- audio_clip = CompositeAudioClip([audio_clip, bgm_clip])
- except Exception as e:
- logger.error(f"failed to add bgm: {str(e)}")
-
- video_clip = video_clip.set_audio(audio_clip)
- video_clip.write_videofile(
- output_file,
- audio_codec="aac",
- temp_audiofile_path=output_dir,
- threads=params.n_threads,
- logger=None,
- fps=30,
- )
- video_clip.close()
- del video_clip
- logger.success(""
- "completed")
+@contextmanager
+def manage_clip(clip):
+ try:
+ yield clip
+ finally:
+ clip.close()
+ del clip
def generate_video_v2(
@@ -310,6 +219,7 @@ def generate_video_v2(
subtitle_path: str,
output_file: str,
params: Union[VideoParams, VideoClipParams],
+ progress_callback=None,
):
"""
合并所有素材
@@ -319,135 +229,163 @@ def generate_video_v2(
subtitle_path: 字幕文件路径
output_file: 输出文件路径
params: 视频参数
+ progress_callback: 进度回调函数,接收 0-100 的进度值
Returns:
"""
- aspect = VideoAspect(params.video_aspect)
- video_width, video_height = aspect.to_resolution()
+ total_steps = 4
+ current_step = 0
+
+ def update_progress(step_name):
+ nonlocal current_step
+ current_step += 1
+ if progress_callback:
+ progress_callback(int(current_step * 100 / total_steps))
+ logger.info(f"完成步骤: {step_name}")
- logger.info(f"开始,视频尺寸: {video_width} x {video_height}")
- logger.info(f" ① 视频: {video_path}")
- logger.info(f" ② 音频: {audio_path}")
- logger.info(f" ③ 字幕: {subtitle_path}")
- logger.info(f" ④ 输出: {output_file}")
-
- # 写入与输出文件相同的目录
- output_dir = os.path.dirname(output_file)
-
- # 字体设置部分保持不变
- font_path = ""
- if params.subtitle_enabled:
- if not params.font_name:
- params.font_name = "STHeitiMedium.ttc"
- font_path = os.path.join(utils.font_dir(), params.font_name)
- if os.name == "nt":
- font_path = font_path.replace("\\", "/")
- logger.info(f"使用字体: {font_path}")
-
- # create_text_clip 函数保持不变
- def create_text_clip(subtitle_item):
- phrase = subtitle_item[1]
- max_width = video_width * 0.9
- wrapped_txt, txt_height = wrap_text(
- phrase, max_width=max_width, font=font_path, fontsize=params.font_size
- )
- _clip = TextClip(
- wrapped_txt,
- font=font_path,
- fontsize=params.font_size,
- color=params.text_fore_color,
- bg_color=params.text_background_color,
- stroke_color=params.stroke_color,
- stroke_width=params.stroke_width,
- print_cmd=False,
- )
- duration = subtitle_item[0][1] - subtitle_item[0][0]
- _clip = _clip.set_start(subtitle_item[0][0])
- _clip = _clip.set_end(subtitle_item[0][1])
- _clip = _clip.set_duration(duration)
- if params.subtitle_position == "bottom":
- _clip = _clip.set_position(("center", video_height * 0.95 - _clip.h))
- elif params.subtitle_position == "top":
- _clip = _clip.set_position(("center", video_height * 0.05))
- elif params.subtitle_position == "custom":
- # 确保字幕完全在屏幕内
- margin = 10 # 额外的边距,单位为像素
- max_y = video_height - _clip.h - margin
- min_y = margin
- custom_y = (video_height - _clip.h) * (params.custom_position / 100)
- custom_y = max(min_y, min(custom_y, max_y)) # 限制 y 值在有效范围内
- _clip = _clip.set_position(("center", custom_y))
- else: # center
- _clip = _clip.set_position(("center", "center"))
- return _clip
-
- video_clip = VideoFileClip(video_path)
- original_audio = video_clip.audio # 保存原始视频的音轨
- video_duration = video_clip.duration
-
- # 处理新的音频文件
- new_audio = AudioFileClip(audio_path).volumex(params.voice_volume)
-
- # 字幕处理部分
- if subtitle_path and os.path.exists(subtitle_path):
- sub = SubtitlesClip(subtitles=subtitle_path, encoding="utf-8")
- text_clips = []
+ try:
+ validate_params(video_path, audio_path, output_file, params)
- for item in sub.subtitles:
- clip = create_text_clip(subtitle_item=item)
-
- # 确保字幕的开始时间不早于视频开始
- start_time = max(clip.start, 0)
-
- # 如果字幕的开始时间晚于视频结束时间,则跳过此字幕
- if start_time >= video_duration:
- continue
-
- # 调整字幕的结束时间,但不要超过视频长度
- end_time = min(clip.end, video_duration)
-
- # 调整字幕的时间范围
- clip = clip.set_start(start_time).set_end(end_time)
-
- text_clips.append(clip)
-
- logger.info(f"处理了 {len(text_clips)} 段字幕")
-
- # 创建一个新的视频剪辑,包含所有字幕
- video_clip = CompositeVideoClip([video_clip, *text_clips])
+ with manage_clip(VideoFileClip(video_path)) as video_clip:
+ aspect = VideoAspect(params.video_aspect)
+ video_width, video_height = aspect.to_resolution()
- # 背景音乐处理部分
+ logger.info(f"开始,视频尺寸: {video_width} x {video_height}")
+ logger.info(f" ① 视频: {video_path}")
+ logger.info(f" ② 音频: {audio_path}")
+ logger.info(f" ③ 字幕: {subtitle_path}")
+ logger.info(f" ④ 输出: {output_file}")
+
+ output_dir = os.path.dirname(output_file)
+ update_progress("初始化完成")
+
+ # 字体设置
+ font_path = ""
+ if params.subtitle_enabled:
+ if not params.font_name:
+ params.font_name = "STHeitiMedium.ttc"
+ font_path = os.path.join(utils.font_dir(), params.font_name)
+ if os.name == "nt":
+ font_path = font_path.replace("\\", "/")
+ logger.info(f"使用字体: {font_path}")
+
+ def create_text_clip(subtitle_item):
+ phrase = subtitle_item[1]
+ max_width = video_width * 0.9
+ wrapped_txt, txt_height = wrap_text(
+ phrase, max_width=max_width, font=font_path, fontsize=params.font_size
+ )
+ _clip = TextClip(
+ wrapped_txt,
+ font=font_path,
+ fontsize=params.font_size,
+ color=params.text_fore_color,
+ bg_color=params.text_background_color,
+ stroke_color=params.stroke_color,
+ stroke_width=params.stroke_width,
+ print_cmd=False,
+ )
+ duration = subtitle_item[0][1] - subtitle_item[0][0]
+ _clip = _clip.set_start(subtitle_item[0][0])
+ _clip = _clip.set_end(subtitle_item[0][1])
+ _clip = _clip.set_duration(duration)
+
+ if params.subtitle_position == "bottom":
+ _clip = _clip.set_position(("center", video_height * 0.95 - _clip.h))
+ elif params.subtitle_position == "top":
+ _clip = _clip.set_position(("center", video_height * 0.05))
+ elif params.subtitle_position == "custom":
+ margin = 10
+ max_y = video_height - _clip.h - margin
+ min_y = margin
+ custom_y = (video_height - _clip.h) * (params.custom_position / 100)
+ custom_y = max(min_y, min(custom_y, max_y))
+ _clip = _clip.set_position(("center", custom_y))
+ else: # center
+ _clip = _clip.set_position(("center", "center"))
+ return _clip
+
+ update_progress("字体设置完成")
+
+ # 处理音频
+ original_audio = video_clip.audio
+ video_duration = video_clip.duration
+ new_audio = AudioFileClip(audio_path)
+ final_audio = process_audio_tracks(original_audio, new_audio, params, video_duration)
+ update_progress("音频处理完成")
+
+ # 处理字幕
+ if subtitle_path and os.path.exists(subtitle_path):
+ video_clip = process_subtitles(subtitle_path, video_clip, video_duration, create_text_clip)
+ update_progress("字幕处理完成")
+
+ # 合并音频和导出
+ video_clip = video_clip.set_audio(final_audio)
+ video_clip.write_videofile(
+ output_file,
+ audio_codec="aac",
+ temp_audiofile=os.path.join(output_dir, "temp-audio.m4a"),
+ threads=params.n_threads,
+ logger=None,
+ fps=30,
+ )
+
+ except FileNotFoundError as e:
+ logger.error(f"文件不存在: {str(e)}")
+ raise
+ except Exception as e:
+ logger.error(f"视频生成失败: {str(e)}")
+ raise
+ finally:
+ logger.success("完成")
+
+
+def process_audio_tracks(original_audio, new_audio, params, video_duration):
+ """处理所有音轨"""
+ audio_tracks = []
+
+ if original_audio is not None:
+ audio_tracks.append(original_audio)
+
+ new_audio = new_audio.volumex(params.voice_volume)
+ audio_tracks.append(new_audio)
+
+ # 处理背景音乐
bgm_file = get_bgm_file(bgm_type=params.bgm_type, bgm_file=params.bgm_file)
-
- # 合并音频轨道
- audio_tracks = [original_audio, new_audio]
-
if bgm_file:
try:
- bgm_clip = (
- AudioFileClip(bgm_file).volumex(params.bgm_volume).audio_fadeout(3)
- )
+ bgm_clip = AudioFileClip(bgm_file).volumex(params.bgm_volume).audio_fadeout(3)
bgm_clip = afx.audio_loop(bgm_clip, duration=video_duration)
audio_tracks.append(bgm_clip)
except Exception as e:
logger.error(f"添加背景音乐失败: {str(e)}")
+
+ return CompositeAudioClip(audio_tracks) if audio_tracks else new_audio
- # 合并所有音频轨道
- final_audio = CompositeAudioClip(audio_tracks)
- video_clip = video_clip.set_audio(final_audio)
- video_clip.write_videofile(
- output_file,
- audio_codec="aac",
- temp_audiofile_path=output_dir,
- threads=params.n_threads,
- logger=None,
- fps=30,
- )
- video_clip.close()
- del video_clip
- logger.success("完成")
+def process_subtitles(subtitle_path, video_clip, video_duration, create_text_clip):
+ """处理字幕"""
+ if not (subtitle_path and os.path.exists(subtitle_path)):
+ return video_clip
+
+ sub = SubtitlesClip(subtitles=subtitle_path, encoding="utf-8")
+ text_clips = []
+
+ for item in sub.subtitles:
+ clip = create_text_clip(subtitle_item=item)
+
+ # 时间范围调整
+ start_time = max(clip.start, 0)
+ if start_time >= video_duration:
+ continue
+
+ end_time = min(clip.end, video_duration)
+ clip = clip.set_start(start_time).set_end(end_time)
+ text_clips.append(clip)
+
+ logger.info(f"处理了 {len(text_clips)} 段字幕")
+ return CompositeVideoClip([video_clip, *text_clips])
def preprocess_video(materials: List[MaterialInfo], clip_duration=4):
@@ -499,7 +437,7 @@ def preprocess_video(materials: List[MaterialInfo], clip_duration=4):
def combine_clip_videos(combined_video_path: str,
video_paths: List[str],
- video_ost_list: List[bool],
+ video_ost_list: List[int],
list_script: list,
video_aspect: VideoAspect = VideoAspect.portrait,
threads: int = 2,
@@ -509,92 +447,119 @@ def combine_clip_videos(combined_video_path: str,
Args:
combined_video_path: 合并后的存储路径
video_paths: 子视频路径列表
- video_ost_list: 原声播放列表
+ video_ost_list: 原声播放列表 (0: 不保留原声, 1: 只保留原声, 2: 保留原声并保留解说)
list_script: 剪辑脚本
video_aspect: 屏幕比例
threads: 线程数
Returns:
-
+ str: 合并后的视频路径
"""
from app.utils.utils import calculate_total_duration
audio_duration = calculate_total_duration(list_script)
logger.info(f"音频的最大持续时间: {audio_duration} s")
- # 每个剪辑所需的持续时间
- req_dur = audio_duration / len(video_paths)
- # req_dur = max_clip_duration
- # logger.info(f"每个剪辑的最大长度为 {req_dur} s")
+
output_dir = os.path.dirname(combined_video_path)
-
aspect = VideoAspect(video_aspect)
video_width, video_height = aspect.to_resolution()
clips = []
- video_duration = 0
- # 一遍又一遍地添加下载的剪辑,直到达到音频的持续时间 (max_duration)
- # while video_duration < audio_duration:
for video_path, video_ost in zip(video_paths, video_ost_list):
- cache_video_path = utils.root_dir()
- clip = VideoFileClip(os.path.join(cache_video_path, video_path))
- # 通过 ost 字段判断是否播放原声
- if not video_ost:
- clip = clip.without_audio()
- # # 检查剪辑是否比剩余音频长
- # if (audio_duration - video_duration) < clip.duration:
- # clip = clip.subclip(0, (audio_duration - video_duration))
- # # 仅当计算出的剪辑长度 (req_dur) 短于实际剪辑时,才缩短剪辑以防止静止图像
- # elif req_dur < clip.duration:
- # clip = clip.subclip(0, req_dur)
- clip = clip.set_fps(30)
+ try:
+ clip = VideoFileClip(video_path)
+
+ if video_ost == 0: # 不保留原声
+ clip = clip.without_audio()
+ # video_ost 为 1 或 2 时都保留原声,不需要特殊处理
+
+ clip = clip.set_fps(30)
- # 并非所有视频的大小都相同,因此我们需要调整它们的大小
- clip_w, clip_h = clip.size
- if clip_w != video_width or clip_h != video_height:
- clip_ratio = clip.w / clip.h
- video_ratio = video_width / video_height
+ # 处理视频尺寸
+ clip_w, clip_h = clip.size
+ if clip_w != video_width or clip_h != video_height:
+ clip = resize_video_with_padding(
+ clip,
+ target_width=video_width,
+ target_height=video_height
+ )
+ logger.info(f"视频 {video_path} 已调整尺寸为 {video_width} x {video_height}")
- if clip_ratio == video_ratio:
- # 等比例缩放
- clip = clip.resize((video_width, video_height))
- else:
- # 等比缩放视频
- if clip_ratio > video_ratio:
- # 按照目标宽度等比缩放
- scale_factor = video_width / clip_w
- else:
- # 按照目标高度等比缩放
- scale_factor = video_height / clip_h
+ clips.append(clip)
+
+ except Exception as e:
+ logger.error(f"处理视频 {video_path} 时出错: {str(e)}")
+ continue
- new_width = int(clip_w * scale_factor)
- new_height = int(clip_h * scale_factor)
- clip_resized = clip.resize(newsize=(new_width, new_height))
+ if not clips:
+ raise ValueError("没有有效的视频片段可以合并")
- background = ColorClip(size=(video_width, video_height), color=(0, 0, 0))
- clip = CompositeVideoClip([
- background.set_duration(clip.duration),
- clip_resized.set_position("center")
- ])
+ try:
+ video_clip = concatenate_videoclips(clips)
+ video_clip = video_clip.set_fps(30)
+
+ logger.info("开始合并视频...")
+ video_clip.write_videofile(
+ filename=combined_video_path,
+ threads=threads,
+ logger=None,
+ audio_codec="aac",
+ fps=30,
+ temp_audiofile=os.path.join(output_dir, "temp-audio.m4a")
+ )
+ finally:
+ # 确保资源被正确���放
+ video_clip.close()
+ for clip in clips:
+ clip.close()
- logger.info(f"将视频 {video_path} 大小调整为 {video_width} x {video_height}, 剪辑尺寸: {clip_w} x {clip_h}")
-
- clips.append(clip)
- video_duration += clip.duration
-
- video_clip = concatenate_videoclips(clips)
- video_clip = video_clip.set_fps(30)
- logger.info(f"合并视频中...")
- video_clip.write_videofile(filename=combined_video_path,
- threads=threads,
- logger=None,
- temp_audiofile_path=output_dir,
- audio_codec="aac",
- fps=30,
- )
- video_clip.close()
- logger.success(f"completed")
+ logger.success("视频合并完成")
return combined_video_path
+def resize_video_with_padding(clip, target_width: int, target_height: int):
+ """辅助函数:调整视频尺寸并添加黑边"""
+ clip_ratio = clip.w / clip.h
+ target_ratio = target_width / target_height
+
+ if clip_ratio == target_ratio:
+ return clip.resize((target_width, target_height))
+
+ if clip_ratio > target_ratio:
+ scale_factor = target_width / clip.w
+ else:
+ scale_factor = target_height / clip.h
+
+ new_width = int(clip.w * scale_factor)
+ new_height = int(clip.h * scale_factor)
+ clip_resized = clip.resize(newsize=(new_width, new_height))
+
+ background = ColorClip(
+ size=(target_width, target_height),
+ color=(0, 0, 0)
+ ).set_duration(clip.duration)
+
+ return CompositeVideoClip([
+ background,
+ clip_resized.set_position("center")
+ ])
+
+
+def validate_params(video_path, audio_path, output_file, params):
+ """验证输入参数"""
+ if not os.path.exists(video_path):
+ raise FileNotFoundError(f"视频文件不存在: {video_path}")
+
+ if not os.path.exists(audio_path):
+ raise FileNotFoundError(f"音频文件不存在: {audio_path}")
+
+ output_dir = os.path.dirname(output_file)
+ if not os.path.exists(output_dir):
+ raise FileNotFoundError(f"输出目录不存在: {output_dir}")
+
+ if not hasattr(params, 'video_aspect'):
+ raise ValueError("params 缺少必要参数 video_aspect")
+
+
if __name__ == "__main__":
# combined_video_path = "../../storage/tasks/12312312/com123.mp4"
#
@@ -635,23 +600,23 @@ if __name__ == "__main__":
# ]
# combine_clip_videos(combined_video_path=combined_video_path, video_paths=video_paths, video_ost_list=video_ost_list, list_script=list_script)
- cfg = VideoClipParams()
- cfg.video_aspect = VideoAspect.portrait
- cfg.font_name = "STHeitiMedium.ttc"
- cfg.font_size = 60
- cfg.stroke_color = "#000000"
- cfg.stroke_width = 1.5
- cfg.text_fore_color = "#FFFFFF"
- cfg.text_background_color = "transparent"
- cfg.bgm_type = "random"
- cfg.bgm_file = ""
- cfg.bgm_volume = 1.0
- cfg.subtitle_enabled = True
- cfg.subtitle_position = "bottom"
- cfg.n_threads = 2
- cfg.paragraph_number = 1
-
- cfg.voice_volume = 1.0
+ # cfg = VideoClipParams()
+ # cfg.video_aspect = VideoAspect.portrait
+ # cfg.font_name = "STHeitiMedium.ttc"
+ # cfg.font_size = 60
+ # cfg.stroke_color = "#000000"
+ # cfg.stroke_width = 1.5
+ # cfg.text_fore_color = "#FFFFFF"
+ # cfg.text_background_color = "transparent"
+ # cfg.bgm_type = "random"
+ # cfg.bgm_file = ""
+ # cfg.bgm_volume = 1.0
+ # cfg.subtitle_enabled = True
+ # cfg.subtitle_position = "bottom"
+ # cfg.n_threads = 2
+ # cfg.paragraph_number = 1
+ #
+ # cfg.voice_volume = 1.0
# generate_video(video_path=video_file,
# audio_path=audio_file,
@@ -659,18 +624,27 @@ if __name__ == "__main__":
# output_file=output_file,
# params=cfg
# )
+ #
+ # video_path = "../../storage/tasks/7f5ae494-abce-43cf-8f4f-4be43320eafa/combined-1.mp4"
+ #
+ # audio_path = "../../storage/tasks/7f5ae494-abce-43cf-8f4f-4be43320eafa/audio_00-00-00-07.mp3"
+ #
+ # subtitle_path = "../../storage/tasks/7f5ae494-abce-43cf-8f4f-4be43320eafa\subtitle.srt"
+ #
+ # output_file = "../../storage/tasks/7f5ae494-abce-43cf-8f4f-4be43320eafa/final-123.mp4"
+ #
+ # generate_video_v2(video_path=video_path,
+ # audio_path=audio_path,
+ # subtitle_path=subtitle_path,
+ # output_file=output_file,
+ # params=cfg
+ # )
- video_path = "../../storage/tasks/7f5ae494-abce-43cf-8f4f-4be43320eafa/combined-1.mp4"
+ # 合并视频
+ video_list = [
+ './storage/cache_videos/vid-01_03-01_50.mp4',
+ './storage/cache_videos/vid-01_55-02_29.mp4',
+ './storage/cache_videos/vid-03_24-04_04.mp4',
+ './storage/cache_videos/vid-04_50-05_28.mp4'
+ ]
- audio_path = "../../storage/tasks/7f5ae494-abce-43cf-8f4f-4be43320eafa/audio_00-00-00-07.mp3"
-
- subtitle_path = "../../storage/tasks/7f5ae494-abce-43cf-8f4f-4be43320eafa\subtitle.srt"
-
- output_file = "../../storage/tasks/7f5ae494-abce-43cf-8f4f-4be43320eafa/final-123.mp4"
-
- generate_video_v2(video_path=video_path,
- audio_path=audio_path,
- subtitle_path=subtitle_path,
- output_file=output_file,
- params=cfg
- )
diff --git a/app/services/voice.py b/app/services/voice.py
index e4776bf..02245f6 100644
--- a/app/services/voice.py
+++ b/app/services/voice.py
@@ -1032,11 +1032,11 @@ def is_azure_v2_voice(voice_name: str):
def tts(
- text: str, voice_name: str, voice_rate: float, voice_file: str
+ text: str, voice_name: str, voice_rate: float, voice_pitch: float, voice_file: str
) -> [SubMaker, None]:
# if is_azure_v2_voice(voice_name):
# return azure_tts_v2(text, voice_name, voice_file)
- return azure_tts_v1(text, voice_name, voice_rate, voice_file)
+ return azure_tts_v1(text, voice_name, voice_rate, voice_pitch, voice_file)
def convert_rate_to_percent(rate: float) -> str:
@@ -1049,18 +1049,29 @@ def convert_rate_to_percent(rate: float) -> str:
return f"{percent}%"
+def convert_pitch_to_percent(rate: float) -> str:
+ if rate == 1.0:
+ return "+0Hz"
+ percent = round((rate - 1.0) * 100)
+ if percent > 0:
+ return f"+{percent}Hz"
+ else:
+ return f"{percent}Hz"
+
+
def azure_tts_v1(
- text: str, voice_name: str, voice_rate: float, voice_file: str
+ text: str, voice_name: str, voice_rate: float, voice_pitch: float, voice_file: str
) -> [SubMaker, None]:
voice_name = parse_voice_name(voice_name)
text = text.strip()
rate_str = convert_rate_to_percent(voice_rate)
+ pitch_str = convert_pitch_to_percent(voice_pitch)
for i in range(3):
try:
logger.info(f"start, voice name: {voice_name}, try: {i + 1}")
async def _do() -> SubMaker:
- communicate = edge_tts.Communicate(text, voice_name, rate=rate_str)
+ communicate = edge_tts.Communicate(text, voice_name, rate=rate_str, pitch=pitch_str, proxy=config.proxy.get("http"))
sub_maker = edge_tts.SubMaker()
with open(voice_file, "wb") as file:
async for chunk in communicate.stream():
@@ -1392,7 +1403,7 @@ def get_audio_duration(sub_maker: submaker.SubMaker):
return sub_maker.offset[-1][1] / 10000000
-def tts_multiple(task_id: str, list_script: list, voice_name: str, voice_rate: float, force_regenerate: bool = True):
+def tts_multiple(task_id: str, list_script: list, voice_name: str, voice_rate: float, voice_pitch: float, force_regenerate: bool = True):
"""
根据JSON文件中的多段文本进行TTS转换
@@ -1409,13 +1420,13 @@ def tts_multiple(task_id: str, list_script: list, voice_name: str, voice_rate: f
sub_maker_list = []
for item in list_script:
- if not item['OST']:
- # timestamp = item['new_timestamp'].replace(':', '@')
- timestamp = item['new_timestamp']
+ if item['OST'] != 1:
+ # 将时间戳中的冒号替换为下划线
+ timestamp = item['new_timestamp'].replace(':', '_')
audio_file = os.path.join(output_dir, f"audio_{timestamp}.mp3")
# 检查文件是否已存在,如存在且不强制重新生成,则跳过
- if os.path.exists(audio_file):
+ if os.path.exists(audio_file) and not force_regenerate:
logger.info(f"音频文件已存在,跳过生成: {audio_file}")
audio_files.append(audio_file)
continue
@@ -1426,7 +1437,8 @@ def tts_multiple(task_id: str, list_script: list, voice_name: str, voice_rate: f
text=text,
voice_name=voice_name,
voice_rate=voice_rate,
- voice_file=audio_file
+ voice_pitch=voice_pitch,
+ voice_file=audio_file,
)
if sub_maker is None:
diff --git a/app/utils/check_script.py b/app/utils/check_script.py
index 623c42a..00e6c0f 100644
--- a/app/utils/check_script.py
+++ b/app/utils/check_script.py
@@ -1,115 +1,81 @@
import json
-from loguru import logger
-import os
-from datetime import timedelta
+from typing import Dict, Any
-def time_to_seconds(time_str):
- parts = list(map(int, time_str.split(':')))
- if len(parts) == 2:
- return timedelta(minutes=parts[0], seconds=parts[1]).total_seconds()
- elif len(parts) == 3:
- return timedelta(hours=parts[0], minutes=parts[1], seconds=parts[2]).total_seconds()
- raise ValueError(f"无法解析时间字符串: {time_str}")
+def check_format(script_content: str) -> Dict[str, Any]:
+ """检查脚本格式
+ Args:
+ script_content: 脚本内容
+ Returns:
+ Dict: {'success': bool, 'message': str}
+ """
+ try:
+ # 检查是否为有效的JSON
+ data = json.loads(script_content)
+
+ # 检查是否为列表
+ if not isinstance(data, list):
+ return {
+ 'success': False,
+ 'message': '脚本必须是JSON数组格式'
+ }
+
+ # 检查每个片段
+ for i, clip in enumerate(data):
+ # 检查必需字段
+ required_fields = ['narration', 'picture', 'timestamp']
+ for field in required_fields:
+ if field not in clip:
+ return {
+ 'success': False,
+ 'message': f'第{i+1}个片段缺少必需字段: {field}'
+ }
+
+ # 检查字段类型
+ if not isinstance(clip['narration'], str):
+ return {
+ 'success': False,
+ 'message': f'第{i+1}个片段的narration必须是字符串'
+ }
+ if not isinstance(clip['picture'], str):
+ return {
+ 'success': False,
+ 'message': f'第{i+1}个片段的picture必须是字符串'
+ }
+ if not isinstance(clip['timestamp'], str):
+ return {
+ 'success': False,
+ 'message': f'第{i+1}个片段的timestamp必须是字符串'
+ }
+
+ # 检查字段内容不能为空
+ if not clip['narration'].strip():
+ return {
+ 'success': False,
+ 'message': f'第{i+1}个片段的narration不能为空'
+ }
+ if not clip['picture'].strip():
+ return {
+ 'success': False,
+ 'message': f'第{i+1}个片段的picture不能为空'
+ }
+ if not clip['timestamp'].strip():
+ return {
+ 'success': False,
+ 'message': f'第{i+1}个片段的timestamp不能为空'
+ }
-def seconds_to_time_str(seconds):
- hours, remainder = divmod(int(seconds), 3600)
- minutes, seconds = divmod(remainder, 60)
- if hours > 0:
- return f"{hours:02d}:{minutes:02d}:{seconds:02d}"
- else:
- return f"{minutes:02d}:{seconds:02d}"
+ return {
+ 'success': True,
+ 'message': '脚本格式检查通过'
+ }
-def adjust_timestamp(start_time, duration):
- start_seconds = time_to_seconds(start_time)
- end_seconds = start_seconds + duration
- return f"{start_time}-{seconds_to_time_str(end_seconds)}"
-
-def estimate_audio_duration(text):
- # 假设平均每个字符需要 0.2 秒
- return len(text) * 0.2
-
-def check_script(data, total_duration):
- errors = []
- time_ranges = []
-
- logger.info("开始检查脚本")
- logger.info(f"视频总时长: {total_duration:.2f} 秒")
- logger.info("=" * 50)
-
- for i, item in enumerate(data, 1):
- logger.info(f"\n检查第 {i} 项:")
-
- # 检查所有必需字段
- required_fields = ['picture', 'timestamp', 'narration', 'OST']
- for field in required_fields:
- if field not in item:
- errors.append(f"第 {i} 项缺少 {field} 字段")
- logger.info(f" - 错误: 缺少 {field} 字段")
- else:
- logger.info(f" - {field}: {item[field]}")
-
- # 检查 OST 相关规则
- if item.get('OST') == False:
- if not item.get('narration'):
- errors.append(f"第 {i} 项 OST 为 false,但 narration 为空")
- logger.info(" - 错误: OST 为 false,但 narration 为空")
- elif len(item['narration']) > 60:
- errors.append(f"第 {i} 项 OST 为 false,但 narration 超过 60 字")
- logger.info(f" - 错误: OST 为 false,但 narration 超过 60 字 (当前: {len(item['narration'])} 字)")
- else:
- logger.info(" - OST 为 false,narration 检查通过")
- elif item.get('OST') == True:
- if "原声播放_" not in item.get('narration'):
- errors.append(f"第 {i} 项 OST 为 true,但 narration 不为空")
- logger.info(" - 错误: OST 为 true,但 narration 不为空")
- else:
- logger.info(" - OST 为 true,narration 检查通过")
-
- # 检查 timestamp
- if 'timestamp' in item:
- start, end = map(time_to_seconds, item['timestamp'].split('-'))
- if any((start < existing_end and end > existing_start) for existing_start, existing_end in time_ranges):
- errors.append(f"第 {i} 项 timestamp '{item['timestamp']}' 与其他时间段重叠")
- logger.info(f" - 错误: timestamp '{item['timestamp']}' 与其他时间段重叠")
- else:
- logger.info(f" - timestamp '{item['timestamp']}' 检查通过")
- time_ranges.append((start, end))
-
- # if end > total_duration:
- # errors.append(f"第 {i} 项 timestamp '{item['timestamp']}' 超过总时长 {total_duration:.2f} 秒")
- # logger.info(f" - 错误: timestamp '{item['timestamp']}' 超过总时长 {total_duration:.2f} 秒")
- # else:
- # logger.info(f" - timestamp 在总时长范围内")
-
- # 处理 narration 字段
- if item.get('OST') == False and item.get('narration'):
- estimated_duration = estimate_audio_duration(item['narration'])
- start_time = item['timestamp'].split('-')[0]
- item['timestamp'] = adjust_timestamp(start_time, estimated_duration)
- logger.info(f" - 已调整 timestamp 为 {item['timestamp']} (估算音频时长: {estimated_duration:.2f} 秒)")
-
- if errors:
- logger.info("检查结果:不通过")
- logger.info("发现以下错误:")
- for error in errors:
- logger.info(f"- {error}")
- else:
- logger.info("检查结果:通过")
- logger.info("所有项目均符合规则要求。")
-
- return errors, data
-
-
-if __name__ == "__main__":
- file_path = "/Users/apple/Desktop/home/NarratoAI/resource/scripts/test004.json"
-
- with open(file_path, 'r', encoding='utf-8') as f:
- data = json.load(f)
-
- total_duration = 280
-
- # check_script(data, total_duration)
-
- from app.utils.utils import add_new_timestamps
- res = add_new_timestamps(data)
- print(json.dumps(res, indent=4, ensure_ascii=False))
+ except json.JSONDecodeError as e:
+ return {
+ 'success': False,
+ 'message': f'JSON格式错误: {str(e)}'
+ }
+ except Exception as e:
+ return {
+ 'success': False,
+ 'message': f'检查过程中发生错误: {str(e)}'
+ }
diff --git a/app/utils/script_generator.py b/app/utils/script_generator.py
new file mode 100644
index 0000000..72e7408
--- /dev/null
+++ b/app/utils/script_generator.py
@@ -0,0 +1,392 @@
+import os
+import json
+import traceback
+from loguru import logger
+import tiktoken
+from typing import List, Dict
+from datetime import datetime
+from openai import OpenAI
+import google.generativeai as genai
+
+
+class BaseGenerator:
+ def __init__(self, model_name: str, api_key: str, prompt: str):
+ self.model_name = model_name
+ self.api_key = api_key
+ self.base_prompt = prompt
+ self.conversation_history = []
+ self.chunk_overlap = 50
+ self.last_chunk_ending = ""
+
+ def generate_script(self, scene_description: str, word_count: int) -> str:
+ raise NotImplementedError("Subclasses must implement generate_script method")
+
+
+class OpenAIGenerator(BaseGenerator):
+ def __init__(self, model_name: str, api_key: str, prompt: str):
+ super().__init__(model_name, api_key, prompt)
+ self.client = OpenAI(api_key=api_key)
+ self.max_tokens = 7000
+ try:
+ self.encoding = tiktoken.encoding_for_model(self.model_name)
+ except KeyError:
+ logger.info(f"警告:未找到模型 {self.model_name} 的专用编码器,使用认编码器")
+ self.encoding = tiktoken.get_encoding("cl100k_base")
+
+ def _count_tokens(self, messages: list) -> int:
+ num_tokens = 0
+ for message in messages:
+ num_tokens += 3
+ for key, value in message.items():
+ num_tokens += len(self.encoding.encode(str(value)))
+ if key == "role":
+ num_tokens += 1
+ num_tokens += 3
+ return num_tokens
+
+ def _trim_conversation_history(self, system_prompt: str, new_user_prompt: str) -> None:
+ base_messages = [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": new_user_prompt}
+ ]
+ base_tokens = self._count_tokens(base_messages)
+
+ temp_history = []
+ current_tokens = base_tokens
+
+ for message in reversed(self.conversation_history):
+ message_tokens = self._count_tokens([message])
+ if current_tokens + message_tokens > self.max_tokens:
+ break
+ temp_history.insert(0, message)
+ current_tokens += message_tokens
+
+ self.conversation_history = temp_history
+
+ def generate_script(self, scene_description: str, word_count: int) -> str:
+ max_attempts = 3
+ tolerance = 5
+
+ for attempt in range(max_attempts):
+ system_prompt, user_prompt = self._create_prompt(scene_description, word_count)
+ self._trim_conversation_history(system_prompt, user_prompt)
+
+ messages = [
+ {"role": "system", "content": system_prompt},
+ *self.conversation_history,
+ {"role": "user", "content": user_prompt}
+ ]
+
+ response = self.client.chat.completions.create(
+ model=self.model_name,
+ messages=messages,
+ temperature=0.7,
+ max_tokens=500,
+ top_p=0.9,
+ frequency_penalty=0.3,
+ presence_penalty=0.5
+ )
+
+ generated_script = response.choices[0].message.content.strip().strip('"').strip("'").replace('\"',
+ '').replace(
+ '\n', '')
+
+ current_length = len(generated_script)
+ if abs(current_length - word_count) <= tolerance:
+ self.conversation_history.append({"role": "user", "content": user_prompt})
+ self.conversation_history.append({"role": "assistant", "content": generated_script})
+ self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len(
+ generated_script) > self.chunk_overlap else generated_script
+ return generated_script
+
+ return generated_script
+
+ def _create_prompt(self, scene_description: str, word_count: int) -> tuple:
+ system_prompt = self.base_prompt.format(word_count=word_count)
+
+ user_prompt = f"""上一段文案的结尾:{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"}
+
+当前画面描述:{scene_description}
+
+请确保新生成的文案与上文自然衔接,保持叙事的连贯性和趣味性。
+严格字数要求:{word_count}字,允许误差±5字。"""
+
+ return system_prompt, user_prompt
+
+
+class GeminiGenerator(BaseGenerator):
+ def __init__(self, model_name: str, api_key: str, prompt: str):
+ super().__init__(model_name, api_key, prompt)
+ genai.configure(api_key=api_key)
+ self.model = genai.GenerativeModel(model_name)
+
+ def generate_script(self, scene_description: str, word_count: int) -> str:
+ max_attempts = 3
+ tolerance = 5
+
+ for attempt in range(max_attempts):
+ prompt = f"""{self.base_prompt}
+
+上一段文案的结尾:{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"}
+
+当前画面描述:{scene_description}
+
+请确保新生成的文案与上文自然衔接,保持叙事的连贯性和趣味性。
+严格字数要求:{word_count}字,允许误差±5字。"""
+
+ response = self.model.generate_content(prompt)
+ generated_script = response.text.strip().strip('"').strip("'").replace('\"', '').replace('\n', '')
+
+ current_length = len(generated_script)
+ if abs(current_length - word_count) <= tolerance:
+ self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len(
+ generated_script) > self.chunk_overlap else generated_script
+ return generated_script
+
+ return generated_script
+
+
+class QwenGenerator(BaseGenerator):
+ def __init__(self, model_name: str, api_key: str, prompt: str):
+ super().__init__(model_name, api_key, prompt)
+ self.client = OpenAI(
+ api_key=api_key,
+ base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
+ )
+
+ def generate_script(self, scene_description: str, word_count: int) -> str:
+ max_attempts = 3
+ tolerance = 5
+
+ for attempt in range(max_attempts):
+ prompt = f"""{self.base_prompt}
+
+上一段文案的结尾:{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"}
+
+当前画面描述:{scene_description}
+
+请确保新生成的文案与上文自然衔接,保持叙事的连贯性和趣味性。
+严格字数要求:{word_count}字,允许误差±5字。"""
+
+ messages = [
+ {"role": "system", "content": self.base_prompt},
+ {"role": "user", "content": prompt}
+ ]
+
+ response = self.client.chat.completions.create(
+ model=self.model_name, # 如 "qwen-plus"
+ messages=messages,
+ temperature=0.7,
+ max_tokens=500,
+ top_p=0.9,
+ frequency_penalty=0.3,
+ presence_penalty=0.5
+ )
+
+ generated_script = response.choices[0].message.content.strip().strip('"').strip("'").replace('\"',
+ '').replace(
+ '\n', '')
+
+ current_length = len(generated_script)
+ if abs(current_length - word_count) <= tolerance:
+ self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len(
+ generated_script) > self.chunk_overlap else generated_script
+ return generated_script
+
+ return generated_script
+
+
+class MoonshotGenerator(BaseGenerator):
+ def __init__(self, model_name: str, api_key: str, prompt: str):
+ super().__init__(model_name, api_key, prompt)
+ self.client = OpenAI(
+ api_key=api_key,
+ base_url="https://api.moonshot.cn/v1"
+ )
+
+ def generate_script(self, scene_description: str, word_count: int) -> str:
+ max_attempts = 3
+ tolerance = 5
+
+ for attempt in range(max_attempts):
+ prompt = f"""{self.base_prompt}
+
+上一段文案的结尾:{self.last_chunk_ending if self.last_chunk_ending else "这是第一段,无需考虑上文"}
+
+当前画面描述:{scene_description}
+
+请确保新生成的文案与上文自然衔接,保持叙事的连贯性和趣味性。
+严格字数要求:{word_count}字,允许误差±5字。"""
+
+ messages = [
+ {"role": "system", "content": self.base_prompt},
+ {"role": "user", "content": prompt}
+ ]
+
+ response = self.client.chat.completions.create(
+ model=self.model_name, # 如 "moonshot-v1-8k"
+ messages=messages,
+ temperature=0.7,
+ max_tokens=500,
+ top_p=0.9,
+ frequency_penalty=0.3,
+ presence_penalty=0.5
+ )
+
+ generated_script = response.choices[0].message.content.strip().strip('"').strip("'").replace('\"',
+ '').replace(
+ '\n', '')
+
+ current_length = len(generated_script)
+ if abs(current_length - word_count) <= tolerance:
+ self.last_chunk_ending = generated_script[-self.chunk_overlap:] if len(
+ generated_script) > self.chunk_overlap else generated_script
+ return generated_script
+
+ return generated_script
+
+
+class ScriptProcessor:
+ def __init__(self, model_name: str, api_key: str = None, prompt: str = None, video_theme: str = ""):
+ self.model_name = model_name
+ self.api_key = api_key
+ self.video_theme = video_theme
+ self.prompt = prompt or self._get_default_prompt()
+
+ # 根据模型名称选择对应的生成器
+ if 'gemini' in model_name.lower():
+ self.generator = GeminiGenerator(model_name, self.api_key, self.prompt)
+ elif 'qwen' in model_name.lower():
+ self.generator = QwenGenerator(model_name, self.api_key, self.prompt)
+ elif 'moonshot' in model_name.lower():
+ self.generator = MoonshotGenerator(model_name, self.api_key, self.prompt)
+ else:
+ self.generator = OpenAIGenerator(model_name, self.api_key, self.prompt)
+
+ def _get_default_prompt(self, word_count=None) -> str:
+ return f"""你是一位极具幽默感的短视频脚本创作大师,擅长用"温和的违反"制造笑点,让{self.video_theme}视频既有趣又富有传播力。你的任务是将视频画面描述转化为能在社交平台疯狂传播的爆款口播文案。
+
+目标受众:热爱生活、追求独特体验的18-35岁年轻人
+文案风格:基于HKRR理论 + 段子手精神
+主题:{self.video_theme}
+
+【创作核心理念】
+1. 敢于用"温和的违反"制造笑点,但不能过于冒犯
+2. 巧妙运用中国式幽默,让观众会心一笑
+3. 保持轻松愉快的叙事基调
+
+【爆款内容四要素】
+
+【快乐元素 Happy】
+1. 用调侃的语气描述建造过程中的"笨手笨脚"
+2. 巧妙植入网络流行梗,增加内容的传播性
+3. 适时自嘲,展现真实且有趣的一面
+
+【知识价值 Knowledge】
+1. 用段子手的方式解释专业知识(比如:"这根木头不是一般的木头,它比我前任还难搞...")
+2. 把复杂的建造技巧转化为生动有趣的比喻
+3. 在幽默中传递实用的野外生存技能
+
+【情感共鸣 Resonance】
+1. 描述"真实但夸张"的建造困境
+2. 把对自然的感悟融入俏皮话中
+3. 用接地气的表达方式拉近与观众距离
+
+【节奏控制 Rhythm】
+1. 严格控制文案字数在{word_count}字左右,允许误差不超过5字
+2. 像讲段子一样,注意铺垫和包袱的节奏
+3. 确保每段都有笑点,但不强求
+4. 段落结尾干净利落,不拖泥带水
+
+【连贯性要求】
+1. 新生成的内容必须自然衔接上一段文案的结尾
+2. 使用恰当的连接词和过渡语,确保叙事流畅
+3. 保持人物视角和语气的一致性
+4. 避免重复上一段已经提到的信息
+5. 确保情节和建造过程的逻辑连续性
+
+【字数控制要求】
+1. 严格控制文案字数在{word_count}字左右,允许误差不超过5字
+2. 如果内容过长,优先精简修饰性词语
+3. 如果内容过短,可以适当增加细节描写
+4. 保持文案结构完整,不因字数限制而牺牲内容质量
+5. 确保每个笑点和包袱都得到完整表达
+
+我会按顺序提供多段视频画面描述。请创作既搞笑又能火爆全网的口播文案。
+记住:要敢于用"温和的违反"制造笑点,但要把握好尺度,让观众在轻松愉快中感受野外建造的乐趣。"""
+
+ def calculate_duration_and_word_count(self, time_range: str) -> int:
+ try:
+ start_str, end_str = time_range.split('-')
+
+ def time_to_seconds(time_str):
+ minutes, seconds = map(int, time_str.split(':'))
+ return minutes * 60 + seconds
+
+ start_seconds = time_to_seconds(start_str)
+ end_seconds = time_to_seconds(end_str)
+ duration = end_seconds - start_seconds
+ word_count = int(duration / 0.2)
+
+ return word_count
+ except Exception as e:
+ logger.info(f"时间格式转换错误: {traceback.format_exc()}")
+ return 100
+
+ def process_frames(self, frame_content_list: List[Dict]) -> List[Dict]:
+ for frame_content in frame_content_list:
+ word_count = self.calculate_duration_and_word_count(frame_content["timestamp"])
+ script = self.generator.generate_script(frame_content["picture"], word_count)
+ frame_content["narration"] = script
+ frame_content["OST"] = 2
+ logger.info(f"时间范围: {frame_content['timestamp']}, 建议字数: {word_count}")
+ logger.info(script)
+
+ self._save_results(frame_content_list)
+ return frame_content_list
+
+ def _save_results(self, frame_content_list: List[Dict]):
+ """保存处理结果,并添加新的时间戳"""
+ try:
+ # 计算新的时间戳
+ current_time = 0 # 当前时间点(秒)
+
+ for frame in frame_content_list:
+ # 获取原始时间戳的持续时间
+ start_str, end_str = frame['timestamp'].split('-')
+
+ def time_to_seconds(time_str):
+ minutes, seconds = map(int, time_str.split(':'))
+ return minutes * 60 + seconds
+
+ # 计算当前片段的持续时间
+ start_seconds = time_to_seconds(start_str)
+ end_seconds = time_to_seconds(end_str)
+ duration = end_seconds - start_seconds
+
+ # 转换秒数为 MM:SS 格式
+ def seconds_to_time(seconds):
+ minutes = seconds // 60
+ remaining_seconds = seconds % 60
+ return f"{minutes:02d}:{remaining_seconds:02d}"
+
+ # 设置新的时间戳
+ new_start = seconds_to_time(current_time)
+ new_end = seconds_to_time(current_time + duration)
+ frame['new_timestamp'] = f"{new_start}-{new_end}"
+
+ # 更新当前时间点
+ current_time += duration
+
+ # 保存结果
+ file_name = f"storage/json/step2_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
+ os.makedirs(os.path.dirname(file_name), exist_ok=True)
+
+ with open(file_name, 'w', encoding='utf-8') as file:
+ json.dump(frame_content_list, file, ensure_ascii=False, indent=4)
+
+ logger.info(f"保存脚本成功,总时长: {seconds_to_time(current_time)}")
+
+ except Exception as e:
+ logger.error(f"保存结果时发生错误: {str(e)}\n{traceback.format_exc()}")
+ raise
diff --git a/app/utils/utils.py b/app/utils/utils.py
index 3a0600f..4880e6c 100644
--- a/app/utils/utils.py
+++ b/app/utils/utils.py
@@ -56,7 +56,7 @@ def to_json(obj):
# 使用serialize函数处理输入对象
serialized_obj = serialize(obj)
- # 序列化处理后的对象为JSON字符串
+ # 序列化处理后的对象为JSON���符串
return json.dumps(serialized_obj, ensure_ascii=False, indent=4)
except Exception as e:
return None
@@ -100,7 +100,7 @@ def task_dir(sub_dir: str = ""):
def font_dir(sub_dir: str = ""):
- d = resource_dir(f"fonts")
+ d = resource_dir("fonts")
if sub_dir:
d = os.path.join(d, sub_dir)
if not os.path.exists(d):
@@ -109,7 +109,7 @@ def font_dir(sub_dir: str = ""):
def song_dir(sub_dir: str = ""):
- d = resource_dir(f"songs")
+ d = resource_dir("songs")
if sub_dir:
d = os.path.join(d, sub_dir)
if not os.path.exists(d):
@@ -423,5 +423,104 @@ def cut_video(params, progress_callback=None):
return task_id, subclip_videos
except Exception as e:
- logger.error(f"视频裁剪过程中发生错误: {traceback.format_exc()}")
+ logger.error(f"视频裁剪过程中发生错误: \n{traceback.format_exc()}")
+ raise
+
+
+def temp_dir(sub_dir: str = ""):
+ """
+ 获取临时文件目录
+ Args:
+ sub_dir: 子目录名
+ Returns:
+ str: 临时文件目录路径
+ """
+ d = os.path.join(storage_dir(), "temp")
+ if sub_dir:
+ d = os.path.join(d, sub_dir)
+ if not os.path.exists(d):
+ os.makedirs(d)
+ return d
+
+
+def clear_keyframes_cache(video_path: str = None):
+ """
+ 清理关键帧缓存
+ Args:
+ video_path: 视频文件路径,如果指定则只清理该视频的缓存
+ """
+ try:
+ keyframes_dir = os.path.join(temp_dir(), "keyframes")
+ if not os.path.exists(keyframes_dir):
+ return
+
+ if video_path:
+ # ���理指定视频的缓存
+ video_hash = md5(video_path + str(os.path.getmtime(video_path)))
+ video_keyframes_dir = os.path.join(keyframes_dir, video_hash)
+ if os.path.exists(video_keyframes_dir):
+ import shutil
+ shutil.rmtree(video_keyframes_dir)
+ logger.info(f"已清理视频关键帧缓存: {video_path}")
+ else:
+ # 清理所有缓存
+ import shutil
+ shutil.rmtree(keyframes_dir)
+ logger.info("已清理所有关键帧缓存")
+
+ except Exception as e:
+ logger.error(f"清理关键帧缓存失败: {e}")
+
+
+def init_resources():
+ """初始化资源文件"""
+ try:
+ # 创建字体目录
+ font_dir = os.path.join(root_dir(), "resource", "fonts")
+ os.makedirs(font_dir, exist_ok=True)
+
+ # 检查字体文件
+ font_files = [
+ ("SourceHanSansCN-Regular.otf", "https://github.com/adobe-fonts/source-han-sans/raw/release/OTF/SimplifiedChinese/SourceHanSansSC-Regular.otf"),
+ ("simhei.ttf", "C:/Windows/Fonts/simhei.ttf"), # Windows 黑体
+ ("simkai.ttf", "C:/Windows/Fonts/simkai.ttf"), # Windows 楷体
+ ("simsun.ttc", "C:/Windows/Fonts/simsun.ttc"), # Windows 宋体
+ ]
+
+ # 优先使用系统字体
+ system_font_found = False
+ for font_name, source in font_files:
+ if not source.startswith("http") and os.path.exists(source):
+ target_path = os.path.join(font_dir, font_name)
+ if not os.path.exists(target_path):
+ import shutil
+ shutil.copy2(source, target_path)
+ logger.info(f"已复制系统字体: {font_name}")
+ system_font_found = True
+ break
+
+ # 如果没有找到系统字体,则下载思源黑体
+ if not system_font_found:
+ source_han_path = os.path.join(font_dir, "SourceHanSansCN-Regular.otf")
+ if not os.path.exists(source_han_path):
+ download_font(font_files[0][1], source_han_path)
+
+ except Exception as e:
+ logger.error(f"初始化资源文件失败: {e}")
+
+def download_font(url: str, font_path: str):
+ """下载字体文件"""
+ try:
+ logger.info(f"正在下载字体文件: {url}")
+ import requests
+ response = requests.get(url)
+ response.raise_for_status()
+
+ with open(font_path, 'wb') as f:
+ f.write(response.content)
+
+ logger.info(f"字体文件下载成功: {font_path}")
+
+ except Exception as e:
+ logger.error(f"下载字体文件失败: {e}")
raise
diff --git a/app/utils/video_processor.py b/app/utils/video_processor.py
new file mode 100644
index 0000000..6822b46
--- /dev/null
+++ b/app/utils/video_processor.py
@@ -0,0 +1,209 @@
+import cv2
+import numpy as np
+from sklearn.cluster import KMeans
+import os
+import re
+from typing import List, Tuple, Generator
+
+
+class VideoProcessor:
+ def __init__(self, video_path: str):
+ """
+ 初始化视频处理器
+
+ Args:
+ video_path: 视频文件路径
+ """
+ if not os.path.exists(video_path):
+ raise FileNotFoundError(f"视频文件不存在: {video_path}")
+
+ self.video_path = video_path
+ self.cap = cv2.VideoCapture(video_path)
+
+ if not self.cap.isOpened():
+ raise RuntimeError(f"无法打开视频文件: {video_path}")
+
+ self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ self.fps = int(self.cap.get(cv2.CAP_PROP_FPS))
+
+ def __del__(self):
+ """析构函数,确保视频资源被释放"""
+ if hasattr(self, 'cap'):
+ self.cap.release()
+
+ def preprocess_video(self) -> Generator[np.ndarray, None, None]:
+ """
+ 使用生成器方式读取视频帧
+
+ Yields:
+ np.ndarray: 视频帧
+ """
+ self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # 重置到视频开始
+ while self.cap.isOpened():
+ ret, frame = self.cap.read()
+ if not ret:
+ break
+ yield frame
+
+ def detect_shot_boundaries(self, frames: List[np.ndarray], threshold: int = 30) -> List[int]:
+ """
+ 使用帧差法检测镜头边界
+
+ Args:
+ frames: 视频帧列表
+ threshold: 差异阈值
+
+ Returns:
+ List[int]: 镜头边界帧的索引列表
+ """
+ shot_boundaries = []
+ for i in range(1, len(frames)):
+ prev_frame = cv2.cvtColor(frames[i - 1], cv2.COLOR_BGR2GRAY)
+ curr_frame = cv2.cvtColor(frames[i], cv2.COLOR_BGR2GRAY)
+ diff = np.mean(np.abs(curr_frame.astype(int) - prev_frame.astype(int)))
+ if diff > threshold:
+ shot_boundaries.append(i)
+ return shot_boundaries
+
+ def extract_keyframes(self, frames: List[np.ndarray], shot_boundaries: List[int]) -> Tuple[List[np.ndarray], List[int]]:
+ """
+ 从每个镜头中提取关键帧
+
+ Args:
+ frames: 视频帧列表
+ shot_boundaries: 镜头边界列表
+
+ Returns:
+ Tuple[List[np.ndarray], List[int]]: 关���帧列表和对应的帧索引
+ """
+ keyframes = []
+ keyframe_indices = []
+
+ for i in range(len(shot_boundaries)):
+ start = shot_boundaries[i - 1] if i > 0 else 0
+ end = shot_boundaries[i]
+ shot_frames = frames[start:end]
+
+ frame_features = np.array([cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).flatten()
+ for frame in shot_frames])
+ kmeans = KMeans(n_clusters=1, random_state=0).fit(frame_features)
+ center_idx = np.argmin(np.sum((frame_features - kmeans.cluster_centers_[0]) ** 2, axis=1))
+
+ keyframes.append(shot_frames[center_idx])
+ keyframe_indices.append(start + center_idx)
+
+ return keyframes, keyframe_indices
+
+ def save_keyframes(self, keyframes: List[np.ndarray], keyframe_indices: List[int],
+ output_dir: str) -> None:
+ """
+ 保存关键帧到指定目录,文件名格式为:keyframe_帧序号_时间戳.jpg
+
+ Args:
+ keyframes: 关键帧列表
+ keyframe_indices: 关键帧索引列表
+ output_dir: 输出目录
+ """
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+
+ for keyframe, frame_idx in zip(keyframes, keyframe_indices):
+ # 计算时间戳(秒)
+ timestamp = frame_idx / self.fps
+ # 将时间戳转换为 HH:MM:SS 格式
+ hours = int(timestamp // 3600)
+ minutes = int((timestamp % 3600) // 60)
+ seconds = int(timestamp % 60)
+ time_str = f"{hours:02d}{minutes:02d}{seconds:02d}"
+
+ # 构建新的文件名格式:keyframe_帧序号_时间戳.jpg
+ output_path = os.path.join(output_dir,
+ f'keyframe_{frame_idx:06d}_{time_str}.jpg')
+ cv2.imwrite(output_path, keyframe)
+
+ print(f"已保存 {len(keyframes)} 个关键帧到 {output_dir}")
+
+ def extract_frames_by_numbers(self, frame_numbers: List[int], output_folder: str) -> None:
+ """
+ 根据指定的帧号提取帧
+
+ Args:
+ frame_numbers: 要提取的帧号列表
+ output_folder: 输出文件夹路径
+ """
+ if not frame_numbers:
+ raise ValueError("未提供帧号列表")
+
+ if any(fn >= self.total_frames or fn < 0 for fn in frame_numbers):
+ raise ValueError("存在无效的帧号")
+
+ if not os.path.exists(output_folder):
+ os.makedirs(output_folder)
+
+ for frame_number in frame_numbers:
+ self.cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
+ ret, frame = self.cap.read()
+
+ if ret:
+ # 计算时间戳
+ timestamp = frame_number / self.fps
+ hours = int(timestamp // 3600)
+ minutes = int((timestamp % 3600) // 60)
+ seconds = int(timestamp % 60)
+ time_str = f"{hours:02d}{minutes:02d}{seconds:02d}"
+
+ # 使用与关键帧相同的命名格式
+ output_path = os.path.join(output_folder,
+ f"extracted_frame_{frame_number:06d}_{time_str}.jpg")
+ cv2.imwrite(output_path, frame)
+ print(f"已提取并保存帧 {frame_number}")
+ else:
+ print(f"无法读取帧 {frame_number}")
+
+ @staticmethod
+ def extract_numbers_from_folder(folder_path: str) -> List[int]:
+ """
+ 从文件夹中提取帧号
+
+ Args:
+ folder_path: 关键帧文件夹路径
+
+ Returns:
+ List[int]: 排序后的帧号列表
+ """
+ files = [f for f in os.listdir(folder_path) if f.endswith('.jpg')]
+ # 更新正则表达式以匹配新的文件名格式:keyframe_000123_010534.jpg
+ pattern = re.compile(r'keyframe_(\d+)_\d+\.jpg$')
+ numbers = []
+ for f in files:
+ match = pattern.search(f)
+ if match:
+ numbers.append(int(match.group(1)))
+ return sorted(numbers)
+
+ def process_video(self, output_dir: str, skip_seconds: float = 0) -> None:
+ """
+ 处理视频并提取关键帧
+
+ Args:
+ output_dir: 输出目录
+ skip_seconds: 跳过视频开头的秒数
+ """
+ # 计算要跳过的帧数
+ skip_frames = int(skip_seconds * self.fps)
+
+ # 获取所有帧
+ frames = list(self.preprocess_video())
+
+ # 跳过指定秒数的帧
+ frames = frames[skip_frames:]
+
+ if not frames:
+ raise ValueError(f"跳过 {skip_seconds} 秒后没有剩余帧可以处理")
+
+ shot_boundaries = self.detect_shot_boundaries(frames)
+ keyframes, keyframe_indices = self.extract_keyframes(frames, shot_boundaries)
+
+ # 调整关键帧索引,加上跳过的帧数
+ adjusted_indices = [idx + skip_frames for idx in keyframe_indices]
+ self.save_keyframes(keyframes, adjusted_indices, output_dir)
\ No newline at end of file
diff --git a/app/utils/vision_analyzer.py b/app/utils/vision_analyzer.py
new file mode 100644
index 0000000..060b29a
--- /dev/null
+++ b/app/utils/vision_analyzer.py
@@ -0,0 +1,183 @@
+import json
+from typing import List, Union, Dict
+import os
+from pathlib import Path
+from loguru import logger
+from tqdm import tqdm
+import asyncio
+from tenacity import retry, stop_after_attempt, RetryError, retry_if_exception_type, wait_exponential
+from google.api_core import exceptions
+import google.generativeai as genai
+import PIL.Image
+import traceback
+
+class VisionAnalyzer:
+ """视觉分析器类"""
+
+ def __init__(self, model_name: str = "gemini-1.5-flash", api_key: str = None):
+ """初始化视觉分析器"""
+ if not api_key:
+ raise ValueError("必须提供API密钥")
+
+ self.model_name = model_name
+ self.api_key = api_key
+
+ # 初始化配置
+ self._configure_client()
+
+ def _configure_client(self):
+ """配置API客户端"""
+ genai.configure(api_key=self.api_key)
+ self.model = genai.GenerativeModel(self.model_name)
+
+ @retry(
+ stop=stop_after_attempt(3),
+ wait=wait_exponential(multiplier=1, min=4, max=10),
+ retry=retry_if_exception_type(exceptions.ResourceExhausted)
+ )
+ async def _generate_content_with_retry(self, prompt, batch):
+ """使用重试机制的内部方法来调用 generate_content_async"""
+ try:
+ return await self.model.generate_content_async([prompt, *batch])
+ except exceptions.ResourceExhausted as e:
+ print(f"API配额限制: {str(e)}")
+ raise RetryError("API调用失败")
+
+ async def analyze_images(self,
+ images: Union[List[str], List[PIL.Image.Image]],
+ prompt: str,
+ batch_size: int = 5) -> List[Dict]:
+ """批量分析多张图片"""
+ try:
+ # 加载图片
+ if isinstance(images[0], str):
+ logger.info("正在加载图片...")
+ images = self.load_images(images)
+
+ # 验证图片列表
+ if not images:
+ raise ValueError("图片列表为空")
+
+ # 验证每个图片对象
+ valid_images = []
+ for i, img in enumerate(images):
+ if not isinstance(img, PIL.Image.Image):
+ logger.error(f"无效的图片对象,索引 {i}: {type(img)}")
+ continue
+ valid_images.append(img)
+
+ if not valid_images:
+ raise ValueError("没有有效的图片对象")
+
+ images = valid_images
+ results = []
+ total_batches = (len(images) + batch_size - 1) // batch_size
+
+ with tqdm(total=total_batches, desc="分析进度") as pbar:
+ for i in range(0, len(images), batch_size):
+ batch = images[i:i + batch_size]
+ retry_count = 0
+
+ while retry_count < 3:
+ try:
+ # 在每个批次处理前添加小延迟
+ if i > 0:
+ await asyncio.sleep(2)
+
+ # 确保每个批次的图片都是有效的
+ valid_batch = [img for img in batch if isinstance(img, PIL.Image.Image)]
+ if not valid_batch:
+ raise ValueError(f"批次 {i // batch_size} 中没有有效的图片")
+
+ response = await self._generate_content_with_retry(prompt, valid_batch)
+ results.append({
+ 'batch_index': i // batch_size,
+ 'images_processed': len(valid_batch),
+ 'response': response.text,
+ 'model_used': self.model_name
+ })
+ break
+
+ except Exception as e:
+ retry_count += 1
+ error_msg = f"批次 {i // batch_size} 处理出错: {str(e)}\n{traceback.format_exc()}"
+ logger.error(error_msg)
+
+ if retry_count >= 3:
+ results.append({
+ 'batch_index': i // batch_size,
+ 'images_processed': len(batch),
+ 'error': error_msg,
+ 'model_used': self.model_name
+ })
+ else:
+ logger.info(f"批次 {i // batch_size} 处理失败,等待60秒后重试...")
+ await asyncio.sleep(60)
+
+ pbar.update(1)
+
+ return results
+
+ except Exception as e:
+ error_msg = f"图片分析过程中发生错误: {str(e)}\n{traceback.format_exc()}"
+ logger.error(error_msg)
+ raise Exception(error_msg)
+
+ def save_results_to_txt(self, results: List[Dict], output_dir: str):
+ """将分析结果保存到txt文件"""
+ # 确保输出目录存在
+ os.makedirs(output_dir, exist_ok=True)
+
+ for result in results:
+ if not result.get('image_paths'):
+ continue
+
+ response_text = result['response']
+ image_paths = result['image_paths']
+
+ img_name_start = Path(image_paths[0]).stem.split('_')[-1]
+ img_name_end = Path(image_paths[-1]).stem.split('_')[-1]
+ txt_path = os.path.join(output_dir, f"frame_{img_name_start}_{img_name_end}.txt")
+
+ # 保存结果到txt文件
+ with open(txt_path, 'w', encoding='utf-8') as f:
+ f.write(response_text.strip())
+ print(f"已保存分析结果到: {txt_path}")
+
+ def load_images(self, image_paths: List[str]) -> List[PIL.Image.Image]:
+ """
+ 加载多张图片
+ Args:
+ image_paths: 图片路径列表
+ Returns:
+ 加载后的PIL Image对象列表
+ """
+ images = []
+ failed_images = []
+
+ for img_path in image_paths:
+ try:
+ if not os.path.exists(img_path):
+ logger.error(f"图片文件不存在: {img_path}")
+ failed_images.append(img_path)
+ continue
+
+ img = PIL.Image.open(img_path)
+ # 确保图片被完全加载
+ img.load()
+ # 转换为RGB模式
+ if img.mode != 'RGB':
+ img = img.convert('RGB')
+ images.append(img)
+
+ except Exception as e:
+ logger.error(f"无法加载图片 {img_path}: {str(e)}")
+ failed_images.append(img_path)
+
+ if failed_images:
+ logger.warning(f"以下图片加载失败:\n{json.dumps(failed_images, indent=2, ensure_ascii=False)}")
+
+ if not images:
+ raise ValueError("没有成功加载任何图片")
+
+ return images
\ No newline at end of file
diff --git a/config.example.toml b/config.example.toml
index 1557101..3b5af4f 100644
--- a/config.example.toml
+++ b/config.example.toml
@@ -1,94 +1,92 @@
[app]
- project_version="0.2.0"
- # 如果你没有 OPENAI API Key,可以使用 g4f 代替,或者使用国内的 Moonshot API
- # If you don't have an OPENAI API Key, you can use g4f instead
- video_llm_provider="gemini"
+ project_version="0.3.2"
+ # 支持视频理解的大模型提供商
+ # gemini
+ # NarratoAPI
+ # qwen2-vl (待增加)
+ vision_llm_provider="gemini"
+ vision_batch_size = 5
+ vision_analysis_prompt = "你是资深视频内容分析专家,擅长分析视频画面信息,分析下面视频画面内容,只输出客观的画面描述不要给任何总结或评价"
- # 支持的提供商 (Supported providers):
- # openai
+ ########## Vision Gemini API Key
+ vision_gemini_api_key = ""
+ vision_gemini_model_name = "gemini-1.5-flash"
+
+ ########### Vision NarratoAPI Key
+ # NarratoAPI 是为了便捷访问不了 Gemini API 的用户, 提供的代理服务
+ narrato_api_key = ""
+ narrato_api_url = "https://narratoapi.scsmtech.cn/api/v1"
+ narrato_vision_model = "gemini-1.5-flash"
+ narrato_vision_key = ""
+ narrato_llm_model = "gpt-4o"
+ narrato_llm_key = ""
+
+ # 用于生成文案的大模型支持的提供商 (Supported providers):
+ # openai (默认)
# moonshot (月之暗面)
# oneapi
# g4f
# azure
# qwen (通义千问)
# gemini
- llm_provider="openai"
- # 支持多模态视频理解能力的大模型
-
- ########## Ollama Settings
- # No need to set it unless you want to use your own proxy
- ollama_base_url = ""
- # Check your available models at https://ollama.com/library
- ollama_model_name = ""
+ text_llm_provider="openai"
########## OpenAI API Key
# Get your API key at https://platform.openai.com/api-keys
- openai_api_key = ""
+ text_openai_api_key = ""
# No need to set it unless you want to use your own proxy
- openai_base_url = ""
+ text_openai_base_url = ""
# Check your available models at https://platform.openai.com/account/limits
- openai_model_name = "gpt-4-turbo"
+ text_openai_model_name = "gpt-4o-mini"
########## Moonshot API Key
# Visit https://platform.moonshot.cn/console/api-keys to get your API key.
- moonshot_api_key=""
- moonshot_base_url = "https://api.moonshot.cn/v1"
- moonshot_model_name = "moonshot-v1-8k"
-
- ########## OneAPI API Key
- # Visit https://github.com/songquanpeng/one-api to get your API key
- oneapi_api_key=""
- oneapi_base_url=""
- oneapi_model_name=""
+ text_moonshot_api_key=""
+ text_moonshot_base_url = "https://api.moonshot.cn/v1"
+ text_moonshot_model_name = "moonshot-v1-8k"
########## G4F
# Visit https://github.com/xtekky/gpt4free to get more details
# Supported model list: https://github.com/xtekky/gpt4free/blob/main/g4f/models.py
- g4f_model_name = "gpt-3.5-turbo"
+ text_g4f_model_name = "gpt-3.5-turbo"
########## Azure API Key
# Visit https://learn.microsoft.com/zh-cn/azure/ai-services/openai/ to get more details
# API documentation: https://learn.microsoft.com/zh-cn/azure/ai-services/openai/reference
- azure_api_key = ""
- azure_base_url=""
- azure_model_name="gpt-35-turbo" # replace with your model deployment name
- azure_api_version = "2024-02-15-preview"
+ text_azure_api_key = ""
+ text_azure_base_url=""
+ text_azure_model_name="gpt-35-turbo" # replace with your model deployment name
+ text_azure_api_version = "2024-02-15-preview"
########## Gemini API Key
- gemini_api_key=""
- gemini_model_name = "gemini-1.5-flash"
+ text_gemini_api_key=""
+ text_gemini_model_name = "gemini-1.5-flash"
########## Qwen API Key
# Visit https://dashscope.console.aliyun.com/apiKey to get your API key
# Visit below links to get more details
# https://tongyi.aliyun.com/qianwen/
# https://help.aliyun.com/zh/dashscope/developer-reference/model-introduction
- qwen_api_key = ""
- qwen_model_name = "qwen-max"
-
+ text_qwen_api_key = ""
+ text_qwen_model_name = "qwen-max"
########## DeepSeek API Key
# Visit https://platform.deepseek.com/api_keys to get your API key
- deepseek_api_key = ""
- deepseek_base_url = "https://api.deepseek.com"
- deepseek_model_name = "deepseek-chat"
+ text_deepseek_api_key = ""
+ text_deepseek_base_url = "https://api.deepseek.com"
+ text_deepseek_model_name = "deepseek-chat"
- # Subtitle Provider, "whisper"
- # If empty, the subtitle will not be generated
+ # 字幕提供商、可选,支持 whisper 和 faster-whisper-large-v2"whisper"
+ # 默认为 faster-whisper-large-v2 模型地址:https://huggingface.co/guillaumekln/faster-whisper-large-v2
subtitle_provider = "faster-whisper-large-v2"
subtitle_enabled = true
- #
# ImageMagick
- #
- # Once you have installed it, ImageMagick will be automatically detected, except on Windows!
- # On Windows, for example "C:\Program Files (x86)\ImageMagick-7.1.1-Q16-HDRI\magick.exe"
- # Download from https://imagemagick.org/archive/binaries/ImageMagick-7.1.1-29-Q16-x64-static.exe
-
+ # 安装后,将自动检测到 ImageMagick,Windows 除外!
+ # 例如,在 Windows 上 "C:\Program Files (x86)\ImageMagick-7.1.1-Q16-HDRI\magick.exe"
+ # 下载位置 https://imagemagick.org/archive/binaries/ImageMagick-7.1.1-29-Q16-x64-static.exe
# imagemagick_path = "C:\\Program Files (x86)\\ImageMagick-7.1.1-Q16\\magick.exe"
-
- #
# FFMPEG
#
# 通常情况下,ffmpeg 会被自动下载,并且会被自动检测到。
@@ -97,12 +95,6 @@
# Install ffmpeg on your system, or set the IMAGEIO_FFMPEG_EXE environment variable.
# 此时你可以手动下载 ffmpeg 并设置 ffmpeg_path,下载地址:https://www.gyan.dev/ffmpeg/builds/
- # Under normal circumstances, ffmpeg is downloaded automatically and detected automatically.
- # However, if there is an issue with your environment that prevents automatic downloading, you might encounter the following error:
- # RuntimeError: No ffmpeg exe could be found.
- # Install ffmpeg on your system, or set the IMAGEIO_FFMPEG_EXE environment variable.
- # In such cases, you can manually download ffmpeg and set the ffmpeg_path, download link: https://www.gyan.dev/ffmpeg/builds/
-
# ffmpeg_path = "C:\\Users\\harry\\Downloads\\ffmpeg.exe"
#########################################################################################
@@ -132,7 +124,7 @@
material_directory = ""
- # Used for state management of the task
+ # 用于任务的状态管理
enable_redis = false
redis_host = "localhost"
redis_port = 6379
@@ -143,7 +135,6 @@
max_concurrent_tasks = 5
# webui界面是否显示配置项
- # webui hide baisc config panel
hide_config = false
@@ -161,7 +152,7 @@
# recommended model_size: "large-v3"
model_size="faster-whisper-large-v2"
- # if you want to use GPU, set device="cuda"
+ # 如果要使用 GPU,请设置 device=“cuda”
device="CPU"
compute_type="int8"
diff --git a/requirements.txt b/requirements.txt
index 0941fb8..2ae1f29 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,28 +1,34 @@
requests~=2.31.0
moviepy~=2.0.0.dev2
-openai~=1.13.3
faster-whisper~=1.0.1
edge_tts~=6.1.15
uvicorn~=0.27.1
-fastapi~=0.110.0
+fastapi~=0.115.4
tomli~=2.0.1
-streamlit~=1.33.0
+streamlit~=1.40.0
loguru~=0.7.2
-aiohttp~=3.9.3
+aiohttp~=3.10.10
urllib3~=2.2.1
-pillow~=10.3.0
pydantic~=2.6.3
g4f~=0.3.0.4
dashscope~=1.15.0
-google.generativeai>=0.8.2
+google.generativeai>=0.8.3
python-multipart~=0.0.9
redis==5.0.3
-# if you use pillow~=10.3.0, you will get "PIL.Image' has no attribute 'ANTIALIAS'" error when resize video
-# please install opencv-python to fix "PIL.Image' has no attribute 'ANTIALIAS'" error
-opencv-python~=4.9.0.80
+opencv-python~=4.10.0.84
# for azure speech
# https://techcommunity.microsoft.com/t5/ai-azure-ai-services-blog/9-more-realistic-ai-voices-for-conversations-now-generally/ba-p/4099471
azure-cognitiveservices-speech~=1.37.0
git-changelog~=2.5.2
watchdog==5.0.2
pydub==0.25.1
+psutil>=5.9.0
+opencv-python~=4.10.0.84
+scikit-learn~=1.5.2
+google-generativeai~=0.8.3
+Pillow>=11.0.0
+python-dotenv~=1.0.1
+openai~=1.53.0
+tqdm>=4.66.6
+tenacity>=9.0.0
+tiktoken==0.8.0
\ No newline at end of file
diff --git a/webui.py b/webui.py
index 5e37dd7..1f4cb97 100644
--- a/webui.py
+++ b/webui.py
@@ -1,6 +1,15 @@
import streamlit as st
+import os
+import sys
+from uuid import uuid4
from app.config import config
+from webui.components import basic_settings, video_settings, audio_settings, subtitle_settings, script_settings, review_settings
+from webui.utils import cache, file_utils
+from app.utils import utils
+from app.models.schema import VideoClipParams, VideoAspect
+from webui.utils.performance import PerformanceMonitor
+# 初始化配置 - 必须是第一个 Streamlit 命令
st.set_page_config(
page_title="NarratoAI",
page_icon="📽️",
@@ -13,126 +22,34 @@ st.set_page_config(
},
)
-import sys
-import os
-import glob
-import json
-import time
-import datetime
-import traceback
-from uuid import uuid4
-import platform
-import streamlit.components.v1 as components
-from loguru import logger
-
-from app.models.const import FILE_TYPE_VIDEOS
-from app.models.schema import VideoClipParams, VideoAspect, VideoConcatMode
-from app.services import task as tm, llm, voice, material
-from app.utils import utils
-
-# # 将项目的根目录添加到系统路径中,以允许从项目导入模块
-root_dir = os.path.dirname(os.path.realpath(__file__))
-if root_dir not in sys.path:
- sys.path.append(root_dir)
- print("******** sys.path ********")
- print(sys.path)
- print("*" * 20)
-
-proxy_url_http = config.proxy.get("http", "") or os.getenv("VPN_PROXY_URL", "")
-proxy_url_https = config.proxy.get("https", "") or os.getenv("VPN_PROXY_URL", "")
-os.environ["HTTP_PROXY"] = proxy_url_http
-os.environ["HTTPS_PROXY"] = proxy_url_https
-
+# 设置页面样式
hide_streamlit_style = """
"""
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
-st.title(f"NarratoAI :sunglasses:📽️")
-support_locales = [
- "zh-CN",
- "zh-HK",
- "zh-TW",
- "en-US",
-]
-font_dir = os.path.join(root_dir, "resource", "fonts")
-song_dir = os.path.join(root_dir, "resource", "songs")
-i18n_dir = os.path.join(root_dir, "webui", "i18n")
-config_file = os.path.join(root_dir, "webui", ".streamlit", "webui.toml")
-system_locale = utils.get_system_locale()
-
-if 'video_clip_json' not in st.session_state:
- st.session_state['video_clip_json'] = []
-if 'video_plot' not in st.session_state:
- st.session_state['video_plot'] = ''
-if 'ui_language' not in st.session_state:
- st.session_state['ui_language'] = config.ui.get("language", system_locale)
-if 'subclip_videos' not in st.session_state:
- st.session_state['subclip_videos'] = {}
-
-
-def get_all_fonts():
- fonts = []
- for root, dirs, files in os.walk(font_dir):
- for file in files:
- if file.endswith(".ttf") or file.endswith(".ttc"):
- fonts.append(file)
- fonts.sort()
- return fonts
-
-
-def get_all_songs():
- songs = []
- for root, dirs, files in os.walk(song_dir):
- for file in files:
- if file.endswith(".mp3"):
- songs.append(file)
- return songs
-
-
-def open_task_folder(task_id):
- try:
- sys = platform.system()
- path = os.path.join(root_dir, "storage", "tasks", task_id)
- if os.path.exists(path):
- if sys == 'Windows':
- os.system(f"start {path}")
- if sys == 'Darwin':
- os.system(f"open {path}")
- except Exception as e:
- logger.error(e)
-
-
-def scroll_to_bottom():
- js = f"""
-
- """
- st.components.v1.html(js, height=0, width=0)
-
def init_log():
+ """初始化日志配置"""
+ from loguru import logger
logger.remove()
_lvl = "DEBUG"
def format_record(record):
- # 获取日志记录中的文件全径
+ # 增加更多需要过滤的警告消息
+ ignore_messages = [
+ "Examining the path of torch.classes raised",
+ "torch.cuda.is_available()",
+ "CUDA initialization"
+ ]
+
+ for msg in ignore_messages:
+ if msg in record["message"]:
+ return ""
+
file_path = record["file"].path
- # 将绝对路径转换为相对于项目根目录的路径
- relative_path = os.path.relpath(file_path, root_dir)
- # 更新记录中的文件路径
+ relative_path = os.path.relpath(file_path, config.root_dir)
record["file"].path = f"./{relative_path}"
- # 返回修改后的格式字符串
- # 您可以根据需要调整这里的格式
- record['message'] = record['message'].replace(root_dir, ".")
+ record['message'] = record['message'].replace(config.root_dir, ".")
_format = '{time:%Y-%m-%d %H:%M:%S}> | ' + \
'{level}> | ' + \
@@ -140,671 +57,143 @@ def init_log():
'- {message}>' + "\n"
return _format
+ # 优化日志过滤器
+ def log_filter(record):
+ ignore_messages = [
+ "Examining the path of torch.classes raised",
+ "torch.cuda.is_available()",
+ "CUDA initialization"
+ ]
+ return not any(msg in record["message"] for msg in ignore_messages)
+
logger.add(
sys.stdout,
level=_lvl,
format=format_record,
colorize=True,
+ filter=log_filter
)
-
-init_log()
-
-locales = utils.load_locales(i18n_dir)
-
+def init_global_state():
+ """初始化全局状态"""
+ if 'video_clip_json' not in st.session_state:
+ st.session_state['video_clip_json'] = []
+ if 'video_plot' not in st.session_state:
+ st.session_state['video_plot'] = ''
+ if 'ui_language' not in st.session_state:
+ st.session_state['ui_language'] = config.ui.get("language", utils.get_system_locale())
+ if 'subclip_videos' not in st.session_state:
+ st.session_state['subclip_videos'] = {}
def tr(key):
+ """翻译函数"""
+ i18n_dir = os.path.join(os.path.dirname(__file__), "webui", "i18n")
+ locales = utils.load_locales(i18n_dir)
loc = locales.get(st.session_state['ui_language'], {})
return loc.get("Translation", {}).get(key, key)
+def render_generate_button():
+ """渲染生成按钮和处理逻辑"""
+ if st.button(tr("Generate Video"), use_container_width=True, type="primary"):
+ try:
+ from app.services import task as tm
+ import torch
+
+ # 重置日志容器和记录
+ log_container = st.empty()
+ log_records = []
-st.write(tr("Get Help"))
+ def log_received(msg):
+ with log_container:
+ log_records.append(msg)
+ st.code("\n".join(log_records))
-# 基础设置
-with st.expander(tr("Basic Settings"), expanded=False):
- config_panels = st.columns(3)
- left_config_panel = config_panels[0]
- middle_config_panel = config_panels[1]
- right_config_panel = config_panels[2]
- with left_config_panel:
- display_languages = []
- selected_index = 0
- for i, code in enumerate(locales.keys()):
- display_languages.append(f"{code} - {locales[code].get('Language')}")
- if code == st.session_state['ui_language']:
- selected_index = i
+ from loguru import logger
+ logger.add(log_received)
- selected_language = st.selectbox(tr("Language"), options=display_languages,
- index=selected_index)
- if selected_language:
- code = selected_language.split(" - ")[0].strip()
- st.session_state['ui_language'] = code
- config.ui['language'] = code
+ config.save_config()
+ task_id = st.session_state.get('task_id')
- HTTP_PROXY = st.text_input(tr("HTTP_PROXY"), value=proxy_url_http)
- HTTPS_PROXY = st.text_input(tr("HTTPs_PROXY"), value=proxy_url_https)
- if HTTP_PROXY:
- config.proxy["http"] = HTTP_PROXY
- if HTTPS_PROXY:
- config.proxy["https"] = HTTPS_PROXY
+ if not task_id:
+ st.error(tr("请先裁剪视频"))
+ return
+ if not st.session_state.get('video_clip_json_path'):
+ st.error(tr("脚本文件不能为空"))
+ return
+ if not st.session_state.get('video_origin_path'):
+ st.error(tr("视频文件不能为空"))
+ return
- # 视频转录大模型
- with middle_config_panel:
- video_llm_providers = ['Gemini']
- saved_llm_provider = config.app.get("llm_provider", "OpenAI").lower()
- saved_llm_provider_index = 0
- for i, provider in enumerate(video_llm_providers):
- if provider.lower() == saved_llm_provider:
- saved_llm_provider_index = i
- break
+ st.toast(tr("生成视频"))
+ logger.info(tr("开始生成视频"))
- video_llm_provider = st.selectbox(tr("Video LLM Provider"), options=video_llm_providers, index=saved_llm_provider_index)
- video_llm_provider = video_llm_provider.lower()
- config.app["video_llm_provider"] = video_llm_provider
+ # 获取所有参数
+ script_params = script_settings.get_script_params()
+ video_params = video_settings.get_video_params()
+ audio_params = audio_settings.get_audio_params()
+ subtitle_params = subtitle_settings.get_subtitle_params()
- video_llm_api_key = config.app.get(f"{video_llm_provider}_api_key", "")
- video_llm_base_url = config.app.get(f"{video_llm_provider}_base_url", "")
- video_llm_model_name = config.app.get(f"{video_llm_provider}_model_name", "")
- video_llm_account_id = config.app.get(f"{video_llm_provider}_account_id", "")
- st_llm_api_key = st.text_input(tr("Video API Key"), value=video_llm_api_key, type="password")
- st_llm_base_url = st.text_input(tr("Video Base Url"), value=video_llm_base_url)
- st_llm_model_name = st.text_input(tr("Video Model Name"), value=video_llm_model_name)
- if st_llm_api_key:
- config.app[f"{video_llm_provider}_api_key"] = st_llm_api_key
- if st_llm_base_url:
- config.app[f"{video_llm_provider}_base_url"] = st_llm_base_url
- if st_llm_model_name:
- config.app[f"{video_llm_provider}_model_name"] = st_llm_model_name
+ # 合并所有参数
+ all_params = {
+ **script_params,
+ **video_params,
+ **audio_params,
+ **subtitle_params
+ }
- # 大语言模型
- with right_config_panel:
- llm_providers = ['Gemini', 'OpenAI', 'Moonshot', 'Azure', 'Qwen', 'Ollama', 'G4f', 'OneAPI', "Cloudflare"]
- saved_llm_provider = config.app.get("llm_provider", "OpenAI").lower()
- saved_llm_provider_index = 0
- for i, provider in enumerate(llm_providers):
- if provider.lower() == saved_llm_provider:
- saved_llm_provider_index = i
- break
+ # 创建参数对象
+ params = VideoClipParams(**all_params)
- llm_provider = st.selectbox(tr("LLM Provider"), options=llm_providers, index=saved_llm_provider_index)
- llm_provider = llm_provider.lower()
- config.app["llm_provider"] = llm_provider
-
- llm_api_key = config.app.get(f"{llm_provider}_api_key", "")
- llm_base_url = config.app.get(f"{llm_provider}_base_url", "")
- llm_model_name = config.app.get(f"{llm_provider}_model_name", "")
- llm_account_id = config.app.get(f"{llm_provider}_account_id", "")
- st_llm_api_key = st.text_input(tr("API Key"), value=llm_api_key, type="password")
- st_llm_base_url = st.text_input(tr("Base Url"), value=llm_base_url)
- st_llm_model_name = st.text_input(tr("Model Name"), value=llm_model_name)
- if st_llm_api_key:
- config.app[f"{llm_provider}_api_key"] = st_llm_api_key
- if st_llm_base_url:
- config.app[f"{llm_provider}_base_url"] = st_llm_base_url
- if st_llm_model_name:
- config.app[f"{llm_provider}_model_name"] = st_llm_model_name
-
- if llm_provider == 'cloudflare':
- st_llm_account_id = st.text_input(tr("Account ID"), value=llm_account_id)
- if st_llm_account_id:
- config.app[f"{llm_provider}_account_id"] = st_llm_account_id
-
-panel = st.columns(3)
-left_panel = panel[0]
-middle_panel = panel[1]
-right_panel = panel[2]
-
-params = VideoClipParams()
-
-# 左侧面板
-with left_panel:
- with st.container(border=True):
- st.write(tr("Video Script Configuration"))
- # 脚本语言
- video_languages = [
- (tr("Auto Detect"), ""),
- ]
- for code in ["zh-CN", "en-US", "zh-TW"]:
- video_languages.append((code, code))
-
- selected_index = st.selectbox(tr("Script Language"),
- index=0,
- options=range(len(video_languages)), # 使用索引作为内部选项值
- format_func=lambda x: video_languages[x][0] # 显示给用户的是标签
- )
- params.video_language = video_languages[selected_index][1]
-
- # 脚本路径
- suffix = "*.json"
- song_dir = utils.script_dir()
- files = glob.glob(os.path.join(song_dir, suffix))
- script_list = []
- for file in files:
- script_list.append({
- "name": os.path.basename(file),
- "size": os.path.getsize(file),
- "file": file,
- "ctime": os.path.getctime(file) # 获取文件创建时间
- })
-
- # 按创建时间降序排序
- script_list.sort(key=lambda x: x["ctime"], reverse=True)
-
- # 本文件 下拉框
- script_path = [(tr("Auto Generate"), ""), ]
- for file in script_list:
- display_name = file['file'].replace(root_dir, "")
- script_path.append((display_name, file['file']))
- selected_script_index = st.selectbox(tr("Script Files"),
- index=0,
- options=range(len(script_path)), # 使用索引作为内部选项值
- format_func=lambda x: script_path[x][0] # 显示给用户的是标签
- )
- params.video_clip_json_path = script_path[selected_script_index][1]
- config.app["video_clip_json_path"] = params.video_clip_json_path
- st.session_state['video_clip_json_path'] = params.video_clip_json_path
-
- # 视频文件处理
- video_files = []
- for suffix in ["*.mp4", "*.mov", "*.avi", "*.mkv"]:
- video_files.extend(glob.glob(os.path.join(utils.video_dir(), suffix)))
- video_files = video_files[::-1]
-
- video_list = []
- for video_file in video_files:
- video_list.append({
- "name": os.path.basename(video_file),
- "size": os.path.getsize(video_file),
- "file": video_file,
- "ctime": os.path.getctime(video_file) # 获取文件创建时间
- })
- # 按创建时间降序排序
- video_list.sort(key=lambda x: x["ctime"], reverse=True)
- video_path = [(tr("None"), ""), (tr("Upload Local Files"), "local")]
- for file in video_list:
- display_name = file['file'].replace(root_dir, "")
- video_path.append((display_name, file['file']))
-
- # 视频文件
- selected_video_index = st.selectbox(tr("Video File"),
- index=0,
- options=range(len(video_path)), # 使用索引作为内部选项值
- format_func=lambda x: video_path[x][0] # 显示给用户的是标签
- )
- params.video_origin_path = video_path[selected_video_index][1]
- config.app["video_origin_path"] = params.video_origin_path
- st.session_state['video_origin_path'] = params.video_origin_path
-
- # 从本地上传 mp4 文件
- if params.video_origin_path == "local":
- _supported_types = FILE_TYPE_VIDEOS
- uploaded_file = st.file_uploader(
- tr("Upload Local Files"),
- type=["mp4", "mov", "avi", "flv", "mkv"],
- accept_multiple_files=False,
+ result = tm.start_subclip(
+ task_id=task_id,
+ params=params,
+ subclip_path_videos=st.session_state['subclip_videos']
)
- if uploaded_file is not None:
- # 构造保存路径
- video_file_path = os.path.join(utils.video_dir(), uploaded_file.name)
- file_name, file_extension = os.path.splitext(uploaded_file.name)
- # 检查文件是否存在,如果存在则添加时间戳
- if os.path.exists(video_file_path):
- timestamp = time.strftime("%Y%m%d%H%M%S")
- file_name_with_timestamp = f"{file_name}_{timestamp}"
- video_file_path = os.path.join(utils.video_dir(), file_name_with_timestamp + file_extension)
- # 将文件保存到指定目录
- with open(video_file_path, "wb") as f:
- f.write(uploaded_file.read())
- st.success(tr("File Uploaded Successfully"))
- time.sleep(1)
- st.rerun()
- # 视频名称
- video_name = st.text_input(tr("Video Name"))
- # 剧情内容
- video_plot = st.text_area(
- tr("Plot Description"),
- value=st.session_state['video_plot'],
- height=180
- )
-
- # 生成视频脚本
- if st.session_state['video_clip_json_path']:
- generate_button_name = tr("Video Script Load")
- else:
- generate_button_name = tr("Video Script Generate")
- if st.button(generate_button_name, key="auto_generate_script"):
- progress_bar = st.progress(0)
- status_text = st.empty()
-
- def update_progress(progress: float, message: str = ""):
- progress_bar.progress(progress)
- if message:
- status_text.text(f"{progress}% - {message}")
- else:
- status_text.text(f"进度: {progress}%")
+ video_files = result.get("videos", [])
+ st.success(tr("视生成完成"))
+
try:
- with st.spinner("正在生成脚本..."):
- if not video_plot:
- st.warning("视频剧情为空; 会极大影响生成效果!")
- if params.video_clip_json_path == "" and params.video_origin_path != "":
- update_progress(10, "压缩视频中...")
- # 使用大模型生成视频脚本
- script = llm.generate_script(
- video_path=params.video_origin_path,
- video_plot=video_plot,
- video_name=video_name,
- language=params.video_language,
- progress_callback=update_progress
- )
- if script is None:
- st.error("生成脚本失败,请检查日志")
- st.stop()
- else:
- update_progress(90)
+ if video_files:
+ player_cols = st.columns(len(video_files) * 2 + 1)
+ for i, url in enumerate(video_files):
+ player_cols[i * 2 + 1].video(url)
+ except Exception as e:
+ logger.error(f"播放视频失败: {e}")
- script = utils.clean_model_output(script)
- st.session_state['video_clip_json'] = json.loads(script)
- else:
- # 从本地加载
- with open(params.video_clip_json_path, 'r', encoding='utf-8') as f:
- update_progress(50)
- status_text.text("从本地加载中...")
- script = f.read()
- script = utils.clean_model_output(script)
- st.session_state['video_clip_json'] = json.loads(script)
- update_progress(100)
- status_text.text("从本地加载成功")
+ file_utils.open_task_folder(config.root_dir, task_id)
+ logger.info(tr("视频生成完成"))
- time.sleep(0.5) # 给进度条一点时间到达100%
- progress_bar.progress(100)
- status_text.text("脚本生成完成!")
- st.success("视频脚本生成成功!")
- except Exception as err:
- st.error(f"生成过程中发生错误: {str(err)}")
- finally:
- time.sleep(2) # 给用户一些时间查看最终状态
- progress_bar.empty()
- status_text.empty()
+ finally:
+ PerformanceMonitor.cleanup_resources()
- # 视频脚本
- video_clip_json_details = st.text_area(
- tr("Video Script"),
- value=json.dumps(st.session_state.video_clip_json, indent=2, ensure_ascii=False),
- height=180
- )
+def main():
+ """主函数"""
+ init_log()
+ init_global_state()
+ utils.init_resources()
+
+ st.title(f"NarratoAI :sunglasses:📽️")
+ st.write(tr("Get Help"))
+
+ # 渲染基础设置面板
+ basic_settings.render_basic_settings(tr)
+
+ # 渲染主面板
+ panel = st.columns(3)
+ with panel[0]:
+ script_settings.render_script_panel(tr)
+ with panel[1]:
+ video_settings.render_video_panel(tr)
+ audio_settings.render_audio_panel(tr)
+ with panel[2]:
+ subtitle_settings.render_subtitle_panel(tr)
+
+ # 渲染视频审查面板
+ review_settings.render_review_panel(tr)
+
+ # 渲染生成按钮和处理逻辑
+ render_generate_button()
- # 保存脚本
- button_columns = st.columns(2)
- with button_columns[0]:
- if st.button(tr("Save Script"), key="auto_generate_terms", use_container_width=True):
- if not video_clip_json_details:
- st.error(tr("请输入视频脚本"))
- st.stop()
-
- with st.spinner(tr("Save Script")):
- script_dir = utils.script_dir()
- # 获取当前时间戳,形如 2024-0618-171820
- timestamp = datetime.datetime.now().strftime("%Y-%m%d-%H%M%S")
- save_path = os.path.join(script_dir, f"{timestamp}.json")
-
- try:
- data = utils.add_new_timestamps(json.loads(video_clip_json_details))
- except Exception as err:
- st.error(f"视频脚本格式错误,请检查脚本是否符合 JSON 格式;{err} \n\n{traceback.format_exc()}")
- st.stop()
-
- # 存储为新的 JSON 文件
- with open(save_path, 'w', encoding='utf-8') as file:
- json.dump(data, file, ensure_ascii=False, indent=4)
- # 将data的值存储到 session_state 中,类似缓存
- st.session_state['video_clip_json'] = data
- st.session_state['video_clip_json_path'] = save_path
- # 刷新页面
- st.rerun()
-
- # 裁剪视频
- with button_columns[1]:
- if st.button(tr("Crop Video"), key="auto_crop_video", use_container_width=True):
- progress_bar = st.progress(0)
- status_text = st.empty()
-
- def update_progress(progress):
- progress_bar.progress(progress)
- status_text.text(f"剪辑进度: {progress}%")
-
- try:
- utils.cut_video(params, update_progress)
- time.sleep(0.5) # 给进度条一点时间到达100%
- progress_bar.progress(100)
- status_text.text("剪辑完成!")
- st.success("视频剪辑成功完成!")
- except Exception as e:
- st.error(f"剪辑过程中发生错误: {str(e)}")
- finally:
- time.sleep(2) # 给用户一些时间查看最终状态
- progress_bar.empty()
- status_text.empty()
-
-# 新中间面板
-with middle_panel:
- with st.container(border=True):
- st.write(tr("Video Settings"))
-
- # 视频比例
- video_aspect_ratios = [
- (tr("Portrait"), VideoAspect.portrait.value),
- (tr("Landscape"), VideoAspect.landscape.value),
- ]
- selected_index = st.selectbox(
- tr("Video Ratio"),
- options=range(len(video_aspect_ratios)), # 使用索引作为内部选项值
- format_func=lambda x: video_aspect_ratios[x][0], # 显示给用户的是标签
- )
- params.video_aspect = VideoAspect(video_aspect_ratios[selected_index][1])
-
- # params.video_clip_duration = st.selectbox(
- # tr("Clip Duration"), options=[2, 3, 4, 5, 6, 7, 8, 9, 10], index=1
- # )
- # params.video_count = st.selectbox(
- # tr("Number of Videos Generated Simultaneously"),
- # options=[1, 2, 3, 4, 5],
- # index=0,
- # )
- with st.container(border=True):
- st.write(tr("Audio Settings"))
-
- # tts_providers = ['edge', 'azure']
- # tts_provider = st.selectbox(tr("TTS Provider"), tts_providers)
-
- voices = voice.get_all_azure_voices(filter_locals=support_locales)
- friendly_names = {
- v: v.replace("Female", tr("Female"))
- .replace("Male", tr("Male"))
- .replace("Neural", "")
- for v in voices
- }
- saved_voice_name = config.ui.get("voice_name", "")
- saved_voice_name_index = 0
- if saved_voice_name in friendly_names:
- saved_voice_name_index = list(friendly_names.keys()).index(saved_voice_name)
- else:
- for i, v in enumerate(voices):
- if (
- v.lower().startswith(st.session_state["ui_language"].lower())
- and "V2" not in v
- ):
- saved_voice_name_index = i
- break
-
- selected_friendly_name = st.selectbox(
- tr("Speech Synthesis"),
- options=list(friendly_names.values()),
- index=saved_voice_name_index,
- )
-
- voice_name = list(friendly_names.keys())[
- list(friendly_names.values()).index(selected_friendly_name)
- ]
- params.voice_name = voice_name
- config.ui["voice_name"] = voice_name
-
- # 试听语言合成
- if st.button(tr("Play Voice")):
- play_content = "这是一段试听语言"
- if not play_content:
- play_content = params.video_script
- if not play_content:
- play_content = tr("Voice Example")
- with st.spinner(tr("Synthesizing Voice")):
- temp_dir = utils.storage_dir("temp", create=True)
- audio_file = os.path.join(temp_dir, f"tmp-voice-{str(uuid4())}.mp3")
- sub_maker = voice.tts(
- text=play_content,
- voice_name=voice_name,
- voice_rate=params.voice_rate,
- voice_file=audio_file,
- )
- # if the voice file generation failed, try again with a default content.
- if not sub_maker:
- play_content = "This is a example voice. if you hear this, the voice synthesis failed with the original content."
- sub_maker = voice.tts(
- text=play_content,
- voice_name=voice_name,
- voice_rate=params.voice_rate,
- voice_file=audio_file,
- )
-
- if sub_maker and os.path.exists(audio_file):
- st.audio(audio_file, format="audio/mp3")
- if os.path.exists(audio_file):
- os.remove(audio_file)
-
- if voice.is_azure_v2_voice(voice_name):
- saved_azure_speech_region = config.azure.get("speech_region", "")
- saved_azure_speech_key = config.azure.get("speech_key", "")
- azure_speech_region = st.text_input(
- tr("Speech Region"), value=saved_azure_speech_region
- )
- azure_speech_key = st.text_input(
- tr("Speech Key"), value=saved_azure_speech_key, type="password"
- )
- config.azure["speech_region"] = azure_speech_region
- config.azure["speech_key"] = azure_speech_key
-
- params.voice_volume = st.selectbox(
- tr("Speech Volume"),
- options=[0.6, 0.8, 1.0, 1.2, 1.5, 2.0, 3.0, 4.0, 5.0],
- index=2,
- )
-
- params.voice_rate = st.selectbox(
- tr("Speech Rate"),
- options=[0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 1.8, 2.0],
- index=2,
- )
-
- bgm_options = [
- (tr("No Background Music"), ""),
- (tr("Random Background Music"), "random"),
- (tr("Custom Background Music"), "custom"),
- ]
- selected_index = st.selectbox(
- tr("Background Music"),
- index=1,
- options=range(len(bgm_options)), # 使用索引作为内部选项值
- format_func=lambda x: bgm_options[x][0], # 显示给用户的是标签
- )
- # 获取选择的背景音乐类型
- params.bgm_type = bgm_options[selected_index][1]
-
- # 根据选择显示或隐藏组件
- if params.bgm_type == "custom":
- custom_bgm_file = st.text_input(tr("Custom Background Music File"))
- if custom_bgm_file and os.path.exists(custom_bgm_file):
- params.bgm_file = custom_bgm_file
- # st.write(f":red[已选择自定义背景音乐]:**{custom_bgm_file}**")
- params.bgm_volume = st.selectbox(
- tr("Background Music Volume"),
- options=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
- index=2,
- )
-
-# 新侧面板
-with right_panel:
- with st.container(border=True):
- st.write(tr("Subtitle Settings"))
- params.subtitle_enabled = st.checkbox(tr("Enable Subtitles"), value=True)
- font_names = get_all_fonts()
- saved_font_name = config.ui.get("font_name", "")
- saved_font_name_index = 0
- if saved_font_name in font_names:
- saved_font_name_index = font_names.index(saved_font_name)
- params.font_name = st.selectbox(
- tr("Font"), font_names, index=saved_font_name_index
- )
- config.ui["font_name"] = params.font_name
-
- subtitle_positions = [
- (tr("Top"), "top"),
- (tr("Center"), "center"),
- (tr("Bottom"), "bottom"),
- (tr("Custom"), "custom"),
- ]
- selected_index = st.selectbox(
- tr("Position"),
- index=2,
- options=range(len(subtitle_positions)),
- format_func=lambda x: subtitle_positions[x][0],
- )
- params.subtitle_position = subtitle_positions[selected_index][1]
-
- if params.subtitle_position == "custom":
- custom_position = st.text_input(
- tr("Custom Position (% from top)"), value="70.0"
- )
- try:
- params.custom_position = float(custom_position)
- if params.custom_position < 0 or params.custom_position > 100:
- st.error(tr("Please enter a value between 0 and 100"))
- except ValueError:
- logger.error(f"输入的值无效: {traceback.format_exc()}")
- st.error(tr("Please enter a valid number"))
-
- font_cols = st.columns([0.3, 0.7])
- with font_cols[0]:
- saved_text_fore_color = config.ui.get("text_fore_color", "#FFFFFF")
- params.text_fore_color = st.color_picker(
- tr("Font Color"), saved_text_fore_color
- )
- config.ui["text_fore_color"] = params.text_fore_color
-
- with font_cols[1]:
- saved_font_size = config.ui.get("font_size", 60)
- params.font_size = st.slider(tr("Font Size"), 30, 100, saved_font_size)
- config.ui["font_size"] = params.font_size
-
- stroke_cols = st.columns([0.3, 0.7])
- with stroke_cols[0]:
- params.stroke_color = st.color_picker(tr("Stroke Color"), "#000000")
- with stroke_cols[1]:
- params.stroke_width = st.slider(tr("Stroke Width"), 0.0, 10.0, 1.5)
-
-# 视频编辑面板
-with st.expander(tr("Video Check"), expanded=False):
- try:
- video_list = st.session_state.video_clip_json
- except KeyError as e:
- video_list = []
-
- # 计算列数和行数
- num_videos = len(video_list)
- cols_per_row = 3
- rows = (num_videos + cols_per_row - 1) // cols_per_row # 向上取整计算行数
-
- # 使用容器展示视频
- for row in range(rows):
- cols = st.columns(cols_per_row)
- for col in range(cols_per_row):
- index = row * cols_per_row + col
- if index < num_videos:
- with cols[col]:
- video_info = video_list[index]
- video_path = video_info.get('path')
- if video_path is not None:
- initial_narration = video_info['narration']
- initial_picture = video_info['picture']
- initial_timestamp = video_info['timestamp']
-
- with open(video_path, 'rb') as video_file:
- video_bytes = video_file.read()
- st.video(video_bytes)
-
- # 可编辑的输入框
- text_panels = st.columns(2)
- with text_panels[0]:
- text1 = st.text_area(tr("timestamp"), value=initial_timestamp, height=20,
- key=f"timestamp_{index}")
- with text_panels[1]:
- text2 = st.text_area(tr("Picture description"), value=initial_picture, height=20,
- key=f"picture_{index}")
- text3 = st.text_area(tr("Narration"), value=initial_narration, height=100,
- key=f"narration_{index}")
-
- # 重新生成按钮
- if st.button(tr("Rebuild"), key=f"rebuild_{index}"):
- # 更新video_list中的对应项
- video_list[index]['timestamp'] = text1
- video_list[index]['picture'] = text2
- video_list[index]['narration'] = text3
-
- for video in video_list:
- if 'path' in video:
- del video['path']
- # 更新session_state以确保更改被保存
- st.session_state['video_clip_json'] = utils.to_json(video_list)
- # 替换原JSON 文件
- with open(params.video_clip_json_path, 'w', encoding='utf-8') as file:
- json.dump(video_list, file, ensure_ascii=False, indent=4)
- utils.cut_video(params, progress_callback=None)
- st.rerun()
-
-# 开始按钮
-start_button = st.button(tr("Generate Video"), use_container_width=True, type="primary")
-if start_button:
- # 重置日志容器和记录
- log_container = st.empty()
- log_records = []
-
- config.save_config()
- task_id = st.session_state.get('task_id')
- if st.session_state.get('video_script_json_path') is not None:
- params.video_clip_json = st.session_state.get('video_clip_json')
-
- logger.debug(f"当前的脚本文件为:{st.session_state.video_clip_json_path}")
- logger.debug(f"当前的视频文件为:{st.session_state.video_origin_path}")
- logger.debug(f"裁剪后是视频列表:{st.session_state.subclip_videos}")
-
- if not task_id:
- st.error(tr("请先裁剪视频"))
- scroll_to_bottom()
- st.stop()
- if not params.video_clip_json_path:
- st.error(tr("脚本文件不能为空"))
- scroll_to_bottom()
- st.stop()
- if not params.video_origin_path:
- st.error(tr("视频文件不能为空"))
- scroll_to_bottom()
- st.stop()
-
- def log_received(msg):
- with log_container:
- log_records.append(msg)
- st.code("\n".join(log_records))
-
- logger.add(log_received)
-
- st.toast(tr("生成视频"))
- logger.info(tr("开始生成视频"))
- logger.info(utils.to_json(params))
- scroll_to_bottom()
-
- result = tm.start_subclip(task_id=task_id, params=params, subclip_path_videos=st.session_state.subclip_videos)
-
- video_files = result.get("videos", [])
- st.success(tr("视频生成完成"))
- try:
- if video_files:
- # 将视频播放器居中
- player_cols = st.columns(len(video_files) * 2 + 1)
- for i, url in enumerate(video_files):
- player_cols[i * 2 + 1].video(url)
- except Exception as e:
- pass
-
- open_task_folder(task_id)
- logger.info(tr("视频生成完成"))
- scroll_to_bottom()
-
-config.save_config()
+if __name__ == "__main__":
+ main()
diff --git a/webui/__init__.py b/webui/__init__.py
new file mode 100644
index 0000000..3c0a334
--- /dev/null
+++ b/webui/__init__.py
@@ -0,0 +1,22 @@
+"""
+NarratoAI WebUI Package
+"""
+from webui.config.settings import config
+from webui.components import (
+ basic_settings,
+ video_settings,
+ audio_settings,
+ subtitle_settings
+)
+from webui.utils import cache, file_utils, performance
+
+__all__ = [
+ 'config',
+ 'basic_settings',
+ 'video_settings',
+ 'audio_settings',
+ 'subtitle_settings',
+ 'cache',
+ 'file_utils',
+ 'performance'
+]
\ No newline at end of file
diff --git a/webui/components/__init__.py b/webui/components/__init__.py
new file mode 100644
index 0000000..6aafcd7
--- /dev/null
+++ b/webui/components/__init__.py
@@ -0,0 +1,15 @@
+from .basic_settings import render_basic_settings
+from .script_settings import render_script_panel
+from .video_settings import render_video_panel
+from .audio_settings import render_audio_panel
+from .subtitle_settings import render_subtitle_panel
+from .review_settings import render_review_panel
+
+__all__ = [
+ 'render_basic_settings',
+ 'render_script_panel',
+ 'render_video_panel',
+ 'render_audio_panel',
+ 'render_subtitle_panel',
+ 'render_review_panel'
+]
\ No newline at end of file
diff --git a/webui/components/audio_settings.py b/webui/components/audio_settings.py
new file mode 100644
index 0000000..a189f65
--- /dev/null
+++ b/webui/components/audio_settings.py
@@ -0,0 +1,198 @@
+import streamlit as st
+import os
+from uuid import uuid4
+from app.config import config
+from app.services import voice
+from app.utils import utils
+from webui.utils.cache import get_songs_cache
+
+def render_audio_panel(tr):
+ """渲染音频设置面板"""
+ with st.container(border=True):
+ st.write(tr("Audio Settings"))
+
+ # 渲染TTS设置
+ render_tts_settings(tr)
+
+ # 渲染背景音乐设置
+ render_bgm_settings(tr)
+
+def render_tts_settings(tr):
+ """渲染TTS(文本转语音)设置"""
+ # 获取支持的语音列表
+ support_locales = ["zh-CN", "zh-HK", "zh-TW", "en-US"]
+ voices = voice.get_all_azure_voices(filter_locals=support_locales)
+
+ # 创建友好的显示名称
+ friendly_names = {
+ v: v.replace("Female", tr("Female"))
+ .replace("Male", tr("Male"))
+ .replace("Neural", "")
+ for v in voices
+ }
+
+ # 获取保存的语音设置
+ saved_voice_name = config.ui.get("voice_name", "")
+ saved_voice_name_index = 0
+
+ if saved_voice_name in friendly_names:
+ saved_voice_name_index = list(friendly_names.keys()).index(saved_voice_name)
+ else:
+ # 如果没有保存的设置,选择与UI语言匹配的第一个语音
+ for i, v in enumerate(voices):
+ if (v.lower().startswith(st.session_state["ui_language"].lower())
+ and "V2" not in v):
+ saved_voice_name_index = i
+ break
+
+ # 语音选择下拉框
+ selected_friendly_name = st.selectbox(
+ tr("Speech Synthesis"),
+ options=list(friendly_names.values()),
+ index=saved_voice_name_index,
+ )
+
+ # 获取实际的语音名称
+ voice_name = list(friendly_names.keys())[
+ list(friendly_names.values()).index(selected_friendly_name)
+ ]
+
+ # 保存设置
+ config.ui["voice_name"] = voice_name
+
+ # Azure V2语音特殊处理
+ if voice.is_azure_v2_voice(voice_name):
+ render_azure_v2_settings(tr)
+
+ # 语音参数设置
+ render_voice_parameters(tr)
+
+ # 试听按钮
+ render_voice_preview(tr, voice_name)
+
+def render_azure_v2_settings(tr):
+ """渲染Azure V2语音设置"""
+ saved_azure_speech_region = config.azure.get("speech_region", "")
+ saved_azure_speech_key = config.azure.get("speech_key", "")
+
+ azure_speech_region = st.text_input(
+ tr("Speech Region"),
+ value=saved_azure_speech_region
+ )
+ azure_speech_key = st.text_input(
+ tr("Speech Key"),
+ value=saved_azure_speech_key,
+ type="password"
+ )
+
+ config.azure["speech_region"] = azure_speech_region
+ config.azure["speech_key"] = azure_speech_key
+
+def render_voice_parameters(tr):
+ """渲染语音参数设置"""
+ # 音量
+ voice_volume = st.selectbox(
+ tr("Speech Volume"),
+ options=[0.6, 0.8, 1.0, 1.2, 1.5, 2.0, 3.0, 4.0, 5.0],
+ index=2,
+ )
+ st.session_state['voice_volume'] = voice_volume
+
+ # 语速
+ voice_rate = st.selectbox(
+ tr("Speech Rate"),
+ options=[0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 1.8, 2.0],
+ index=2,
+ )
+ st.session_state['voice_rate'] = voice_rate
+
+ # 音调
+ voice_pitch = st.selectbox(
+ tr("Speech Pitch"),
+ options=[0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 1.8, 2.0],
+ index=2,
+ )
+ st.session_state['voice_pitch'] = voice_pitch
+
+def render_voice_preview(tr, voice_name):
+ """渲染语音试听功能"""
+ if st.button(tr("Play Voice")):
+ play_content = "感谢关注 NarratoAI,有任何问题或建议,可以关注微信公众号,求助或讨论"
+ if not play_content:
+ play_content = st.session_state.get('video_script', '')
+ if not play_content:
+ play_content = tr("Voice Example")
+
+ with st.spinner(tr("Synthesizing Voice")):
+ temp_dir = utils.storage_dir("temp", create=True)
+ audio_file = os.path.join(temp_dir, f"tmp-voice-{str(uuid4())}.mp3")
+
+ sub_maker = voice.tts(
+ text=play_content,
+ voice_name=voice_name,
+ voice_rate=st.session_state.get('voice_rate', 1.0),
+ voice_pitch=st.session_state.get('voice_pitch', 1.0),
+ voice_file=audio_file,
+ )
+
+ # 如果语音文件生成失败,使用默认内容重试
+ if not sub_maker:
+ play_content = "This is a example voice. if you hear this, the voice synthesis failed with the original content."
+ sub_maker = voice.tts(
+ text=play_content,
+ voice_name=voice_name,
+ voice_rate=st.session_state.get('voice_rate', 1.0),
+ voice_pitch=st.session_state.get('voice_pitch', 1.0),
+ voice_file=audio_file,
+ )
+
+ if sub_maker and os.path.exists(audio_file):
+ st.audio(audio_file, format="audio/mp3")
+ if os.path.exists(audio_file):
+ os.remove(audio_file)
+
+def render_bgm_settings(tr):
+ """渲染背景音乐设置"""
+ # 背景音乐选项
+ bgm_options = [
+ (tr("No Background Music"), ""),
+ (tr("Random Background Music"), "random"),
+ (tr("Custom Background Music"), "custom"),
+ ]
+
+ selected_index = st.selectbox(
+ tr("Background Music"),
+ index=1,
+ options=range(len(bgm_options)),
+ format_func=lambda x: bgm_options[x][0],
+ )
+
+ # 获取选择的背景音乐类型
+ bgm_type = bgm_options[selected_index][1]
+ st.session_state['bgm_type'] = bgm_type
+
+ # 自定义背景音乐处理
+ if bgm_type == "custom":
+ custom_bgm_file = st.text_input(tr("Custom Background Music File"))
+ if custom_bgm_file and os.path.exists(custom_bgm_file):
+ st.session_state['bgm_file'] = custom_bgm_file
+
+ # 背景音乐音量
+ bgm_volume = st.selectbox(
+ tr("Background Music Volume"),
+ options=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
+ index=2,
+ )
+ st.session_state['bgm_volume'] = bgm_volume
+
+def get_audio_params():
+ """获取音频参数"""
+ return {
+ 'voice_name': config.ui.get("voice_name", ""),
+ 'voice_volume': st.session_state.get('voice_volume', 1.0),
+ 'voice_rate': st.session_state.get('voice_rate', 1.0),
+ 'voice_pitch': st.session_state.get('voice_pitch', 1.0),
+ 'bgm_type': st.session_state.get('bgm_type', 'random'),
+ 'bgm_file': st.session_state.get('bgm_file', ''),
+ 'bgm_volume': st.session_state.get('bgm_volume', 0.2),
+ }
\ No newline at end of file
diff --git a/webui/components/basic_settings.py b/webui/components/basic_settings.py
new file mode 100644
index 0000000..b8e69ad
--- /dev/null
+++ b/webui/components/basic_settings.py
@@ -0,0 +1,236 @@
+import streamlit as st
+import os
+from app.config import config
+from app.utils import utils
+
+
+def render_basic_settings(tr):
+ """渲染基础设置面板"""
+ with st.expander(tr("Basic Settings"), expanded=False):
+ config_panels = st.columns(3)
+ left_config_panel = config_panels[0]
+ middle_config_panel = config_panels[1]
+ right_config_panel = config_panels[2]
+
+ with left_config_panel:
+ render_language_settings(tr)
+ render_proxy_settings(tr)
+
+ with middle_config_panel:
+ render_vision_llm_settings(tr) # 视频分析模型设置
+
+ with right_config_panel:
+ render_text_llm_settings(tr) # 文案生成模型设置
+
+
+def render_language_settings(tr):
+ st.subheader(tr("Proxy Settings"))
+
+ """渲染语言设置"""
+ system_locale = utils.get_system_locale()
+ i18n_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "i18n")
+ locales = utils.load_locales(i18n_dir)
+
+ display_languages = []
+ selected_index = 0
+ for i, code in enumerate(locales.keys()):
+ display_languages.append(f"{code} - {locales[code].get('Language')}")
+ if code == st.session_state.get('ui_language', system_locale):
+ selected_index = i
+
+ selected_language = st.selectbox(
+ tr("Language"),
+ options=display_languages,
+ index=selected_index
+ )
+
+ if selected_language:
+ code = selected_language.split(" - ")[0].strip()
+ st.session_state['ui_language'] = code
+ config.ui['language'] = code
+
+
+def render_proxy_settings(tr):
+ """渲染代理设置"""
+ proxy_url_http = config.proxy.get("http", "") or os.getenv("VPN_PROXY_URL", "")
+ proxy_url_https = config.proxy.get("https", "") or os.getenv("VPN_PROXY_URL", "")
+
+ HTTP_PROXY = st.text_input(tr("HTTP_PROXY"), value=proxy_url_http)
+ HTTPS_PROXY = st.text_input(tr("HTTPs_PROXY"), value=proxy_url_https)
+
+ if HTTP_PROXY:
+ config.proxy["http"] = HTTP_PROXY
+ os.environ["HTTP_PROXY"] = HTTP_PROXY
+ if HTTPS_PROXY:
+ config.proxy["https"] = HTTPS_PROXY
+ os.environ["HTTPS_PROXY"] = HTTPS_PROXY
+
+
+def render_vision_llm_settings(tr):
+ """渲染视频分析模型设置"""
+ st.subheader(tr("Vision Model Settings"))
+
+ # 视频分析模型提供商选择
+ vision_providers = ['Gemini', 'NarratoAPI']
+ saved_vision_provider = config.app.get("vision_llm_provider", "Gemini").lower()
+ saved_provider_index = 0
+
+ for i, provider in enumerate(vision_providers):
+ if provider.lower() == saved_vision_provider:
+ saved_provider_index = i
+ break
+
+ vision_provider = st.selectbox(
+ tr("Vision Model Provider"),
+ options=vision_providers,
+ index=saved_provider_index
+ )
+ vision_provider = vision_provider.lower()
+ config.app["vision_llm_provider"] = vision_provider
+ st.session_state['vision_llm_providers'] = vision_provider
+
+ # 获取已保存的视觉模型配置
+ vision_api_key = config.app.get(f"vision_{vision_provider}_api_key", "")
+ vision_base_url = config.app.get(f"vision_{vision_provider}_base_url", "")
+ vision_model_name = config.app.get(f"vision_{vision_provider}_model_name", "")
+
+ # 渲染视觉模型配置输入框
+ st_vision_api_key = st.text_input(tr("Vision API Key"), value=vision_api_key, type="password")
+ st_vision_base_url = st.text_input(tr("Vision Base URL"), value=vision_base_url)
+ st_vision_model_name = st.text_input(tr("Vision Model Name"), value=vision_model_name)
+
+ # 保存视觉模型配置
+ if st_vision_api_key:
+ config.app[f"vision_{vision_provider}_api_key"] = st_vision_api_key
+ st.session_state[f"vision_{vision_provider}_api_key"] = st_vision_api_key # 用于script_settings.py
+ if st_vision_base_url:
+ config.app[f"vision_{vision_provider}_base_url"] = st_vision_base_url
+ st.session_state[f"vision_{vision_provider}_base_url"] = st_vision_base_url
+ if st_vision_model_name:
+ config.app[f"vision_{vision_provider}_model_name"] = st_vision_model_name
+ st.session_state[f"vision_{vision_provider}_model_name"] = st_vision_model_name
+
+ # NarratoAPI 特殊配置
+ if vision_provider == 'narratoapi':
+ st.subheader(tr("Narrato Additional Settings"))
+
+ # Narrato API 基础配置
+ narrato_api_key = st.text_input(
+ tr("Narrato API Key"),
+ value=config.app.get("narrato_api_key", ""),
+ type="password",
+ help="用于访问 Narrato API 的密钥"
+ )
+ if narrato_api_key:
+ config.app["narrato_api_key"] = narrato_api_key
+ st.session_state['narrato_api_key'] = narrato_api_key
+
+ narrato_api_url = st.text_input(
+ tr("Narrato API URL"),
+ value=config.app.get("narrato_api_url", "http://127.0.0.1:8000/api/v1/video/analyze")
+ )
+ if narrato_api_url:
+ config.app["narrato_api_url"] = narrato_api_url
+ st.session_state['narrato_api_url'] = narrato_api_url
+
+ # 视频分析模型配置
+ st.markdown("##### " + tr("Vision Model Settings"))
+ narrato_vision_model = st.text_input(
+ tr("Vision Model Name"),
+ value=config.app.get("narrato_vision_model", "gemini-1.5-flash")
+ )
+ narrato_vision_key = st.text_input(
+ tr("Vision Model API Key"),
+ value=config.app.get("narrato_vision_key", ""),
+ type="password",
+ help="用于视频分析的模型 API Key"
+ )
+
+ if narrato_vision_model:
+ config.app["narrato_vision_model"] = narrato_vision_model
+ st.session_state['narrato_vision_model'] = narrato_vision_model
+ if narrato_vision_key:
+ config.app["narrato_vision_key"] = narrato_vision_key
+ st.session_state['narrato_vision_key'] = narrato_vision_key
+
+ # 文案生成模型配置
+ st.markdown("##### " + tr("Text Generation Model Settings"))
+ narrato_llm_model = st.text_input(
+ tr("LLM Model Name"),
+ value=config.app.get("narrato_llm_model", "qwen-plus")
+ )
+ narrato_llm_key = st.text_input(
+ tr("LLM Model API Key"),
+ value=config.app.get("narrato_llm_key", ""),
+ type="password",
+ help="用于文案生成的模型 API Key"
+ )
+
+ if narrato_llm_model:
+ config.app["narrato_llm_model"] = narrato_llm_model
+ st.session_state['narrato_llm_model'] = narrato_llm_model
+ if narrato_llm_key:
+ config.app["narrato_llm_key"] = narrato_llm_key
+ st.session_state['narrato_llm_key'] = narrato_llm_key
+
+ # 批处理配置
+ narrato_batch_size = st.number_input(
+ tr("Batch Size"),
+ min_value=1,
+ max_value=50,
+ value=config.app.get("narrato_batch_size", 10),
+ help="每批处理的图片数量"
+ )
+ if narrato_batch_size:
+ config.app["narrato_batch_size"] = narrato_batch_size
+ st.session_state['narrato_batch_size'] = narrato_batch_size
+
+
+def render_text_llm_settings(tr):
+ """渲染文案生成模型设置"""
+ st.subheader(tr("Text Generation Model Settings"))
+
+ # 文案生成模型提供商选择
+ text_providers = ['OpenAI', 'Gemini', 'Moonshot', 'Azure', 'Qwen', 'Ollama', 'G4f', 'OneAPI', 'Cloudflare']
+ saved_text_provider = config.app.get("text_llm_provider", "OpenAI").lower()
+ saved_provider_index = 0
+
+ for i, provider in enumerate(text_providers):
+ if provider.lower() == saved_text_provider:
+ saved_provider_index = i
+ break
+
+ text_provider = st.selectbox(
+ tr("Text Model Provider"),
+ options=text_providers,
+ index=saved_provider_index
+ )
+ text_provider = text_provider.lower()
+ config.app["text_llm_provider"] = text_provider
+
+ # 获取已保存的文本模型配置
+ text_api_key = config.app.get(f"text_{text_provider}_api_key", "")
+ text_base_url = config.app.get(f"text_{text_provider}_base_url", "")
+ text_model_name = config.app.get(f"text_{text_provider}_model_name", "")
+
+ # 渲染文本模型配置输入框
+ st_text_api_key = st.text_input(tr("Text API Key"), value=text_api_key, type="password")
+ st_text_base_url = st.text_input(tr("Text Base URL"), value=text_base_url)
+ st_text_model_name = st.text_input(tr("Text Model Name"), value=text_model_name)
+
+ # 保存文本模型配置
+ if st_text_api_key:
+ config.app[f"text_{text_provider}_api_key"] = st_text_api_key
+ if st_text_base_url:
+ config.app[f"text_{text_provider}_base_url"] = st_text_base_url
+ if st_text_model_name:
+ config.app[f"text_{text_provider}_model_name"] = st_text_model_name
+
+ # Cloudflare 特殊配置
+ if text_provider == 'cloudflare':
+ st_account_id = st.text_input(
+ tr("Account ID"),
+ value=config.app.get(f"text_{text_provider}_account_id", "")
+ )
+ if st_account_id:
+ config.app[f"text_{text_provider}_account_id"] = st_account_id
diff --git a/webui/components/review_settings.py b/webui/components/review_settings.py
new file mode 100644
index 0000000..513f938
--- /dev/null
+++ b/webui/components/review_settings.py
@@ -0,0 +1,85 @@
+import streamlit as st
+import os
+from loguru import logger
+
+def render_review_panel(tr):
+ """渲染视频审查面板"""
+ with st.expander(tr("Video Check"), expanded=False):
+ try:
+ video_list = st.session_state.get('video_clip_json', [])
+ subclip_videos = st.session_state.get('subclip_videos', {})
+ except KeyError:
+ video_list = []
+ subclip_videos = {}
+
+ # 计算列数和行数
+ num_videos = len(video_list)
+ cols_per_row = 3
+ rows = (num_videos + cols_per_row - 1) // cols_per_row # 向上取整计算行数
+
+ # 使用容器展示视频
+ for row in range(rows):
+ cols = st.columns(cols_per_row)
+ for col in range(cols_per_row):
+ index = row * cols_per_row + col
+ if index < num_videos:
+ with cols[col]:
+ render_video_item(tr, video_list, subclip_videos, index)
+
+def render_video_item(tr, video_list, subclip_videos, index):
+ """渲染单个视频项"""
+ video_script = video_list[index]
+
+ # 显示时间戳
+ timestamp = video_script.get('timestamp', '')
+ st.text_area(
+ tr("Timestamp"),
+ value=timestamp,
+ height=70,
+ disabled=True,
+ key=f"timestamp_{index}"
+ )
+
+ # 显示视频播放器
+ video_path = subclip_videos.get(timestamp)
+ if video_path and os.path.exists(video_path):
+ try:
+ st.video(video_path)
+ except Exception as e:
+ logger.error(f"加载视频失败 {video_path}: {e}")
+ st.error(f"无法加载视频: {os.path.basename(video_path)}")
+ else:
+ st.warning(tr("视频文件未找到"))
+
+ # 显示画面描述
+ st.text_area(
+ tr("Picture Description"),
+ value=video_script.get('picture', ''),
+ height=150,
+ disabled=True,
+ key=f"picture_{index}"
+ )
+
+ # 显示旁白文本
+ narration = st.text_area(
+ tr("Narration"),
+ value=video_script.get('narration', ''),
+ height=150,
+ key=f"narration_{index}"
+ )
+ # 保存修改后的旁白文本
+ if narration != video_script.get('narration', ''):
+ video_script['narration'] = narration
+ st.session_state['video_clip_json'] = video_list
+
+ # 显示剪辑模式
+ ost = st.selectbox(
+ tr("Clip Mode"),
+ options=range(1, 10),
+ index=video_script.get('OST', 1) - 1,
+ key=f"ost_{index}"
+ )
+ # 保存修改后的剪辑模式
+ if ost != video_script.get('OST', 1):
+ video_script['OST'] = ost
+ st.session_state['video_clip_json'] = video_list
\ No newline at end of file
diff --git a/webui/components/script_settings.py b/webui/components/script_settings.py
new file mode 100644
index 0000000..fcc1913
--- /dev/null
+++ b/webui/components/script_settings.py
@@ -0,0 +1,652 @@
+import os
+import glob
+import json
+import time
+import asyncio
+import traceback
+
+import requests
+import streamlit as st
+from loguru import logger
+
+from app.config import config
+from app.models.schema import VideoClipParams
+from app.utils.script_generator import ScriptProcessor
+from app.utils import utils, check_script, vision_analyzer, video_processor
+from webui.utils import file_utils
+
+
+def render_script_panel(tr):
+ """渲染脚本配置面板"""
+ with st.container(border=True):
+ st.write(tr("Video Script Configuration"))
+ params = VideoClipParams()
+
+ # 渲染脚本文件选择
+ render_script_file(tr, params)
+
+ # 渲染视频文件选择
+ render_video_file(tr, params)
+
+ # 渲染视频主题和提示词
+ render_video_details(tr)
+
+ # 渲染脚本操作按钮
+ render_script_buttons(tr, params)
+
+
+def render_script_file(tr, params):
+ """渲染脚本文件选择"""
+ script_list = [(tr("None"), ""), (tr("Auto Generate"), "auto")]
+
+ # 获取已有脚本文件
+ suffix = "*.json"
+ script_dir = utils.script_dir()
+ files = glob.glob(os.path.join(script_dir, suffix))
+ file_list = []
+
+ for file in files:
+ file_list.append({
+ "name": os.path.basename(file),
+ "file": file,
+ "ctime": os.path.getctime(file)
+ })
+
+ file_list.sort(key=lambda x: x["ctime"], reverse=True)
+ for file in file_list:
+ display_name = file['file'].replace(config.root_dir, "")
+ script_list.append((display_name, file['file']))
+
+ # 找到保存的脚本文件在列表中的索引
+ saved_script_path = st.session_state.get('video_clip_json_path', '')
+ selected_index = 0
+ for i, (_, path) in enumerate(script_list):
+ if path == saved_script_path:
+ selected_index = i
+ break
+
+ selected_script_index = st.selectbox(
+ tr("Script Files"),
+ index=selected_index, # 使用找到的索引
+ options=range(len(script_list)),
+ format_func=lambda x: script_list[x][0]
+ )
+
+ script_path = script_list[selected_script_index][1]
+ st.session_state['video_clip_json_path'] = script_path
+ params.video_clip_json_path = script_path
+
+
+def render_video_file(tr, params):
+ """渲染视频文件选择"""
+ video_list = [(tr("None"), ""), (tr("Upload Local Files"), "local")]
+
+ # 获取已有视频文件
+ for suffix in ["*.mp4", "*.mov", "*.avi", "*.mkv"]:
+ video_files = glob.glob(os.path.join(utils.video_dir(), suffix))
+ for file in video_files:
+ display_name = file.replace(config.root_dir, "")
+ video_list.append((display_name, file))
+
+ selected_video_index = st.selectbox(
+ tr("Video File"),
+ index=0,
+ options=range(len(video_list)),
+ format_func=lambda x: video_list[x][0]
+ )
+
+ video_path = video_list[selected_video_index][1]
+ st.session_state['video_origin_path'] = video_path
+ params.video_origin_path = video_path
+
+ if video_path == "local":
+ uploaded_file = st.file_uploader(
+ tr("Upload Local Files"),
+ type=["mp4", "mov", "avi", "flv", "mkv"],
+ accept_multiple_files=False,
+ )
+
+ if uploaded_file is not None:
+ video_file_path = os.path.join(utils.video_dir(), uploaded_file.name)
+ file_name, file_extension = os.path.splitext(uploaded_file.name)
+
+ if os.path.exists(video_file_path):
+ timestamp = time.strftime("%Y%m%d%H%M%S")
+ file_name_with_timestamp = f"{file_name}_{timestamp}"
+ video_file_path = os.path.join(utils.video_dir(), file_name_with_timestamp + file_extension)
+
+ with open(video_file_path, "wb") as f:
+ f.write(uploaded_file.read())
+ st.success(tr("File Uploaded Successfully"))
+ st.session_state['video_origin_path'] = video_file_path
+ params.video_origin_path = video_file_path
+ time.sleep(1)
+ st.rerun()
+
+
+def render_video_details(tr):
+ """渲染视频主题和提示词"""
+ video_theme = st.text_input(tr("Video Theme"))
+ custom_prompt = st.text_area(
+ tr("Generation Prompt"),
+ value=st.session_state.get('video_plot', ''),
+ help=tr("Custom prompt for LLM, leave empty to use default prompt"),
+ height=180
+ )
+ st.session_state['video_theme'] = video_theme
+ st.session_state['custom_prompt'] = custom_prompt
+ return video_theme, custom_prompt
+
+
+def render_script_buttons(tr, params):
+ """渲染脚本操作按钮"""
+ # 生成/加载按钮
+ script_path = st.session_state.get('video_clip_json_path', '')
+ if script_path == "auto":
+ button_name = tr("Generate Video Script")
+ elif script_path:
+ button_name = tr("Load Video Script")
+ else:
+ button_name = tr("Please Select Script File")
+
+ if st.button(button_name, key="script_action", disabled=not script_path):
+ if script_path == "auto":
+ generate_script(tr, params)
+ else:
+ load_script(tr, script_path)
+
+ # 视频脚本编辑区
+ video_clip_json_details = st.text_area(
+ tr("Video Script"),
+ value=json.dumps(st.session_state.get('video_clip_json', []), indent=2, ensure_ascii=False),
+ height=180
+ )
+
+ # 操作按钮行
+ button_cols = st.columns(3)
+ with button_cols[0]:
+ if st.button(tr("Check Format"), key="check_format", use_container_width=True):
+ check_script_format(tr, video_clip_json_details)
+
+ with button_cols[1]:
+ if st.button(tr("Save Script"), key="save_script", use_container_width=True):
+ save_script(tr, video_clip_json_details)
+
+ with button_cols[2]:
+ script_valid = st.session_state.get('script_format_valid', False)
+ if st.button(tr("Crop Video"), key="crop_video", disabled=not script_valid, use_container_width=True):
+ crop_video(tr, params)
+
+
+def check_script_format(tr, script_content):
+ """检查脚本格式"""
+ try:
+ result = check_script.check_format(script_content)
+ if result.get('success'):
+ st.success(tr("Script format check passed"))
+ st.session_state['script_format_valid'] = True
+ else:
+ st.error(f"{tr('Script format check failed')}: {result.get('message')}")
+ st.session_state['script_format_valid'] = False
+ except Exception as e:
+ st.error(f"{tr('Script format check error')}: {str(e)}")
+ st.session_state['script_format_valid'] = False
+
+
+def load_script(tr, script_path):
+ """加载脚本文件"""
+ try:
+ with open(script_path, 'r', encoding='utf-8') as f:
+ script = f.read()
+ script = utils.clean_model_output(script)
+ st.session_state['video_clip_json'] = json.loads(script)
+ st.success(tr("Script loaded successfully"))
+ st.rerun()
+ except Exception as e:
+ st.error(f"{tr('Failed to load script')}: {str(e)}")
+
+
+def generate_script(tr, params):
+ """生成视频脚本"""
+ progress_bar = st.progress(0)
+ status_text = st.empty()
+
+ def update_progress(progress: float, message: str = ""):
+ progress_bar.progress(progress)
+ if message:
+ status_text.text(f"{progress}% - {message}")
+ else:
+ status_text.text(f"进度: {progress}%")
+
+ try:
+ with st.spinner("正在生成脚本..."):
+ if not params.video_origin_path:
+ st.error("请先选择视频文件")
+ st.stop()
+ return
+
+ # ===================提取键帧===================
+ update_progress(10, "正在提取关键帧...")
+
+ # 创建临时目录用于存储关键帧
+ keyframes_dir = os.path.join(utils.temp_dir(), "keyframes")
+ video_hash = utils.md5(params.video_origin_path + str(os.path.getmtime(params.video_origin_path)))
+ video_keyframes_dir = os.path.join(keyframes_dir, video_hash)
+
+ # 检查是否已经提取过关键帧
+ keyframe_files = []
+ if os.path.exists(video_keyframes_dir):
+ # 获取已有的关键帧文件
+ for filename in sorted(os.listdir(video_keyframes_dir)):
+ if filename.endswith('.jpg'):
+ keyframe_files.append(os.path.join(video_keyframes_dir, filename))
+
+ if keyframe_files:
+ logger.info(f"使用已缓存的关键帧: {video_keyframes_dir}")
+ st.info(f"使用已缓存的关键帧,如需重新提取请删除目录: {video_keyframes_dir}")
+ update_progress(20, f"使用已缓存关键帧,共 {len(keyframe_files)} 帧")
+
+ # 如果没有缓存的关键帧,则进行提取
+ if not keyframe_files:
+ try:
+ # 确保目录存在
+ os.makedirs(video_keyframes_dir, exist_ok=True)
+
+ # 初始化视频处理器
+ processor = video_processor.VideoProcessor(params.video_origin_path)
+
+ # 处理视频并提取关键帧
+ processor.process_video(
+ output_dir=video_keyframes_dir,
+ skip_seconds=0
+ )
+
+ # 获取所有关键帧文件路径
+ for filename in sorted(os.listdir(video_keyframes_dir)):
+ if filename.endswith('.jpg'):
+ keyframe_files.append(os.path.join(video_keyframes_dir, filename))
+
+ if not keyframe_files:
+ raise Exception("未提取到任何关键帧")
+
+ update_progress(20, f"关键帧提取完成,共 {len(keyframe_files)} 帧")
+
+ except Exception as e:
+ # 如果提取失败,清理创建的目录
+ try:
+ if os.path.exists(video_keyframes_dir):
+ import shutil
+ shutil.rmtree(video_keyframes_dir)
+ except Exception as cleanup_err:
+ logger.error(f"清理失败的关键帧目录时出错: {cleanup_err}")
+
+ raise Exception(f"关键帧提取失败: {str(e)}")
+
+ # 根据不同的 LLM 提供商处理
+ vision_llm_provider = st.session_state.get('vision_llm_providers').lower()
+ logger.debug(f"Vision LLM 提供商: {vision_llm_provider}")
+
+ if vision_llm_provider == 'gemini':
+ try:
+ # ===================初始化视觉分析器===================
+ update_progress(30, "正在初始化视觉分析器...")
+
+ # 从配置中获取 Gemini 相关配置
+ vision_api_key = st.session_state.get('vision_gemini_api_key')
+ vision_model = st.session_state.get('vision_gemini_model_name')
+ vision_base_url = st.session_state.get('vision_gemini_base_url')
+
+ if not vision_api_key or not vision_model:
+ raise ValueError("未配置 Gemini API Key 或者 模型,请在基础设置中配置")
+
+ analyzer = vision_analyzer.VisionAnalyzer(
+ model_name=vision_model,
+ api_key=vision_api_key
+ )
+
+ update_progress(40, "正在分析关键帧...")
+
+ # ===================创建异步事件循环===================
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+
+ # 执行异步分析
+ results = loop.run_until_complete(
+ analyzer.analyze_images(
+ images=keyframe_files,
+ prompt=config.app.get('vision_analysis_prompt'),
+ batch_size=config.app.get("vision_batch_size", 5)
+ )
+ )
+ loop.close()
+
+ # ===================处理分析结果===================
+ update_progress(60, "正在整理分析结果...")
+
+ # 合并所有批次的析结果
+ frame_analysis = ""
+ for result in results:
+ if 'error' in result:
+ logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
+ continue
+
+ # 获取当前批次的图片文件
+ batch_start = result['batch_index'] * config.app.get("vision_batch_size", 5)
+ batch_end = min(batch_start + config.app.get("vision_batch_size", 5), len(keyframe_files))
+ batch_files = keyframe_files[batch_start:batch_end]
+
+ # 提取首帧和尾帧的时间戳
+ first_frame = os.path.basename(batch_files[0])
+ last_frame = os.path.basename(batch_files[-1])
+
+ # 从文件名中提取时间信息
+ first_time = first_frame.split('_')[2].replace('.jpg', '') # 000002
+ last_time = last_frame.split('_')[2].replace('.jpg', '') # 000005
+
+ # 转换为分:秒格式
+ def format_timestamp(time_str):
+ seconds = int(time_str)
+ minutes = seconds // 60
+ seconds = seconds % 60
+ return f"{minutes:02d}:{seconds:02d}"
+
+ first_timestamp = format_timestamp(first_time)
+ last_timestamp = format_timestamp(last_time)
+
+ # 添加带时间戳的分析结果
+ frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n"
+ frame_analysis += result['response']
+ frame_analysis += "\n"
+
+ if not frame_analysis.strip():
+ raise Exception("未能生成有效的帧分析结果")
+
+ # 保存分析结果
+ analysis_path = os.path.join(utils.temp_dir(), "frame_analysis.txt")
+ with open(analysis_path, 'w', encoding='utf-8') as f:
+ f.write(frame_analysis)
+
+ update_progress(70, "正在生成脚本...")
+
+ # 从配置中获取文本生成相关配置
+ text_provider = config.app.get('text_llm_provider', 'gemini').lower()
+ text_api_key = config.app.get(f'text_{text_provider}_api_key')
+ text_model = config.app.get(f'text_{text_provider}_model_name')
+ text_base_url = config.app.get(f'text_{text_provider}_base_url')
+
+ # 构建帧内容列表
+ frame_content_list = []
+ for i, result in enumerate(results):
+ if 'error' in result:
+ continue
+
+ # 获取当前批次的图片文件
+ batch_start = result['batch_index'] * config.app.get("vision_batch_size", 5)
+ batch_end = min(batch_start + config.app.get("vision_batch_size", 5), len(keyframe_files))
+ batch_files = keyframe_files[batch_start:batch_end]
+
+ # 提取首帧和尾帧的时间戳
+ first_frame = os.path.basename(batch_files[0])
+ last_frame = os.path.basename(batch_files[-1])
+
+ # 从文件名中提取时间信息
+ first_time = first_frame.split('_')[2].replace('.jpg', '') # 000002
+ last_time = last_frame.split('_')[2].replace('.jpg', '') # 000005
+
+ # 转换为分:秒格式
+ def format_timestamp(time_str):
+ seconds = int(time_str)
+ minutes = seconds // 60
+ seconds = seconds % 60
+ return f"{minutes:02d}:{seconds:02d}"
+
+ first_timestamp = format_timestamp(first_time)
+ last_timestamp = format_timestamp(last_time)
+
+ # 构建时间戳范围字符串 (MM:SS-MM:SS 格式)
+ timestamp_range = f"{first_timestamp}-{last_timestamp}"
+
+ frame_content = {
+ "timestamp": timestamp_range, # 使用时间范围格式 "MM:SS-MM:SS"
+ "picture": result['response'], # 图片分析结果
+ "narration": "", # 将由 ScriptProcessor 生成
+ "OST": 2 # 默认值
+ }
+ frame_content_list.append(frame_content)
+
+ logger.debug(f"添加帧内容: 时间范围={timestamp_range}, 分析结果长度={len(result['response'])}")
+
+ if not frame_content_list:
+ raise Exception("没有有效的帧内容可以处理")
+
+ # ===================开始生成文案===================
+ update_progress(90, "正在生成文案...")
+ # 校验配置
+ api_params = {
+ 'batch_size': st.session_state.get('narrato_batch_size', 10),
+ 'use_ai': False,
+ 'start_offset': 0,
+ 'vision_model': vision_model,
+ 'vision_api_key': vision_api_key,
+ 'llm_model': text_model,
+ 'llm_api_key': text_api_key,
+ 'custom_prompt': st.session_state.get('custom_prompt', '')
+ }
+ response = requests.post(
+ f"{config.app.get('narrato_api_url')}/video/config",
+ params=api_params,
+ timeout=30,
+ verify=False
+ )
+ custom_prompt = st.session_state.get('custom_prompt', '')
+ processor = ScriptProcessor(
+ model_name=text_model,
+ api_key=text_api_key,
+ prompt=custom_prompt,
+ video_theme=st.session_state.get('video_theme', '')
+ )
+
+ # 处理帧内容生成脚本
+ script_result = processor.process_frames(frame_content_list)
+
+ # 将结果转换为JSON字符串
+ script = json.dumps(script_result, ensure_ascii=False, indent=2)
+
+ except Exception as e:
+ logger.exception(f"Gemini 处理过程中发生错误\n{traceback.format_exc()}")
+ raise Exception(f"视觉分析失败: {str(e)}")
+
+ elif vision_llm_provider == 'narratoapi': # NarratoAPI
+ try:
+ # 创建临时目录
+ temp_dir = utils.temp_dir("narrato")
+
+ # 打包关键帧
+ update_progress(30, "正在打包关键帧...")
+ zip_path = os.path.join(temp_dir, f"keyframes_{int(time.time())}.zip")
+ if not file_utils.create_zip(keyframe_files, zip_path):
+ raise Exception("打包关键帧失败")
+
+ # 获取API配置
+ api_url = st.session_state.get('narrato_api_url')
+ api_key = st.session_state.get('narrato_api_key')
+
+ if not api_key:
+ raise ValueError("未配置 Narrato API Key,请在基础设置中配置")
+
+ # 准备API请求
+ headers = {
+ 'X-API-Key': api_key,
+ 'accept': 'application/json'
+ }
+
+ api_params = {
+ 'batch_size': st.session_state.get('narrato_batch_size', 10),
+ 'use_ai': False,
+ 'start_offset': 0,
+ 'vision_model': st.session_state.get('narrato_vision_model', 'gemini-1.5-flash'),
+ 'vision_api_key': st.session_state.get('narrato_vision_key'),
+ 'llm_model': st.session_state.get('narrato_llm_model', 'qwen-plus'),
+ 'llm_api_key': st.session_state.get('narrato_llm_key'),
+ 'custom_prompt': st.session_state.get('custom_prompt', '')
+ }
+
+ # 发送API请求
+ logger.info(f"请求NarratoAPI: {api_url}")
+ update_progress(40, "正在上传文件...")
+ with open(zip_path, 'rb') as f:
+ files = {'file': (os.path.basename(zip_path), f, 'application/x-zip-compressed')}
+ try:
+ response = requests.post(
+ f"{api_url}/video/analyze",
+ headers=headers,
+ params=api_params,
+ files=files,
+ timeout=30 # 设置超时时间
+ )
+ response.raise_for_status()
+ except requests.RequestException as e:
+ logger.error(f"Narrato API 请求失败:\n{traceback.format_exc()}")
+ raise Exception(f"API请求失败: {str(e)}")
+
+ task_data = response.json()
+ task_id = task_data["data"].get('task_id')
+ if not task_id:
+ raise Exception(f"无效的API响应: {response.text}")
+
+ # 轮询任务状态
+ update_progress(50, "正在等待分析结果...")
+ retry_count = 0
+ max_retries = 60 # 最多等待2分钟
+
+ while retry_count < max_retries:
+ try:
+ status_response = requests.get(
+ f"{api_url}/video/tasks/{task_id}",
+ headers=headers,
+ timeout=10
+ )
+ status_response.raise_for_status()
+ task_status = status_response.json()['data']
+
+ if task_status['status'] == 'SUCCESS':
+ script = task_status['result']['data']
+ break
+ elif task_status['status'] in ['FAILURE', 'RETRY']:
+ raise Exception(f"任务失败: {task_status.get('error')}")
+
+ retry_count += 1
+ time.sleep(2)
+
+ except requests.RequestException as e:
+ logger.warning(f"获取任务状态失败,重试中: {str(e)}")
+ retry_count += 1
+ time.sleep(2)
+ continue
+
+ if retry_count >= max_retries:
+ raise Exception("任务执行超时")
+
+ except Exception as e:
+ logger.exception(f"NarratoAPI 处理过程中发生错误\n{traceback.format_exc()}")
+ raise Exception(f"NarratoAPI 处理失败: {str(e)}")
+ finally:
+ # 清理临时文件
+ try:
+ if os.path.exists(zip_path):
+ os.remove(zip_path)
+ except Exception as e:
+ logger.warning(f"清理临时文件失败: {str(e)}")
+
+ else:
+ logger.exception("Vision Model 未启用,请检查配置")
+
+ if script is None:
+ st.error("生成脚本失败,请检查日志")
+ st.stop()
+ logger.info(f"脚本生成完成\n{script} \n{type(script)}")
+ if isinstance(script, list):
+ st.session_state['video_clip_json'] = script
+ elif isinstance(script, str):
+ st.session_state['video_clip_json'] = json.loads(script)
+ update_progress(90, "脚本生成完成")
+
+ time.sleep(0.5)
+ progress_bar.progress(100)
+ status_text.text("脚本生成完成!")
+ st.success("视频脚本生成成功!")
+
+ except Exception as err:
+ st.error(f"生成过程中发生错误: {str(err)}")
+ logger.exception(f"生成脚本时发生错误\n{traceback.format_exc()}")
+ finally:
+ time.sleep(2)
+ progress_bar.empty()
+ status_text.empty()
+
+
+def save_script(tr, video_clip_json_details):
+ """保存视频脚本"""
+ if not video_clip_json_details:
+ st.error(tr("请输入视频脚本"))
+ st.stop()
+
+ with st.spinner(tr("Save Script")):
+ script_dir = utils.script_dir()
+ timestamp = time.strftime("%Y-%m%d-%H%M%S")
+ save_path = os.path.join(script_dir, f"{timestamp}.json")
+
+ try:
+ data = json.loads(video_clip_json_details)
+ with open(save_path, 'w', encoding='utf-8') as file:
+ json.dump(data, file, ensure_ascii=False, indent=4)
+ st.session_state['video_clip_json'] = data
+ st.session_state['video_clip_json_path'] = save_path
+
+ # 更新配置
+ config.app["video_clip_json_path"] = save_path
+
+ # 显示成功消息
+ st.success(tr("Script saved successfully"))
+
+ # 强制重新加载页面��更新选择框
+ time.sleep(0.5) # 给一点时间让用户看到成功消息
+ st.rerun()
+
+ except Exception as err:
+ st.error(f"{tr('Failed to save script')}: {str(err)}")
+ st.stop()
+
+
+def crop_video(tr, params):
+ """裁剪视频"""
+ progress_bar = st.progress(0)
+ status_text = st.empty()
+
+ def update_progress(progress):
+ progress_bar.progress(progress)
+ status_text.text(f"剪辑进度: {progress}%")
+
+ try:
+ utils.cut_video(params, update_progress)
+ time.sleep(0.5)
+ progress_bar.progress(100)
+ status_text.text("剪辑完成!")
+ st.success("视频剪辑成功完成!")
+ except Exception as e:
+ st.error(f"剪辑过程中发生错误: {str(e)}")
+ finally:
+ time.sleep(2)
+ progress_bar.empty()
+ status_text.empty()
+
+
+def get_script_params():
+ """获取脚本参数"""
+ return {
+ 'video_language': st.session_state.get('video_language', ''),
+ 'video_clip_json_path': st.session_state.get('video_clip_json_path', ''),
+ 'video_origin_path': st.session_state.get('video_origin_path', ''),
+ 'video_name': st.session_state.get('video_name', ''),
+ 'video_plot': st.session_state.get('video_plot', '')
+ }
diff --git a/webui/components/subtitle_settings.py b/webui/components/subtitle_settings.py
new file mode 100644
index 0000000..9b94e3c
--- /dev/null
+++ b/webui/components/subtitle_settings.py
@@ -0,0 +1,129 @@
+import streamlit as st
+from app.config import config
+from webui.utils.cache import get_fonts_cache
+import os
+
+def render_subtitle_panel(tr):
+ """渲染字幕设置面板"""
+ with st.container(border=True):
+ st.write(tr("Subtitle Settings"))
+
+ # 启用字幕选项
+ enable_subtitles = st.checkbox(tr("Enable Subtitles"), value=True)
+ st.session_state['subtitle_enabled'] = enable_subtitles
+
+ if enable_subtitles:
+ render_font_settings(tr)
+ render_position_settings(tr)
+ render_style_settings(tr)
+
+def render_font_settings(tr):
+ """渲染字体设置"""
+ # 获取字体列表
+ font_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "resource", "fonts")
+ font_names = get_fonts_cache(font_dir)
+
+ # 获取保存的字体设置
+ saved_font_name = config.ui.get("font_name", "")
+ saved_font_name_index = 0
+ if saved_font_name in font_names:
+ saved_font_name_index = font_names.index(saved_font_name)
+
+ # 字体选择
+ font_name = st.selectbox(
+ tr("Font"),
+ options=font_names,
+ index=saved_font_name_index
+ )
+ config.ui["font_name"] = font_name
+ st.session_state['font_name'] = font_name
+
+ # 字体大小
+ font_cols = st.columns([0.3, 0.7])
+ with font_cols[0]:
+ saved_text_fore_color = config.ui.get("text_fore_color", "#FFFFFF")
+ text_fore_color = st.color_picker(
+ tr("Font Color"),
+ saved_text_fore_color
+ )
+ config.ui["text_fore_color"] = text_fore_color
+ st.session_state['text_fore_color'] = text_fore_color
+
+ with font_cols[1]:
+ saved_font_size = config.ui.get("font_size", 60)
+ font_size = st.slider(
+ tr("Font Size"),
+ min_value=30,
+ max_value=100,
+ value=saved_font_size
+ )
+ config.ui["font_size"] = font_size
+ st.session_state['font_size'] = font_size
+
+def render_position_settings(tr):
+ """渲染位置设置"""
+ subtitle_positions = [
+ (tr("Top"), "top"),
+ (tr("Center"), "center"),
+ (tr("Bottom"), "bottom"),
+ (tr("Custom"), "custom"),
+ ]
+
+ selected_index = st.selectbox(
+ tr("Position"),
+ index=2,
+ options=range(len(subtitle_positions)),
+ format_func=lambda x: subtitle_positions[x][0],
+ )
+
+ subtitle_position = subtitle_positions[selected_index][1]
+ st.session_state['subtitle_position'] = subtitle_position
+
+ # 自定义位置处理
+ if subtitle_position == "custom":
+ custom_position = st.text_input(
+ tr("Custom Position (% from top)"),
+ value="70.0"
+ )
+ try:
+ custom_position_value = float(custom_position)
+ if custom_position_value < 0 or custom_position_value > 100:
+ st.error(tr("Please enter a value between 0 and 100"))
+ else:
+ st.session_state['custom_position'] = custom_position_value
+ except ValueError:
+ st.error(tr("Please enter a valid number"))
+
+def render_style_settings(tr):
+ """渲染样式设置"""
+ stroke_cols = st.columns([0.3, 0.7])
+
+ with stroke_cols[0]:
+ stroke_color = st.color_picker(
+ tr("Stroke Color"),
+ value="#000000"
+ )
+ st.session_state['stroke_color'] = stroke_color
+
+ with stroke_cols[1]:
+ stroke_width = st.slider(
+ tr("Stroke Width"),
+ min_value=0.0,
+ max_value=10.0,
+ value=1.5,
+ step=0.1
+ )
+ st.session_state['stroke_width'] = stroke_width
+
+def get_subtitle_params():
+ """获取字幕参数"""
+ return {
+ 'enabled': st.session_state.get('subtitle_enabled', True),
+ 'font_name': st.session_state.get('font_name', ''),
+ 'font_size': st.session_state.get('font_size', 60),
+ 'text_fore_color': st.session_state.get('text_fore_color', '#FFFFFF'),
+ 'position': st.session_state.get('subtitle_position', 'bottom'),
+ 'custom_position': st.session_state.get('custom_position', 70.0),
+ 'stroke_color': st.session_state.get('stroke_color', '#000000'),
+ 'stroke_width': st.session_state.get('stroke_width', 1.5),
+ }
\ No newline at end of file
diff --git a/webui/components/video_settings.py b/webui/components/video_settings.py
new file mode 100644
index 0000000..7942bee
--- /dev/null
+++ b/webui/components/video_settings.py
@@ -0,0 +1,47 @@
+import streamlit as st
+from app.models.schema import VideoClipParams, VideoAspect
+
+def render_video_panel(tr):
+ """渲染视频配置面板"""
+ with st.container(border=True):
+ st.write(tr("Video Settings"))
+ params = VideoClipParams()
+ render_video_config(tr, params)
+
+def render_video_config(tr, params):
+ """渲染视频配置"""
+ # 视频比例
+ video_aspect_ratios = [
+ (tr("Portrait"), VideoAspect.portrait.value),
+ (tr("Landscape"), VideoAspect.landscape.value),
+ ]
+ selected_index = st.selectbox(
+ tr("Video Ratio"),
+ options=range(len(video_aspect_ratios)),
+ format_func=lambda x: video_aspect_ratios[x][0],
+ )
+ params.video_aspect = VideoAspect(video_aspect_ratios[selected_index][1])
+ st.session_state['video_aspect'] = params.video_aspect.value
+
+ # 视频画质
+ video_qualities = [
+ ("4K (2160p)", "2160p"),
+ ("2K (1440p)", "1440p"),
+ ("Full HD (1080p)", "1080p"),
+ ("HD (720p)", "720p"),
+ ("SD (480p)", "480p"),
+ ]
+ quality_index = st.selectbox(
+ tr("Video Quality"),
+ options=range(len(video_qualities)),
+ format_func=lambda x: video_qualities[x][0],
+ index=2 # 默认选择 1080p
+ )
+ st.session_state['video_quality'] = video_qualities[quality_index][1]
+
+def get_video_params():
+ """获取视频参数"""
+ return {
+ 'video_aspect': st.session_state.get('video_aspect', VideoAspect.portrait.value),
+ 'video_quality': st.session_state.get('video_quality', '1080p')
+ }
\ No newline at end of file
diff --git a/webui/config/settings.py b/webui/config/settings.py
new file mode 100644
index 0000000..a4b3ada
--- /dev/null
+++ b/webui/config/settings.py
@@ -0,0 +1,167 @@
+import os
+import tomli
+from loguru import logger
+from typing import Dict, Any, Optional
+from dataclasses import dataclass
+
+@dataclass
+class WebUIConfig:
+ """WebUI配置类"""
+ # UI配置
+ ui: Dict[str, Any] = None
+ # 代理配置
+ proxy: Dict[str, str] = None
+ # 应用配置
+ app: Dict[str, Any] = None
+ # Azure配置
+ azure: Dict[str, str] = None
+ # 项目版本
+ project_version: str = "0.1.0"
+ # 项目根目录
+ root_dir: str = None
+ # Gemini API Key
+ gemini_api_key: str = ""
+ # 每批处理的图片数量
+ vision_batch_size: int = 5
+ # 提示词
+ vision_prompt: str = """..."""
+ # Narrato API 配置
+ narrato_api_url: str = "http://127.0.0.1:8000/api/v1/video/analyze"
+ narrato_api_key: str = ""
+ narrato_batch_size: int = 10
+ narrato_vision_model: str = "gemini-1.5-flash"
+ narrato_llm_model: str = "qwen-plus"
+
+ def __post_init__(self):
+ """初始化默认值"""
+ self.ui = self.ui or {}
+ self.proxy = self.proxy or {}
+ self.app = self.app or {}
+ self.azure = self.azure or {}
+ self.root_dir = self.root_dir or os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
+
+def load_config(config_path: Optional[str] = None) -> WebUIConfig:
+ """加载配置文件
+ Args:
+ config_path: 配置文件路径,如果为None则使用默认路径
+ Returns:
+ WebUIConfig: 配置对象
+ """
+ try:
+ if config_path is None:
+ config_path = os.path.join(
+ os.path.dirname(os.path.dirname(__file__)),
+ ".streamlit",
+ "webui.toml"
+ )
+
+ # 如果配置文件不存在,使用示例配置
+ if not os.path.exists(config_path):
+ example_config = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
+ "config.example.toml"
+ )
+ if os.path.exists(example_config):
+ config_path = example_config
+ else:
+ logger.warning(f"配置文件不存在: {config_path}")
+ return WebUIConfig()
+
+ # 读取配置文件
+ with open(config_path, "rb") as f:
+ config_dict = tomli.load(f)
+
+ # 创建配置对象
+ config = WebUIConfig(
+ ui=config_dict.get("ui", {}),
+ proxy=config_dict.get("proxy", {}),
+ app=config_dict.get("app", {}),
+ azure=config_dict.get("azure", {}),
+ project_version=config_dict.get("project_version", "0.1.0")
+ )
+
+ return config
+
+ except Exception as e:
+ logger.error(f"加载配置文件失败: {e}")
+ return WebUIConfig()
+
+def save_config(config: WebUIConfig, config_path: Optional[str] = None) -> bool:
+ """保存配置到文件
+ Args:
+ config: 配置对象
+ config_path: 配置文件路径,如果为None则使用默认路径
+ Returns:
+ bool: 是否保存成功
+ """
+ try:
+ if config_path is None:
+ config_path = os.path.join(
+ os.path.dirname(os.path.dirname(__file__)),
+ ".streamlit",
+ "webui.toml"
+ )
+
+ # 确保目录存在
+ os.makedirs(os.path.dirname(config_path), exist_ok=True)
+
+ # 转换为字典
+ config_dict = {
+ "ui": config.ui,
+ "proxy": config.proxy,
+ "app": config.app,
+ "azure": config.azure,
+ "project_version": config.project_version
+ }
+
+ # 保存配置
+ with open(config_path, "w", encoding="utf-8") as f:
+ import tomli_w
+ tomli_w.dump(config_dict, f)
+
+ return True
+
+ except Exception as e:
+ logger.error(f"保存配置文件失败: {e}")
+ return False
+
+def get_config() -> WebUIConfig:
+ """获取全局配置对象
+ Returns:
+ WebUIConfig: 配置对象
+ """
+ if not hasattr(get_config, "_config"):
+ get_config._config = load_config()
+ return get_config._config
+
+def update_config(config_dict: Dict[str, Any]) -> bool:
+ """更新配置
+ Args:
+ config_dict: 配置字典
+ Returns:
+ bool: 是否更新成功
+ """
+ try:
+ config = get_config()
+
+ # 更新配置
+ if "ui" in config_dict:
+ config.ui.update(config_dict["ui"])
+ if "proxy" in config_dict:
+ config.proxy.update(config_dict["proxy"])
+ if "app" in config_dict:
+ config.app.update(config_dict["app"])
+ if "azure" in config_dict:
+ config.azure.update(config_dict["azure"])
+ if "project_version" in config_dict:
+ config.project_version = config_dict["project_version"]
+
+ # 保存配置
+ return save_config(config)
+
+ except Exception as e:
+ logger.error(f"更新配置失败: {e}")
+ return False
+
+# 导出全局配置对象
+config = get_config()
\ No newline at end of file
diff --git a/webui/i18n/__init__.py b/webui/i18n/__init__.py
new file mode 100644
index 0000000..0f05c76
--- /dev/null
+++ b/webui/i18n/__init__.py
@@ -0,0 +1 @@
+# 空文件,用于标记包
\ No newline at end of file
diff --git a/webui/i18n/zh.json b/webui/i18n/zh.json
index cbad21b..48b50cf 100644
--- a/webui/i18n/zh.json
+++ b/webui/i18n/zh.json
@@ -2,20 +2,20 @@
"Language": "简体中文",
"Translation": {
"Video Script Configuration": "**视频脚本配置**",
- "Video Script Generate": "生成视频脚本",
+ "Generate Video Script": "生成视频脚本",
"Video Subject": "视频主题(给定一个关键词,:red[AI自动生成]视频文案)",
"Script Language": "生成视频脚本的语言(一般情况AI会自动根据你输入的主题语言输出)",
"Script Files": "脚本文件",
"Generate Video Script and Keywords": "点击使用AI根据**主题**生成 【视频文案】 和 【视频关键词】",
"Auto Detect": "自动检测",
"Auto Generate": "自动生成",
- "Video Name": "视频名称",
- "Video Script": "视频脚本(:blue[①使用AI生成 ②从本机加载])",
+ "Video Theme": "视频主题",
+ "Generation Prompt": "自定义提示词",
"Save Script": "保存脚本",
"Crop Video": "裁剪视频",
"Video File": "视频文件(:blue[1️⃣支持上传视频文件(限制2G) 2️⃣大文件建议直接导入 ./resource/videos 目录])",
"Plot Description": "剧情描述 (:blue[可从 https://www.tvmao.com/ 获取])",
- "Generate Video Keywords": "点击使用AI根据**文案**生成【视频关键词】",
+ "Generate Video Keywords": "点击使用AI根据**文案**生成【视频关键��】",
"Please Enter the Video Subject": "请先填写视频文案",
"Generating Video Script and Keywords": "AI正在生成视频文案和关键词...",
"Generating Video Keywords": "AI正在生成视频关键词...",
@@ -91,6 +91,40 @@
"Picture description": "图片描述",
"Narration": "视频文案",
"Rebuild": "重新生成",
- "Video Script Load": "加载视频脚本"
+ "Load Video Script": "加载视频脚本",
+ "Speech Pitch": "语调",
+ "Please Select Script File": "请选择脚本文件",
+ "Check Format": "脚本格式检查",
+ "Script Loaded Successfully": "脚本加载成功",
+ "Script format check passed": "脚本格式检查通过",
+ "Script format check failed": "脚本格式检查失��",
+ "Failed to Load Script": "加载脚本失败",
+ "Failed to Save Script": "保存脚本失败",
+ "Script saved successfully": "脚本保存成功",
+ "Video Script": "视频脚本",
+ "Video Quality": "视频质量",
+ "Custom prompt for LLM, leave empty to use default prompt": "自定义提示词,留空则使用默认提示词",
+ "Basic Settings": "基础设置",
+ "Proxy Settings": "代理设置",
+ "Language": "界面语言",
+ "HTTP_PROXY": "HTTP 代理",
+ "HTTPs_PROXY": "HTTPS 代理",
+ "Vision Model Settings": "视频分析模型设置",
+ "Vision Model Provider": "视频分析模型提供商",
+ "Vision API Key": "视频分析 API 密钥",
+ "Vision Base URL": "视频分析接口地址",
+ "Vision Model Name": "视频分析模型名称",
+ "Narrato Additional Settings": "Narrato 附加设置",
+ "Narrato API Key": "Narrato API 密钥",
+ "Narrato API URL": "Narrato API 地址",
+ "Text Generation Model Settings": "文案生成模型设置",
+ "LLM Model Name": "大语言模型名称",
+ "LLM Model API Key": "大语言模型 API 密钥",
+ "Batch Size": "批处理大小",
+ "Text Model Provider": "文案生成模型提供商",
+ "Text API Key": "文案生成 API 密钥",
+ "Text Base URL": "文案生成接口地址",
+ "Text Model Name": "文案生成模型名称",
+ "Account ID": "账户 ID"
}
}
\ No newline at end of file
diff --git a/webui/utils/__init__.py b/webui/utils/__init__.py
new file mode 100644
index 0000000..74dd09d
--- /dev/null
+++ b/webui/utils/__init__.py
@@ -0,0 +1,8 @@
+from .performance import monitor_performance, PerformanceMonitor
+from .cache import *
+from .file_utils import *
+
+__all__ = [
+ 'monitor_performance',
+ 'PerformanceMonitor'
+]
\ No newline at end of file
diff --git a/webui/utils/cache.py b/webui/utils/cache.py
new file mode 100644
index 0000000..6cc3b05
--- /dev/null
+++ b/webui/utils/cache.py
@@ -0,0 +1,33 @@
+import streamlit as st
+import os
+import glob
+from app.utils import utils
+
+def get_fonts_cache(font_dir):
+ if 'fonts_cache' not in st.session_state:
+ fonts = []
+ for root, dirs, files in os.walk(font_dir):
+ for file in files:
+ if file.endswith(".ttf") or file.endswith(".ttc"):
+ fonts.append(file)
+ fonts.sort()
+ st.session_state['fonts_cache'] = fonts
+ return st.session_state['fonts_cache']
+
+def get_video_files_cache():
+ if 'video_files_cache' not in st.session_state:
+ video_files = []
+ for suffix in ["*.mp4", "*.mov", "*.avi", "*.mkv"]:
+ video_files.extend(glob.glob(os.path.join(utils.video_dir(), suffix)))
+ st.session_state['video_files_cache'] = video_files[::-1]
+ return st.session_state['video_files_cache']
+
+def get_songs_cache(song_dir):
+ if 'songs_cache' not in st.session_state:
+ songs = []
+ for root, dirs, files in os.walk(song_dir):
+ for file in files:
+ if file.endswith(".mp3"):
+ songs.append(file)
+ st.session_state['songs_cache'] = songs
+ return st.session_state['songs_cache']
\ No newline at end of file
diff --git a/webui/utils/file_utils.py b/webui/utils/file_utils.py
new file mode 100644
index 0000000..b6b1238
--- /dev/null
+++ b/webui/utils/file_utils.py
@@ -0,0 +1,230 @@
+import os
+import glob
+import time
+import platform
+import shutil
+from uuid import uuid4
+from loguru import logger
+from app.utils import utils
+
+def open_task_folder(root_dir, task_id):
+ """打开任务文件夹
+ Args:
+ root_dir: 项目根目录
+ task_id: 任务ID
+ """
+ try:
+ sys = platform.system()
+ path = os.path.join(root_dir, "storage", "tasks", task_id)
+ if os.path.exists(path):
+ if sys == 'Windows':
+ os.system(f"start {path}")
+ if sys == 'Darwin':
+ os.system(f"open {path}")
+ if sys == 'Linux':
+ os.system(f"xdg-open {path}")
+ except Exception as e:
+ logger.error(f"打开任务文件夹失败: {e}")
+
+def cleanup_temp_files(temp_dir, max_age=3600):
+ """清理临时文件
+ Args:
+ temp_dir: 临时文件目录
+ max_age: 文件最大保存时间(秒)
+ """
+ if os.path.exists(temp_dir):
+ for file in os.listdir(temp_dir):
+ file_path = os.path.join(temp_dir, file)
+ try:
+ if os.path.getctime(file_path) < time.time() - max_age:
+ if os.path.isfile(file_path):
+ os.remove(file_path)
+ elif os.path.isdir(file_path):
+ shutil.rmtree(file_path)
+ logger.debug(f"已清理临时文件: {file_path}")
+ except Exception as e:
+ logger.error(f"清理临时文件失败: {file_path}, 错误: {e}")
+
+def get_file_list(directory, file_types=None, sort_by='ctime', reverse=True):
+ """获取指定目录下的文件列表
+ Args:
+ directory: 目录路径
+ file_types: 文件类型列表,如 ['.mp4', '.mov']
+ sort_by: 排序方式,支持 'ctime'(创建时间), 'mtime'(修改时间), 'size'(文件大小), 'name'(文件名)
+ reverse: 是否倒序排序
+ Returns:
+ list: 文件信息列表
+ """
+ if not os.path.exists(directory):
+ return []
+
+ files = []
+ if file_types:
+ for file_type in file_types:
+ files.extend(glob.glob(os.path.join(directory, f"*{file_type}")))
+ else:
+ files = glob.glob(os.path.join(directory, "*"))
+
+ file_list = []
+ for file_path in files:
+ try:
+ file_stat = os.stat(file_path)
+ file_info = {
+ "name": os.path.basename(file_path),
+ "path": file_path,
+ "size": file_stat.st_size,
+ "ctime": file_stat.st_ctime,
+ "mtime": file_stat.st_mtime
+ }
+ file_list.append(file_info)
+ except Exception as e:
+ logger.error(f"获取文件信息失败: {file_path}, 错误: {e}")
+
+ # 排序
+ if sort_by in ['ctime', 'mtime', 'size', 'name']:
+ file_list.sort(key=lambda x: x.get(sort_by, ''), reverse=reverse)
+
+ return file_list
+
+def save_uploaded_file(uploaded_file, save_dir, allowed_types=None):
+ """保存上传的文件
+ Args:
+ uploaded_file: StreamlitUploadedFile对象
+ save_dir: 保存目录
+ allowed_types: 允许的文件类型列表,如 ['.mp4', '.mov']
+ Returns:
+ str: 保存后的文件路径,失败返回None
+ """
+ try:
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+
+ file_name, file_extension = os.path.splitext(uploaded_file.name)
+
+ # 检查文件类型
+ if allowed_types and file_extension.lower() not in allowed_types:
+ logger.error(f"不支持的文件类型: {file_extension}")
+ return None
+
+ # 如果文件已存在,添加时间戳
+ save_path = os.path.join(save_dir, uploaded_file.name)
+ if os.path.exists(save_path):
+ timestamp = time.strftime("%Y%m%d%H%M%S")
+ new_file_name = f"{file_name}_{timestamp}{file_extension}"
+ save_path = os.path.join(save_dir, new_file_name)
+
+ # 保存文件
+ with open(save_path, "wb") as f:
+ f.write(uploaded_file.read())
+
+ logger.info(f"文件保存成功: {save_path}")
+ return save_path
+
+ except Exception as e:
+ logger.error(f"保存上传文件失败: {e}")
+ return None
+
+def create_temp_file(prefix='tmp', suffix='', directory=None):
+ """创建临时文件
+ Args:
+ prefix: 文件名前缀
+ suffix: 文件扩展名
+ directory: 临时文件目录,默认使用系统临时目录
+ Returns:
+ str: 临时文件路径
+ """
+ try:
+ if directory is None:
+ directory = utils.storage_dir("temp", create=True)
+
+ if not os.path.exists(directory):
+ os.makedirs(directory)
+
+ temp_file = os.path.join(directory, f"{prefix}-{str(uuid4())}{suffix}")
+ return temp_file
+
+ except Exception as e:
+ logger.error(f"创建临时文件失败: {e}")
+ return None
+
+def get_file_size(file_path, format='MB'):
+ """获取文件大小
+ Args:
+ file_path: 文件路径
+ format: 返回格式,支持 'B', 'KB', 'MB', 'GB'
+ Returns:
+ float: 文件大小
+ """
+ try:
+ size_bytes = os.path.getsize(file_path)
+
+ if format.upper() == 'B':
+ return size_bytes
+ elif format.upper() == 'KB':
+ return size_bytes / 1024
+ elif format.upper() == 'MB':
+ return size_bytes / (1024 * 1024)
+ elif format.upper() == 'GB':
+ return size_bytes / (1024 * 1024 * 1024)
+ else:
+ return size_bytes
+
+ except Exception as e:
+ logger.error(f"获取文件大小失败: {file_path}, 错误: {e}")
+ return 0
+
+def ensure_directory(directory):
+ """确保目录存在,如果不存在则创建
+ Args:
+ directory: 目录路径
+ Returns:
+ bool: 是否成功
+ """
+ try:
+ if not os.path.exists(directory):
+ os.makedirs(directory)
+ return True
+ except Exception as e:
+ logger.error(f"创建目录失败: {directory}, 错误: {e}")
+ return False
+
+def create_zip(files: list, zip_path: str, base_dir: str = None, folder_name: str = "demo") -> bool:
+ """
+ 创建zip文件
+ Args:
+ files: 要打包的文件列表
+ zip_path: zip文件保存路径
+ base_dir: 基础目录,用于保持目录结构
+ folder_name: zip解压后的文件夹名称,默认为frames
+ Returns:
+ bool: 是否成功
+ """
+ try:
+ import zipfile
+
+ # 确保目标目录存在
+ os.makedirs(os.path.dirname(zip_path), exist_ok=True)
+
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
+ for file in files:
+ if not os.path.exists(file):
+ logger.warning(f"文件不存在,跳过: {file}")
+ continue
+
+ # 计算文件在zip中的路径,添加folder_name作为前缀目录
+ if base_dir:
+ arcname = os.path.join(folder_name, os.path.relpath(file, base_dir))
+ else:
+ arcname = os.path.join(folder_name, os.path.basename(file))
+
+ try:
+ zipf.write(file, arcname)
+ except Exception as e:
+ logger.error(f"添加文件到zip失败: {file}, 错误: {e}")
+ continue
+
+ return True
+
+ except Exception as e:
+ logger.error(f"创建zip文件失败: {e}")
+ return False
\ No newline at end of file
diff --git a/webui/utils/performance.py b/webui/utils/performance.py
new file mode 100644
index 0000000..0eab5fa
--- /dev/null
+++ b/webui/utils/performance.py
@@ -0,0 +1,37 @@
+import psutil
+import os
+from loguru import logger
+import torch
+
+class PerformanceMonitor:
+ @staticmethod
+ def monitor_memory():
+ process = psutil.Process(os.getpid())
+ memory_info = process.memory_info()
+
+ logger.debug(f"Memory usage: {memory_info.rss / 1024 / 1024:.2f} MB")
+
+ if torch.cuda.is_available():
+ gpu_memory = torch.cuda.memory_allocated() / 1024 / 1024
+ logger.debug(f"GPU Memory usage: {gpu_memory:.2f} MB")
+
+ @staticmethod
+ def cleanup_resources():
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ import gc
+ gc.collect()
+
+ PerformanceMonitor.monitor_memory()
+
+def monitor_performance(func):
+ """性能监控装饰器"""
+ def wrapper(*args, **kwargs):
+ try:
+ PerformanceMonitor.monitor_memory()
+ result = func(*args, **kwargs)
+ return result
+ finally:
+ PerformanceMonitor.cleanup_resources()
+ return wrapper
\ No newline at end of file