fix: 优化短剧混剪字幕上传逻辑,与短剧解说保持一致

This commit is contained in:
linyq 2025-12-25 10:43:28 +08:00
parent 26f0dfeab5
commit 08f682bb50
5 changed files with 320 additions and 67 deletions

View File

@ -1,43 +1,125 @@
"""
视频脚本生成pipeline串联各个处理步骤
"""
import os
from typing import Any, Dict, Optional
from loguru import logger
from .utils.step1_subtitle_analyzer_openai import analyze_subtitle
from .utils.step5_merge_script import merge_script
from app.services.upload_validation import InputValidationError, resolve_subtitle_input
def generate_script(srt_path: str, api_key: str, model_name: str, output_path: str, base_url: str = None, custom_clips: int = 5, provider: str = None):
"""生成视频混剪脚本
def generate_script_result(
api_key: str,
model_name: str,
output_path: str,
base_url: str = None,
custom_clips: int = 5,
provider: str = None,
*,
srt_path: Optional[str] = None,
subtitle_content: Optional[str] = None,
subtitle_file_path: Optional[str] = None,
) -> Dict[str, Any]:
"""生成视频混剪脚本(安全版本,返回结果字典)
Args:
srt_path: 字幕文件路径
api_key: API密钥
model_name: 模型名称
output_path: 输出文件路径可选
base_url: API基础URL
custom_clips: 自定义片段数量
provider: LLM服务提供商
output_path: 输出文件路径
base_url: API基础URL可选
custom_clips: 自定义片段数量默认5
provider: LLM服务提供商可选
srt_path: 字幕文件路径向后兼容
subtitle_content: 字幕文本内容
subtitle_file_path: 字幕文件路径推荐
Returns:
Dict[str, Any]:
成功: {"status": "success", "script": [...]}
失败: {"status": "error", "message": "错误信息"}
"""
try:
# 解析字幕输入源(支持内容或文件路径)
resolved_content, resolved_path = resolve_subtitle_input(
subtitle_content=subtitle_content,
subtitle_file_path=subtitle_file_path,
srt_path=srt_path,
)
logger.info("开始分析字幕内容...")
openai_analysis = analyze_subtitle(
model_name=model_name,
api_key=api_key,
base_url=base_url,
custom_clips=custom_clips,
provider=provider,
srt_path=resolved_path,
subtitle_content=resolved_content,
)
adjusted_results = openai_analysis['plot_points']
final_script = merge_script(adjusted_results, output_path)
return {"status": "success", "script": final_script}
except InputValidationError as e:
logger.error(f"输入验证失败: {e}")
return {"status": "error", "message": str(e)}
except Exception as e:
logger.exception(f"SDP 脚本生成失败: {e}")
return {"status": "error", "message": f"生成脚本失败: {str(e)}"}
def generate_script(
srt_path: Optional[str] = None,
api_key: str = None,
model_name: str = None,
output_path: str = None,
base_url: str = None,
custom_clips: int = 5,
provider: str = None,
*,
subtitle_content: Optional[str] = None,
subtitle_file_path: Optional[str] = None,
):
"""生成视频混剪脚本(向后兼容版本)
Args:
srt_path: 字幕文件路径向后兼容参数可选
api_key: API密钥
model_name: 模型名称
output_path: 输出文件路径
base_url: API基础URL可选
custom_clips: 自定义片段数量默认5
provider: LLM服务提供商可选
subtitle_content: 字幕文本内容可选
subtitle_file_path: 字幕文件路径推荐使用可选
Returns:
str: 生成的脚本内容
"""
# 验证输入文件
if not os.path.exists(srt_path):
raise FileNotFoundError(f"字幕文件不存在: {srt_path}")
# 分析字幕
print("开始分析...")
openai_analysis = analyze_subtitle(
srt_path=srt_path,
Raises:
FileNotFoundError: 字幕文件不存在向后兼容
ValueError: 输入验证失败或脚本生成失败
"""
result = generate_script_result(
api_key=api_key,
model_name=model_name,
output_path=output_path,
base_url=base_url,
custom_clips=custom_clips,
provider=provider
provider=provider,
srt_path=srt_path,
subtitle_content=subtitle_content,
subtitle_file_path=subtitle_file_path,
)
# 合并生成最终脚本
adjusted_results = openai_analysis['plot_points']
final_script = merge_script(adjusted_results, output_path)
if result.get("status") != "success":
error_message = result.get("message", "生成脚本失败")
# 保持向后兼容:如果是文件不存在错误,抛出 FileNotFoundError
if "不存在" in error_message and (srt_path or subtitle_file_path):
raise FileNotFoundError(error_message)
raise ValueError(error_message)
return final_script
return result["script"]

