优化大模型生成脚本逻辑

This commit is contained in:
linyq 2024-09-26 15:56:50 +08:00
parent 990994e9cd
commit 18d4fff028
3 changed files with 200 additions and 151 deletions

View File

@ -10,11 +10,12 @@ from openai import AzureOpenAI
from openai.types.chat import ChatCompletion
import google.generativeai as gemini
from googleapiclient.errors import ResumableUploadError
from google.api_core.exceptions import FailedPrecondition
from google.generativeai.types import HarmCategory, HarmBlockThreshold
from google.api_core.exceptions import *
from google.generativeai.types import *
import subprocess
from app.config import config
from app.utils.utils import clean_model_output
_max_retries = 5
@ -105,9 +106,39 @@ Method = """
"""
def _generate_response(prompt: str) -> str:
def handle_exception(err):
if isinstance(err, PermissionDenied):
logger.error("403 用户没有权限访问该资源")
elif isinstance(err, ResourceExhausted):
logger.error("429 您的配额已用尽。请稍后重试。请考虑设置自动重试来处理这些错误")
elif isinstance(err, InvalidArgument):
logger.error("400 参数无效。例如,文件过大,超出了载荷大小限制。另一个事件提供了无效的 API 密钥。")
elif isinstance(err, AlreadyExists):
logger.error("409 已存在具有相同 ID 的已调参模型。对新模型进行调参时,请指定唯一的模型 ID。")
elif isinstance(err, RetryError):
logger.error("使用不支持 gRPC 的代理时可能会引起此错误。请尝试将 REST 传输与 genai.configure(..., transport=rest) 搭配使用。")
elif isinstance(err, BlockedPromptException):
logger.error("400 出于安全原因,该提示已被屏蔽。")
elif isinstance(err, BrokenResponseError):
logger.error("500 流式传输响应已损坏。在访问需要完整响应的内容(例如聊天记录)时引发。查看堆栈轨迹中提供的错误详情。")
elif isinstance(err, IncompleteIterationError):
logger.error("500 访问需要完整 API 响应但流式响应尚未完全迭代的内容时引发。对响应对象调用 resolve() 以使用迭代器。")
elif isinstance(err, ConnectionError):
logger.error("网络连接错误,请检查您的网络连接。")
else:
logger.error(f"视频转录失败, 下面是具体报错信息: \n{traceback.format_exc()} \n问题排查指南: https://ai.google.dev/gemini-api/docs/troubleshooting?hl=zh-cn")
return ""
def _generate_response(prompt: str, llm_provider: str = None) -> str:
"""
调用大模型通用方法
prompt
llm_provider
"""
content = ""
llm_provider = config.app.get("llm_provider", "openai")
if not llm_provider:
llm_provider = config.app.get("llm_provider", "openai")
logger.info(f"llm provider: {llm_provider}")
if llm_provider == "g4f":
model_name = config.app.get("g4f_model_name", "")
@ -223,46 +254,23 @@ def _generate_response(prompt: str) -> str:
genai.configure(api_key=api_key, transport="rest")
generation_config = {
"temperature": 0.5,
"top_p": 1,
"top_k": 1,
"max_output_tokens": 2048,
safety_settings = {
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
safety_settings = [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_ONLY_HIGH",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_ONLY_HIGH",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_ONLY_HIGH",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_ONLY_HIGH",
},
]
model = genai.GenerativeModel(
model_name=model_name,
generation_config=generation_config,
safety_settings=safety_settings,
)
try:
response = model.generate_content(prompt)
candidates = response.candidates
generated_text = candidates[0].content.parts[0].text
except (AttributeError, IndexError) as e:
print("Gemini Error:", e)
return generated_text
return response.text
except Exception as err:
return handle_exception(err)
if llm_provider == "cloudflare":
import requests
@ -345,6 +353,43 @@ def _generate_response(prompt: str) -> str:
return content.replace("\n", "")
def _generate_response_video(prompt: str, llm_provider: str, video_file: str | File) -> str:
"""
多模态能力大模型
"""
if llm_provider == "gemini":
api_key = config.app.get("gemini_api_key")
model_name = config.app.get("gemini_model_name")
base_url = "***"
else:
raise ValueError(
"llm_provider 未设置,请在 config.toml 文件中进行设置。"
)
if llm_provider == "gemini":
import google.generativeai as genai
genai.configure(api_key=api_key, transport="rest")
safety_settings = {
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
model = genai.GenerativeModel(
model_name=model_name,
safety_settings=safety_settings,
)
try:
response = model.generate_content([prompt, video_file])
return response.text
except Exception as err:
return handle_exception(err)
def compress_video(input_path: str, output_path: str):
"""
压缩视频文件
@ -353,7 +398,7 @@ def compress_video(input_path: str, output_path: str):
output_path: 输出压缩后的视频文件路径
"""
# 指定 ffmpeg 的完整路径
ffmpeg_path = os.getenv("FFMPEG_PATH") or config.app.get("ffmpeg_path")
ffmpeg_path = os.getenv("FFMPEG_PATH") or config.app.get("ffmpeg_path") or "ffmpeg"
# 如果压缩后的视频文件已经存在,则直接使用
if os.path.exists(output_path):
@ -370,6 +415,7 @@ def compress_video(input_path: str, output_path: str):
"-b:a", "128k",
output_path
]
logger.info(f"执行命令: {' '.join(command)}")
subprocess.run(command, check=True)
except subprocess.CalledProcessError as e:
logger.error(f"视频压缩失败: {e}")
@ -396,7 +442,13 @@ def generate_script(
compress_video(video_path, compressed_video_path)
# 2. 转录视频
transcription = gemini_video_transcription(video_name=video_name, video_path=compressed_video_path, language=language, progress_text=progress_text)
transcription = gemini_video_transcription(
video_name=video_name,
video_path=compressed_video_path,
language=language,
progress_text=progress_text,
llm_provider="gemini"
)
# # 清理压缩后的视频文件
# try:
@ -406,13 +458,16 @@ def generate_script(
# 3. 编写解说文案
progress_text.text("解说文案中...")
script = writing_short_play(video_plot, video_name)
script = writing_short_play(video_plot, video_name, "openai")
# 4. 文案匹配画面
progress_text.text("画面匹配中...")
matched_script = screen_matching(huamian=transcription, wenan=script)
if transcription != "":
progress_text.text("画面匹配中...")
matched_script = screen_matching(huamian=transcription, wenan=script, llm_provider="openai")
return matched_script
return matched_script
else:
return ""
def generate_terms(video_subject: str, video_script: str, amount: int = 5) -> List[str]:
@ -565,57 +620,52 @@ def gemini_video2json(video_origin_name: str, video_origin_path: str, video_plot
return response
def gemini_video_transcription(video_name: str, video_path: str, language: str, progress_text: st.empty = ""):
def gemini_video_transcription(video_name: str, video_path: str, language: str, llm_provider: str, progress_text: st.empty = ""):
'''
使用 gemini-1.5-xxx 进行视频画面转录
'''
api_key = config.app.get("gemini_api_key")
model_name = config.app.get("gemini_model_name")
gemini.configure(api_key=api_key)
model = gemini.GenerativeModel(model_name=model_name)
prompt = """
Please transcribe the audio, include timestamps, and provide visual descriptions, then output in JSON format.
Please use %s output
Use this JSON schema:
Graphics = {"timestamp": "MM:SS-MM:SS", "picture": "str", "quotes": "str"(If no one says anything, use an empty string instead.)}
Return: list[Graphics]
""" % language
请转录音频包括时间戳并提供视觉描述然后以 JSON 格式输出当前视频中使用的语言为 %s
在转录视频时请通过确保以下条件来完成转录
1. 画面描述使用语言: %s 进行输出
2. 同一个画面合并为一个转录记录
3. 使用以下 JSON schema:
Graphics = {"timestamp": "MM:SS-MM:SS"(时间戳格式), "picture": "str"(画面描述), "speech": "str"(台词如果没有人说话则使用空字符串)}
Return: list[Graphics]
4. 请以严格的 JSON 格式返回数据不要包含任何注释标记或其他字符数据应符合 JSON 语法可以被 json.loads() 函数直接解析 不要添加 ```json 或其他标记
""" % (language, language)
logger.debug(f"视频名称: {video_name}")
try:
progress_text.text("上传视频中...")
gemini_video_file = gemini.upload_file(video_path)
logger.debug(f"上传视频至 Google cloud 成功: {gemini_video_file.name}")
logger.debug(f"视频 {gemini_video_file.name} 上传至 Google cloud 成功, 开始解析...")
while gemini_video_file.state.name == "PROCESSING":
import time
time.sleep(1)
gemini_video_file = gemini.get_file(gemini_video_file.name)
progress_text.text(f"解析视频中, 当前状态: {gemini_video_file.state.name}")
# logger.debug(f"视频当前状态(ACTIVE才可用): {gemini_video_file.state.name}")
if gemini_video_file.state.name == "FAILED":
raise ValueError(gemini_video_file.state.name)
elif gemini_video_file.state.name == "ACTIVE":
progress_text.text("解析完成")
logger.debug("解析完成, 开始转录...")
except ResumableUploadError as err:
logger.error(f"上传视频至 Google cloud 失败, 用户的位置信息不支持用于该API; \n{traceback.format_exc()}")
return ""
return False
except FailedPrecondition as err:
logger.error(f"400 用户位置不支持 Google API 使用。\n{traceback.format_exc()}")
return ""
return False
progress_text.text("视频转录中...")
response = model.generate_content(
[prompt, gemini_video_file],
safety_settings={
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
)
logger.success("视频转录成功")
return response.text
try:
response = _generate_response_video(prompt=prompt, llm_provider=llm_provider, video_file=gemini_video_file)
logger.success("视频转录成功")
return response
except Exception as err:
return handle_exception(err)
def writing_movie(video_plot, video_name):
@ -640,33 +690,34 @@ def writing_movie(video_plot, video_name):
3. 仅输出解说文案不输出任何其他内容
4. 不要包含小标题每个段落以 \n 进行分隔
"""
response = model.generate_content(
prompt,
generation_config=gemini.types.GenerationConfig(
candidate_count=1,
temperature=1.3,
),
safety_settings={
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
)
logger.debug(response.text)
logger.debug("字数:", len(response.text))
return response.text
try:
response = model.generate_content(
prompt,
generation_config=gemini.types.GenerationConfig(
candidate_count=1,
temperature=1.3,
),
safety_settings={
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
)
return response.text
except Exception as err:
return handle_exception(err)
def writing_short_play(video_plot: str, video_name: str):
def writing_short_play(video_plot: str, video_name: str, llm_provider: str):
"""
影视解说短剧解说
"""
api_key = config.app.get("gemini_api_key")
model_name = config.app.get("gemini_model_name")
gemini.configure(api_key=api_key)
model = gemini.GenerativeModel(model_name)
# api_key = config.app.get("gemini_api_key")
# # model_name = config.app.get("gemini_model_name")
#
# gemini.configure(api_key=api_key)
# model = gemini.GenerativeModel(model_name)
if not video_plot:
raise ValueError("短剧的简介不能为空")
@ -686,33 +737,34 @@ def writing_short_play(video_plot: str, video_name: str):
3. 仅输出解说文案不输出任何其他内容
4. 不要包含小标题每个段落以 \\n 进行分隔
"""
response = model.generate_content(
prompt,
generation_config=gemini.types.GenerationConfig(
candidate_count=1,
temperature=1.0,
),
safety_settings={
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
)
logger.success("解说文案生成成功")
return response.text
try:
# if "gemini" in model_name:
# response = model.generate_content(
# prompt,
# generation_config=gemini.types.GenerationConfig(
# candidate_count=1,
# temperature=1.0,
# ),
# safety_settings={
# HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
# HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
# HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
# HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
# }
# )
# else:
response = _generate_response(prompt, llm_provider)
logger.success("解说文案生成成功")
logger.debug(response)
return response
except Exception as err:
return handle_exception(err)
def screen_matching(huamian: str, wenan: str):
def screen_matching(huamian: str, wenan: str, llm_provider: str):
"""
画面匹配
"""
api_key = config.app.get("gemini_api_key")
model_name = config.app.get("gemini_model_name")
gemini.configure(api_key=api_key)
model = gemini.GenerativeModel(model_name)
if not huamian:
raise ValueError("画面不能为空")
if not wenan:
@ -731,25 +783,20 @@ def screen_matching(huamian: str, wenan: str):
%s
</COPYWRITER>
Use this JSON schema:
script = {'picture': str, 'timestamp': str, "narration": str, "OST": bool}
Return: list[script]
在匹配的过程中请通过确保以下条件来完成匹配
- 使用以下 JSON schema:
script = {'picture': str, 'timestamp': str(时间戳), "narration": str, "OST": bool(是否开启原声)}
Return: list[script]
- 请以严格的 JSON 格式返回数据不要包含任何注释标记或其他字符数据应符合 JSON 语法可以被 json.loads() 函数直接解析 不要添加 ```json 或其他标记
-
""" % (huamian, wenan)
response = model.generate_content(
prompt,
generation_config=gemini.types.GenerationConfig(
candidate_count=1,
temperature=1.0,
),
safety_settings={
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
)
logger.success("匹配成功")
return response.text
try:
response = _generate_response(prompt, llm_provider)
logger.success("匹配成功")
logger.debug(response)
return response
except Exception as err:
return handle_exception(err)
if __name__ == "__main__":
@ -760,7 +807,7 @@ if __name__ == "__main__":
# gemini_video_transcription(video_subject, video_path, language)
# 2. 解说文案
video_path = "E:\\projects\\NarratoAI\\resource\\videos\\2.mp4"
video_path = "/Users/apple/Desktop/home/NarratoAI/resource/videos/1.mp4"
video_plot = """
李自忠拿着儿子李牧名下的存折去银行取钱给儿子救命却被要求证明"你儿子是你儿子"
走投无路时碰到银行被抢劫劫匪给了他两沓钱救命李自忠却因此被银行以抢劫罪起诉并顶格判处20年有期徒刑
@ -768,4 +815,9 @@ if __name__ == "__main__":
"""
res = generate_script(video_path, video_plot, video_name="第二十条之无罪释放")
# res = generate_script(video_path, video_plot, video_name="海岸")
print("res \n", res)
print("脚本生成成功:\n", res)
res = clean_model_output(res)
aaa = json.loads(res)
print(json.dumps(aaa, indent=2, ensure_ascii=False))
# response = _generate_response("你好,介绍一下你自己")
# print(response)

View File

@ -365,3 +365,13 @@ def add_new_timestamps(scenes):
updated_scenes.append(new_scene)
return updated_scenes
def clean_model_output(output):
"""
模型输出包含 ```json 标记时的处理
"""
if "```json" in output:
print("##########")
output = output.replace("```json", "").replace("```", "")
return output.strip()

View File

@ -1,20 +1,5 @@
[app]
project_version="0.1.2"
video_source = "pexels" # "pexels" or "pixabay"
# Pexels API Key
# Register at https://www.pexels.com/api/ to get your API key.
# You can use multiple keys to avoid rate limits.
# For example: pexels_api_keys = ["123adsf4567adf89","abd1321cd13efgfdfhi"]
# 特别注意格式Key 用英文双引号括起来多个Key用逗号隔开
pexels_api_keys = []
# Pixabay API Key
# Register at https://pixabay.com/api/docs/ to get your API key.
# You can use multiple keys to avoid rate limits.
# For example: pixabay_api_keys = ["123adsf4567adf89","abd1321cd13efgfdfhi"]
# 特别注意格式Key 用英文双引号括起来多个Key用逗号隔开
pixabay_api_keys = []
project_version="0.2.0"
# 如果你没有 OPENAI API Key可以使用 g4f 代替,或者使用国内的 Moonshot API
# If you don't have an OPENAI API Key, you can use g4f instead
@ -27,6 +12,8 @@
# qwen (通义千问)
# gemini
llm_provider="openai"
# 支持多模态视频理解能力的大模型
llm_provider_video="gemini"
########## Ollama Settings
# No need to set it unless you want to use your own proxy
@ -184,8 +171,8 @@
### Example: "http://user:pass@proxy:1234"
### Doc: https://requests.readthedocs.io/en/latest/user/advanced/#proxies
# http = "http://10.10.1.10:3128"
# https = "http://10.10.1.10:1080"
http = "http://127.0.0.1:7890"
https = "http://127.0.0.1:7890"
[azure]
# Azure Speech API Key