fix: 优化短剧混剪字幕上传逻辑,与短剧解说保持一致

This commit is contained in:
linyq 2025-12-25 10:43:28 +08:00
parent 26f0dfeab5
commit 08f682bb50
5 changed files with 320 additions and 67 deletions

View File

@ -1,43 +1,125 @@
""" """
视频脚本生成pipeline串联各个处理步骤 视频脚本生成pipeline串联各个处理步骤
""" """
import os from typing import Any, Dict, Optional
from loguru import logger
from .utils.step1_subtitle_analyzer_openai import analyze_subtitle from .utils.step1_subtitle_analyzer_openai import analyze_subtitle
from .utils.step5_merge_script import merge_script from .utils.step5_merge_script import merge_script
from app.services.upload_validation import InputValidationError, resolve_subtitle_input
def generate_script(srt_path: str, api_key: str, model_name: str, output_path: str, base_url: str = None, custom_clips: int = 5, provider: str = None): def generate_script_result(
"""生成视频混剪脚本 api_key: str,
model_name: str,
output_path: str,
base_url: str = None,
custom_clips: int = 5,
provider: str = None,
*,
srt_path: Optional[str] = None,
subtitle_content: Optional[str] = None,
subtitle_file_path: Optional[str] = None,
) -> Dict[str, Any]:
"""生成视频混剪脚本(安全版本,返回结果字典)
Args: Args:
srt_path: 字幕文件路径
api_key: API密钥 api_key: API密钥
model_name: 模型名称 model_name: 模型名称
output_path: 输出文件路径可选 output_path: 输出文件路径
base_url: API基础URL base_url: API基础URL可选
custom_clips: 自定义片段数量 custom_clips: 自定义片段数量默认5
provider: LLM服务提供商 provider: LLM服务提供商可选
srt_path: 字幕文件路径向后兼容
subtitle_content: 字幕文本内容
subtitle_file_path: 字幕文件路径推荐
Returns:
Dict[str, Any]:
成功: {"status": "success", "script": [...]}
失败: {"status": "error", "message": "错误信息"}
"""
try:
# 解析字幕输入源(支持内容或文件路径)
resolved_content, resolved_path = resolve_subtitle_input(
subtitle_content=subtitle_content,
subtitle_file_path=subtitle_file_path,
srt_path=srt_path,
)
logger.info("开始分析字幕内容...")
openai_analysis = analyze_subtitle(
model_name=model_name,
api_key=api_key,
base_url=base_url,
custom_clips=custom_clips,
provider=provider,
srt_path=resolved_path,
subtitle_content=resolved_content,
)
adjusted_results = openai_analysis['plot_points']
final_script = merge_script(adjusted_results, output_path)
return {"status": "success", "script": final_script}
except InputValidationError as e:
logger.error(f"输入验证失败: {e}")
return {"status": "error", "message": str(e)}
except Exception as e:
logger.exception(f"SDP 脚本生成失败: {e}")
return {"status": "error", "message": f"生成脚本失败: {str(e)}"}
def generate_script(
srt_path: Optional[str] = None,
api_key: str = None,
model_name: str = None,
output_path: str = None,
base_url: str = None,
custom_clips: int = 5,
provider: str = None,
*,
subtitle_content: Optional[str] = None,
subtitle_file_path: Optional[str] = None,
):
"""生成视频混剪脚本(向后兼容版本)
Args:
srt_path: 字幕文件路径向后兼容参数可选
api_key: API密钥
model_name: 模型名称
output_path: 输出文件路径
base_url: API基础URL可选
custom_clips: 自定义片段数量默认5
provider: LLM服务提供商可选
subtitle_content: 字幕文本内容可选
subtitle_file_path: 字幕文件路径推荐使用可选
Returns: Returns:
str: 生成的脚本内容 str: 生成的脚本内容
"""
# 验证输入文件
if not os.path.exists(srt_path):
raise FileNotFoundError(f"字幕文件不存在: {srt_path}")
# 分析字幕 Raises:
print("开始分析...") FileNotFoundError: 字幕文件不存在向后兼容
openai_analysis = analyze_subtitle( ValueError: 输入验证失败或脚本生成失败
srt_path=srt_path, """
result = generate_script_result(
api_key=api_key, api_key=api_key,
model_name=model_name, model_name=model_name,
output_path=output_path,
base_url=base_url, base_url=base_url,
custom_clips=custom_clips, custom_clips=custom_clips,
provider=provider provider=provider,
srt_path=srt_path,
subtitle_content=subtitle_content,
subtitle_file_path=subtitle_file_path,
) )
# 合并生成最终脚本 if result.get("status") != "success":
adjusted_results = openai_analysis['plot_points'] error_message = result.get("message", "生成脚本失败")
final_script = merge_script(adjusted_results, output_path) # 保持向后兼容:如果是文件不存在错误,抛出 FileNotFoundError
if "不存在" in error_message and (srt_path or subtitle_file_path):
raise FileNotFoundError(error_message)
raise ValueError(error_message)
return final_script return result["script"]