View File

@ -3,10 +3,9 @@
"""
import traceback
import json
import asyncio
from loguru import logger
from .utils import load_srt
from .utils import load_srt, load_srt_from_content
# 导入新的提示词管理系统
from app.services.prompts import PromptManager
# 导入统一LLM服务
@ -16,34 +15,43 @@ from app.services.llm.migration_adapter import _run_async_safely
def analyze_subtitle(
srt_path: str,
model_name: str,
api_key: str = None,
base_url: str = None,
custom_clips: int = 5,
provider: str = None
provider: str = None,
srt_path: str = None,
subtitle_content: str = None
) -> dict:
"""分析字幕内容,返回完整的分析结果
Args:
srt_path (str): SRT字幕文件路径
model_name (str): 大模型名称
api_key (str, optional): 大模型API密钥. Defaults to None.
base_url (str, optional): 大模型API基础URL. Defaults to None.
custom_clips (int): 需要提取的片段数量. Defaults to 5.
provider (str, optional): LLM服务提供商. Defaults to None.
srt_path (str, optional): SRT字幕文件路径与subtitle_content二选一
subtitle_content (str, optional): SRT字幕文本内容与srt_path二选一
Returns:
dict: 包含剧情梗概和结构化的时间段分析的字典
"""
try:
# 加载字幕文件
subtitles = load_srt(srt_path)
# 加载字幕文件或内容
if subtitle_content and subtitle_content.strip():
subtitles = load_srt_from_content(subtitle_content)
source_label = "字幕内容(直接传入)"
elif srt_path:
subtitles = load_srt(srt_path)
source_label = f"字幕文件: {srt_path}"
else:
raise ValueError("必须提供 srt_path 或 subtitle_content 参数")
# 检查字幕是否为空
if not subtitles:
error_msg = (
f"字幕文件 {srt_path} 解析后无有效内容。\n"
f"字幕来源 [{source_label}] 解析后无有效内容。\n"
f"请检查:\n"
f"1. 文件格式是否为标准 SRT\n"
f"2. 文件编码是否为 UTF-8、GBK 或 GB2312\n"
@ -52,12 +60,9 @@ def analyze_subtitle(
logger.error(error_msg)
raise ValueError(error_msg)
logger.info(f"成功加载字幕文件 {srt_path},共 {len(subtitles)} 条有效字幕")
logger.info(f"成功加载字幕来源 [{source_label}],共 {len(subtitles)} 条有效字幕")
subtitle_content = "\n".join([f"{sub['timestamp']}\n{sub['text']}" for sub in subtitles])
# 初始化统一LLM服务
llm_service = UnifiedLLMService()
# 如果没有指定provider根据model_name推断
if not provider:
if "deepseek" in model_name.lower():

View File

@ -78,3 +78,46 @@ def load_srt(file_path: str) -> List[Dict]:
logger.info(f"成功解析 {len(subtitles)} 条有效字幕")
return subtitles
def load_srt_from_content(srt_content: str) -> List[Dict]:
"""从字符串内容解析SRT用于直接传入字幕内容无需依赖文件路径
Args:
srt_content: SRT格式的字幕文本内容
Returns:
字幕内容列表格式同 load_srt 函数
Raises:
ValueError: 字幕内容为空或格式错误
"""
if srt_content is None or not str(srt_content).strip():
raise ValueError("字幕内容为空")
try:
subs = pysrt.from_string(str(srt_content))
except Exception as e:
logger.error(f"无法解析字幕内容: {e}")
raise ValueError("无法解析字幕内容,请确保为标准 SRT 格式") from e
if not subs:
logger.warning("字幕内容解析后无有效内容")
return []
subtitles = []
for sub in subs:
text = sub.text.replace('\n', ' ').strip()
if not text:
continue
subtitles.append({
'number': sub.index,
'timestamp': f"{sub.start} --> {sub.end}",
'text': text,
'start_time': str(sub.start),
'end_time': str(sub.end)
})
logger.info(f"成功从内容解析 {len(subtitles)} 条有效字幕")
return subtitles

View File

@ -0,0 +1,107 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project: NarratoAI
@File : upload_validation.py
@Author : AI Assistant
@Date : 2025/12/25
@Desc : 统一的文件上传验证工具用于短剧混剪和短剧解说功能
"""
import os
from typing import Optional, Tuple
class InputValidationError(ValueError):
"""当必需的用户输入(路径/内容)缺失或无效时抛出"""
pass
def ensure_existing_file(
file_path: str,
*,
label: str = "文件",
allowed_exts: Optional[Tuple[str, ...]] = None,
) -> str:
"""
验证文件路径是否存在且有效
Args:
file_path: 待验证的文件路径
label: 文件类型标签用于错误提示
allowed_exts: 允许的文件扩展名元组 ('.srt', '.txt')
Returns:
str: 规范化后的绝对路径
Raises:
InputValidationError: 文件路径无效文件不存在或格式不支持
"""
if not file_path or not str(file_path).strip():
raise InputValidationError(f"{label}不能为空,请先上传{label}")
normalized = os.path.abspath(str(file_path))
if not os.path.exists(normalized):
raise InputValidationError(f"{label}文件不存在: {normalized}")
if not os.path.isfile(normalized):
raise InputValidationError(f"{label}不是有效文件: {normalized}")
if allowed_exts:
ext = os.path.splitext(normalized)[1].lower()
allowed = tuple(e.lower() for e in allowed_exts)
if ext not in allowed:
raise InputValidationError(
f"{label}格式不支持: {ext},仅支持: {', '.join(allowed_exts)}"
)
return normalized
def resolve_subtitle_input(
*,
subtitle_content: Optional[str] = None,
subtitle_file_path: Optional[str] = None,
srt_path: Optional[str] = None,
) -> Tuple[Optional[str], Optional[str]]:
"""
解析字幕输入源确保只有一个有效来源
Args:
subtitle_content: 字幕文本内容
subtitle_file_path: 字幕文件路径推荐
srt_path: 字幕文件路径向后兼容SDP旧参数
Returns:
Tuple[Optional[str], Optional[str]]: (字幕内容, 字幕文件路径)
- 返回 (content, None) 表示使用内容输入
- 返回 (None, file_path) 表示使用文件路径输入
Raises:
InputValidationError: 未提供输入或同时提供多个输入
"""
file_path = subtitle_file_path or srt_path
has_content = subtitle_content is not None and bool(str(subtitle_content).strip())
has_file = file_path is not None and bool(str(file_path).strip())
if has_content and has_file:
raise InputValidationError("只能提供字幕内容或字幕文件路径之一")
if not has_content and not has_file:
raise InputValidationError("必须提供字幕内容或字幕文件路径")
if has_content:
content = str(subtitle_content)
if not content.strip():
raise InputValidationError("字幕内容为空")
return content, None
resolved_path = ensure_existing_file(
str(file_path),
label="字幕",
allowed_exts=(".srt",),
)
return None, resolved_path

