mirror of
https://github.com/linyqh/NarratoAI.git
synced 2025-12-10 18:02:51 +00:00
feat(app): 新增脚本生成 V2 接口并重构相关功能
- 新增 V2脚本生成接口和相关服务 - 重构脚本生成逻辑,提高可维护性和可扩展性 - 优化关键帧提取和处理流程 - 改进错误处理和日志记录
This commit is contained in:
parent
4621a6729a
commit
8dd4b27fc3
@ -163,109 +163,109 @@ def delete_video(request: Request, task_id: str = Path(..., description="Task ID
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/musics", response_model=BgmRetrieveResponse, summary="Retrieve local BGM files"
|
||||
)
|
||||
def get_bgm_list(request: Request):
|
||||
suffix = "*.mp3"
|
||||
song_dir = utils.song_dir()
|
||||
files = glob.glob(os.path.join(song_dir, suffix))
|
||||
bgm_list = []
|
||||
for file in files:
|
||||
bgm_list.append(
|
||||
{
|
||||
"name": os.path.basename(file),
|
||||
"size": os.path.getsize(file),
|
||||
"file": file,
|
||||
}
|
||||
)
|
||||
response = {"files": bgm_list}
|
||||
return utils.get_response(200, response)
|
||||
# @router.get(
|
||||
# "/musics", response_model=BgmRetrieveResponse, summary="Retrieve local BGM files"
|
||||
# )
|
||||
# def get_bgm_list(request: Request):
|
||||
# suffix = "*.mp3"
|
||||
# song_dir = utils.song_dir()
|
||||
# files = glob.glob(os.path.join(song_dir, suffix))
|
||||
# bgm_list = []
|
||||
# for file in files:
|
||||
# bgm_list.append(
|
||||
# {
|
||||
# "name": os.path.basename(file),
|
||||
# "size": os.path.getsize(file),
|
||||
# "file": file,
|
||||
# }
|
||||
# )
|
||||
# response = {"files": bgm_list}
|
||||
# return utils.get_response(200, response)
|
||||
#
|
||||
|
||||
|
||||
@router.post(
|
||||
"/musics",
|
||||
response_model=BgmUploadResponse,
|
||||
summary="Upload the BGM file to the songs directory",
|
||||
)
|
||||
def upload_bgm_file(request: Request, file: UploadFile = File(...)):
|
||||
request_id = base.get_task_id(request)
|
||||
# check file ext
|
||||
if file.filename.endswith("mp3"):
|
||||
song_dir = utils.song_dir()
|
||||
save_path = os.path.join(song_dir, file.filename)
|
||||
# save file
|
||||
with open(save_path, "wb+") as buffer:
|
||||
# If the file already exists, it will be overwritten
|
||||
file.file.seek(0)
|
||||
buffer.write(file.file.read())
|
||||
response = {"file": save_path}
|
||||
return utils.get_response(200, response)
|
||||
|
||||
raise HttpException(
|
||||
"", status_code=400, message=f"{request_id}: Only *.mp3 files can be uploaded"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stream/{file_path:path}")
|
||||
async def stream_video(request: Request, file_path: str):
|
||||
tasks_dir = utils.task_dir()
|
||||
video_path = os.path.join(tasks_dir, file_path)
|
||||
range_header = request.headers.get("Range")
|
||||
video_size = os.path.getsize(video_path)
|
||||
start, end = 0, video_size - 1
|
||||
|
||||
length = video_size
|
||||
if range_header:
|
||||
range_ = range_header.split("bytes=")[1]
|
||||
start, end = [int(part) if part else None for part in range_.split("-")]
|
||||
if start is None:
|
||||
start = video_size - end
|
||||
end = video_size - 1
|
||||
if end is None:
|
||||
end = video_size - 1
|
||||
length = end - start + 1
|
||||
|
||||
def file_iterator(file_path, offset=0, bytes_to_read=None):
|
||||
with open(file_path, "rb") as f:
|
||||
f.seek(offset, os.SEEK_SET)
|
||||
remaining = bytes_to_read or video_size
|
||||
while remaining > 0:
|
||||
bytes_to_read = min(4096, remaining)
|
||||
data = f.read(bytes_to_read)
|
||||
if not data:
|
||||
break
|
||||
remaining -= len(data)
|
||||
yield data
|
||||
|
||||
response = StreamingResponse(
|
||||
file_iterator(video_path, start, length), media_type="video/mp4"
|
||||
)
|
||||
response.headers["Content-Range"] = f"bytes {start}-{end}/{video_size}"
|
||||
response.headers["Accept-Ranges"] = "bytes"
|
||||
response.headers["Content-Length"] = str(length)
|
||||
response.status_code = 206 # Partial Content
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/download/{file_path:path}")
|
||||
async def download_video(_: Request, file_path: str):
|
||||
"""
|
||||
download video
|
||||
:param _: Request request
|
||||
:param file_path: video file path, eg: /cd1727ed-3473-42a2-a7da-4faafafec72b/final-1.mp4
|
||||
:return: video file
|
||||
"""
|
||||
tasks_dir = utils.task_dir()
|
||||
video_path = os.path.join(tasks_dir, file_path)
|
||||
file_path = pathlib.Path(video_path)
|
||||
filename = file_path.stem
|
||||
extension = file_path.suffix
|
||||
headers = {"Content-Disposition": f"attachment; filename={filename}{extension}"}
|
||||
return FileResponse(
|
||||
path=video_path,
|
||||
headers=headers,
|
||||
filename=f"{filename}{extension}",
|
||||
media_type=f"video/{extension[1:]}",
|
||||
)
|
||||
# @router.post(
|
||||
# "/musics",
|
||||
# response_model=BgmUploadResponse,
|
||||
# summary="Upload the BGM file to the songs directory",
|
||||
# )
|
||||
# def upload_bgm_file(request: Request, file: UploadFile = File(...)):
|
||||
# request_id = base.get_task_id(request)
|
||||
# # check file ext
|
||||
# if file.filename.endswith("mp3"):
|
||||
# song_dir = utils.song_dir()
|
||||
# save_path = os.path.join(song_dir, file.filename)
|
||||
# # save file
|
||||
# with open(save_path, "wb+") as buffer:
|
||||
# # If the file already exists, it will be overwritten
|
||||
# file.file.seek(0)
|
||||
# buffer.write(file.file.read())
|
||||
# response = {"file": save_path}
|
||||
# return utils.get_response(200, response)
|
||||
#
|
||||
# raise HttpException(
|
||||
# "", status_code=400, message=f"{request_id}: Only *.mp3 files can be uploaded"
|
||||
# )
|
||||
#
|
||||
#
|
||||
# @router.get("/stream/{file_path:path}")
|
||||
# async def stream_video(request: Request, file_path: str):
|
||||
# tasks_dir = utils.task_dir()
|
||||
# video_path = os.path.join(tasks_dir, file_path)
|
||||
# range_header = request.headers.get("Range")
|
||||
# video_size = os.path.getsize(video_path)
|
||||
# start, end = 0, video_size - 1
|
||||
#
|
||||
# length = video_size
|
||||
# if range_header:
|
||||
# range_ = range_header.split("bytes=")[1]
|
||||
# start, end = [int(part) if part else None for part in range_.split("-")]
|
||||
# if start is None:
|
||||
# start = video_size - end
|
||||
# end = video_size - 1
|
||||
# if end is None:
|
||||
# end = video_size - 1
|
||||
# length = end - start + 1
|
||||
#
|
||||
# def file_iterator(file_path, offset=0, bytes_to_read=None):
|
||||
# with open(file_path, "rb") as f:
|
||||
# f.seek(offset, os.SEEK_SET)
|
||||
# remaining = bytes_to_read or video_size
|
||||
# while remaining > 0:
|
||||
# bytes_to_read = min(4096, remaining)
|
||||
# data = f.read(bytes_to_read)
|
||||
# if not data:
|
||||
# break
|
||||
# remaining -= len(data)
|
||||
# yield data
|
||||
#
|
||||
# response = StreamingResponse(
|
||||
# file_iterator(video_path, start, length), media_type="video/mp4"
|
||||
# )
|
||||
# response.headers["Content-Range"] = f"bytes {start}-{end}/{video_size}"
|
||||
# response.headers["Accept-Ranges"] = "bytes"
|
||||
# response.headers["Content-Length"] = str(length)
|
||||
# response.status_code = 206 # Partial Content
|
||||
#
|
||||
# return response
|
||||
#
|
||||
#
|
||||
# @router.get("/download/{file_path:path}")
|
||||
# async def download_video(_: Request, file_path: str):
|
||||
# """
|
||||
# download video
|
||||
# :param _: Request request
|
||||
# :param file_path: video file path, eg: /cd1727ed-3473-42a2-a7da-4faafafec72b/final-1.mp4
|
||||
# :return: video file
|
||||
# """
|
||||
# tasks_dir = utils.task_dir()
|
||||
# video_path = os.path.join(tasks_dir, file_path)
|
||||
# file_path = pathlib.Path(video_path)
|
||||
# filename = file_path.stem
|
||||
# extension = file_path.suffix
|
||||
# headers = {"Content-Disposition": f"attachment; filename={filename}{extension}"}
|
||||
# return FileResponse(
|
||||
# path=video_path,
|
||||
# headers=headers,
|
||||
# filename=f"{filename}{extension}",
|
||||
# media_type=f"video/{extension[1:]}",
|
||||
# )
|
||||
|
||||
11
app/controllers/v2/base.py
Normal file
11
app/controllers/v2/base.py
Normal file
@ -0,0 +1,11 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
|
||||
def v2_router(dependencies=None):
|
||||
router = APIRouter()
|
||||
router.tags = ["V2"]
|
||||
router.prefix = "/api/v2"
|
||||
# 将认证依赖项应用于所有路由
|
||||
if dependencies:
|
||||
router.dependencies = dependencies
|
||||
return router
|
||||
45
app/controllers/v2/script.py
Normal file
45
app/controllers/v2/script.py
Normal file
@ -0,0 +1,45 @@
|
||||
from fastapi import APIRouter, BackgroundTasks
|
||||
from loguru import logger
|
||||
|
||||
from app.models.schema_v2 import GenerateScriptRequest, GenerateScriptResponse
|
||||
from app.services.script_service import ScriptGenerator
|
||||
from app.utils import utils
|
||||
from app.controllers.v2.base import v2_router
|
||||
|
||||
# router = APIRouter(prefix="/api/v2", tags=["Script Generation V2"])
|
||||
router = v2_router()
|
||||
|
||||
@router.post(
|
||||
"/scripts/generate",
|
||||
response_model=GenerateScriptResponse,
|
||||
summary="生成视频脚本 (V2)"
|
||||
)
|
||||
async def generate_script(
|
||||
request: GenerateScriptRequest,
|
||||
background_tasks: BackgroundTasks
|
||||
):
|
||||
"""
|
||||
生成视频脚本的V2版本API
|
||||
"""
|
||||
task_id = utils.get_uuid()
|
||||
|
||||
try:
|
||||
generator = ScriptGenerator()
|
||||
script = await generator.generate_script(
|
||||
video_path=request.video_path,
|
||||
video_theme=request.video_theme,
|
||||
custom_prompt=request.custom_prompt,
|
||||
skip_seconds=request.skip_seconds,
|
||||
threshold=request.threshold,
|
||||
vision_batch_size=request.vision_batch_size,
|
||||
vision_llm_provider=request.vision_llm_provider
|
||||
)
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"script": script
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Generate script failed: {str(e)}")
|
||||
raise
|
||||
15
app/models/schema_v2.py
Normal file
15
app/models/schema_v2.py
Normal file
@ -0,0 +1,15 @@
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel
|
||||
|
||||
class GenerateScriptRequest(BaseModel):
|
||||
video_path: str
|
||||
video_theme: Optional[str] = ""
|
||||
custom_prompt: Optional[str] = ""
|
||||
skip_seconds: Optional[int] = 0
|
||||
threshold: Optional[int] = 30
|
||||
vision_batch_size: Optional[int] = 5
|
||||
vision_llm_provider: Optional[str] = "gemini"
|
||||
|
||||
class GenerateScriptResponse(BaseModel):
|
||||
task_id: str
|
||||
script: List[dict]
|
||||
@ -10,8 +10,12 @@ Resources:
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.controllers.v1 import llm, video
|
||||
from app.controllers.v2 import script
|
||||
|
||||
root_api_router = APIRouter()
|
||||
# v1
|
||||
root_api_router.include_router(video.router)
|
||||
root_api_router.include_router(llm.router)
|
||||
|
||||
# v2
|
||||
root_api_router.include_router(script.router)
|
||||
|
||||
378
app/services/script_service.py
Normal file
378
app/services/script_service.py
Normal file
@ -0,0 +1,378 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import asyncio
|
||||
import requests
|
||||
from loguru import logger
|
||||
from typing import List, Dict, Any, Callable
|
||||
|
||||
from app.utils import utils, vision_analyzer, video_processor, video_processor_v2
|
||||
from app.utils.script_generator import ScriptProcessor
|
||||
from app.config import config
|
||||
|
||||
|
||||
class ScriptGenerator:
|
||||
def __init__(self):
|
||||
self.temp_dir = utils.temp_dir()
|
||||
self.keyframes_dir = os.path.join(self.temp_dir, "keyframes")
|
||||
|
||||
async def generate_script(
|
||||
self,
|
||||
video_path: str,
|
||||
video_theme: str = "",
|
||||
custom_prompt: str = "",
|
||||
skip_seconds: int = 0,
|
||||
threshold: int = 30,
|
||||
vision_batch_size: int = 5,
|
||||
vision_llm_provider: str = "gemini",
|
||||
progress_callback: Callable[[float, str], None] = None
|
||||
) -> List[Dict[Any, Any]]:
|
||||
"""
|
||||
生成视频脚本的核心逻辑
|
||||
|
||||
Args:
|
||||
video_path: 视频文件路径
|
||||
video_theme: 视频主题
|
||||
custom_prompt: 自定义提示词
|
||||
skip_seconds: 跳过开始的秒数
|
||||
threshold: 差异阈值
|
||||
vision_batch_size: 视觉处理批次大小
|
||||
vision_llm_provider: 视觉模型提供商
|
||||
progress_callback: 进度回调函数
|
||||
|
||||
Returns:
|
||||
List[Dict]: 生成的视频脚本
|
||||
"""
|
||||
if progress_callback is None:
|
||||
progress_callback = lambda p, m: None
|
||||
|
||||
try:
|
||||
# 提取关键帧
|
||||
progress_callback(10, "正在提取关键帧...")
|
||||
keyframe_files = await self._extract_keyframes(
|
||||
video_path,
|
||||
skip_seconds,
|
||||
threshold
|
||||
)
|
||||
|
||||
if vision_llm_provider == "gemini":
|
||||
script = await self._process_with_gemini(
|
||||
keyframe_files,
|
||||
video_theme,
|
||||
custom_prompt,
|
||||
vision_batch_size,
|
||||
progress_callback
|
||||
)
|
||||
elif vision_llm_provider == "narratoapi":
|
||||
script = await self._process_with_narrato(
|
||||
keyframe_files,
|
||||
video_theme,
|
||||
custom_prompt,
|
||||
vision_batch_size,
|
||||
progress_callback
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported vision provider: {vision_llm_provider}")
|
||||
|
||||
return json.loads(script) if isinstance(script, str) else script
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Generate script failed")
|
||||
raise
|
||||
|
||||
async def _extract_keyframes(
|
||||
self,
|
||||
video_path: str,
|
||||
skip_seconds: int,
|
||||
threshold: int
|
||||
) -> List[str]:
|
||||
"""提取视频关键帧"""
|
||||
video_hash = utils.md5(video_path + str(os.path.getmtime(video_path)))
|
||||
video_keyframes_dir = os.path.join(self.keyframes_dir, video_hash)
|
||||
|
||||
# 检查缓存
|
||||
keyframe_files = []
|
||||
if os.path.exists(video_keyframes_dir):
|
||||
for filename in sorted(os.listdir(video_keyframes_dir)):
|
||||
if filename.endswith('.jpg'):
|
||||
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
|
||||
|
||||
if keyframe_files:
|
||||
logger.info(f"Using cached keyframes: {video_keyframes_dir}")
|
||||
return keyframe_files
|
||||
|
||||
# 提取新的关键帧
|
||||
os.makedirs(video_keyframes_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
if config.frames.get("version") == "v2":
|
||||
processor = video_processor_v2.VideoProcessor(video_path)
|
||||
processor.process_video_pipeline(
|
||||
output_dir=video_keyframes_dir,
|
||||
skip_seconds=skip_seconds,
|
||||
threshold=threshold
|
||||
)
|
||||
else:
|
||||
processor = video_processor.VideoProcessor(video_path)
|
||||
processor.process_video(
|
||||
output_dir=video_keyframes_dir,
|
||||
skip_seconds=skip_seconds
|
||||
)
|
||||
|
||||
for filename in sorted(os.listdir(video_keyframes_dir)):
|
||||
if filename.endswith('.jpg'):
|
||||
keyframe_files.append(os.path.join(video_keyframes_dir, filename))
|
||||
|
||||
return keyframe_files
|
||||
|
||||
except Exception as e:
|
||||
if os.path.exists(video_keyframes_dir):
|
||||
import shutil
|
||||
shutil.rmtree(video_keyframes_dir)
|
||||
raise
|
||||
|
||||
async def _process_with_gemini(
|
||||
self,
|
||||
keyframe_files: List[str],
|
||||
video_theme: str,
|
||||
custom_prompt: str,
|
||||
vision_batch_size: int,
|
||||
progress_callback: Callable[[float, str], None]
|
||||
) -> str:
|
||||
"""使用Gemini处理视频帧"""
|
||||
progress_callback(30, "正在初始化视觉分析器...")
|
||||
|
||||
# 获取Gemini配置
|
||||
vision_api_key = config.app.get("vision_gemini_api_key")
|
||||
vision_model = config.app.get("vision_gemini_model_name")
|
||||
|
||||
if not vision_api_key or not vision_model:
|
||||
raise ValueError("未配置 Gemini API Key 或者模型")
|
||||
|
||||
analyzer = vision_analyzer.VisionAnalyzer(
|
||||
model_name=vision_model,
|
||||
api_key=vision_api_key,
|
||||
)
|
||||
|
||||
progress_callback(40, "正在分析关键帧...")
|
||||
|
||||
# 执行异步分析
|
||||
results = await analyzer.analyze_images(
|
||||
images=keyframe_files,
|
||||
prompt=config.app.get('vision_analysis_prompt'),
|
||||
batch_size=vision_batch_size
|
||||
)
|
||||
|
||||
progress_callback(60, "正在整理分析结果...")
|
||||
|
||||
# 合并所有批次的分析结果
|
||||
frame_analysis = ""
|
||||
prev_batch_files = None
|
||||
|
||||
for result in results:
|
||||
if 'error' in result:
|
||||
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
|
||||
continue
|
||||
|
||||
batch_files = self._get_batch_files(keyframe_files, result, vision_batch_size)
|
||||
first_timestamp, last_timestamp, _ = self._get_batch_timestamps(batch_files, prev_batch_files)
|
||||
|
||||
# 添加带时间戳的分析结果
|
||||
frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n"
|
||||
frame_analysis += result['response']
|
||||
frame_analysis += "\n"
|
||||
|
||||
prev_batch_files = batch_files
|
||||
|
||||
if not frame_analysis.strip():
|
||||
raise Exception("未能生成有效的帧分析结果")
|
||||
|
||||
progress_callback(70, "正在生成脚本...")
|
||||
|
||||
# 构建帧内容列表
|
||||
frame_content_list = []
|
||||
prev_batch_files = None
|
||||
|
||||
for result in results:
|
||||
if 'error' in result:
|
||||
continue
|
||||
|
||||
batch_files = self._get_batch_files(keyframe_files, result, vision_batch_size)
|
||||
_, _, timestamp_range = self._get_batch_timestamps(batch_files, prev_batch_files)
|
||||
|
||||
frame_content = {
|
||||
"timestamp": timestamp_range,
|
||||
"picture": result['response'],
|
||||
"narration": "",
|
||||
"OST": 2
|
||||
}
|
||||
frame_content_list.append(frame_content)
|
||||
prev_batch_files = batch_files
|
||||
|
||||
if not frame_content_list:
|
||||
raise Exception("没有有效的帧内容可以处理")
|
||||
|
||||
progress_callback(90, "正在生成文案...")
|
||||
|
||||
# 获取文本生成配置
|
||||
text_provider = config.app.get('text_llm_provider', 'gemini').lower()
|
||||
text_api_key = config.app.get(f'text_{text_provider}_api_key')
|
||||
text_model = config.app.get(f'text_{text_provider}_model_name')
|
||||
|
||||
processor = ScriptProcessor(
|
||||
model_name=text_model,
|
||||
api_key=text_api_key,
|
||||
prompt=custom_prompt,
|
||||
video_theme=video_theme
|
||||
)
|
||||
|
||||
return processor.process_frames(frame_content_list)
|
||||
|
||||
async def _process_with_narrato(
|
||||
self,
|
||||
keyframe_files: List[str],
|
||||
video_theme: str,
|
||||
custom_prompt: str,
|
||||
vision_batch_size: int,
|
||||
progress_callback: Callable[[float, str], None]
|
||||
) -> str:
|
||||
"""使用NarratoAPI处理视频帧"""
|
||||
# 创建临时目录
|
||||
temp_dir = utils.temp_dir("narrato")
|
||||
|
||||
# 打包关键帧
|
||||
progress_callback(30, "正在打包关键帧...")
|
||||
zip_path = os.path.join(temp_dir, f"keyframes_{int(time.time())}.zip")
|
||||
|
||||
try:
|
||||
if not utils.create_zip(keyframe_files, zip_path):
|
||||
raise Exception("打包关键帧失败")
|
||||
|
||||
# 获取API配置
|
||||
api_url = config.app.get("narrato_api_url")
|
||||
api_key = config.app.get("narrato_api_key")
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("未配置 Narrato API Key")
|
||||
|
||||
headers = {
|
||||
'X-API-Key': api_key,
|
||||
'accept': 'application/json'
|
||||
}
|
||||
|
||||
api_params = {
|
||||
'batch_size': vision_batch_size,
|
||||
'use_ai': False,
|
||||
'start_offset': 0,
|
||||
'vision_model': config.app.get('narrato_vision_model', 'gemini-1.5-flash'),
|
||||
'vision_api_key': config.app.get('narrato_vision_key'),
|
||||
'llm_model': config.app.get('narrato_llm_model', 'qwen-plus'),
|
||||
'llm_api_key': config.app.get('narrato_llm_key'),
|
||||
'custom_prompt': custom_prompt
|
||||
}
|
||||
|
||||
progress_callback(40, "正在上传文件...")
|
||||
with open(zip_path, 'rb') as f:
|
||||
files = {'file': (os.path.basename(zip_path), f, 'application/x-zip-compressed')}
|
||||
response = requests.post(
|
||||
f"{api_url}/video/analyze",
|
||||
headers=headers,
|
||||
params=api_params,
|
||||
files=files,
|
||||
timeout=30
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
task_data = response.json()
|
||||
task_id = task_data["data"].get('task_id')
|
||||
if not task_id:
|
||||
raise Exception(f"无效的API响应: {response.text}")
|
||||
|
||||
progress_callback(50, "正在等待分析结果...")
|
||||
retry_count = 0
|
||||
max_retries = 60
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
status_response = requests.get(
|
||||
f"{api_url}/video/tasks/{task_id}",
|
||||
headers=headers,
|
||||
timeout=10
|
||||
)
|
||||
status_response.raise_for_status()
|
||||
task_status = status_response.json()['data']
|
||||
|
||||
if task_status['status'] == 'SUCCESS':
|
||||
return task_status['result']['data']
|
||||
elif task_status['status'] in ['FAILURE', 'RETRY']:
|
||||
raise Exception(f"任务失败: {task_status.get('error')}")
|
||||
|
||||
retry_count += 1
|
||||
time.sleep(2)
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.warning(f"获取任务状态失败,重试中: {str(e)}")
|
||||
retry_count += 1
|
||||
time.sleep(2)
|
||||
continue
|
||||
|
||||
raise Exception("任务执行超时")
|
||||
|
||||
finally:
|
||||
# 清理临时文件
|
||||
try:
|
||||
if os.path.exists(zip_path):
|
||||
os.remove(zip_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"清理临时文件失败: {str(e)}")
|
||||
|
||||
def _get_batch_files(
|
||||
self,
|
||||
keyframe_files: List[str],
|
||||
result: Dict[str, Any],
|
||||
batch_size: int
|
||||
) -> List[str]:
|
||||
"""获取当前批次的图片文件"""
|
||||
batch_start = result['batch_index'] * batch_size
|
||||
batch_end = min(batch_start + batch_size, len(keyframe_files))
|
||||
return keyframe_files[batch_start:batch_end]
|
||||
|
||||
def _get_batch_timestamps(
|
||||
self,
|
||||
batch_files: List[str],
|
||||
prev_batch_files: List[str] = None
|
||||
) -> tuple[str, str, str]:
|
||||
"""获取一批文件的时间戳范围"""
|
||||
if not batch_files:
|
||||
logger.warning("Empty batch files")
|
||||
return "00:00", "00:00", "00:00-00:00"
|
||||
|
||||
if len(batch_files) == 1 and prev_batch_files and len(prev_batch_files) > 0:
|
||||
first_frame = os.path.basename(prev_batch_files[-1])
|
||||
last_frame = os.path.basename(batch_files[0])
|
||||
else:
|
||||
first_frame = os.path.basename(batch_files[0])
|
||||
last_frame = os.path.basename(batch_files[-1])
|
||||
|
||||
first_time = first_frame.split('_')[2].replace('.jpg', '')
|
||||
last_time = last_frame.split('_')[2].replace('.jpg', '')
|
||||
|
||||
def format_timestamp(time_str: str) -> str:
|
||||
if len(time_str) < 4:
|
||||
logger.warning(f"Invalid timestamp format: {time_str}")
|
||||
return "00:00"
|
||||
|
||||
minutes = int(time_str[-4:-2])
|
||||
seconds = int(time_str[-2:])
|
||||
|
||||
if seconds >= 60:
|
||||
minutes += seconds // 60
|
||||
seconds = seconds % 60
|
||||
|
||||
return f"{minutes:02d}:{seconds:02d}"
|
||||
|
||||
first_timestamp = format_timestamp(first_time)
|
||||
last_timestamp = format_timestamp(last_time)
|
||||
timestamp_range = f"{first_timestamp}-{last_timestamp}"
|
||||
|
||||
return first_timestamp, last_timestamp, timestamp_range
|
||||
@ -103,7 +103,6 @@
|
||||
"Video Quality": "视频质量",
|
||||
"Custom prompt for LLM, leave empty to use default prompt": "自定义提示词,留空则使用默认提示词",
|
||||
"Proxy Settings": "代理设置",
|
||||
"Language": "界面语言",
|
||||
"HTTP_PROXY": "HTTP 代理",
|
||||
"HTTPs_PROXY": "HTTPS 代理",
|
||||
"Vision Model Settings": "视频分析模型设置",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user