From 08f682bb503268cd6bceb1df72662672cb8ecf61 Mon Sep 17 00:00:00 2001 From: linyq Date: Thu, 25 Dec 2025 10:43:28 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BC=98=E5=8C=96=E7=9F=AD=E5=89=A7?= =?UTF-8?q?=E6=B7=B7=E5=89=AA=E5=AD=97=E5=B9=95=E4=B8=8A=E4=BC=A0=E9=80=BB?= =?UTF-8?q?=E8=BE=91=EF=BC=8C=E4=B8=8E=E7=9F=AD=E5=89=A7=E8=A7=A3=E8=AF=B4?= =?UTF-8?q?=E4=BF=9D=E6=8C=81=E4=B8=80=E8=87=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/services/SDP/generate_script_short.py | 124 +++++++++++++++--- .../utils/step1_subtitle_analyzer_openai.py | 29 ++-- app/services/SDP/utils/utils.py | 43 ++++++ app/services/upload_validation.py | 107 +++++++++++++++ webui/tools/generate_script_short.py | 84 +++++++----- 5 files changed, 320 insertions(+), 67 deletions(-) create mode 100644 app/services/upload_validation.py diff --git a/app/services/SDP/generate_script_short.py b/app/services/SDP/generate_script_short.py index 713d26c..c8c062a 100644 --- a/app/services/SDP/generate_script_short.py +++ b/app/services/SDP/generate_script_short.py @@ -1,43 +1,125 @@ """ 视频脚本生成pipeline,串联各个处理步骤 """ -import os +from typing import Any, Dict, Optional +from loguru import logger + from .utils.step1_subtitle_analyzer_openai import analyze_subtitle from .utils.step5_merge_script import merge_script +from app.services.upload_validation import InputValidationError, resolve_subtitle_input -def generate_script(srt_path: str, api_key: str, model_name: str, output_path: str, base_url: str = None, custom_clips: int = 5, provider: str = None): - """生成视频混剪脚本 +def generate_script_result( + api_key: str, + model_name: str, + output_path: str, + base_url: str = None, + custom_clips: int = 5, + provider: str = None, + *, + srt_path: Optional[str] = None, + subtitle_content: Optional[str] = None, + subtitle_file_path: Optional[str] = None, +) -> Dict[str, Any]: + """生成视频混剪脚本(安全版本,返回结果字典) Args: - srt_path: 字幕文件路径 api_key: API密钥 model_name: 模型名称 - output_path: 输出文件路径,可选 - base_url: API基础URL - custom_clips: 自定义片段数量 - provider: LLM服务提供商 + output_path: 输出文件路径 + base_url: API基础URL,可选 + custom_clips: 自定义片段数量,默认5 + provider: LLM服务提供商,可选 + srt_path: 字幕文件路径(向后兼容) + subtitle_content: 字幕文本内容 + subtitle_file_path: 字幕文件路径(推荐) + + Returns: + Dict[str, Any]: + 成功: {"status": "success", "script": [...]} + 失败: {"status": "error", "message": "错误信息"} + """ + try: + # 解析字幕输入源(支持内容或文件路径) + resolved_content, resolved_path = resolve_subtitle_input( + subtitle_content=subtitle_content, + subtitle_file_path=subtitle_file_path, + srt_path=srt_path, + ) + + logger.info("开始分析字幕内容...") + openai_analysis = analyze_subtitle( + model_name=model_name, + api_key=api_key, + base_url=base_url, + custom_clips=custom_clips, + provider=provider, + srt_path=resolved_path, + subtitle_content=resolved_content, + ) + + adjusted_results = openai_analysis['plot_points'] + final_script = merge_script(adjusted_results, output_path) + + return {"status": "success", "script": final_script} + + except InputValidationError as e: + logger.error(f"输入验证失败: {e}") + return {"status": "error", "message": str(e)} + except Exception as e: + logger.exception(f"SDP 脚本生成失败: {e}") + return {"status": "error", "message": f"生成脚本失败: {str(e)}"} + + +def generate_script( + srt_path: Optional[str] = None, + api_key: str = None, + model_name: str = None, + output_path: str = None, + base_url: str = None, + custom_clips: int = 5, + provider: str = None, + *, + subtitle_content: Optional[str] = None, + subtitle_file_path: Optional[str] = None, +): + """生成视频混剪脚本(向后兼容版本) + + Args: + srt_path: 字幕文件路径(向后兼容参数,可选) + api_key: API密钥 + model_name: 模型名称 + output_path: 输出文件路径 + base_url: API基础URL,可选 + custom_clips: 自定义片段数量,默认5 + provider: LLM服务提供商,可选 + subtitle_content: 字幕文本内容(可选) + subtitle_file_path: 字幕文件路径(推荐使用,可选) Returns: str: 生成的脚本内容 - """ - # 验证输入文件 - if not os.path.exists(srt_path): - raise FileNotFoundError(f"字幕文件不存在: {srt_path}") - # 分析字幕 - print("开始分析...") - openai_analysis = analyze_subtitle( - srt_path=srt_path, + Raises: + FileNotFoundError: 字幕文件不存在(向后兼容) + ValueError: 输入验证失败或脚本生成失败 + """ + result = generate_script_result( api_key=api_key, model_name=model_name, + output_path=output_path, base_url=base_url, custom_clips=custom_clips, - provider=provider + provider=provider, + srt_path=srt_path, + subtitle_content=subtitle_content, + subtitle_file_path=subtitle_file_path, ) - # 合并生成最终脚本 - adjusted_results = openai_analysis['plot_points'] - final_script = merge_script(adjusted_results, output_path) + if result.get("status") != "success": + error_message = result.get("message", "生成脚本失败") + # 保持向后兼容:如果是文件不存在错误,抛出 FileNotFoundError + if "不存在" in error_message and (srt_path or subtitle_file_path): + raise FileNotFoundError(error_message) + raise ValueError(error_message) - return final_script + return result["script"] diff --git a/app/services/SDP/utils/step1_subtitle_analyzer_openai.py b/app/services/SDP/utils/step1_subtitle_analyzer_openai.py index f55cb56..2ca5243 100644 --- a/app/services/SDP/utils/step1_subtitle_analyzer_openai.py +++ b/app/services/SDP/utils/step1_subtitle_analyzer_openai.py @@ -3,10 +3,9 @@ """ import traceback import json -import asyncio from loguru import logger -from .utils import load_srt +from .utils import load_srt, load_srt_from_content # 导入新的提示词管理系统 from app.services.prompts import PromptManager # 导入统一LLM服务 @@ -16,34 +15,43 @@ from app.services.llm.migration_adapter import _run_async_safely def analyze_subtitle( - srt_path: str, model_name: str, api_key: str = None, base_url: str = None, custom_clips: int = 5, - provider: str = None + provider: str = None, + srt_path: str = None, + subtitle_content: str = None ) -> dict: """分析字幕内容,返回完整的分析结果 Args: - srt_path (str): SRT字幕文件路径 model_name (str): 大模型名称 api_key (str, optional): 大模型API密钥. Defaults to None. base_url (str, optional): 大模型API基础URL. Defaults to None. custom_clips (int): 需要提取的片段数量. Defaults to 5. provider (str, optional): LLM服务提供商. Defaults to None. + srt_path (str, optional): SRT字幕文件路径(与subtitle_content二选一) + subtitle_content (str, optional): SRT字幕文本内容(与srt_path二选一) Returns: dict: 包含剧情梗概和结构化的时间段分析的字典 """ try: - # 加载字幕文件 - subtitles = load_srt(srt_path) + # 加载字幕文件或内容 + if subtitle_content and subtitle_content.strip(): + subtitles = load_srt_from_content(subtitle_content) + source_label = "字幕内容(直接传入)" + elif srt_path: + subtitles = load_srt(srt_path) + source_label = f"字幕文件: {srt_path}" + else: + raise ValueError("必须提供 srt_path 或 subtitle_content 参数") # 检查字幕是否为空 if not subtitles: error_msg = ( - f"字幕文件 {srt_path} 解析后无有效内容。\n" + f"字幕来源 [{source_label}] 解析后无有效内容。\n" f"请检查:\n" f"1. 文件格式是否为标准 SRT\n" f"2. 文件编码是否为 UTF-8、GBK 或 GB2312\n" @@ -52,12 +60,9 @@ def analyze_subtitle( logger.error(error_msg) raise ValueError(error_msg) - logger.info(f"成功加载字幕文件 {srt_path},共 {len(subtitles)} 条有效字幕") + logger.info(f"成功加载字幕来源 [{source_label}],共 {len(subtitles)} 条有效字幕") subtitle_content = "\n".join([f"{sub['timestamp']}\n{sub['text']}" for sub in subtitles]) - # 初始化统一LLM服务 - llm_service = UnifiedLLMService() - # 如果没有指定provider,根据model_name推断 if not provider: if "deepseek" in model_name.lower(): diff --git a/app/services/SDP/utils/utils.py b/app/services/SDP/utils/utils.py index d6e5e38..d2bb54a 100644 --- a/app/services/SDP/utils/utils.py +++ b/app/services/SDP/utils/utils.py @@ -78,3 +78,46 @@ def load_srt(file_path: str) -> List[Dict]: logger.info(f"成功解析 {len(subtitles)} 条有效字幕") return subtitles + + +def load_srt_from_content(srt_content: str) -> List[Dict]: + """从字符串内容解析SRT(用于直接传入字幕内容,无需依赖文件路径) + + Args: + srt_content: SRT格式的字幕文本内容 + + Returns: + 字幕内容列表,格式同 load_srt 函数 + + Raises: + ValueError: 字幕内容为空或格式错误 + """ + if srt_content is None or not str(srt_content).strip(): + raise ValueError("字幕内容为空") + + try: + subs = pysrt.from_string(str(srt_content)) + except Exception as e: + logger.error(f"无法解析字幕内容: {e}") + raise ValueError("无法解析字幕内容,请确保为标准 SRT 格式") from e + + if not subs: + logger.warning("字幕内容解析后无有效内容") + return [] + + subtitles = [] + for sub in subs: + text = sub.text.replace('\n', ' ').strip() + if not text: + continue + + subtitles.append({ + 'number': sub.index, + 'timestamp': f"{sub.start} --> {sub.end}", + 'text': text, + 'start_time': str(sub.start), + 'end_time': str(sub.end) + }) + + logger.info(f"成功从内容解析 {len(subtitles)} 条有效字幕") + return subtitles diff --git a/app/services/upload_validation.py b/app/services/upload_validation.py new file mode 100644 index 0000000..b1d293e --- /dev/null +++ b/app/services/upload_validation.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +""" +@Project: NarratoAI +@File : upload_validation.py +@Author : AI Assistant +@Date : 2025/12/25 +@Desc : 统一的文件上传验证工具,用于短剧混剪和短剧解说功能 +""" + +import os +from typing import Optional, Tuple + + +class InputValidationError(ValueError): + """当必需的用户输入(路径/内容)缺失或无效时抛出""" + pass + + +def ensure_existing_file( + file_path: str, + *, + label: str = "文件", + allowed_exts: Optional[Tuple[str, ...]] = None, +) -> str: + """ + 验证文件路径是否存在且有效 + + Args: + file_path: 待验证的文件路径 + label: 文件类型标签(用于错误提示) + allowed_exts: 允许的文件扩展名元组(如 ('.srt', '.txt')) + + Returns: + str: 规范化后的绝对路径 + + Raises: + InputValidationError: 文件路径无效、文件不存在或格式不支持 + """ + if not file_path or not str(file_path).strip(): + raise InputValidationError(f"{label}不能为空,请先上传{label}") + + normalized = os.path.abspath(str(file_path)) + + if not os.path.exists(normalized): + raise InputValidationError(f"{label}文件不存在: {normalized}") + + if not os.path.isfile(normalized): + raise InputValidationError(f"{label}不是有效文件: {normalized}") + + if allowed_exts: + ext = os.path.splitext(normalized)[1].lower() + allowed = tuple(e.lower() for e in allowed_exts) + if ext not in allowed: + raise InputValidationError( + f"{label}格式不支持: {ext},仅支持: {', '.join(allowed_exts)}" + ) + + return normalized + + +def resolve_subtitle_input( + *, + subtitle_content: Optional[str] = None, + subtitle_file_path: Optional[str] = None, + srt_path: Optional[str] = None, +) -> Tuple[Optional[str], Optional[str]]: + """ + 解析字幕输入源,确保只有一个有效来源 + + Args: + subtitle_content: 字幕文本内容 + subtitle_file_path: 字幕文件路径(推荐) + srt_path: 字幕文件路径(向后兼容SDP旧参数) + + Returns: + Tuple[Optional[str], Optional[str]]: (字幕内容, 字幕文件路径) + - 返回 (content, None) 表示使用内容输入 + - 返回 (None, file_path) 表示使用文件路径输入 + + Raises: + InputValidationError: 未提供输入或同时提供多个输入 + """ + file_path = subtitle_file_path or srt_path + + has_content = subtitle_content is not None and bool(str(subtitle_content).strip()) + has_file = file_path is not None and bool(str(file_path).strip()) + + if has_content and has_file: + raise InputValidationError("只能提供字幕内容或字幕文件路径之一") + + if not has_content and not has_file: + raise InputValidationError("必须提供字幕内容或字幕文件路径") + + if has_content: + content = str(subtitle_content) + if not content.strip(): + raise InputValidationError("字幕内容为空") + return content, None + + resolved_path = ensure_existing_file( + str(file_path), + label="字幕", + allowed_exts=(".srt",), + ) + return None, resolved_path diff --git a/webui/tools/generate_script_short.py b/webui/tools/generate_script_short.py index 78fd1ee..a6ab013 100644 --- a/webui/tools/generate_script_short.py +++ b/webui/tools/generate_script_short.py @@ -1,13 +1,11 @@ -import os import json import time -import asyncio import traceback -import requests import streamlit as st from loguru import logger from app.config import config +from app.services.upload_validation import ensure_existing_file, InputValidationError def generate_script_short(tr, params, custom_clips=5): @@ -31,12 +29,47 @@ def generate_script_short(tr, params, custom_clips=5): try: with st.spinner("正在生成脚本..."): + # ========== 严格验证:必须上传视频和字幕(与短剧解说保持一致)========== + # 1. 验证视频文件 + video_path = getattr(params, "video_origin_path", None) + if not video_path or not str(video_path).strip(): + st.error("请先选择视频文件") + st.stop() + + try: + ensure_existing_file( + str(video_path), + label="视频", + allowed_exts=(".mp4", ".mov", ".avi", ".flv", ".mkv"), + ) + except InputValidationError as e: + st.error(str(e)) + st.stop() + + # 2. 验证字幕文件(移除推断逻辑,必须上传) + subtitle_path = st.session_state.get("subtitle_path") + if not subtitle_path or not str(subtitle_path).strip(): + st.error("请先上传字幕文件") + st.stop() + + try: + subtitle_path = ensure_existing_file( + str(subtitle_path), + label="字幕", + allowed_exts=(".srt",), + ) + except InputValidationError as e: + st.error(str(e)) + st.stop() + + logger.info(f"使用用户上传的字幕文件: {subtitle_path}") + + # ========== 获取 LLM 配置 ========== text_provider = config.app.get('text_llm_provider', 'gemini').lower() text_api_key = config.app.get(f'text_{text_provider}_api_key') text_model = config.app.get(f'text_{text_provider}_model_name') text_base_url = config.app.get(f'text_{text_provider}_base_url') - - # 优先从 session_state 获取,若未设置则回退到 config 配置 + vision_llm_provider = st.session_state.get('vision_llm_providers') or config.app.get('vision_llm_provider', 'gemini') vision_llm_provider = vision_llm_provider.lower() vision_api_key = st.session_state.get(f'vision_{vision_llm_provider}_api_key') or config.app.get(f'vision_{vision_llm_provider}_api_key', "") @@ -45,48 +78,31 @@ def generate_script_short(tr, params, custom_clips=5): update_progress(20, "开始准备生成脚本") - # 优先使用用户上传的字幕文件 - uploaded_subtitle = st.session_state.get('subtitle_path') - if uploaded_subtitle and os.path.exists(uploaded_subtitle): - srt_path = uploaded_subtitle - logger.info(f"使用用户上传的字幕文件: {srt_path}") - else: - # 回退到根据视频路径自动推断 - srt_path = params.video_origin_path.replace(".mp4", ".srt").replace("videos", "srt").replace("video", "subtitle") - if not os.path.exists(srt_path): - logger.error(f"{srt_path} 文件不存在请检查或重新转录") - st.error(f"{srt_path} 文件不存在,请上传字幕文件或重新转录") - st.stop() + # ========== 调用后端生成脚本 ========== + from app.services.SDP.generate_script_short import generate_script_result - api_params = { - "vision_provider": vision_llm_provider, - "vision_api_key": vision_api_key, - "vision_model_name": vision_model, - "vision_base_url": vision_base_url or "", - "text_provider": text_provider, - "text_api_key": text_api_key, - "text_model_name": text_model, - "text_base_url": text_base_url or "" - } - from app.services.SDP.generate_script_short import generate_script - script = generate_script( - srt_path=srt_path, - output_path="resource/scripts/merged_subtitle.json", + result = generate_script_result( api_key=text_api_key, model_name=text_model, + output_path="resource/scripts/merged_subtitle.json", base_url=text_base_url, custom_clips=custom_clips, - provider=text_provider + provider=text_provider, + subtitle_file_path=subtitle_path, ) - if script is None: - st.error("生成脚本失败,请检查日志") + if result.get("status") != "success": + st.error(result.get("message", "生成脚本失败,请检查日志")) st.stop() + + script = result.get("script") logger.info(f"脚本生成完成 {json.dumps(script, ensure_ascii=False, indent=4)}") + if isinstance(script, list): st.session_state['video_clip_json'] = script elif isinstance(script, str): st.session_state['video_clip_json'] = json.loads(script) + update_progress(80, "脚本生成完成") time.sleep(0.1)