View File

@ -1,13 +1,11 @@
import os
import json
import time
import asyncio
import traceback
import requests
import streamlit as st
from loguru import logger
from app.config import config
from app.services.upload_validation import ensure_existing_file, InputValidationError
def generate_script_short(tr, params, custom_clips=5):
@ -31,12 +29,47 @@ def generate_script_short(tr, params, custom_clips=5):
try:
with st.spinner("正在生成脚本..."):
# ========== 严格验证:必须上传视频和字幕(与短剧解说保持一致)==========
# 1. 验证视频文件
video_path = getattr(params, "video_origin_path", None)
if not video_path or not str(video_path).strip():
st.error("请先选择视频文件")
st.stop()
try:
ensure_existing_file(
str(video_path),
label="视频",
allowed_exts=(".mp4", ".mov", ".avi", ".flv", ".mkv"),
)
except InputValidationError as e:
st.error(str(e))
st.stop()
# 2. 验证字幕文件(移除推断逻辑,必须上传)
subtitle_path = st.session_state.get("subtitle_path")
if not subtitle_path or not str(subtitle_path).strip():
st.error("请先上传字幕文件")
st.stop()
try:
subtitle_path = ensure_existing_file(
str(subtitle_path),
label="字幕",
allowed_exts=(".srt",),
)
except InputValidationError as e:
st.error(str(e))
st.stop()
logger.info(f"使用用户上传的字幕文件: {subtitle_path}")
# ========== 获取 LLM 配置 ==========
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
text_api_key = config.app.get(f'text_{text_provider}_api_key')
text_model = config.app.get(f'text_{text_provider}_model_name')
text_base_url = config.app.get(f'text_{text_provider}_base_url')
# 优先从 session_state 获取,若未设置则回退到 config 配置
vision_llm_provider = st.session_state.get('vision_llm_providers') or config.app.get('vision_llm_provider', 'gemini')
vision_llm_provider = vision_llm_provider.lower()
vision_api_key = st.session_state.get(f'vision_{vision_llm_provider}_api_key') or config.app.get(f'vision_{vision_llm_provider}_api_key', "")
@ -45,48 +78,31 @@ def generate_script_short(tr, params, custom_clips=5):
update_progress(20, "开始准备生成脚本")
# 优先使用用户上传的字幕文件
uploaded_subtitle = st.session_state.get('subtitle_path')
if uploaded_subtitle and os.path.exists(uploaded_subtitle):
srt_path = uploaded_subtitle
logger.info(f"使用用户上传的字幕文件: {srt_path}")
else:
# 回退到根据视频路径自动推断
srt_path = params.video_origin_path.replace(".mp4", ".srt").replace("videos", "srt").replace("video", "subtitle")
if not os.path.exists(srt_path):
logger.error(f"{srt_path} 文件不存在请检查或重新转录")
st.error(f"{srt_path} 文件不存在,请上传字幕文件或重新转录")
st.stop()
# ========== 调用后端生成脚本 ==========
from app.services.SDP.generate_script_short import generate_script_result
api_params = {
"vision_provider": vision_llm_provider,
"vision_api_key": vision_api_key,
"vision_model_name": vision_model,
"vision_base_url": vision_base_url or "",
"text_provider": text_provider,
"text_api_key": text_api_key,
"text_model_name": text_model,
"text_base_url": text_base_url or ""
}
from app.services.SDP.generate_script_short import generate_script
script = generate_script(
srt_path=srt_path,
output_path="resource/scripts/merged_subtitle.json",
result = generate_script_result(
api_key=text_api_key,
model_name=text_model,
output_path="resource/scripts/merged_subtitle.json",
base_url=text_base_url,
custom_clips=custom_clips,
provider=text_provider
provider=text_provider,
subtitle_file_path=subtitle_path,
)
if script is None:
st.error("生成脚本失败,请检查日志")
if result.get("status") != "success":
st.error(result.get("message", "生成脚本失败,请检查日志"))
st.stop()
script = result.get("script")
logger.info(f"脚本生成完成 {json.dumps(script, ensure_ascii=False, indent=4)}")
if isinstance(script, list):
st.session_state['video_clip_json'] = script
elif isinstance(script, str):
st.session_state['video_clip_json'] = json.loads(script)
update_progress(80, "脚本生成完成")
time.sleep(0.1)