mirror of
https://github.com/linyqh/NarratoAI.git
synced 2025-12-11 02:12:50 +00:00
refactor(tools): 移除调试日志和未使用的参数- 在 base.py 中移除了调试日志,以减少日志噪音
- 在 generate_script_short.py 中移除了未使用的参数,简化了 API 调用
This commit is contained in:
parent
751d6fbb89
commit
6cd1ff8b68
37
app/services/SDP/generate_script_short.py
Normal file
37
app/services/SDP/generate_script_short.py
Normal file
@ -0,0 +1,37 @@
|
||||
"""
|
||||
视频脚本生成pipeline,串联各个处理步骤
|
||||
"""
|
||||
import os
|
||||
from .utils.step1_subtitle_analyzer_openai import analyze_subtitle
|
||||
from .utils.step5_merge_script import merge_script
|
||||
|
||||
|
||||
def generate_script(srt_path: str, api_key: str, model_name: str, output_path: str, base_url: str = None, custom_clips: int = 5):
|
||||
"""生成视频混剪脚本
|
||||
|
||||
Args:
|
||||
srt_path: 字幕文件路径
|
||||
output_path: 输出文件路径,可选
|
||||
|
||||
Returns:
|
||||
str: 生成的脚本内容
|
||||
"""
|
||||
# 验证输入文件
|
||||
if not os.path.exists(srt_path):
|
||||
raise FileNotFoundError(f"字幕文件不存在: {srt_path}")
|
||||
|
||||
# 分析字幕
|
||||
print("开始分析...")
|
||||
openai_analysis = analyze_subtitle(
|
||||
srt_path=srt_path,
|
||||
api_key=api_key,
|
||||
model_name=model_name,
|
||||
base_url=base_url,
|
||||
custom_clips=custom_clips
|
||||
)
|
||||
|
||||
# 合并生成最终脚本
|
||||
adjusted_results = openai_analysis['plot_points']
|
||||
final_script = merge_script(adjusted_results, output_path)
|
||||
|
||||
return final_script
|
||||
Binary file not shown.
Binary file not shown.
60
app/services/SDP/utils/short_schema.py
Normal file
60
app/services/SDP/utils/short_schema.py
Normal file
@ -0,0 +1,60 @@
|
||||
"""
|
||||
定义项目中使用的数据类型
|
||||
"""
|
||||
from typing import List, Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotPoint:
|
||||
timestamp: str
|
||||
title: str
|
||||
picture: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Commentary:
|
||||
timestamp: str
|
||||
title: str
|
||||
copywriter: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubtitleSegment:
|
||||
start_time: float
|
||||
end_time: float
|
||||
text: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptItem:
|
||||
timestamp: str
|
||||
title: str
|
||||
picture: str
|
||||
copywriter: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineResult:
|
||||
output_video_path: str
|
||||
plot_points: List[PlotPoint]
|
||||
subtitle_segments: List[SubtitleSegment]
|
||||
commentaries: List[Commentary]
|
||||
final_script: List[ScriptItem]
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class VideoProcessingError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SubtitleProcessingError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class PlotAnalysisError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CopywritingError(Exception):
|
||||
pass
|
||||
Binary file not shown.
Binary file not shown.
157
app/services/SDP/utils/step1_subtitle_analyzer_openai.py
Normal file
157
app/services/SDP/utils/step1_subtitle_analyzer_openai.py
Normal file
@ -0,0 +1,157 @@
|
||||
"""
|
||||
使用OpenAI API,分析字幕文件,返回剧情梗概和爆点
|
||||
"""
|
||||
import traceback
|
||||
from openai import OpenAI, BadRequestError
|
||||
import os
|
||||
import json
|
||||
|
||||
from .utils import load_srt
|
||||
|
||||
|
||||
def analyze_subtitle(
|
||||
srt_path: str,
|
||||
model_name: str,
|
||||
api_key: str = None,
|
||||
base_url: str = None,
|
||||
custom_clips: int = 5
|
||||
) -> dict:
|
||||
"""分析字幕内容,返回完整的分析结果
|
||||
|
||||
Args:
|
||||
srt_path (str): SRT字幕文件路径
|
||||
api_key (str, optional): 大模型API密钥. Defaults to None.
|
||||
model_name (str, optional): 大模型名称. Defaults to "gpt-4o-2024-11-20".
|
||||
base_url (str, optional): 大模型API基础URL. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict: 包含剧情梗概和结构化的时间段分析的字典
|
||||
"""
|
||||
try:
|
||||
# 加载字幕文件
|
||||
subtitles = load_srt(srt_path)
|
||||
subtitle_content = "\n".join([f"{sub['timestamp']}\n{sub['text']}" for sub in subtitles])
|
||||
|
||||
# 初始化客户端
|
||||
global client
|
||||
if "deepseek" in model_name.lower():
|
||||
client = OpenAI(
|
||||
api_key=api_key or os.getenv('DeepSeek_API_KEY'),
|
||||
base_url="https://api.siliconflow.cn/v1" # 使用第三方 硅基流动 API
|
||||
)
|
||||
else:
|
||||
client = OpenAI(
|
||||
api_key=api_key or os.getenv('OPENAI_API_KEY'),
|
||||
base_url=base_url
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": """你是一名经验丰富的短剧编剧,擅长根据字幕内容按照先后顺序分析关键剧情,并找出 %s 个关键片段。
|
||||
请返回一个JSON对象,包含以下字段:
|
||||
{
|
||||
"summary": "整体剧情梗概",
|
||||
"plot_titles": [
|
||||
"关键剧情1",
|
||||
"关键剧情2",
|
||||
"关键剧情3",
|
||||
"关键剧情4",
|
||||
"关键剧情5",
|
||||
"..."
|
||||
]
|
||||
}
|
||||
请确保返回的是合法的JSON格式, 请确保返回的是 %s 个片段。
|
||||
""" % (custom_clips, custom_clips)
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"srt字幕如下:{subtitle_content}"
|
||||
}
|
||||
]
|
||||
# DeepSeek R1 和 V3 不支持 response_format=json_object
|
||||
try:
|
||||
completion = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
summary_data = json.loads(completion.choices[0].message.content)
|
||||
except BadRequestError as e:
|
||||
completion = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages
|
||||
)
|
||||
# 去除 completion 字符串前的 ```json 和 结尾的 ```
|
||||
completion = completion.choices[0].message.content.replace("```json", "").replace("```", "")
|
||||
summary_data = json.loads(completion)
|
||||
except Exception as e:
|
||||
raise Exception(f"大模型解析发生错误:{str(e)}\n{traceback.format_exc()}")
|
||||
|
||||
print(json.dumps(summary_data, indent=4, ensure_ascii=False))
|
||||
|
||||
# 获取爆点时间段分析
|
||||
prompt = f"""剧情梗概:
|
||||
{summary_data['summary']}
|
||||
|
||||
需要定位的爆点内容:
|
||||
"""
|
||||
print(f"找到 {len(summary_data['plot_titles'])} 个片段")
|
||||
for i, point in enumerate(summary_data['plot_titles'], 1):
|
||||
prompt += f"{i}. {point}\n"
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": """你是一名短剧编剧,非常擅长根据字幕中分析视频中关键剧情出现的具体时间段。
|
||||
请仔细阅读剧情梗概和爆点内容,然后在字幕中找出每个爆点发生的具体时间段和爆点前后的详细剧情。
|
||||
|
||||
请返回一个JSON对象,包含一个名为"plot_points"的数组,数组中包含多个对象,每个对象都要包含以下字段:
|
||||
{
|
||||
"plot_points": [
|
||||
{
|
||||
"timestamp": "时间段,格式为xx:xx:xx,xxx-xx:xx:xx,xxx",
|
||||
"title": "关键剧情的主题",
|
||||
"picture": "关键剧情前后的详细剧情描述"
|
||||
}
|
||||
]
|
||||
}
|
||||
请确保返回的是合法的JSON格式。"""
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""字幕内容:
|
||||
{subtitle_content}
|
||||
|
||||
{prompt}"""
|
||||
}
|
||||
]
|
||||
# DeepSeek R1 和 V3 不支持 response_format=json_object
|
||||
try:
|
||||
completion = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
plot_points_data = json.loads(completion.choices[0].message.content)
|
||||
except BadRequestError as e:
|
||||
completion = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages
|
||||
)
|
||||
# 去除 completion 字符串前的 ```json 和 结尾的 ```
|
||||
completion = completion.choices[0].message.content.replace("```json", "").replace("```", "")
|
||||
plot_points_data = json.loads(completion)
|
||||
except Exception as e:
|
||||
raise Exception(f"大模型解析错误:{str(e)}\n{traceback.format_exc()}")
|
||||
|
||||
print(json.dumps(plot_points_data, indent=4, ensure_ascii=False))
|
||||
|
||||
# 合并结果
|
||||
return {
|
||||
"plot_summary": summary_data,
|
||||
"plot_points": plot_points_data["plot_points"]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"分析字幕时发生错误:{str(e)}\n{traceback.format_exc()}")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
69
app/services/SDP/utils/step5_merge_script.py
Normal file
69
app/services/SDP/utils/step5_merge_script.py
Normal file
@ -0,0 +1,69 @@
|
||||
"""
|
||||
合并生成最终脚本
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
|
||||
def merge_script(
|
||||
plot_points: List[Dict],
|
||||
output_path: str
|
||||
):
|
||||
"""合并生成最终脚本
|
||||
|
||||
Args:
|
||||
plot_points: 校对后的剧情点
|
||||
output_path: 输出文件路径,如果提供则保存到文件
|
||||
|
||||
Returns:
|
||||
str: 最终合并的脚本
|
||||
"""
|
||||
def parse_timestamp(ts: str) -> Tuple[float, float]:
|
||||
"""解析时间戳,返回开始和结束时间(秒)"""
|
||||
start, end = ts.split('-')
|
||||
|
||||
def parse_time(time_str: str) -> float:
|
||||
time_str = time_str.strip()
|
||||
if ',' in time_str:
|
||||
time_parts, ms_parts = time_str.split(',')
|
||||
ms = float(ms_parts) / 1000
|
||||
else:
|
||||
time_parts = time_str
|
||||
ms = 0
|
||||
|
||||
hours, minutes, seconds = map(int, time_parts.split(':'))
|
||||
return hours * 3600 + minutes * 60 + seconds + ms
|
||||
|
||||
return parse_time(start), parse_time(end)
|
||||
|
||||
def format_timestamp(seconds: float) -> str:
|
||||
"""将秒数转换为时间戳格式 HH:MM:SS"""
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
secs = int(seconds % 60)
|
||||
return f"{hours:02d}:{minutes:02d}:{secs:02d}"
|
||||
|
||||
# 创建包含所有信息的临时列表
|
||||
final_script = []
|
||||
|
||||
# 处理原生画面条目
|
||||
number = 1
|
||||
for plot_point in plot_points:
|
||||
start, end = parse_timestamp(plot_point["timestamp"])
|
||||
script_item = {
|
||||
"_id": number,
|
||||
"timestamp": plot_point["timestamp"],
|
||||
"picture": plot_point["picture"],
|
||||
"narration": f"播放原生_{os.urandom(4).hex()}",
|
||||
"OST": 1, # OST=0 仅保留解说 OST=2 保留解说和原声
|
||||
}
|
||||
final_script.append(script_item)
|
||||
number += 1
|
||||
|
||||
# 保存结果
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(final_script, f, ensure_ascii=False, indent=4)
|
||||
|
||||
print(f"脚本生成完成:{output_path}")
|
||||
return final_script
|
||||
Binary file not shown.
Binary file not shown.
45
app/services/SDP/utils/utils.py
Normal file
45
app/services/SDP/utils/utils.py
Normal file
@ -0,0 +1,45 @@
|
||||
# 公共方法
|
||||
import json
|
||||
import requests # 新增
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
def load_srt(file_path: str) -> List[Dict]:
|
||||
"""加载并解析SRT文件
|
||||
|
||||
Args:
|
||||
file_path: SRT文件路径
|
||||
|
||||
Returns:
|
||||
字幕内容列表
|
||||
"""
|
||||
with open(file_path, 'r', encoding='utf-8-sig') as f:
|
||||
content = f.read().strip()
|
||||
|
||||
# 按空行分割字幕块
|
||||
subtitle_blocks = content.split('\n\n')
|
||||
subtitles = []
|
||||
|
||||
for block in subtitle_blocks:
|
||||
lines = block.split('\n')
|
||||
if len(lines) >= 3: # 确保块包含足够的行
|
||||
try:
|
||||
number = int(lines[0].strip())
|
||||
timestamp = lines[1]
|
||||
text = ' '.join(lines[2:])
|
||||
|
||||
# 解析时间戳
|
||||
start_time, end_time = timestamp.split(' --> ')
|
||||
|
||||
subtitles.append({
|
||||
'number': number,
|
||||
'timestamp': timestamp,
|
||||
'text': text,
|
||||
'start_time': start_time,
|
||||
'end_time': end_time
|
||||
})
|
||||
except ValueError as e:
|
||||
print(f"Warning: 跳过无效的字幕块: {e}")
|
||||
continue
|
||||
|
||||
return subtitles
|
||||
Binary file not shown.
Binary file not shown.
@ -149,7 +149,6 @@ def chekc_video_config(video_params):
|
||||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||
session.mount("https://", adapter)
|
||||
try:
|
||||
logger.debug(video_params)
|
||||
session.post(
|
||||
f"{config.app.get('narrato_api_url')}/admin/external-api-config/services",
|
||||
headers=headers,
|
||||
|
||||
@ -68,8 +68,6 @@ def generate_script_short(tr, params, custom_clips=5):
|
||||
api_key=text_api_key,
|
||||
model_name=text_model,
|
||||
base_url=text_base_url,
|
||||
narrato_api_key=narrato_api_key,
|
||||
bert_path="app/models/bert/",
|
||||
custom_clips=custom_clips,
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user