View File

@ -3,10 +3,9 @@
""" """
import traceback import traceback
import json import json
import asyncio
from loguru import logger from loguru import logger
from .utils import load_srt from .utils import load_srt, load_srt_from_content
# 导入新的提示词管理系统 # 导入新的提示词管理系统
from app.services.prompts import PromptManager from app.services.prompts import PromptManager
# 导入统一LLM服务 # 导入统一LLM服务
@ -16,34 +15,43 @@ from app.services.llm.migration_adapter import _run_async_safely
def analyze_subtitle( def analyze_subtitle(
srt_path: str,
model_name: str, model_name: str,
api_key: str = None, api_key: str = None,
base_url: str = None, base_url: str = None,
custom_clips: int = 5, custom_clips: int = 5,
provider: str = None provider: str = None,
srt_path: str = None,
subtitle_content: str = None
) -> dict: ) -> dict:
"""分析字幕内容,返回完整的分析结果 """分析字幕内容,返回完整的分析结果
Args: Args:
srt_path (str): SRT字幕文件路径
model_name (str): 大模型名称 model_name (str): 大模型名称
api_key (str, optional): 大模型API密钥. Defaults to None. api_key (str, optional): 大模型API密钥. Defaults to None.
base_url (str, optional): 大模型API基础URL. Defaults to None. base_url (str, optional): 大模型API基础URL. Defaults to None.
custom_clips (int): 需要提取的片段数量. Defaults to 5. custom_clips (int): 需要提取的片段数量. Defaults to 5.
provider (str, optional): LLM服务提供商. Defaults to None. provider (str, optional): LLM服务提供商. Defaults to None.
srt_path (str, optional): SRT字幕文件路径与subtitle_content二选一
subtitle_content (str, optional): SRT字幕文本内容与srt_path二选一
Returns: Returns:
dict: 包含剧情梗概和结构化的时间段分析的字典 dict: 包含剧情梗概和结构化的时间段分析的字典
""" """
try: try:
# 加载字幕文件 # 加载字幕文件或内容
subtitles = load_srt(srt_path) if subtitle_content and subtitle_content.strip():
subtitles = load_srt_from_content(subtitle_content)
source_label = "字幕内容(直接传入)"
elif srt_path:
subtitles = load_srt(srt_path)
source_label = f"字幕文件: {srt_path}"
else:
raise ValueError("必须提供 srt_path 或 subtitle_content 参数")
# 检查字幕是否为空 # 检查字幕是否为空
if not subtitles: if not subtitles:
error_msg = ( error_msg = (
f"字幕文件 {srt_path} 解析后无有效内容。\n" f"字幕来源 [{source_label}] 解析后无有效内容。\n"
f"请检查:\n" f"请检查:\n"
f"1. 文件格式是否为标准 SRT\n" f"1. 文件格式是否为标准 SRT\n"
f"2. 文件编码是否为 UTF-8、GBK 或 GB2312\n" f"2. 文件编码是否为 UTF-8、GBK 或 GB2312\n"
@ -52,12 +60,9 @@ def analyze_subtitle(
logger.error(error_msg) logger.error(error_msg)
raise ValueError(error_msg) raise ValueError(error_msg)
logger.info(f"成功加载字幕文件 {srt_path},共 {len(subtitles)} 条有效字幕") logger.info(f"成功加载字幕来源 [{source_label}],共 {len(subtitles)} 条有效字幕")
subtitle_content = "\n".join([f"{sub['timestamp']}\n{sub['text']}" for sub in subtitles]) subtitle_content = "\n".join([f"{sub['timestamp']}\n{sub['text']}" for sub in subtitles])
# 初始化统一LLM服务
llm_service = UnifiedLLMService()
# 如果没有指定provider根据model_name推断 # 如果没有指定provider根据model_name推断
if not provider: if not provider:
if "deepseek" in model_name.lower(): if "deepseek" in model_name.lower():

View File

@ -78,3 +78,46 @@ def load_srt(file_path: str) -> List[Dict]:
logger.info(f"成功解析 {len(subtitles)} 条有效字幕") logger.info(f"成功解析 {len(subtitles)} 条有效字幕")
return subtitles return subtitles
def load_srt_from_content(srt_content: str) -> List[Dict]:
"""从字符串内容解析SRT用于直接传入字幕内容无需依赖文件路径
Args:
srt_content: SRT格式的字幕文本内容
Returns:
字幕内容列表格式同 load_srt 函数
Raises:
ValueError: 字幕内容为空或格式错误
"""
if srt_content is None or not str(srt_content).strip():
raise ValueError("字幕内容为空")
try:
subs = pysrt.from_string(str(srt_content))
except Exception as e:
logger.error(f"无法解析字幕内容: {e}")
raise ValueError("无法解析字幕内容,请确保为标准 SRT 格式") from e
if not subs:
logger.warning("字幕内容解析后无有效内容")
return []
subtitles = []
for sub in subs:
text = sub.text.replace('\n', ' ').strip()
if not text:
continue
subtitles.append({
'number': sub.index,
'timestamp': f"{sub.start} --> {sub.end}",
'text': text,
'start_time': str(sub.start),
'end_time': str(sub.end)
})
logger.info(f"成功从内容解析 {len(subtitles)} 条有效字幕")
return subtitles

View File

@ -0,0 +1,107 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project: NarratoAI
@File : upload_validation.py
@Author : AI Assistant
@Date : 2025/12/25
@Desc : 统一的文件上传验证工具用于短剧混剪和短剧解说功能
"""
import os
from typing import Optional, Tuple
class InputValidationError(ValueError):
"""当必需的用户输入(路径/内容)缺失或无效时抛出"""
pass
def ensure_existing_file(
file_path: str,
*,
label: str = "文件",
allowed_exts: Optional[Tuple[str, ...]] = None,
) -> str:
"""
验证文件路径是否存在且有效
Args:
file_path: 待验证的文件路径
label: 文件类型标签用于错误提示
allowed_exts: 允许的文件扩展名元组 ('.srt', '.txt')
Returns:
str: 规范化后的绝对路径
Raises:
InputValidationError: 文件路径无效文件不存在或格式不支持
"""
if not file_path or not str(file_path).strip():
raise InputValidationError(f"{label}不能为空,请先上传{label}")
normalized = os.path.abspath(str(file_path))
if not os.path.exists(normalized):
raise InputValidationError(f"{label}文件不存在: {normalized}")
if not os.path.isfile(normalized):
raise InputValidationError(f"{label}不是有效文件: {normalized}")
if allowed_exts:
ext = os.path.splitext(normalized)[1].lower()
allowed = tuple(e.lower() for e in allowed_exts)
if ext not in allowed:
raise InputValidationError(
f"{label}格式不支持: {ext},仅支持: {', '.join(allowed_exts)}"
)
return normalized
def resolve_subtitle_input(
*,
subtitle_content: Optional[str] = None,
subtitle_file_path: Optional[str] = None,
srt_path: Optional[str] = None,
) -> Tuple[Optional[str], Optional[str]]:
"""
解析字幕输入源确保只有一个有效来源
Args:
subtitle_content: 字幕文本内容
subtitle_file_path: 字幕文件路径推荐
srt_path: 字幕文件路径向后兼容SDP旧参数
Returns:
Tuple[Optional[str], Optional[str]]: (字幕内容, 字幕文件路径)
- 返回 (content, None) 表示使用内容输入
- 返回 (None, file_path) 表示使用文件路径输入
Raises:
InputValidationError: 未提供输入或同时提供多个输入
"""
file_path = subtitle_file_path or srt_path
has_content = subtitle_content is not None and bool(str(subtitle_content).strip())
has_file = file_path is not None and bool(str(file_path).strip())
if has_content and has_file:
raise InputValidationError("只能提供字幕内容或字幕文件路径之一")
if not has_content and not has_file:
raise InputValidationError("必须提供字幕内容或字幕文件路径")
if has_content:
content = str(subtitle_content)
if not content.strip():
raise InputValidationError("字幕内容为空")
return content, None
resolved_path = ensure_existing_file(
str(file_path),
label="字幕",
allowed_exts=(".srt",),
)
return None, resolved_path

View File

@ -1,13 +1,11 @@
import os
import json import json
import time import time
import asyncio
import traceback import traceback
import requests
import streamlit as st import streamlit as st
from loguru import logger from loguru import logger
from app.config import config from app.config import config
from app.services.upload_validation import ensure_existing_file, InputValidationError
def generate_script_short(tr, params, custom_clips=5): def generate_script_short(tr, params, custom_clips=5):
@ -31,12 +29,47 @@ def generate_script_short(tr, params, custom_clips=5):
try: try:
with st.spinner("正在生成脚本..."): with st.spinner("正在生成脚本..."):
# ========== 严格验证:必须上传视频和字幕(与短剧解说保持一致)==========
# 1. 验证视频文件
video_path = getattr(params, "video_origin_path", None)
if not video_path or not str(video_path).strip():
st.error("请先选择视频文件")
st.stop()
try:
ensure_existing_file(
str(video_path),
label="视频",
allowed_exts=(".mp4", ".mov", ".avi", ".flv", ".mkv"),
)
except InputValidationError as e:
st.error(str(e))
st.stop()
# 2. 验证字幕文件(移除推断逻辑,必须上传)
subtitle_path = st.session_state.get("subtitle_path")
if not subtitle_path or not str(subtitle_path).strip():
st.error("请先上传字幕文件")
st.stop()
try:
subtitle_path = ensure_existing_file(
str(subtitle_path),
label="字幕",
allowed_exts=(".srt",),
)
except InputValidationError as e:
st.error(str(e))
st.stop()
logger.info(f"使用用户上传的字幕文件: {subtitle_path}")
# ========== 获取 LLM 配置 ==========
text_provider = config.app.get('text_llm_provider', 'gemini').lower() text_provider = config.app.get('text_llm_provider', 'gemini').lower()
text_api_key = config.app.get(f'text_{text_provider}_api_key') text_api_key = config.app.get(f'text_{text_provider}_api_key')
text_model = config.app.get(f'text_{text_provider}_model_name') text_model = config.app.get(f'text_{text_provider}_model_name')
text_base_url = config.app.get(f'text_{text_provider}_base_url') text_base_url = config.app.get(f'text_{text_provider}_base_url')
# 优先从 session_state 获取,若未设置则回退到 config 配置
vision_llm_provider = st.session_state.get('vision_llm_providers') or config.app.get('vision_llm_provider', 'gemini') vision_llm_provider = st.session_state.get('vision_llm_providers') or config.app.get('vision_llm_provider', 'gemini')
vision_llm_provider = vision_llm_provider.lower() vision_llm_provider = vision_llm_provider.lower()
vision_api_key = st.session_state.get(f'vision_{vision_llm_provider}_api_key') or config.app.get(f'vision_{vision_llm_provider}_api_key', "") vision_api_key = st.session_state.get(f'vision_{vision_llm_provider}_api_key') or config.app.get(f'vision_{vision_llm_provider}_api_key', "")
@ -45,48 +78,31 @@ def generate_script_short(tr, params, custom_clips=5):
update_progress(20, "开始准备生成脚本") update_progress(20, "开始准备生成脚本")
# 优先使用用户上传的字幕文件 # ========== 调用后端生成脚本 ==========
uploaded_subtitle = st.session_state.get('subtitle_path') from app.services.SDP.generate_script_short import generate_script_result
if uploaded_subtitle and os.path.exists(uploaded_subtitle):
srt_path = uploaded_subtitle
logger.info(f"使用用户上传的字幕文件: {srt_path}")
else:
# 回退到根据视频路径自动推断
srt_path = params.video_origin_path.replace(".mp4", ".srt").replace("videos", "srt").replace("video", "subtitle")
if not os.path.exists(srt_path):
logger.error(f"{srt_path} 文件不存在请检查或重新转录")
st.error(f"{srt_path} 文件不存在,请上传字幕文件或重新转录")
st.stop()
api_params = { result = generate_script_result(
"vision_provider": vision_llm_provider,
"vision_api_key": vision_api_key,
"vision_model_name": vision_model,
"vision_base_url": vision_base_url or "",
"text_provider": text_provider,
"text_api_key": text_api_key,
"text_model_name": text_model,
"text_base_url": text_base_url or ""
}
from app.services.SDP.generate_script_short import generate_script
script = generate_script(
srt_path=srt_path,
output_path="resource/scripts/merged_subtitle.json",
api_key=text_api_key, api_key=text_api_key,
model_name=text_model, model_name=text_model,
output_path="resource/scripts/merged_subtitle.json",
base_url=text_base_url, base_url=text_base_url,
custom_clips=custom_clips, custom_clips=custom_clips,
provider=text_provider provider=text_provider,
subtitle_file_path=subtitle_path,
) )
if script is None: if result.get("status") != "success":
st.error("生成脚本失败,请检查日志") st.error(result.get("message", "生成脚本失败,请检查日志"))
st.stop() st.stop()
script = result.get("script")
logger.info(f"脚本生成完成 {json.dumps(script, ensure_ascii=False, indent=4)}") logger.info(f"脚本生成完成 {json.dumps(script, ensure_ascii=False, indent=4)}")
if isinstance(script, list): if isinstance(script, list):
st.session_state['video_clip_json'] = script st.session_state['video_clip_json'] = script
elif isinstance(script, str): elif isinstance(script, str):
st.session_state['video_clip_json'] = json.loads(script) st.session_state['video_clip_json'] = json.loads(script)
update_progress(80, "脚本生成完成") update_progress(80, "脚本生成完成")
time.sleep(0.1) time.sleep(0.1)