diff --git a/app/controllers/v1/video.py b/app/controllers/v1/video.py index 0430707..336084f 100644 --- a/app/controllers/v1/video.py +++ b/app/controllers/v1/video.py @@ -163,109 +163,109 @@ def delete_video(request: Request, task_id: str = Path(..., description="Task ID ) -@router.get( - "/musics", response_model=BgmRetrieveResponse, summary="Retrieve local BGM files" -) -def get_bgm_list(request: Request): - suffix = "*.mp3" - song_dir = utils.song_dir() - files = glob.glob(os.path.join(song_dir, suffix)) - bgm_list = [] - for file in files: - bgm_list.append( - { - "name": os.path.basename(file), - "size": os.path.getsize(file), - "file": file, - } - ) - response = {"files": bgm_list} - return utils.get_response(200, response) +# @router.get( +# "/musics", response_model=BgmRetrieveResponse, summary="Retrieve local BGM files" +# ) +# def get_bgm_list(request: Request): +# suffix = "*.mp3" +# song_dir = utils.song_dir() +# files = glob.glob(os.path.join(song_dir, suffix)) +# bgm_list = [] +# for file in files: +# bgm_list.append( +# { +# "name": os.path.basename(file), +# "size": os.path.getsize(file), +# "file": file, +# } +# ) +# response = {"files": bgm_list} +# return utils.get_response(200, response) +# - -@router.post( - "/musics", - response_model=BgmUploadResponse, - summary="Upload the BGM file to the songs directory", -) -def upload_bgm_file(request: Request, file: UploadFile = File(...)): - request_id = base.get_task_id(request) - # check file ext - if file.filename.endswith("mp3"): - song_dir = utils.song_dir() - save_path = os.path.join(song_dir, file.filename) - # save file - with open(save_path, "wb+") as buffer: - # If the file already exists, it will be overwritten - file.file.seek(0) - buffer.write(file.file.read()) - response = {"file": save_path} - return utils.get_response(200, response) - - raise HttpException( - "", status_code=400, message=f"{request_id}: Only *.mp3 files can be uploaded" - ) - - -@router.get("/stream/{file_path:path}") -async def stream_video(request: Request, file_path: str): - tasks_dir = utils.task_dir() - video_path = os.path.join(tasks_dir, file_path) - range_header = request.headers.get("Range") - video_size = os.path.getsize(video_path) - start, end = 0, video_size - 1 - - length = video_size - if range_header: - range_ = range_header.split("bytes=")[1] - start, end = [int(part) if part else None for part in range_.split("-")] - if start is None: - start = video_size - end - end = video_size - 1 - if end is None: - end = video_size - 1 - length = end - start + 1 - - def file_iterator(file_path, offset=0, bytes_to_read=None): - with open(file_path, "rb") as f: - f.seek(offset, os.SEEK_SET) - remaining = bytes_to_read or video_size - while remaining > 0: - bytes_to_read = min(4096, remaining) - data = f.read(bytes_to_read) - if not data: - break - remaining -= len(data) - yield data - - response = StreamingResponse( - file_iterator(video_path, start, length), media_type="video/mp4" - ) - response.headers["Content-Range"] = f"bytes {start}-{end}/{video_size}" - response.headers["Accept-Ranges"] = "bytes" - response.headers["Content-Length"] = str(length) - response.status_code = 206 # Partial Content - - return response - - -@router.get("/download/{file_path:path}") -async def download_video(_: Request, file_path: str): - """ - download video - :param _: Request request - :param file_path: video file path, eg: /cd1727ed-3473-42a2-a7da-4faafafec72b/final-1.mp4 - :return: video file - """ - tasks_dir = utils.task_dir() - video_path = os.path.join(tasks_dir, file_path) - file_path = pathlib.Path(video_path) - filename = file_path.stem - extension = file_path.suffix - headers = {"Content-Disposition": f"attachment; filename={filename}{extension}"} - return FileResponse( - path=video_path, - headers=headers, - filename=f"{filename}{extension}", - media_type=f"video/{extension[1:]}", - ) +# @router.post( +# "/musics", +# response_model=BgmUploadResponse, +# summary="Upload the BGM file to the songs directory", +# ) +# def upload_bgm_file(request: Request, file: UploadFile = File(...)): +# request_id = base.get_task_id(request) +# # check file ext +# if file.filename.endswith("mp3"): +# song_dir = utils.song_dir() +# save_path = os.path.join(song_dir, file.filename) +# # save file +# with open(save_path, "wb+") as buffer: +# # If the file already exists, it will be overwritten +# file.file.seek(0) +# buffer.write(file.file.read()) +# response = {"file": save_path} +# return utils.get_response(200, response) +# +# raise HttpException( +# "", status_code=400, message=f"{request_id}: Only *.mp3 files can be uploaded" +# ) +# +# +# @router.get("/stream/{file_path:path}") +# async def stream_video(request: Request, file_path: str): +# tasks_dir = utils.task_dir() +# video_path = os.path.join(tasks_dir, file_path) +# range_header = request.headers.get("Range") +# video_size = os.path.getsize(video_path) +# start, end = 0, video_size - 1 +# +# length = video_size +# if range_header: +# range_ = range_header.split("bytes=")[1] +# start, end = [int(part) if part else None for part in range_.split("-")] +# if start is None: +# start = video_size - end +# end = video_size - 1 +# if end is None: +# end = video_size - 1 +# length = end - start + 1 +# +# def file_iterator(file_path, offset=0, bytes_to_read=None): +# with open(file_path, "rb") as f: +# f.seek(offset, os.SEEK_SET) +# remaining = bytes_to_read or video_size +# while remaining > 0: +# bytes_to_read = min(4096, remaining) +# data = f.read(bytes_to_read) +# if not data: +# break +# remaining -= len(data) +# yield data +# +# response = StreamingResponse( +# file_iterator(video_path, start, length), media_type="video/mp4" +# ) +# response.headers["Content-Range"] = f"bytes {start}-{end}/{video_size}" +# response.headers["Accept-Ranges"] = "bytes" +# response.headers["Content-Length"] = str(length) +# response.status_code = 206 # Partial Content +# +# return response +# +# +# @router.get("/download/{file_path:path}") +# async def download_video(_: Request, file_path: str): +# """ +# download video +# :param _: Request request +# :param file_path: video file path, eg: /cd1727ed-3473-42a2-a7da-4faafafec72b/final-1.mp4 +# :return: video file +# """ +# tasks_dir = utils.task_dir() +# video_path = os.path.join(tasks_dir, file_path) +# file_path = pathlib.Path(video_path) +# filename = file_path.stem +# extension = file_path.suffix +# headers = {"Content-Disposition": f"attachment; filename={filename}{extension}"} +# return FileResponse( +# path=video_path, +# headers=headers, +# filename=f"{filename}{extension}", +# media_type=f"video/{extension[1:]}", +# ) diff --git a/app/controllers/v2/base.py b/app/controllers/v2/base.py new file mode 100644 index 0000000..4612983 --- /dev/null +++ b/app/controllers/v2/base.py @@ -0,0 +1,11 @@ +from fastapi import APIRouter, Depends + + +def v2_router(dependencies=None): + router = APIRouter() + router.tags = ["V2"] + router.prefix = "/api/v2" + # 将认证依赖项应用于所有路由 + if dependencies: + router.dependencies = dependencies + return router diff --git a/app/controllers/v2/script.py b/app/controllers/v2/script.py new file mode 100644 index 0000000..85f4238 --- /dev/null +++ b/app/controllers/v2/script.py @@ -0,0 +1,45 @@ +from fastapi import APIRouter, BackgroundTasks +from loguru import logger + +from app.models.schema_v2 import GenerateScriptRequest, GenerateScriptResponse +from app.services.script_service import ScriptGenerator +from app.utils import utils +from app.controllers.v2.base import v2_router + +# router = APIRouter(prefix="/api/v2", tags=["Script Generation V2"]) +router = v2_router() + +@router.post( + "/scripts/generate", + response_model=GenerateScriptResponse, + summary="生成视频脚本 (V2)" +) +async def generate_script( + request: GenerateScriptRequest, + background_tasks: BackgroundTasks +): + """ + 生成视频脚本的V2版本API + """ + task_id = utils.get_uuid() + + try: + generator = ScriptGenerator() + script = await generator.generate_script( + video_path=request.video_path, + video_theme=request.video_theme, + custom_prompt=request.custom_prompt, + skip_seconds=request.skip_seconds, + threshold=request.threshold, + vision_batch_size=request.vision_batch_size, + vision_llm_provider=request.vision_llm_provider + ) + + return { + "task_id": task_id, + "script": script + } + + except Exception as e: + logger.exception(f"Generate script failed: {str(e)}") + raise \ No newline at end of file diff --git a/app/models/schema_v2.py b/app/models/schema_v2.py new file mode 100644 index 0000000..786c018 --- /dev/null +++ b/app/models/schema_v2.py @@ -0,0 +1,15 @@ +from typing import Optional, List +from pydantic import BaseModel + +class GenerateScriptRequest(BaseModel): + video_path: str + video_theme: Optional[str] = "" + custom_prompt: Optional[str] = "" + skip_seconds: Optional[int] = 0 + threshold: Optional[int] = 30 + vision_batch_size: Optional[int] = 5 + vision_llm_provider: Optional[str] = "gemini" + +class GenerateScriptResponse(BaseModel): + task_id: str + script: List[dict] \ No newline at end of file diff --git a/app/router.py b/app/router.py index cf84037..df60500 100644 --- a/app/router.py +++ b/app/router.py @@ -10,8 +10,12 @@ Resources: from fastapi import APIRouter from app.controllers.v1 import llm, video +from app.controllers.v2 import script root_api_router = APIRouter() # v1 root_api_router.include_router(video.router) root_api_router.include_router(llm.router) + +# v2 +root_api_router.include_router(script.router) diff --git a/app/services/script_service.py b/app/services/script_service.py new file mode 100644 index 0000000..1693cbc --- /dev/null +++ b/app/services/script_service.py @@ -0,0 +1,378 @@ +import os +import json +import time +import asyncio +import requests +from loguru import logger +from typing import List, Dict, Any, Callable + +from app.utils import utils, vision_analyzer, video_processor, video_processor_v2 +from app.utils.script_generator import ScriptProcessor +from app.config import config + + +class ScriptGenerator: + def __init__(self): + self.temp_dir = utils.temp_dir() + self.keyframes_dir = os.path.join(self.temp_dir, "keyframes") + + async def generate_script( + self, + video_path: str, + video_theme: str = "", + custom_prompt: str = "", + skip_seconds: int = 0, + threshold: int = 30, + vision_batch_size: int = 5, + vision_llm_provider: str = "gemini", + progress_callback: Callable[[float, str], None] = None + ) -> List[Dict[Any, Any]]: + """ + 生成视频脚本的核心逻辑 + + Args: + video_path: 视频文件路径 + video_theme: 视频主题 + custom_prompt: 自定义提示词 + skip_seconds: 跳过开始的秒数 + threshold: 差异阈值 + vision_batch_size: 视觉处理批次大小 + vision_llm_provider: 视觉模型提供商 + progress_callback: 进度回调函数 + + Returns: + List[Dict]: 生成的视频脚本 + """ + if progress_callback is None: + progress_callback = lambda p, m: None + + try: + # 提取关键帧 + progress_callback(10, "正在提取关键帧...") + keyframe_files = await self._extract_keyframes( + video_path, + skip_seconds, + threshold + ) + + if vision_llm_provider == "gemini": + script = await self._process_with_gemini( + keyframe_files, + video_theme, + custom_prompt, + vision_batch_size, + progress_callback + ) + elif vision_llm_provider == "narratoapi": + script = await self._process_with_narrato( + keyframe_files, + video_theme, + custom_prompt, + vision_batch_size, + progress_callback + ) + else: + raise ValueError(f"Unsupported vision provider: {vision_llm_provider}") + + return json.loads(script) if isinstance(script, str) else script + + except Exception as e: + logger.exception("Generate script failed") + raise + + async def _extract_keyframes( + self, + video_path: str, + skip_seconds: int, + threshold: int + ) -> List[str]: + """提取视频关键帧""" + video_hash = utils.md5(video_path + str(os.path.getmtime(video_path))) + video_keyframes_dir = os.path.join(self.keyframes_dir, video_hash) + + # 检查缓存 + keyframe_files = [] + if os.path.exists(video_keyframes_dir): + for filename in sorted(os.listdir(video_keyframes_dir)): + if filename.endswith('.jpg'): + keyframe_files.append(os.path.join(video_keyframes_dir, filename)) + + if keyframe_files: + logger.info(f"Using cached keyframes: {video_keyframes_dir}") + return keyframe_files + + # 提取新的关键帧 + os.makedirs(video_keyframes_dir, exist_ok=True) + + try: + if config.frames.get("version") == "v2": + processor = video_processor_v2.VideoProcessor(video_path) + processor.process_video_pipeline( + output_dir=video_keyframes_dir, + skip_seconds=skip_seconds, + threshold=threshold + ) + else: + processor = video_processor.VideoProcessor(video_path) + processor.process_video( + output_dir=video_keyframes_dir, + skip_seconds=skip_seconds + ) + + for filename in sorted(os.listdir(video_keyframes_dir)): + if filename.endswith('.jpg'): + keyframe_files.append(os.path.join(video_keyframes_dir, filename)) + + return keyframe_files + + except Exception as e: + if os.path.exists(video_keyframes_dir): + import shutil + shutil.rmtree(video_keyframes_dir) + raise + + async def _process_with_gemini( + self, + keyframe_files: List[str], + video_theme: str, + custom_prompt: str, + vision_batch_size: int, + progress_callback: Callable[[float, str], None] + ) -> str: + """使用Gemini处理视频帧""" + progress_callback(30, "正在初始化视觉分析器...") + + # 获取Gemini配置 + vision_api_key = config.app.get("vision_gemini_api_key") + vision_model = config.app.get("vision_gemini_model_name") + + if not vision_api_key or not vision_model: + raise ValueError("未配置 Gemini API Key 或者模型") + + analyzer = vision_analyzer.VisionAnalyzer( + model_name=vision_model, + api_key=vision_api_key, + ) + + progress_callback(40, "正在分析关键帧...") + + # 执行异步分析 + results = await analyzer.analyze_images( + images=keyframe_files, + prompt=config.app.get('vision_analysis_prompt'), + batch_size=vision_batch_size + ) + + progress_callback(60, "正在整理分析结果...") + + # 合并所有批次的分析结果 + frame_analysis = "" + prev_batch_files = None + + for result in results: + if 'error' in result: + logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}") + continue + + batch_files = self._get_batch_files(keyframe_files, result, vision_batch_size) + first_timestamp, last_timestamp, _ = self._get_batch_timestamps(batch_files, prev_batch_files) + + # 添加带时间戳的分析结果 + frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n" + frame_analysis += result['response'] + frame_analysis += "\n" + + prev_batch_files = batch_files + + if not frame_analysis.strip(): + raise Exception("未能生成有效的帧分析结果") + + progress_callback(70, "正在生成脚本...") + + # 构建帧内容列表 + frame_content_list = [] + prev_batch_files = None + + for result in results: + if 'error' in result: + continue + + batch_files = self._get_batch_files(keyframe_files, result, vision_batch_size) + _, _, timestamp_range = self._get_batch_timestamps(batch_files, prev_batch_files) + + frame_content = { + "timestamp": timestamp_range, + "picture": result['response'], + "narration": "", + "OST": 2 + } + frame_content_list.append(frame_content) + prev_batch_files = batch_files + + if not frame_content_list: + raise Exception("没有有效的帧内容可以处理") + + progress_callback(90, "正在生成文案...") + + # 获取文本生成配置 + text_provider = config.app.get('text_llm_provider', 'gemini').lower() + text_api_key = config.app.get(f'text_{text_provider}_api_key') + text_model = config.app.get(f'text_{text_provider}_model_name') + + processor = ScriptProcessor( + model_name=text_model, + api_key=text_api_key, + prompt=custom_prompt, + video_theme=video_theme + ) + + return processor.process_frames(frame_content_list) + + async def _process_with_narrato( + self, + keyframe_files: List[str], + video_theme: str, + custom_prompt: str, + vision_batch_size: int, + progress_callback: Callable[[float, str], None] + ) -> str: + """使用NarratoAPI处理视频帧""" + # 创建临时目录 + temp_dir = utils.temp_dir("narrato") + + # 打包关键帧 + progress_callback(30, "正在打包关键帧...") + zip_path = os.path.join(temp_dir, f"keyframes_{int(time.time())}.zip") + + try: + if not utils.create_zip(keyframe_files, zip_path): + raise Exception("打包关键帧失败") + + # 获取API配置 + api_url = config.app.get("narrato_api_url") + api_key = config.app.get("narrato_api_key") + + if not api_key: + raise ValueError("未配置 Narrato API Key") + + headers = { + 'X-API-Key': api_key, + 'accept': 'application/json' + } + + api_params = { + 'batch_size': vision_batch_size, + 'use_ai': False, + 'start_offset': 0, + 'vision_model': config.app.get('narrato_vision_model', 'gemini-1.5-flash'), + 'vision_api_key': config.app.get('narrato_vision_key'), + 'llm_model': config.app.get('narrato_llm_model', 'qwen-plus'), + 'llm_api_key': config.app.get('narrato_llm_key'), + 'custom_prompt': custom_prompt + } + + progress_callback(40, "正在上传文件...") + with open(zip_path, 'rb') as f: + files = {'file': (os.path.basename(zip_path), f, 'application/x-zip-compressed')} + response = requests.post( + f"{api_url}/video/analyze", + headers=headers, + params=api_params, + files=files, + timeout=30 + ) + response.raise_for_status() + + task_data = response.json() + task_id = task_data["data"].get('task_id') + if not task_id: + raise Exception(f"无效的API响应: {response.text}") + + progress_callback(50, "正在等待分析结果...") + retry_count = 0 + max_retries = 60 + + while retry_count < max_retries: + try: + status_response = requests.get( + f"{api_url}/video/tasks/{task_id}", + headers=headers, + timeout=10 + ) + status_response.raise_for_status() + task_status = status_response.json()['data'] + + if task_status['status'] == 'SUCCESS': + return task_status['result']['data'] + elif task_status['status'] in ['FAILURE', 'RETRY']: + raise Exception(f"任务失败: {task_status.get('error')}") + + retry_count += 1 + time.sleep(2) + + except requests.RequestException as e: + logger.warning(f"获取任务状态失败,重试中: {str(e)}") + retry_count += 1 + time.sleep(2) + continue + + raise Exception("任务执行超时") + + finally: + # 清理临时文件 + try: + if os.path.exists(zip_path): + os.remove(zip_path) + except Exception as e: + logger.warning(f"清理临时文件失败: {str(e)}") + + def _get_batch_files( + self, + keyframe_files: List[str], + result: Dict[str, Any], + batch_size: int + ) -> List[str]: + """获取当前批次的图片文件""" + batch_start = result['batch_index'] * batch_size + batch_end = min(batch_start + batch_size, len(keyframe_files)) + return keyframe_files[batch_start:batch_end] + + def _get_batch_timestamps( + self, + batch_files: List[str], + prev_batch_files: List[str] = None + ) -> tuple[str, str, str]: + """获取一批文件的时间戳范围""" + if not batch_files: + logger.warning("Empty batch files") + return "00:00", "00:00", "00:00-00:00" + + if len(batch_files) == 1 and prev_batch_files and len(prev_batch_files) > 0: + first_frame = os.path.basename(prev_batch_files[-1]) + last_frame = os.path.basename(batch_files[0]) + else: + first_frame = os.path.basename(batch_files[0]) + last_frame = os.path.basename(batch_files[-1]) + + first_time = first_frame.split('_')[2].replace('.jpg', '') + last_time = last_frame.split('_')[2].replace('.jpg', '') + + def format_timestamp(time_str: str) -> str: + if len(time_str) < 4: + logger.warning(f"Invalid timestamp format: {time_str}") + return "00:00" + + minutes = int(time_str[-4:-2]) + seconds = int(time_str[-2:]) + + if seconds >= 60: + minutes += seconds // 60 + seconds = seconds % 60 + + return f"{minutes:02d}:{seconds:02d}" + + first_timestamp = format_timestamp(first_time) + last_timestamp = format_timestamp(last_time) + timestamp_range = f"{first_timestamp}-{last_timestamp}" + + return first_timestamp, last_timestamp, timestamp_range \ No newline at end of file diff --git a/webui/i18n/zh.json b/webui/i18n/zh.json index 68b968a..db17ccc 100644 --- a/webui/i18n/zh.json +++ b/webui/i18n/zh.json @@ -103,7 +103,6 @@ "Video Quality": "视频质量", "Custom prompt for LLM, leave empty to use default prompt": "自定义提示词,留空则使用默认提示词", "Proxy Settings": "代理设置", - "Language": "界面语言", "HTTP_PROXY": "HTTP 代理", "HTTPs_PROXY": "HTTPS 代理", "Vision Model Settings": "视频分析模型设置",