mirror of
https://github.com/linyqh/NarratoAI.git
synced 2026-01-12 19:38:11 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0706b00577 | ||
|
|
08f682bb50 | ||
|
|
26f0dfeab5 |
7
.gitignore
vendored
7
.gitignore
vendored
@ -39,4 +39,9 @@ bug清单.md
|
||||
task.md
|
||||
.claude/*
|
||||
.serena/*
|
||||
CLAUDE.md
|
||||
|
||||
# OpenSpec: 忽略活动的变更提案,但保留归档和规范
|
||||
openspec/*
|
||||
AGENTS.md
|
||||
CLAUDE.md
|
||||
tests/*
|
||||
@ -15,6 +15,7 @@ from typing import Dict, Any, Optional
|
||||
from loguru import logger
|
||||
from app.config import config
|
||||
from app.utils.utils import get_uuid, storage_dir
|
||||
from app.services.subtitle_text import read_subtitle_text
|
||||
# 导入新的提示词管理系统
|
||||
from app.services.prompts import PromptManager
|
||||
|
||||
@ -309,8 +310,13 @@ class SubtitleAnalyzer:
|
||||
}
|
||||
|
||||
# 读取文件内容
|
||||
with open(subtitle_file_path, 'r', encoding='utf-8') as f:
|
||||
subtitle_content = f.read()
|
||||
subtitle_content = read_subtitle_text(subtitle_file_path).text
|
||||
if not subtitle_content:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"字幕文件内容为空或无法读取: {subtitle_file_path}",
|
||||
"temperature": self.temperature
|
||||
}
|
||||
|
||||
# 分析字幕
|
||||
return self.analyze_subtitle(subtitle_content)
|
||||
|
||||
@ -1,43 +1,125 @@
|
||||
"""
|
||||
视频脚本生成pipeline,串联各个处理步骤
|
||||
"""
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
from loguru import logger
|
||||
|
||||
from .utils.step1_subtitle_analyzer_openai import analyze_subtitle
|
||||
from .utils.step5_merge_script import merge_script
|
||||
from app.services.upload_validation import InputValidationError, resolve_subtitle_input
|
||||
|
||||
|
||||
def generate_script(srt_path: str, api_key: str, model_name: str, output_path: str, base_url: str = None, custom_clips: int = 5, provider: str = None):
|
||||
"""生成视频混剪脚本
|
||||
def generate_script_result(
|
||||
api_key: str,
|
||||
model_name: str,
|
||||
output_path: str,
|
||||
base_url: str = None,
|
||||
custom_clips: int = 5,
|
||||
provider: str = None,
|
||||
*,
|
||||
srt_path: Optional[str] = None,
|
||||
subtitle_content: Optional[str] = None,
|
||||
subtitle_file_path: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""生成视频混剪脚本(安全版本,返回结果字典)
|
||||
|
||||
Args:
|
||||
srt_path: 字幕文件路径
|
||||
api_key: API密钥
|
||||
model_name: 模型名称
|
||||
output_path: 输出文件路径,可选
|
||||
base_url: API基础URL
|
||||
custom_clips: 自定义片段数量
|
||||
provider: LLM服务提供商
|
||||
output_path: 输出文件路径
|
||||
base_url: API基础URL,可选
|
||||
custom_clips: 自定义片段数量,默认5
|
||||
provider: LLM服务提供商,可选
|
||||
srt_path: 字幕文件路径(向后兼容)
|
||||
subtitle_content: 字幕文本内容
|
||||
subtitle_file_path: 字幕文件路径(推荐)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]:
|
||||
成功: {"status": "success", "script": [...]}
|
||||
失败: {"status": "error", "message": "错误信息"}
|
||||
"""
|
||||
try:
|
||||
# 解析字幕输入源(支持内容或文件路径)
|
||||
resolved_content, resolved_path = resolve_subtitle_input(
|
||||
subtitle_content=subtitle_content,
|
||||
subtitle_file_path=subtitle_file_path,
|
||||
srt_path=srt_path,
|
||||
)
|
||||
|
||||
logger.info("开始分析字幕内容...")
|
||||
openai_analysis = analyze_subtitle(
|
||||
model_name=model_name,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
custom_clips=custom_clips,
|
||||
provider=provider,
|
||||
srt_path=resolved_path,
|
||||
subtitle_content=resolved_content,
|
||||
)
|
||||
|
||||
adjusted_results = openai_analysis['plot_points']
|
||||
final_script = merge_script(adjusted_results, output_path)
|
||||
|
||||
return {"status": "success", "script": final_script}
|
||||
|
||||
except InputValidationError as e:
|
||||
logger.error(f"输入验证失败: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
except Exception as e:
|
||||
logger.exception(f"SDP 脚本生成失败: {e}")
|
||||
return {"status": "error", "message": f"生成脚本失败: {str(e)}"}
|
||||
|
||||
|
||||
def generate_script(
|
||||
srt_path: Optional[str] = None,
|
||||
api_key: str = None,
|
||||
model_name: str = None,
|
||||
output_path: str = None,
|
||||
base_url: str = None,
|
||||
custom_clips: int = 5,
|
||||
provider: str = None,
|
||||
*,
|
||||
subtitle_content: Optional[str] = None,
|
||||
subtitle_file_path: Optional[str] = None,
|
||||
):
|
||||
"""生成视频混剪脚本(向后兼容版本)
|
||||
|
||||
Args:
|
||||
srt_path: 字幕文件路径(向后兼容参数,可选)
|
||||
api_key: API密钥
|
||||
model_name: 模型名称
|
||||
output_path: 输出文件路径
|
||||
base_url: API基础URL,可选
|
||||
custom_clips: 自定义片段数量,默认5
|
||||
provider: LLM服务提供商,可选
|
||||
subtitle_content: 字幕文本内容(可选)
|
||||
subtitle_file_path: 字幕文件路径(推荐使用,可选)
|
||||
|
||||
Returns:
|
||||
str: 生成的脚本内容
|
||||
"""
|
||||
# 验证输入文件
|
||||
if not os.path.exists(srt_path):
|
||||
raise FileNotFoundError(f"字幕文件不存在: {srt_path}")
|
||||
|
||||
# 分析字幕
|
||||
print("开始分析...")
|
||||
openai_analysis = analyze_subtitle(
|
||||
srt_path=srt_path,
|
||||
Raises:
|
||||
FileNotFoundError: 字幕文件不存在(向后兼容)
|
||||
ValueError: 输入验证失败或脚本生成失败
|
||||
"""
|
||||
result = generate_script_result(
|
||||
api_key=api_key,
|
||||
model_name=model_name,
|
||||
output_path=output_path,
|
||||
base_url=base_url,
|
||||
custom_clips=custom_clips,
|
||||
provider=provider
|
||||
provider=provider,
|
||||
srt_path=srt_path,
|
||||
subtitle_content=subtitle_content,
|
||||
subtitle_file_path=subtitle_file_path,
|
||||
)
|
||||
|
||||
# 合并生成最终脚本
|
||||
adjusted_results = openai_analysis['plot_points']
|
||||
final_script = merge_script(adjusted_results, output_path)
|
||||
if result.get("status") != "success":
|
||||
error_message = result.get("message", "生成脚本失败")
|
||||
# 保持向后兼容:如果是文件不存在错误,抛出 FileNotFoundError
|
||||
if "不存在" in error_message and (srt_path or subtitle_file_path):
|
||||
raise FileNotFoundError(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
return final_script
|
||||
return result["script"]
|
||||
|
||||
@ -3,10 +3,9 @@
|
||||
"""
|
||||
import traceback
|
||||
import json
|
||||
import asyncio
|
||||
from loguru import logger
|
||||
|
||||
from .utils import load_srt
|
||||
from app.services.subtitle_text import has_timecodes, normalize_subtitle_text, read_subtitle_text
|
||||
# 导入新的提示词管理系统
|
||||
from app.services.prompts import PromptManager
|
||||
# 导入统一LLM服务
|
||||
@ -16,33 +15,64 @@ from app.services.llm.migration_adapter import _run_async_safely
|
||||
|
||||
|
||||
def analyze_subtitle(
|
||||
srt_path: str,
|
||||
model_name: str,
|
||||
api_key: str = None,
|
||||
base_url: str = None,
|
||||
custom_clips: int = 5,
|
||||
provider: str = None
|
||||
provider: str = None,
|
||||
srt_path: str = None,
|
||||
subtitle_content: str = None
|
||||
) -> dict:
|
||||
"""分析字幕内容,返回完整的分析结果
|
||||
|
||||
Args:
|
||||
srt_path (str): SRT字幕文件路径
|
||||
model_name (str): 大模型名称
|
||||
api_key (str, optional): 大模型API密钥. Defaults to None.
|
||||
base_url (str, optional): 大模型API基础URL. Defaults to None.
|
||||
custom_clips (int): 需要提取的片段数量. Defaults to 5.
|
||||
provider (str, optional): LLM服务提供商. Defaults to None.
|
||||
srt_path (str, optional): SRT字幕文件路径(与subtitle_content二选一)
|
||||
subtitle_content (str, optional): SRT字幕文本内容(与srt_path二选一)
|
||||
|
||||
Returns:
|
||||
dict: 包含剧情梗概和结构化的时间段分析的字典
|
||||
"""
|
||||
try:
|
||||
# 加载字幕文件
|
||||
subtitles = load_srt(srt_path)
|
||||
subtitle_content = "\n".join([f"{sub['timestamp']}\n{sub['text']}" for sub in subtitles])
|
||||
# 读取并规范化字幕文本(不依赖结构化 SRT 解析,提升兼容性)
|
||||
if subtitle_content and str(subtitle_content).strip():
|
||||
normalized_subtitle_text = normalize_subtitle_text(subtitle_content)
|
||||
source_label = "字幕内容(直接传入)"
|
||||
elif srt_path:
|
||||
decoded = read_subtitle_text(srt_path)
|
||||
normalized_subtitle_text = decoded.text
|
||||
source_label = f"字幕文件: {srt_path} (encoding: {decoded.encoding})"
|
||||
else:
|
||||
raise ValueError("必须提供 srt_path 或 subtitle_content 参数")
|
||||
|
||||
# 初始化统一LLM服务
|
||||
llm_service = UnifiedLLMService()
|
||||
# 基础校验:必须有内容且包含可用于定位的时间码
|
||||
if not normalized_subtitle_text or len(normalized_subtitle_text.strip()) < 10:
|
||||
error_msg = (
|
||||
f"字幕来源 [{source_label}] 内容为空或过短。\n"
|
||||
f"请检查:\n"
|
||||
f"1. 文件格式是否为标准 SRT\n"
|
||||
f"2. 文件编码是否为 UTF-8、UTF-16、GBK 或 GB2312\n"
|
||||
f"3. 文件内容是否为空"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if not has_timecodes(normalized_subtitle_text):
|
||||
error_msg = (
|
||||
f"字幕来源 [{source_label}] 未检测到有效时间码,无法进行时间段定位。\n"
|
||||
f"请确保字幕包含类似以下格式的时间轴:\n"
|
||||
f"00:00:01,000 --> 00:00:02,000\n"
|
||||
f"(若毫秒分隔符为'.',系统会自动规范化为',')"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
logger.info(f"成功加载字幕来源 [{source_label}],字符数: {len(normalized_subtitle_text)}")
|
||||
subtitle_content = normalized_subtitle_text
|
||||
|
||||
# 如果没有指定provider,根据model_name推断
|
||||
if not provider:
|
||||
@ -88,11 +118,11 @@ def analyze_subtitle(
|
||||
raise Exception("无法解析LLM返回的JSON数据")
|
||||
|
||||
logger.info(f"字幕分析完成,找到 {len(summary_data.get('plot_titles', []))} 个关键情节")
|
||||
print(json.dumps(summary_data, indent=4, ensure_ascii=False))
|
||||
logger.debug(json.dumps(summary_data, indent=4, ensure_ascii=False))
|
||||
|
||||
# 构建爆点标题列表
|
||||
plot_titles_text = ""
|
||||
print(f"找到 {len(summary_data['plot_titles'])} 个片段")
|
||||
logger.info(f"找到 {len(summary_data.get('plot_titles', []))} 个片段")
|
||||
for i, point in enumerate(summary_data['plot_titles'], 1):
|
||||
plot_titles_text += f"{i}. {point}\n"
|
||||
|
||||
@ -140,4 +170,3 @@ def analyze_subtitle(
|
||||
except Exception as e:
|
||||
logger.error(f"分析字幕时发生错误: {str(e)}")
|
||||
raise Exception(f"分析字幕时发生错误:{str(e)}\n{traceback.format_exc()}")
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
from typing import List, Dict, Tuple
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
def merge_script(
|
||||
@ -19,38 +19,12 @@ def merge_script(
|
||||
Returns:
|
||||
str: 最终合并的脚本
|
||||
"""
|
||||
def parse_timestamp(ts: str) -> Tuple[float, float]:
|
||||
"""解析时间戳,返回开始和结束时间(秒)"""
|
||||
start, end = ts.split('-')
|
||||
|
||||
def parse_time(time_str: str) -> float:
|
||||
time_str = time_str.strip()
|
||||
if ',' in time_str:
|
||||
time_parts, ms_parts = time_str.split(',')
|
||||
ms = float(ms_parts) / 1000
|
||||
else:
|
||||
time_parts = time_str
|
||||
ms = 0
|
||||
|
||||
hours, minutes, seconds = map(int, time_parts.split(':'))
|
||||
return hours * 3600 + minutes * 60 + seconds + ms
|
||||
|
||||
return parse_time(start), parse_time(end)
|
||||
|
||||
def format_timestamp(seconds: float) -> str:
|
||||
"""将秒数转换为时间戳格式 HH:MM:SS"""
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
secs = int(seconds % 60)
|
||||
return f"{hours:02d}:{minutes:02d}:{secs:02d}"
|
||||
|
||||
# 创建包含所有信息的临时列表
|
||||
final_script = []
|
||||
|
||||
# 处理原生画面条目
|
||||
number = 1
|
||||
for plot_point in plot_points:
|
||||
start, end = parse_timestamp(plot_point["timestamp"])
|
||||
script_item = {
|
||||
"_id": number,
|
||||
"timestamp": plot_point["timestamp"],
|
||||
@ -62,6 +36,11 @@ def merge_script(
|
||||
number += 1
|
||||
|
||||
# 保存结果
|
||||
if not output_path or not str(output_path).strip():
|
||||
raise ValueError("output_path不能为空")
|
||||
|
||||
output_path = str(output_path)
|
||||
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(final_script, f, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
@ -1,45 +1,123 @@
|
||||
# 公共方法
|
||||
import json
|
||||
import requests # 新增
|
||||
import pysrt
|
||||
from loguru import logger
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
def load_srt(file_path: str) -> List[Dict]:
|
||||
"""加载并解析SRT文件
|
||||
"""加载并解析SRT文件(使用 pysrt 库,支持多种编码和格式)
|
||||
|
||||
Args:
|
||||
file_path: SRT文件路径
|
||||
|
||||
Returns:
|
||||
字幕内容列表
|
||||
字幕内容列表,格式:
|
||||
[
|
||||
{
|
||||
'number': int, # 字幕序号
|
||||
'timestamp': str, # "00:00:01,000 --> 00:00:03,000"
|
||||
'text': str, # 字幕文本
|
||||
'start_time': str, # "00:00:01,000"
|
||||
'end_time': str # "00:00:03,000"
|
||||
},
|
||||
...
|
||||
]
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 文件不存在
|
||||
ValueError: 文件编码不支持或格式错误
|
||||
"""
|
||||
with open(file_path, 'r', encoding='utf-8-sig') as f:
|
||||
content = f.read().strip()
|
||||
# 编码自动检测:依次尝试常见编码
|
||||
encodings = ['utf-8', 'utf-8-sig', 'gbk', 'gb2312']
|
||||
subs = None
|
||||
detected_encoding = None
|
||||
|
||||
# 按空行分割字幕块
|
||||
subtitle_blocks = content.split('\n\n')
|
||||
for encoding in encodings:
|
||||
try:
|
||||
subs = pysrt.open(file_path, encoding=encoding)
|
||||
detected_encoding = encoding
|
||||
logger.info(f"成功加载字幕文件 {file_path},编码:{encoding},共 {len(subs)} 条")
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"使用编码 {encoding} 加载失败: {e}")
|
||||
continue
|
||||
|
||||
if subs is None:
|
||||
# 所有编码都失败
|
||||
raise ValueError(
|
||||
f"无法读取字幕文件 {file_path},"
|
||||
f"请检查文件编码(支持 UTF-8、GBK、GB2312)"
|
||||
)
|
||||
|
||||
# 检查是否为空
|
||||
if not subs:
|
||||
logger.warning(f"字幕文件 {file_path} 解析后无有效内容")
|
||||
return []
|
||||
|
||||
# 转换为原格式(向后兼容)
|
||||
subtitles = []
|
||||
for sub in subs:
|
||||
# 合并多行文本为单行(某些 SRT 文件会有换行)
|
||||
text = sub.text.replace('\n', ' ').strip()
|
||||
|
||||
for block in subtitle_blocks:
|
||||
lines = block.split('\n')
|
||||
if len(lines) >= 3: # 确保块包含足够的行
|
||||
try:
|
||||
number = int(lines[0].strip())
|
||||
timestamp = lines[1]
|
||||
text = ' '.join(lines[2:])
|
||||
# 跳过空字幕
|
||||
if not text:
|
||||
continue
|
||||
|
||||
# 解析时间戳
|
||||
start_time, end_time = timestamp.split(' --> ')
|
||||
|
||||
subtitles.append({
|
||||
'number': number,
|
||||
'timestamp': timestamp,
|
||||
'text': text,
|
||||
'start_time': start_time,
|
||||
'end_time': end_time
|
||||
})
|
||||
except ValueError as e:
|
||||
print(f"Warning: 跳过无效的字幕块: {e}")
|
||||
continue
|
||||
subtitles.append({
|
||||
'number': sub.index,
|
||||
'timestamp': f"{sub.start} --> {sub.end}",
|
||||
'text': text,
|
||||
'start_time': str(sub.start),
|
||||
'end_time': str(sub.end)
|
||||
})
|
||||
|
||||
logger.info(f"成功解析 {len(subtitles)} 条有效字幕")
|
||||
return subtitles
|
||||
|
||||
|
||||
def load_srt_from_content(srt_content: str) -> List[Dict]:
|
||||
"""从字符串内容解析SRT(用于直接传入字幕内容,无需依赖文件路径)
|
||||
|
||||
Args:
|
||||
srt_content: SRT格式的字幕文本内容
|
||||
|
||||
Returns:
|
||||
字幕内容列表,格式同 load_srt 函数
|
||||
|
||||
Raises:
|
||||
ValueError: 字幕内容为空或格式错误
|
||||
"""
|
||||
if srt_content is None or not str(srt_content).strip():
|
||||
raise ValueError("字幕内容为空")
|
||||
|
||||
try:
|
||||
subs = pysrt.from_string(str(srt_content))
|
||||
except Exception as e:
|
||||
logger.error(f"无法解析字幕内容: {e}")
|
||||
raise ValueError("无法解析字幕内容,请确保为标准 SRT 格式") from e
|
||||
|
||||
if not subs:
|
||||
logger.warning("字幕内容解析后无有效内容")
|
||||
return []
|
||||
|
||||
subtitles = []
|
||||
for sub in subs:
|
||||
text = sub.text.replace('\n', ' ').strip()
|
||||
if not text:
|
||||
continue
|
||||
|
||||
subtitles.append({
|
||||
'number': sub.index,
|
||||
'timestamp': f"{sub.start} --> {sub.end}",
|
||||
'text': text,
|
||||
'start_time': str(sub.start),
|
||||
'end_time': str(sub.end)
|
||||
})
|
||||
|
||||
logger.info(f"成功从内容解析 {len(subtitles)} 条有效字幕")
|
||||
return subtitles
|
||||
|
||||
124
app/services/subtitle_text.py
Normal file
124
app/services/subtitle_text.py
Normal file
@ -0,0 +1,124 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
|
||||
"""
|
||||
Subtitle text utilities.
|
||||
|
||||
This module provides a shared, cross-platform way to read and normalize subtitle
|
||||
content. Both Short Drama Editing (混剪) and Short Drama Narration (解说) should
|
||||
consume subtitle content through this module to avoid platform-specific parsing
|
||||
issues (e.g. Windows UTF-16 SRT, timestamp separators, etc.).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Optional
|
||||
|
||||
|
||||
_SRT_TIME_RE = re.compile(
|
||||
r"\b\d{2}:\d{2}:\d{2}(?:[,.]\d{3})?\s*-->\s*\d{2}:\d{2}:\d{2}(?:[,.]\d{3})?\b"
|
||||
)
|
||||
_SRT_MS_DOT_RE = re.compile(r"(\b\d{2}:\d{2}:\d{2})\.(\d{3}\b)")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DecodedSubtitle:
|
||||
text: str
|
||||
encoding: str
|
||||
|
||||
|
||||
def has_timecodes(text: str) -> bool:
|
||||
"""Return True if the subtitle text contains at least one SRT timecode."""
|
||||
if not text:
|
||||
return False
|
||||
return _SRT_TIME_RE.search(text) is not None
|
||||
|
||||
|
||||
def normalize_subtitle_text(text: str) -> str:
|
||||
"""
|
||||
Normalize subtitle text to improve cross-platform reliability.
|
||||
|
||||
- Unifies line endings to LF
|
||||
- Removes BOM and NUL bytes
|
||||
- Normalizes millisecond separators from '.' to ',' in timecodes
|
||||
"""
|
||||
if text is None:
|
||||
return ""
|
||||
|
||||
normalized = str(text)
|
||||
|
||||
# Strip BOM.
|
||||
if normalized.startswith("\ufeff"):
|
||||
normalized = normalized.lstrip("\ufeff")
|
||||
|
||||
# Remove NUL bytes (common when UTF-16 is mis-decoded elsewhere).
|
||||
normalized = normalized.replace("\x00", "")
|
||||
|
||||
# Normalize newlines.
|
||||
normalized = normalized.replace("\r\n", "\n").replace("\r", "\n")
|
||||
|
||||
# Normalize timestamp millisecond separator: 00:00:01.000 -> 00:00:01,000
|
||||
normalized = _SRT_MS_DOT_RE.sub(r"\1,\2", normalized)
|
||||
|
||||
return normalized.strip()
|
||||
|
||||
|
||||
def decode_subtitle_bytes(
|
||||
data: bytes,
|
||||
*,
|
||||
encodings: Optional[Iterable[str]] = None,
|
||||
) -> DecodedSubtitle:
|
||||
"""
|
||||
Decode subtitle bytes using a small set of common encodings.
|
||||
|
||||
Preference is given to decodings that yield detectable SRT timecodes.
|
||||
"""
|
||||
if data is None:
|
||||
return DecodedSubtitle(text="", encoding="utf-8")
|
||||
|
||||
candidates = list(encodings) if encodings else [
|
||||
"utf-8",
|
||||
"utf-8-sig",
|
||||
"utf-16",
|
||||
"utf-16-le",
|
||||
"utf-16-be",
|
||||
"gbk",
|
||||
"gb2312",
|
||||
]
|
||||
|
||||
decoded_results: list[DecodedSubtitle] = []
|
||||
for encoding in candidates:
|
||||
try:
|
||||
decoded_text = data.decode(encoding)
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
decoded_results.append(
|
||||
DecodedSubtitle(text=normalize_subtitle_text(decoded_text), encoding=encoding)
|
||||
)
|
||||
|
||||
# Fast path: if we already see timecodes, keep the first such decode.
|
||||
if has_timecodes(decoded_results[-1].text):
|
||||
return decoded_results[-1]
|
||||
|
||||
if decoded_results:
|
||||
# Fall back to the first successful decoding.
|
||||
return decoded_results[0]
|
||||
|
||||
# Last resort: replace undecodable bytes.
|
||||
return DecodedSubtitle(text=normalize_subtitle_text(data.decode("utf-8", errors="replace")), encoding="utf-8")
|
||||
|
||||
|
||||
def read_subtitle_text(file_path: str) -> DecodedSubtitle:
|
||||
"""Read subtitle file from disk, decode and normalize its text."""
|
||||
if not file_path or not str(file_path).strip():
|
||||
return DecodedSubtitle(text="", encoding="utf-8")
|
||||
|
||||
normalized_path = os.path.abspath(str(file_path))
|
||||
with open(normalized_path, "rb") as f:
|
||||
data = f.read()
|
||||
|
||||
return decode_subtitle_bytes(data)
|
||||
|
||||
107
app/services/upload_validation.py
Normal file
107
app/services/upload_validation.py
Normal file
@ -0,0 +1,107 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
|
||||
"""
|
||||
@Project: NarratoAI
|
||||
@File : upload_validation.py
|
||||
@Author : AI Assistant
|
||||
@Date : 2025/12/25
|
||||
@Desc : 统一的文件上传验证工具,用于短剧混剪和短剧解说功能
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
class InputValidationError(ValueError):
|
||||
"""当必需的用户输入(路径/内容)缺失或无效时抛出"""
|
||||
pass
|
||||
|
||||
|
||||
def ensure_existing_file(
|
||||
file_path: str,
|
||||
*,
|
||||
label: str = "文件",
|
||||
allowed_exts: Optional[Tuple[str, ...]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
验证文件路径是否存在且有效
|
||||
|
||||
Args:
|
||||
file_path: 待验证的文件路径
|
||||
label: 文件类型标签(用于错误提示)
|
||||
allowed_exts: 允许的文件扩展名元组(如 ('.srt', '.txt'))
|
||||
|
||||
Returns:
|
||||
str: 规范化后的绝对路径
|
||||
|
||||
Raises:
|
||||
InputValidationError: 文件路径无效、文件不存在或格式不支持
|
||||
"""
|
||||
if not file_path or not str(file_path).strip():
|
||||
raise InputValidationError(f"{label}不能为空,请先上传{label}")
|
||||
|
||||
normalized = os.path.abspath(str(file_path))
|
||||
|
||||
if not os.path.exists(normalized):
|
||||
raise InputValidationError(f"{label}文件不存在: {normalized}")
|
||||
|
||||
if not os.path.isfile(normalized):
|
||||
raise InputValidationError(f"{label}不是有效文件: {normalized}")
|
||||
|
||||
if allowed_exts:
|
||||
ext = os.path.splitext(normalized)[1].lower()
|
||||
allowed = tuple(e.lower() for e in allowed_exts)
|
||||
if ext not in allowed:
|
||||
raise InputValidationError(
|
||||
f"{label}格式不支持: {ext},仅支持: {', '.join(allowed_exts)}"
|
||||
)
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def resolve_subtitle_input(
|
||||
*,
|
||||
subtitle_content: Optional[str] = None,
|
||||
subtitle_file_path: Optional[str] = None,
|
||||
srt_path: Optional[str] = None,
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
解析字幕输入源,确保只有一个有效来源
|
||||
|
||||
Args:
|
||||
subtitle_content: 字幕文本内容
|
||||
subtitle_file_path: 字幕文件路径(推荐)
|
||||
srt_path: 字幕文件路径(向后兼容SDP旧参数)
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[str], Optional[str]]: (字幕内容, 字幕文件路径)
|
||||
- 返回 (content, None) 表示使用内容输入
|
||||
- 返回 (None, file_path) 表示使用文件路径输入
|
||||
|
||||
Raises:
|
||||
InputValidationError: 未提供输入或同时提供多个输入
|
||||
"""
|
||||
file_path = subtitle_file_path or srt_path
|
||||
|
||||
has_content = subtitle_content is not None and bool(str(subtitle_content).strip())
|
||||
has_file = file_path is not None and bool(str(file_path).strip())
|
||||
|
||||
if has_content and has_file:
|
||||
raise InputValidationError("只能提供字幕内容或字幕文件路径之一")
|
||||
|
||||
if not has_content and not has_file:
|
||||
raise InputValidationError("必须提供字幕内容或字幕文件路径")
|
||||
|
||||
if has_content:
|
||||
content = str(subtitle_content)
|
||||
if not content.strip():
|
||||
raise InputValidationError("字幕内容为空")
|
||||
return content, None
|
||||
|
||||
resolved_path = ensure_existing_file(
|
||||
str(file_path),
|
||||
label="字幕",
|
||||
allowed_exts=(".srt",),
|
||||
)
|
||||
return None, resolved_path
|
||||
@ -8,6 +8,7 @@ from loguru import logger
|
||||
|
||||
from app.config import config
|
||||
from app.models.schema import VideoClipParams
|
||||
from app.services.subtitle_text import decode_subtitle_bytes
|
||||
from app.utils import utils, check_script
|
||||
from webui.tools.generate_script_docu import generate_script_docu
|
||||
from webui.tools.generate_script_short import generate_script_short
|
||||
@ -190,8 +191,9 @@ def render_script_file(tr, params):
|
||||
json_data = json.loads(script_content)
|
||||
|
||||
# 保存到脚本目录
|
||||
script_file_path = os.path.join(script_dir, uploaded_file.name)
|
||||
file_name, file_extension = os.path.splitext(uploaded_file.name)
|
||||
safe_filename = os.path.basename(uploaded_file.name)
|
||||
script_file_path = os.path.join(script_dir, safe_filename)
|
||||
file_name, file_extension = os.path.splitext(safe_filename)
|
||||
|
||||
# 如果文件已存在,添加时间戳
|
||||
if os.path.exists(script_file_path):
|
||||
@ -250,8 +252,9 @@ def render_video_file(tr, params):
|
||||
)
|
||||
|
||||
if uploaded_file is not None:
|
||||
video_file_path = os.path.join(utils.video_dir(), uploaded_file.name)
|
||||
file_name, file_extension = os.path.splitext(uploaded_file.name)
|
||||
safe_filename = os.path.basename(uploaded_file.name)
|
||||
video_file_path = os.path.join(utils.video_dir(), safe_filename)
|
||||
file_name, file_extension = os.path.splitext(safe_filename)
|
||||
|
||||
if os.path.exists(video_file_path):
|
||||
timestamp = time.strftime("%Y%m%d%H%M%S")
|
||||
@ -337,18 +340,31 @@ def short_drama_summary(tr):
|
||||
st.info(f"已上传字幕: {os.path.basename(st.session_state['subtitle_path'])}")
|
||||
if st.button(tr("清除已上传字幕")):
|
||||
st.session_state['subtitle_path'] = None
|
||||
st.session_state['subtitle_content'] = None
|
||||
st.session_state['subtitle_file_processed'] = False
|
||||
st.rerun()
|
||||
|
||||
# 只有当有文件上传且尚未处理时才执行处理逻辑
|
||||
if subtitle_file is not None and not st.session_state['subtitle_file_processed']:
|
||||
try:
|
||||
# 读取上传的SRT内容
|
||||
script_content = subtitle_file.read().decode('utf-8')
|
||||
# 清理文件名,防止路径污染和路径遍历攻击
|
||||
safe_filename = os.path.basename(subtitle_file.name)
|
||||
|
||||
decoded = decode_subtitle_bytes(subtitle_file.getvalue())
|
||||
script_content = decoded.text
|
||||
detected_encoding = decoded.encoding
|
||||
|
||||
if not script_content:
|
||||
st.error(tr("无法读取字幕文件,请检查文件编码(支持 UTF-8、UTF-16、GBK、GB2312)"))
|
||||
st.stop()
|
||||
|
||||
# 验证字幕内容(简单检查)
|
||||
if len(script_content.strip()) < 10:
|
||||
st.warning(tr("字幕文件内容似乎为空,请检查文件"))
|
||||
|
||||
# 保存到字幕目录
|
||||
script_file_path = os.path.join(utils.subtitle_dir(), subtitle_file.name)
|
||||
file_name, file_extension = os.path.splitext(subtitle_file.name)
|
||||
script_file_path = os.path.join(utils.subtitle_dir(), safe_filename)
|
||||
file_name, file_extension = os.path.splitext(safe_filename)
|
||||
|
||||
# 如果文件已存在,添加时间戳
|
||||
if os.path.exists(script_file_path):
|
||||
@ -356,18 +372,23 @@ def short_drama_summary(tr):
|
||||
file_name_with_timestamp = f"{file_name}_{timestamp}"
|
||||
script_file_path = os.path.join(utils.subtitle_dir(), file_name_with_timestamp + file_extension)
|
||||
|
||||
# 直接写入SRT内容,不进行JSON转换
|
||||
# 直接写入SRT内容(统一使用 UTF-8)
|
||||
with open(script_file_path, "w", encoding='utf-8') as f:
|
||||
f.write(script_content)
|
||||
|
||||
# 更新状态
|
||||
st.success(tr("字幕上传成功"))
|
||||
st.success(
|
||||
f"{tr('字幕上传成功')} "
|
||||
f"(编码: {detected_encoding.upper()}, "
|
||||
f"大小: {len(script_content)} 字符)"
|
||||
)
|
||||
st.session_state['subtitle_path'] = script_file_path
|
||||
st.session_state['subtitle_content'] = script_content
|
||||
st.session_state['subtitle_file_processed'] = True # 标记已处理
|
||||
|
||||
|
||||
# 避免使用rerun,使用更新状态的方式
|
||||
# st.rerun()
|
||||
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"{tr('Upload failed')}: {str(e)}")
|
||||
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import asyncio
|
||||
import traceback
|
||||
import requests
|
||||
import streamlit as st
|
||||
from loguru import logger
|
||||
|
||||
from app.config import config
|
||||
from app.services.upload_validation import ensure_existing_file, InputValidationError
|
||||
from app.utils import utils
|
||||
|
||||
|
||||
def generate_script_short(tr, params, custom_clips=5):
|
||||
@ -31,12 +31,47 @@ def generate_script_short(tr, params, custom_clips=5):
|
||||
|
||||
try:
|
||||
with st.spinner("正在生成脚本..."):
|
||||
# ========== 严格验证:必须上传视频和字幕(与短剧解说保持一致)==========
|
||||
# 1. 验证视频文件
|
||||
video_path = getattr(params, "video_origin_path", None)
|
||||
if not video_path or not str(video_path).strip():
|
||||
st.error("请先选择视频文件")
|
||||
st.stop()
|
||||
|
||||
try:
|
||||
ensure_existing_file(
|
||||
str(video_path),
|
||||
label="视频",
|
||||
allowed_exts=(".mp4", ".mov", ".avi", ".flv", ".mkv"),
|
||||
)
|
||||
except InputValidationError as e:
|
||||
st.error(str(e))
|
||||
st.stop()
|
||||
|
||||
# 2. 验证字幕文件(移除推断逻辑,必须上传)
|
||||
subtitle_path = st.session_state.get("subtitle_path")
|
||||
if not subtitle_path or not str(subtitle_path).strip():
|
||||
st.error("请先上传字幕文件")
|
||||
st.stop()
|
||||
|
||||
try:
|
||||
subtitle_path = ensure_existing_file(
|
||||
str(subtitle_path),
|
||||
label="字幕",
|
||||
allowed_exts=(".srt",),
|
||||
)
|
||||
except InputValidationError as e:
|
||||
st.error(str(e))
|
||||
st.stop()
|
||||
|
||||
logger.info(f"使用用户上传的字幕文件: {subtitle_path}")
|
||||
|
||||
# ========== 获取 LLM 配置 ==========
|
||||
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
|
||||
text_api_key = config.app.get(f'text_{text_provider}_api_key')
|
||||
text_model = config.app.get(f'text_{text_provider}_model_name')
|
||||
text_base_url = config.app.get(f'text_{text_provider}_base_url')
|
||||
|
||||
# 优先从 session_state 获取,若未设置则回退到 config 配置
|
||||
|
||||
vision_llm_provider = st.session_state.get('vision_llm_providers') or config.app.get('vision_llm_provider', 'gemini')
|
||||
vision_llm_provider = vision_llm_provider.lower()
|
||||
vision_api_key = st.session_state.get(f'vision_{vision_llm_provider}_api_key') or config.app.get(f'vision_{vision_llm_provider}_api_key', "")
|
||||
@ -45,48 +80,40 @@ def generate_script_short(tr, params, custom_clips=5):
|
||||
|
||||
update_progress(20, "开始准备生成脚本")
|
||||
|
||||
# 优先使用用户上传的字幕文件
|
||||
uploaded_subtitle = st.session_state.get('subtitle_path')
|
||||
if uploaded_subtitle and os.path.exists(uploaded_subtitle):
|
||||
srt_path = uploaded_subtitle
|
||||
logger.info(f"使用用户上传的字幕文件: {srt_path}")
|
||||
else:
|
||||
# 回退到根据视频路径自动推断
|
||||
srt_path = params.video_origin_path.replace(".mp4", ".srt").replace("videos", "srt").replace("video", "subtitle")
|
||||
if not os.path.exists(srt_path):
|
||||
logger.error(f"{srt_path} 文件不存在请检查或重新转录")
|
||||
st.error(f"{srt_path} 文件不存在,请上传字幕文件或重新转录")
|
||||
st.stop()
|
||||
# ========== 调用后端生成脚本 ==========
|
||||
from app.services.SDP.generate_script_short import generate_script_result
|
||||
|
||||
api_params = {
|
||||
"vision_provider": vision_llm_provider,
|
||||
"vision_api_key": vision_api_key,
|
||||
"vision_model_name": vision_model,
|
||||
"vision_base_url": vision_base_url or "",
|
||||
"text_provider": text_provider,
|
||||
"text_api_key": text_api_key,
|
||||
"text_model_name": text_model,
|
||||
"text_base_url": text_base_url or ""
|
||||
}
|
||||
from app.services.SDP.generate_script_short import generate_script
|
||||
script = generate_script(
|
||||
srt_path=srt_path,
|
||||
output_path="resource/scripts/merged_subtitle.json",
|
||||
api_key=text_api_key,
|
||||
model_name=text_model,
|
||||
base_url=text_base_url,
|
||||
custom_clips=custom_clips,
|
||||
provider=text_provider
|
||||
output_path = os.path.join(utils.script_dir(), "merged_subtitle.json")
|
||||
|
||||
subtitle_content = st.session_state.get("subtitle_content")
|
||||
subtitle_kwargs = (
|
||||
{"subtitle_content": str(subtitle_content)}
|
||||
if subtitle_content is not None and str(subtitle_content).strip()
|
||||
else {"subtitle_file_path": subtitle_path}
|
||||
)
|
||||
|
||||
if script is None:
|
||||
st.error("生成脚本失败,请检查日志")
|
||||
result = generate_script_result(
|
||||
api_key=text_api_key,
|
||||
model_name=text_model,
|
||||
output_path=output_path,
|
||||
base_url=text_base_url,
|
||||
custom_clips=custom_clips,
|
||||
provider=text_provider,
|
||||
**subtitle_kwargs,
|
||||
)
|
||||
|
||||
if result.get("status") != "success":
|
||||
st.error(result.get("message", "生成脚本失败,请检查日志"))
|
||||
st.stop()
|
||||
|
||||
script = result.get("script")
|
||||
logger.info(f"脚本生成完成 {json.dumps(script, ensure_ascii=False, indent=4)}")
|
||||
|
||||
if isinstance(script, list):
|
||||
st.session_state['video_clip_json'] = script
|
||||
elif isinstance(script, str):
|
||||
st.session_state['video_clip_json'] = json.loads(script)
|
||||
|
||||
update_progress(80, "脚本生成完成")
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
@ -16,6 +16,7 @@ from loguru import logger
|
||||
|
||||
from app.config import config
|
||||
from app.services.SDE.short_drama_explanation import analyze_subtitle, generate_narration_script
|
||||
from app.services.subtitle_text import read_subtitle_text
|
||||
# 导入新的LLM服务模块 - 确保提供商被注册
|
||||
import app.services.llm # 这会触发提供商注册
|
||||
from app.services.llm.migration_adapter import SubtitleAnalyzerAdapter
|
||||
@ -173,8 +174,10 @@ def generate_script_short_sunmmary(params, subtitle_path, video_theme, temperatu
|
||||
text_base_url = config.app.get(f'text_{text_provider}_base_url')
|
||||
|
||||
# 读取字幕文件内容(无论使用哪种实现都需要)
|
||||
with open(subtitle_path, 'r', encoding='utf-8') as f:
|
||||
subtitle_content = f.read()
|
||||
subtitle_content = read_subtitle_text(subtitle_path).text
|
||||
if not subtitle_content:
|
||||
st.error("字幕文件内容为空或无法读取")
|
||||
return
|
||||
|
||||
try:
|
||||
# 优先使用新的LLM服务架构
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user