diff --git a/app/services/SDP/generate_script_short.py b/app/services/SDP/generate_script_short.py index c8c062a..07108e5 100644 --- a/app/services/SDP/generate_script_short.py +++ b/app/services/SDP/generate_script_short.py @@ -1,6 +1,8 @@ """ 视频脚本生成pipeline,串联各个处理步骤 """ +import json +import os from typing import Any, Dict, Optional from loguru import logger @@ -16,6 +18,10 @@ def generate_script_result( base_url: str = None, custom_clips: int = 5, provider: str = None, + video_paths=None, + plot_analysis: Optional[str] = None, + short_name: str = "", + drama_genre: str = "", *, srt_path: Optional[str] = None, subtitle_content: Optional[str] = None, @@ -30,6 +36,10 @@ def generate_script_result( base_url: API基础URL,可选 custom_clips: 自定义片段数量,默认5 provider: LLM服务提供商,可选 + video_paths: 原始视频路径列表,用于生成 video_id/video_name + plot_analysis: 已完成的剧情理解文本,提供时会跳过混剪内部剧情理解 + short_name: 短剧名称 + drama_genre: 短剧类型 srt_path: 字幕文件路径(向后兼容) subtitle_content: 字幕文本内容 subtitle_file_path: 字幕文件路径(推荐) @@ -56,10 +66,23 @@ def generate_script_result( provider=provider, srt_path=resolved_path, subtitle_content=resolved_content, + plot_analysis=plot_analysis, + video_paths=video_paths, + short_name=short_name, + drama_genre=drama_genre, ) - adjusted_results = openai_analysis['plot_points'] - final_script = merge_script(adjusted_results, output_path) + if openai_analysis.get("script_items"): + final_script = openai_analysis["script_items"] + if not output_path or not str(output_path).strip(): + raise ValueError("output_path不能为空") + os.makedirs(os.path.dirname(str(output_path)) or ".", exist_ok=True) + with open(output_path, "w", encoding="utf-8") as f: + json.dump(final_script, f, ensure_ascii=False, indent=4) + logger.info(f"短剧混剪脚本生成完成:{output_path}") + else: + adjusted_results = openai_analysis['plot_points'] + final_script = merge_script(adjusted_results, output_path, video_paths=video_paths) return {"status": "success", "script": final_script} @@ -79,6 +102,10 @@ def generate_script( base_url: str = None, custom_clips: int = 5, provider: str = None, + video_paths=None, + plot_analysis: Optional[str] = None, + short_name: str = "", + drama_genre: str = "", *, subtitle_content: Optional[str] = None, subtitle_file_path: Optional[str] = None, @@ -93,6 +120,10 @@ def generate_script( base_url: API基础URL,可选 custom_clips: 自定义片段数量,默认5 provider: LLM服务提供商,可选 + video_paths: 原始视频路径列表,用于生成 video_id/video_name + plot_analysis: 已完成的剧情理解文本 + short_name: 短剧名称 + drama_genre: 短剧类型 subtitle_content: 字幕文本内容(可选) subtitle_file_path: 字幕文件路径(推荐使用,可选) @@ -110,6 +141,10 @@ def generate_script( base_url=base_url, custom_clips=custom_clips, provider=provider, + video_paths=video_paths, + plot_analysis=plot_analysis, + short_name=short_name, + drama_genre=drama_genre, srt_path=srt_path, subtitle_content=subtitle_content, subtitle_file_path=subtitle_file_path, diff --git a/app/services/SDP/utils/step1_subtitle_analyzer_openai.py b/app/services/SDP/utils/step1_subtitle_analyzer_openai.py index 2d7a7e7..71f5f7d 100644 --- a/app/services/SDP/utils/step1_subtitle_analyzer_openai.py +++ b/app/services/SDP/utils/step1_subtitle_analyzer_openai.py @@ -1,11 +1,16 @@ """ 使用统一LLM服务,分析字幕文件,返回剧情梗概和爆点 """ +import os import traceback import json from loguru import logger from app.services.subtitle_text import has_timecodes, normalize_subtitle_text, read_subtitle_text +from app.services.short_drama_narration_validation import ( + build_subtitle_index, + parse_script_timestamp_range, +) # 导入新的提示词管理系统 from app.services.prompts import PromptManager # 导入统一LLM服务 @@ -14,6 +19,176 @@ from app.services.llm.unified_service import UnifiedLLMService from app.services.llm.migration_adapter import _run_async_safely +def _normalize_paths(paths): + if isinstance(paths, str): + paths = [paths] + if not paths: + return [] + + normalized_paths = [] + seen = set() + for path in paths: + if not isinstance(path, str): + continue + path = path.strip() + if not path or path in seen: + continue + normalized_paths.append(path) + seen.add(path) + return normalized_paths + + +def _coerce_positive_int(value): + try: + number = int(value) + except (TypeError, ValueError): + return None + return number if number > 0 else None + + +def _match_video_id_by_name(video_name, video_paths): + video_name = os.path.basename(str(video_name or "").strip()) + if not video_name: + return None + + for index, video_path in enumerate(video_paths, start=1): + if os.path.basename(video_path) == video_name: + return index + return None + + +def _default_video_name(video_id, video_paths): + if 1 <= video_id <= len(video_paths): + return os.path.basename(video_paths[video_id - 1]) + return "" + + +def _normalize_short_mix_items(items, video_paths, subtitle_content): + if not isinstance(items, list) or not items: + raise ValueError("短剧混剪脚本 items 必须是非空数组") + + normalized_video_paths = _normalize_paths(video_paths) + subtitle_index = build_subtitle_index(subtitle_content, normalized_video_paths) + available_video_ids = {cue.video_id for cue in subtitle_index} + if normalized_video_paths: + available_video_ids.update(range(1, len(normalized_video_paths) + 1)) + + normalized_items = [] + ranges_by_video = {} + for index, raw_item in enumerate(items, start=1): + if not isinstance(raw_item, dict): + raise ValueError(f"第 {index} 个混剪片段必须是对象") + + item_id = index + video_id = ( + _match_video_id_by_name(raw_item.get("video_name") or raw_item.get("source_video"), normalized_video_paths) + or _coerce_positive_int(raw_item.get("video_id") or raw_item.get("video_index")) + or 1 + ) + if available_video_ids and video_id not in available_video_ids: + raise ValueError(f"片段 {item_id} 的 video_id={video_id} 不在已选视频范围内") + + try: + start_ms, end_ms, timestamp = parse_script_timestamp_range(raw_item.get("timestamp", "")) + except ValueError as exc: + raise ValueError(f"片段 {item_id}: {exc}") from exc + if start_ms >= end_ms: + raise ValueError(f"片段 {item_id} 的开始时间必须早于结束时间") + + video_cues = [cue for cue in subtitle_index if cue.video_id == video_id] + if video_cues: + min_start = min(cue.start_ms for cue in video_cues) + max_end = max(cue.end_ms for cue in video_cues) + if start_ms < min_start or end_ms > max_end: + raise ValueError(f"片段 {item_id} 的时间戳不在视频 {video_id} 的字幕范围内") + if not any(start_ms < cue.end_ms and end_ms > cue.start_ms for cue in video_cues): + raise ValueError(f"片段 {item_id} 的时间戳没有命中视频 {video_id} 的字幕内容") + + picture = str( + raw_item.get("picture") + or raw_item.get("title") + or raw_item.get("narrative_function") + or raw_item.get("intent") + or raw_item.get("story_role") + or "" + ).strip() + if not picture: + raise ValueError(f"片段 {item_id} 的 picture 不能为空") + + video_name = str(raw_item.get("video_name") or "").strip() + if normalized_video_paths: + video_name = _default_video_name(video_id, normalized_video_paths) + + normalized_items.append( + { + "_id": item_id, + "video_id": video_id, + "video_name": video_name, + "timestamp": timestamp, + "picture": picture, + "narration": f"播放原片{item_id}", + "OST": 1, + } + ) + ranges_by_video.setdefault(video_id, []).append((start_ms, end_ms, item_id)) + + for video_id, ranges in ranges_by_video.items(): + ranges = sorted(ranges, key=lambda item: (item[0], item[1], item[2])) + previous_start, previous_end, previous_id = ranges[0] + for start_ms, end_ms, item_id in ranges[1:]: + if start_ms < previous_end: + raise ValueError(f"视频 {video_id} 的片段 {item_id} 与片段 {previous_id} 时间戳重叠") + if end_ms > previous_end: + previous_start, previous_end, previous_id = start_ms, end_ms, item_id + + return normalized_items + + +def _generate_short_mix_script( + *, + subtitle_content, + plot_analysis, + custom_clips, + provider, + model_name, + api_key, + base_url, + video_paths=None, + short_name="", + drama_genre="", +): + script_generation_prompt = PromptManager.get_prompt( + category="short_drama_editing", + name="script_generation", + parameters={ + "drama_name": short_name or "短剧", + "drama_genre": drama_genre or "短剧", + "plot_analysis": plot_analysis, + "subtitle_content": subtitle_content, + "custom_clips": int(custom_clips or 5), + }, + ) + + response = _run_async_safely( + UnifiedLLMService.generate_text, + prompt=script_generation_prompt, + provider=provider, + model=model_name, + api_key=api_key, + base_url=base_url, + temperature=0.1, + max_tokens=4000, + ) + + from webui.tools.generate_short_summary import parse_and_fix_json + script_data = parse_and_fix_json(response) + if not script_data: + raise ValueError("无法解析短剧混剪脚本JSON") + + script_items = script_data.get("items") or script_data.get("segments") or script_data.get("plot_points") + return _normalize_short_mix_items(script_items, video_paths, subtitle_content) + + def analyze_subtitle( model_name: str, api_key: str = None, @@ -21,7 +196,11 @@ def analyze_subtitle( custom_clips: int = 5, provider: str = None, srt_path: str = None, - subtitle_content: str = None + subtitle_content: str = None, + plot_analysis: str = None, + video_paths=None, + short_name: str = "", + drama_genre: str = "", ) -> dict: """分析字幕内容,返回完整的分析结果 @@ -33,6 +212,10 @@ def analyze_subtitle( provider (str, optional): LLM服务提供商. Defaults to None. srt_path (str, optional): SRT字幕文件路径(与subtitle_content二选一) subtitle_content (str, optional): SRT字幕文本内容(与srt_path二选一) + plot_analysis (str, optional): 已审核/缓存的剧情理解文本,提供时直接进入混剪脚本生成 + video_paths (list, optional): 原始视频路径列表,用于补齐 video_id/video_name + short_name (str, optional): 短剧名称 + drama_genre (str, optional): 短剧类型 Returns: dict: 包含剧情梗概和结构化的时间段分析的字典 @@ -87,6 +270,27 @@ def analyze_subtitle( logger.info(f"使用LLM服务分析字幕,提供商: {provider}, 模型: {model_name}") + if plot_analysis and str(plot_analysis).strip(): + logger.info("使用已有剧情理解直接生成短剧混剪脚本") + script_items = _generate_short_mix_script( + subtitle_content=subtitle_content, + plot_analysis=str(plot_analysis).strip(), + custom_clips=custom_clips, + provider=provider, + model_name=model_name, + api_key=api_key, + base_url=base_url, + video_paths=video_paths, + short_name=short_name, + drama_genre=drama_genre, + ) + return { + "summary": str(plot_analysis).strip(), + "plot_titles": [], + "plot_points": [], + "script_items": script_items, + } + # 使用新的提示词管理系统 subtitle_analysis_prompt = PromptManager.get_prompt( category="short_drama_editing", @@ -120,6 +324,28 @@ def analyze_subtitle( logger.info(f"字幕分析完成,找到 {len(summary_data.get('plot_titles', []))} 个关键情节") logger.debug(json.dumps(summary_data, indent=4, ensure_ascii=False)) + try: + script_items = _generate_short_mix_script( + subtitle_content=subtitle_content, + plot_analysis=json.dumps(summary_data, ensure_ascii=False, indent=2), + custom_clips=custom_clips, + provider=provider, + model_name=model_name, + api_key=api_key, + base_url=base_url, + video_paths=video_paths, + short_name=short_name, + drama_genre=drama_genre, + ) + return { + "summary": summary_data.get("summary", ""), + "plot_titles": summary_data.get("plot_titles", []), + "plot_points": [], + "script_items": script_items, + } + except Exception as direct_script_error: + logger.warning(f"直接生成短剧混剪脚本失败,回退到时间段定位: {direct_script_error}") + # 构建爆点标题列表 plot_titles_text = "" logger.info(f"找到 {len(summary_data.get('plot_titles', []))} 个片段") diff --git a/app/services/SDP/utils/step5_merge_script.py b/app/services/SDP/utils/step5_merge_script.py index a4d2802..a385bd4 100644 --- a/app/services/SDP/utils/step5_merge_script.py +++ b/app/services/SDP/utils/step5_merge_script.py @@ -8,7 +8,8 @@ from typing import Dict, List def merge_script( plot_points: List[Dict], - output_path: str + output_path: str, + video_paths=None, ): """合并生成最终脚本 @@ -19,6 +20,10 @@ def merge_script( Returns: str: 最终合并的脚本 """ + if isinstance(video_paths, str): + video_paths = [video_paths] + video_paths = [path for path in (video_paths or []) if isinstance(path, str) and path.strip()] + # 创建包含所有信息的临时列表 final_script = [] @@ -29,9 +34,12 @@ def merge_script( "_id": number, "timestamp": plot_point["timestamp"], "picture": plot_point["picture"], - "narration": f"播放原生_{os.urandom(4).hex()}", + "narration": f"播放原片{number}", "OST": 1, # OST=0 仅保留解说 OST=2 保留解说和原声 } + if video_paths: + script_item["video_id"] = 1 + script_item["video_name"] = os.path.basename(video_paths[0]) final_script.append(script_item) number += 1 diff --git a/app/services/prompts/short_drama_editing/__init__.py b/app/services/prompts/short_drama_editing/__init__.py index 0f3bd04..2d53ed5 100644 --- a/app/services/prompts/short_drama_editing/__init__.py +++ b/app/services/prompts/short_drama_editing/__init__.py @@ -11,6 +11,7 @@ from .subtitle_analysis import SubtitleAnalysisPrompt from .plot_extraction import PlotExtractionPrompt +from .script_generation import ScriptGenerationPrompt from ..manager import PromptManager @@ -25,9 +26,14 @@ def register_prompts(): plot_extraction_prompt = PlotExtractionPrompt() PromptManager.register_prompt(plot_extraction_prompt, is_default=True) + # 注册混剪脚本生成提示词 + script_generation_prompt = ScriptGenerationPrompt() + PromptManager.register_prompt(script_generation_prompt, is_default=True) + __all__ = [ "SubtitleAnalysisPrompt", - "PlotExtractionPrompt", + "PlotExtractionPrompt", + "ScriptGenerationPrompt", "register_prompts" ] diff --git a/app/services/prompts/short_drama_editing/script_generation.py b/app/services/prompts/short_drama_editing/script_generation.py new file mode 100644 index 0000000..ca9bbd6 --- /dev/null +++ b/app/services/prompts/short_drama_editing/script_generation.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +""" +@Project: 短剧混剪-剪辑脚本生成 +@File : script_generation.py +@Description: 基于剧情理解和字幕直接生成短剧混剪脚本 +""" + +from ..base import ParameterizedPrompt, PromptMetadata, ModelType, OutputFormat + + +class ScriptGenerationPrompt(ParameterizedPrompt): + """短剧混剪脚本生成提示词""" + + def __init__(self): + metadata = PromptMetadata( + name="script_generation", + category="short_drama_editing", + version="v1.0", + description="基于剧情理解和原始字幕直接生成短剧混剪脚本,不生成解说文案", + model_type=ModelType.TEXT, + output_format=OutputFormat.JSON, + tags=["短剧", "混剪", "剪辑脚本", "时间戳", "多视频", "原声"], + parameters=[ + "drama_name", + "drama_genre", + "plot_analysis", + "subtitle_content", + "custom_clips", + ], + ) + super().__init__( + metadata, + required_parameters=["plot_analysis", "subtitle_content", "custom_clips"], + ) + + self._system_prompt = ( + "你是一名专业短剧混剪剪辑师。你必须严格输出JSON," + "只从字幕中选择真实存在的可剪辑原声片段,不生成解说文案。" + ) + + def get_template(self) -> str: + return """# 短剧混剪脚本生成任务 + +## 目标 +根据剧情理解和原始字幕,为短剧《${drama_name}》生成一份可直接裁剪的混剪 JSON 脚本。 + +短剧混剪与短剧解说的区别: +- 不生成解说文案。 +- 不需要用户审核旁白。 +- 直接从剧情理解中选择能串成故事线的原片片段。 +- 每个片段默认保留原声,OST 必须为 1。 + +## 用户选择的短剧类型 + +${drama_genre} + + +## 需要生成的片段数量 + +${custom_clips} + + +## 剧情理解材料 + +${plot_analysis} + + +## 原始字幕(含视频编号和局部时间戳) + +${subtitle_content} + + +## 选择原则 +1. 选择 ${custom_clips} 个片段,尽量形成“开端 -> 冲突升级 -> 高潮/反转 -> 悬念或阶段结果”的完整观看路径。 +2. 只能使用原始字幕中真实存在的视频编号、视频文件名和时间范围。 +3. timestamp 必须是对应 video_id 内部的局部时间戳,格式为 "HH:MM:SS,mmm-HH:MM:SS,mmm"。 +4. 同一个 video_id 内的片段不得交叉或重叠;整体顺序要服务剧情理解,单个视频内尽量按时间顺序。 +5. 优先选择关键对白、身份揭露、情绪爆发、反转、冲突升级和能看懂前因后果的片段。 +6. 单个片段建议 5-45 秒;不要只截 1-2 秒的孤立金句,也不要截过长的流水账。 +7. 如果两个关键剧情之间跳跃太大,优先选择包含上下文的连续时间段,而不是硬切爆点。 +8. picture 要描述画面中人物、动作、情绪、场景和该片段的剧情作用。 +9. narration 字段必须写成“播放原片+_id”,例如 _id 为 3 时写“播放原片3”。 +10. OST 必须为 1,表示保留原片原声。 + +## 字段规则 +- _id:从 1 开始连续递增。 +- video_id:来自字幕分段标题,例如“视频 2”就填 2;单视频填 1。 +- video_name:对应视频文件名,必须从字幕分段标题提取;单视频也要填写。 +- timestamp:必须来自对应视频字幕时间轴。 +- picture:非空字符串。 +- narration:固定为“播放原片+_id”。 +- OST:固定为 1。 + +## 输出格式 +只输出严格 JSON: + +{ + "items": [ + { + "_id": 1, + "video_id": 1, + "video_name": "1.mp4", + "timestamp": "00:00:01,000-00:00:12,500", + "picture": "女主被当众羞辱仍然强撑,冲突正式爆发,为后续逆袭埋下情绪钩子", + "narration": "播放原片1", + "OST": 1 + } + ] +} + +现在请生成短剧混剪脚本。""" diff --git a/webui/components/script_settings.py b/webui/components/script_settings.py index ea345a6..6eb6193 100644 --- a/webui/components/script_settings.py +++ b/webui/components/script_settings.py @@ -26,6 +26,7 @@ from webui.tools.generate_short_summary import ( SCRIPT_TABLE_BASE_COLUMNS = ["_id", "video_id", "video_name", "timestamp", "picture", "narration", "OST"] +SCRIPT_TABLE_TEXT_COLUMNS = {"video_name", "timestamp", "picture", "narration", "value"} MODE_FILE = "file_selection" MODE_AUTO = "auto" MODE_SHORT = "short" @@ -666,16 +667,42 @@ def render_short_generate_options(tr): 渲染Short Generate模式下的特殊选项 在Short Generate模式下,替换原有的输入框为自定义片段选项 """ - summary_narration_panel(tr, SUMMARY_MODE_CONFIGS[MODE_SHORT_SUMMARY]) - # 显示自定义片段数量选择器 - custom_clips = st.number_input( - tr("自定义片段"), - min_value=1, - max_value=20, - value=st.session_state.get('custom_clips', 5), - help=tr("设置需要生成的短视频片段数量"), - key="custom_clips_input" - ) + summary_config = SUMMARY_MODE_CONFIGS[MODE_SHORT_SUMMARY] + summary_narration_panel(tr, summary_config) + + type_option_key = _summary_state_key(summary_config, "type_option") + custom_type_key = _summary_state_key(summary_config, "custom_type") + type_options = [code for code, _ in summary_config["type_options"]] + if st.session_state.get(type_option_key) not in type_options: + st.session_state[type_option_key] = summary_config["default_type"] + + show_custom_type = st.session_state.get(type_option_key, summary_config["default_type"]) == "custom" + option_cols = st.columns([1.1, 1.1, 1], vertical_alignment="bottom") if show_custom_type else st.columns([1.1, 1], vertical_alignment="bottom") + with option_cols[0]: + st.selectbox( + tr(summary_config["type_label_key"]), + options=type_options, + format_func=lambda code: tr(dict(summary_config["type_options"]).get(code, code)), + key=type_option_key, + ) + option_index = 1 + if show_custom_type: + with option_cols[option_index]: + st.text_input( + tr(summary_config["custom_type_label_key"]), + key=custom_type_key, + placeholder=tr(summary_config["custom_type_placeholder_key"]), + ) + option_index += 1 + with option_cols[option_index]: + custom_clips = st.number_input( + tr("自定义片段"), + min_value=1, + max_value=20, + value=st.session_state.get('custom_clips', 5), + help=tr("设置需要生成的短视频片段数量"), + key="custom_clips_input" + ) st.session_state['custom_clips'] = custom_clips @@ -729,6 +756,7 @@ def summary_narration_panel(tr, summary_config): plot_analysis_key = _summary_state_key(summary_config, "plot_analysis") plot_source_key = _summary_state_key(summary_config, "plot_analysis_subtitle_path") plot_signature_key = _summary_state_key(summary_config, "plot_analysis_signature") + pending_plot_key = _summary_state_key(summary_config, "pending_plot_analysis") st.markdown( f""" @@ -815,6 +843,15 @@ def summary_narration_panel(tr, summary_config): st.session_state[plot_analysis_key] = "" st.session_state[plot_source_key] = "" st.session_state[plot_signature_key] = "" + st.session_state.pop(pending_plot_key, None) + else: + pending_plot = st.session_state.pop(pending_plot_key, None) + if isinstance(pending_plot, dict) and pending_plot.get("signature") == current_signature: + pending_analysis = str(pending_plot.get("plot_analysis") or "") + if pending_analysis: + st.session_state[plot_analysis_key] = pending_analysis + st.session_state[plot_source_key] = pending_plot.get("subtitle_path") or current_subtitle_path + st.session_state[plot_signature_key] = current_signature if analyze_plot_clicked: with st.spinner(tr("Analyzing plot...")): @@ -1003,10 +1040,17 @@ def _script_json_to_table(script_data): {"value": json.dumps(item, ensure_ascii=False)} for item in script_data ] - return pd.DataFrame(rows, columns=["value"]) + return _normalize_script_table_types(pd.DataFrame(rows, columns=["value"])) columns = _ordered_script_columns(script_data) - return pd.DataFrame(script_data, columns=columns) + return _normalize_script_table_types(pd.DataFrame(script_data, columns=columns)) + + +def _normalize_script_table_types(table_data): + for column in SCRIPT_TABLE_TEXT_COLUMNS: + if column in table_data.columns: + table_data[column] = table_data[column].where(table_data[column].notna(), "").astype(str).astype("object") + return table_data def _normalize_script_table_value(column, value): @@ -1723,8 +1767,66 @@ def render_script_buttons(tr, params): generate_script_docu(params, tr) elif script_path == "short": # 执行 短剧混剪 脚本生成 + summary_config = SUMMARY_MODE_CONFIGS[MODE_SHORT_SUMMARY] + type_option_key = _summary_state_key(summary_config, "type_option") + custom_type_key = _summary_state_key(summary_config, "custom_type") + web_search_key = _summary_state_key(summary_config, "web_search_enabled") + plot_analysis_key = _summary_state_key(summary_config, "plot_analysis") + plot_source_key = _summary_state_key(summary_config, "plot_analysis_subtitle_path") + plot_signature_key = _summary_state_key(summary_config, "plot_analysis_signature") + pending_plot_key = _summary_state_key(summary_config, "pending_plot_analysis") + if ( + st.session_state.get(type_option_key) == "custom" + and not str(st.session_state.get(custom_type_key, '') or '').strip() + ): + st.error(tr(summary_config["custom_type_empty_key"])) + st.stop() + + subtitle_paths = _selected_subtitle_paths() + subtitle_path = subtitle_paths[0] if subtitle_paths else None + video_theme = st.session_state.get('video_theme') + web_search_enabled = bool(st.session_state.get(web_search_key, False)) + current_signature = _short_drama_plot_analysis_signature( + subtitle_paths, + video_theme, + web_search_enabled, + _selected_video_paths(), + ) + plot_analysis = "" + if st.session_state.get(plot_signature_key) == current_signature: + plot_analysis = st.session_state.get(plot_analysis_key, '') + elif ( + not web_search_enabled + and st.session_state.get(plot_source_key) == subtitle_path + ): + plot_analysis = st.session_state.get(plot_analysis_key, '') + custom_clips = st.session_state.get('custom_clips') - generate_script_short(tr, params, custom_clips) + short_result = generate_script_short( + tr, + params, + custom_clips, + subtitle_paths=subtitle_paths, + video_theme=video_theme, + temperature=st.session_state.get('temperature', 0.7), + plot_analysis=plot_analysis, + subtitle_content=st.session_state.get('subtitle_content', ''), + enable_web_search=web_search_enabled, + video_paths=_selected_video_paths(), + drama_genre=_resolve_short_drama_type(), + prompt_category=summary_config["prompt_category"], + search_keywords=summary_config["search_keywords"], + empty_title_message_key=summary_config["empty_title_message_key"], + web_search_context_description=summary_config["web_search_context_description"], + ) + if short_result and short_result.get("plot_analysis"): + st.session_state[pending_plot_key] = { + "plot_analysis": short_result["plot_analysis"], + "subtitle_path": subtitle_path, + "signature": current_signature, + } + st.session_state[plot_source_key] = subtitle_path + st.session_state[plot_signature_key] = current_signature else: load_script(tr, script_path) diff --git a/webui/tools/generate_script_short.py b/webui/tools/generate_script_short.py index 5c7b43d..e01b921 100644 --- a/webui/tools/generate_script_short.py +++ b/webui/tools/generate_script_short.py @@ -8,9 +8,32 @@ from loguru import logger from app.config import config from app.services.upload_validation import ensure_existing_file, InputValidationError from app.utils import utils +from webui.tools.generate_short_summary import ( + SHORT_DRAMA_PROMPT_CATEGORY, + SHORT_DRAMA_SEARCH_KEYWORDS, + _build_combined_subtitle_content, + _normalize_paths, + analyze_short_drama_plot, +) -def generate_script_short(tr, params, custom_clips=5): +def generate_script_short( + tr, + params, + custom_clips=5, + subtitle_paths=None, + video_theme=None, + temperature=0.7, + plot_analysis=None, + subtitle_content=None, + enable_web_search=False, + video_paths=None, + drama_genre="逆袭/复仇", + prompt_category=SHORT_DRAMA_PROMPT_CATEGORY, + search_keywords=SHORT_DRAMA_SEARCH_KEYWORDS, + empty_title_message_key="Please enter short drama name before web search", + web_search_context_description="短剧名称、人物关系、剧情背景和公开剧情梗概", +): """ 生成短视频脚本 @@ -18,6 +41,14 @@ def generate_script_short(tr, params, custom_clips=5): tr: 翻译函数 params: 视频参数对象 custom_clips: 自定义片段数量,默认为5 + subtitle_paths: 已转写/上传/翻译/校准后的字幕路径列表 + video_theme: 短剧名称 + temperature: LLM温度 + plot_analysis: 已完成的剧情理解文本 + subtitle_content: 已合并的字幕文本 + enable_web_search: 是否在剧情理解前联网搜索 + video_paths: 原始视频路径列表 + drama_genre: 用户选择的短剧类型 """ progress_bar = st.progress(0) status_text = st.empty() @@ -33,38 +64,47 @@ def generate_script_short(tr, params, custom_clips=5): with st.spinner(tr("Generating script...")): # ========== 严格验证:必须上传视频和字幕(与短剧解说保持一致)========== # 1. 验证视频文件 - video_path = getattr(params, "video_origin_path", None) - if not video_path or not str(video_path).strip(): + selected_video_paths = _normalize_paths( + video_paths + or getattr(params, "video_origin_paths", []) + or getattr(params, "video_origin_path", "") + ) + if not selected_video_paths: st.error(tr("Please select video file first")) st.stop() - try: - ensure_existing_file( - str(video_path), - label=tr("Video"), - allowed_exts=(".mp4", ".mov", ".avi", ".flv", ".mkv"), - ) - except InputValidationError as e: - st.error(str(e)) - st.stop() + for video_path in selected_video_paths: + try: + ensure_existing_file( + str(video_path), + label=tr("Video"), + allowed_exts=(".mp4", ".mov", ".avi", ".flv", ".mkv"), + ) + except InputValidationError as e: + st.error(str(e)) + st.stop() # 2. 验证字幕文件(移除推断逻辑,必须上传) - subtitle_path = st.session_state.get("subtitle_path") - if not subtitle_path or not str(subtitle_path).strip(): + subtitle_paths = _normalize_paths(subtitle_paths or st.session_state.get("subtitle_paths") or st.session_state.get("subtitle_path")) + if not subtitle_paths: st.error(tr("Please upload subtitle file first")) st.stop() + validated_subtitle_paths = [] try: - subtitle_path = ensure_existing_file( - str(subtitle_path), - label=tr("Subtitle"), - allowed_exts=(".srt",), - ) + for subtitle_path in subtitle_paths: + validated_subtitle_paths.append( + ensure_existing_file( + str(subtitle_path), + label=tr("Subtitle"), + allowed_exts=(".srt",), + ) + ) except InputValidationError as e: st.error(str(e)) st.stop() - logger.info(f"使用用户上传的字幕文件: {subtitle_path}") + logger.info(f"使用用户处理后的字幕文件: {validated_subtitle_paths}") # ========== 获取 LLM 配置 ========== text_provider = config.app.get('text_llm_provider', 'gemini').lower() @@ -80,18 +120,40 @@ def generate_script_short(tr, params, custom_clips=5): update_progress(20, tr("Preparing script generation")) + subtitle_content = str(subtitle_content or "").strip() or _build_combined_subtitle_content( + validated_subtitle_paths, + selected_video_paths, + ) + if not subtitle_content: + st.error(tr("Subtitle file is empty or unreadable")) + st.stop() + + plot_analysis = str(plot_analysis or "").strip() + if not plot_analysis: + update_progress(35, tr("Analyzing subtitles with model...")) + plot_analysis = analyze_short_drama_plot( + validated_subtitle_paths, + temperature, + tr, + subtitle_content=subtitle_content, + short_name=video_theme, + enable_web_search=enable_web_search, + video_paths=selected_video_paths, + prompt_category=prompt_category, + search_keywords=search_keywords, + empty_title_message_key=empty_title_message_key, + web_search_context_description=web_search_context_description, + ) + if not plot_analysis: + st.error(tr("Script generation failed check logs")) + st.stop() + # ========== 调用后端生成脚本 ========== from app.services.SDP.generate_script_short import generate_script_result output_path = os.path.join(utils.script_dir(), "merged_subtitle.json") - subtitle_content = st.session_state.get("subtitle_content") - subtitle_kwargs = ( - {"subtitle_content": str(subtitle_content)} - if subtitle_content is not None and str(subtitle_content).strip() - else {"subtitle_file_path": subtitle_path} - ) - + update_progress(55, tr("Generating script...")) result = generate_script_result( api_key=text_api_key, model_name=text_model, @@ -99,7 +161,11 @@ def generate_script_short(tr, params, custom_clips=5): base_url=text_base_url, custom_clips=custom_clips, provider=text_provider, - **subtitle_kwargs, + subtitle_content=subtitle_content, + video_paths=selected_video_paths, + plot_analysis=plot_analysis, + short_name=video_theme or "", + drama_genre=drama_genre or "", ) if result.get("status") != "success": @@ -120,8 +186,14 @@ def generate_script_short(tr, params, custom_clips=5): progress_bar.progress(100) status_text.text(tr("Script generation completed!")) st.success(tr("Video script generated successfully")) + return { + "script": st.session_state.get('video_clip_json', []), + "plot_analysis": plot_analysis, + "subtitle_content": subtitle_content, + } except Exception as err: progress_bar.progress(100) st.error(f"{tr('Generation error')}: {str(err)}") logger.exception(f"生成脚本时发生错误\n{traceback.format_exc()}") + return None