mirror of
https://github.com/linyqh/NarratoAI.git
synced 2026-06-17 04:42:05 +00:00
chore(app/utils): 移除两个废弃的Gemini视觉分析工具文件
This commit is contained in:
parent
9f28fcfa98
commit
8b1fcbafa5
@ -1,326 +0,0 @@
|
||||
import json
|
||||
from typing import List, Union, Dict
|
||||
import os
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from tqdm import tqdm
|
||||
import asyncio
|
||||
from tenacity import retry, stop_after_attempt, retry_if_exception_type, wait_exponential
|
||||
import requests
|
||||
import PIL.Image
|
||||
import traceback
|
||||
import base64
|
||||
import io
|
||||
from app.utils import utils
|
||||
|
||||
|
||||
class VisionAnalyzer:
|
||||
"""原生Gemini视觉分析器类"""
|
||||
|
||||
def __init__(self, model_name: str = "gemini-2.0-flash-exp", api_key: str = None, base_url: str = None):
|
||||
"""初始化视觉分析器"""
|
||||
if not api_key:
|
||||
raise ValueError("必须提供API密钥")
|
||||
|
||||
self.model_name = model_name
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url or "https://generativelanguage.googleapis.com/v1beta"
|
||||
|
||||
# 初始化配置
|
||||
self._configure_client()
|
||||
|
||||
def _configure_client(self):
|
||||
"""配置原生Gemini API客户端"""
|
||||
# 使用原生Gemini REST API
|
||||
self.client = None
|
||||
logger.info(f"配置原生Gemini API,端点: {self.base_url}, 模型: {self.model_name}")
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(requests.exceptions.RequestException)
|
||||
)
|
||||
async def _generate_content_with_retry(self, prompt, batch):
|
||||
"""使用重试机制调用原生Gemini API"""
|
||||
try:
|
||||
return await self._generate_with_gemini_api(prompt, batch)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.warning(f"Gemini API请求异常: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Gemini API生成内容时发生错误: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _generate_with_gemini_api(self, prompt, batch):
|
||||
"""使用原生Gemini REST API生成内容"""
|
||||
# 将PIL图片转换为base64编码
|
||||
image_parts = []
|
||||
for img in batch:
|
||||
# 将PIL图片转换为字节流
|
||||
img_buffer = io.BytesIO()
|
||||
img.save(img_buffer, format='JPEG', quality=85) # 优化图片质量
|
||||
img_bytes = img_buffer.getvalue()
|
||||
|
||||
# 转换为base64
|
||||
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
|
||||
image_parts.append({
|
||||
"inline_data": {
|
||||
"mime_type": "image/jpeg",
|
||||
"data": img_base64
|
||||
}
|
||||
})
|
||||
|
||||
# 构建符合官方文档的请求数据
|
||||
request_data = {
|
||||
"contents": [{
|
||||
"parts": [
|
||||
{"text": prompt},
|
||||
*image_parts
|
||||
]
|
||||
}],
|
||||
"generationConfig": {
|
||||
"temperature": 1.0,
|
||||
"topK": 40,
|
||||
"topP": 0.95,
|
||||
"maxOutputTokens": 8192,
|
||||
"candidateCount": 1,
|
||||
"stopSequences": []
|
||||
},
|
||||
"safetySettings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# 构建请求URL
|
||||
url = f"{self.base_url}/models/{self.model_name}:generateContent"
|
||||
|
||||
# 发送请求
|
||||
response = await asyncio.to_thread(
|
||||
requests.post,
|
||||
url,
|
||||
json=request_data,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"x-goog-api-key": self.api_key
|
||||
},
|
||||
timeout=120 # 增加超时时间
|
||||
)
|
||||
|
||||
# 处理HTTP错误
|
||||
if response.status_code == 429:
|
||||
raise requests.exceptions.RequestException(f"API配额限制: {response.text}")
|
||||
elif response.status_code == 400:
|
||||
raise Exception(f"请求参数错误: {response.text}")
|
||||
elif response.status_code == 403:
|
||||
raise Exception(f"API密钥无效或权限不足: {response.text}")
|
||||
elif response.status_code != 200:
|
||||
raise Exception(f"Gemini API请求失败: {response.status_code} - {response.text}")
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
# 检查响应格式
|
||||
if "candidates" not in response_data or not response_data["candidates"]:
|
||||
raise Exception("Gemini API返回无效响应,可能触发了安全过滤")
|
||||
|
||||
candidate = response_data["candidates"][0]
|
||||
|
||||
# 检查是否被安全过滤阻止
|
||||
if "finishReason" in candidate and candidate["finishReason"] == "SAFETY":
|
||||
raise Exception("内容被Gemini安全过滤器阻止")
|
||||
|
||||
if "content" not in candidate or "parts" not in candidate["content"]:
|
||||
raise Exception("Gemini API返回内容格式错误")
|
||||
|
||||
# 提取文本内容
|
||||
text_content = ""
|
||||
for part in candidate["content"]["parts"]:
|
||||
if "text" in part:
|
||||
text_content += part["text"]
|
||||
|
||||
if not text_content.strip():
|
||||
raise Exception("Gemini API返回空内容")
|
||||
|
||||
# 创建兼容的响应对象
|
||||
class CompatibleResponse:
|
||||
def __init__(self, text):
|
||||
self.text = text
|
||||
|
||||
return CompatibleResponse(text_content)
|
||||
|
||||
async def analyze_images(self,
|
||||
images: Union[List[str], List[PIL.Image.Image]],
|
||||
prompt: str,
|
||||
batch_size: int) -> List[Dict]:
|
||||
"""批量分析多张图片"""
|
||||
try:
|
||||
# 加载图片
|
||||
if isinstance(images[0], str):
|
||||
images = self.load_images(images)
|
||||
|
||||
# 验证图片列表
|
||||
if not images:
|
||||
raise ValueError("图片列表为空")
|
||||
|
||||
# 验证每个图片对象
|
||||
valid_images = []
|
||||
for i, img in enumerate(images):
|
||||
if not isinstance(img, PIL.Image.Image):
|
||||
logger.error(f"无效的图片对象,索引 {i}: {type(img)}")
|
||||
continue
|
||||
valid_images.append(img)
|
||||
|
||||
if not valid_images:
|
||||
raise ValueError("没有有效的图片对象")
|
||||
|
||||
images = valid_images
|
||||
results = []
|
||||
# 视频帧总数除以批量处理大小,如果有小数则+1
|
||||
batches_needed = len(images) // batch_size
|
||||
if len(images) % batch_size > 0:
|
||||
batches_needed += 1
|
||||
|
||||
logger.debug(f"视频帧总数:{len(images)}, 每批处理 {batch_size} 帧, 需要访问 VLM {batches_needed} 次")
|
||||
|
||||
with tqdm(total=batches_needed, desc="分析进度") as pbar:
|
||||
for i in range(0, len(images), batch_size):
|
||||
batch = images[i:i + batch_size]
|
||||
retry_count = 0
|
||||
|
||||
while retry_count < 3:
|
||||
try:
|
||||
# 在每个批次处理前添加小延迟
|
||||
# if i > 0:
|
||||
# await asyncio.sleep(2)
|
||||
|
||||
# 确保每个批次的图片都是有效的
|
||||
valid_batch = [img for img in batch if isinstance(img, PIL.Image.Image)]
|
||||
if not valid_batch:
|
||||
raise ValueError(f"批次 {i // batch_size} 中没有有效的图片")
|
||||
|
||||
response = await self._generate_content_with_retry(prompt, valid_batch)
|
||||
results.append({
|
||||
'batch_index': i // batch_size,
|
||||
'images_processed': len(valid_batch),
|
||||
'response': response.text,
|
||||
'model_used': self.model_name
|
||||
})
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
error_msg = f"批次 {i // batch_size} 处理出错: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
|
||||
if retry_count >= 3:
|
||||
results.append({
|
||||
'batch_index': i // batch_size,
|
||||
'images_processed': len(batch),
|
||||
'error': error_msg,
|
||||
'model_used': self.model_name
|
||||
})
|
||||
else:
|
||||
logger.info(f"批次 {i // batch_size} 处理失败,等待60秒后重试当前批次...")
|
||||
await asyncio.sleep(60)
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"图片分析过程中发生错误: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
def save_results_to_txt(self, results: List[Dict], output_dir: str):
|
||||
"""将分析结果保存到txt文件"""
|
||||
# 确保输出目录存在
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
for result in results:
|
||||
if not result.get('image_paths'):
|
||||
continue
|
||||
|
||||
response_text = result['response']
|
||||
image_paths = result['image_paths']
|
||||
|
||||
# 从文件名中提取时间戳并转换为标准格式
|
||||
def format_timestamp(img_path):
|
||||
# 从文件名中提取时间部分
|
||||
timestamp = Path(img_path).stem.split('_')[-1]
|
||||
try:
|
||||
# 将时间转换为秒
|
||||
seconds = utils.time_to_seconds(timestamp.replace('_', ':'))
|
||||
# 转换为 HH:MM:SS,mmm 格式
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
seconds_remainder = seconds % 60
|
||||
whole_seconds = int(seconds_remainder)
|
||||
milliseconds = int((seconds_remainder - whole_seconds) * 1000)
|
||||
|
||||
return f"{hours:02d}:{minutes:02d}:{whole_seconds:02d},{milliseconds:03d}"
|
||||
except Exception as e:
|
||||
logger.error(f"时间戳格式转换错误: {timestamp}, {str(e)}")
|
||||
return timestamp
|
||||
|
||||
start_timestamp = format_timestamp(image_paths[0])
|
||||
end_timestamp = format_timestamp(image_paths[-1])
|
||||
|
||||
txt_path = os.path.join(output_dir, f"frame_{start_timestamp}_{end_timestamp}.txt")
|
||||
|
||||
# 保存结果到txt文件
|
||||
with open(txt_path, 'w', encoding='utf-8') as f:
|
||||
f.write(response_text.strip())
|
||||
logger.info(f"已保存分析结果到: {txt_path}")
|
||||
|
||||
def load_images(self, image_paths: List[str]) -> List[PIL.Image.Image]:
|
||||
"""
|
||||
加载多张图片
|
||||
Args:
|
||||
image_paths: 图片路径列表
|
||||
Returns:
|
||||
加载后的PIL Image对象列表
|
||||
"""
|
||||
images = []
|
||||
failed_images = []
|
||||
|
||||
for img_path in image_paths:
|
||||
try:
|
||||
if not os.path.exists(img_path):
|
||||
logger.error(f"图片文件不存在: {img_path}")
|
||||
failed_images.append(img_path)
|
||||
continue
|
||||
|
||||
img = PIL.Image.open(img_path)
|
||||
# 确保图片被完全加载
|
||||
img.load()
|
||||
# 转换为RGB模式
|
||||
if img.mode != 'RGB':
|
||||
img = img.convert('RGB')
|
||||
images.append(img)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"无法加载图片 {img_path}: {str(e)}")
|
||||
failed_images.append(img_path)
|
||||
|
||||
if failed_images:
|
||||
logger.warning(f"以下图片加载失败:\n{json.dumps(failed_images, indent=2, ensure_ascii=False)}")
|
||||
|
||||
if not images:
|
||||
raise ValueError("没有成功加载任何图片")
|
||||
|
||||
return images
|
||||
@ -1,177 +0,0 @@
|
||||
"""
|
||||
OpenAI兼容的Gemini视觉分析器
|
||||
使用标准OpenAI格式调用Gemini代理服务
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import List, Union, Dict
|
||||
import os
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from tqdm import tqdm
|
||||
import asyncio
|
||||
from tenacity import retry, stop_after_attempt, retry_if_exception_type, wait_exponential
|
||||
import requests
|
||||
import PIL.Image
|
||||
import traceback
|
||||
import base64
|
||||
import io
|
||||
from app.utils import utils
|
||||
|
||||
|
||||
class GeminiOpenAIAnalyzer:
|
||||
"""OpenAI兼容的Gemini视觉分析器类"""
|
||||
|
||||
def __init__(self, model_name: str = "gemini-2.0-flash-exp", api_key: str = None, base_url: str = None):
|
||||
"""初始化OpenAI兼容的Gemini分析器"""
|
||||
if not api_key:
|
||||
raise ValueError("必须提供API密钥")
|
||||
|
||||
if not base_url:
|
||||
raise ValueError("必须提供OpenAI兼容的代理端点URL")
|
||||
|
||||
self.model_name = model_name
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url.rstrip('/')
|
||||
|
||||
# 初始化OpenAI客户端
|
||||
self._configure_client()
|
||||
|
||||
def _configure_client(self):
|
||||
"""配置OpenAI兼容的客户端"""
|
||||
from openai import OpenAI
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
logger.info(f"配置OpenAI兼容Gemini代理,端点: {self.base_url}, 模型: {self.model_name}")
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((requests.exceptions.RequestException, Exception))
|
||||
)
|
||||
async def _generate_content_with_retry(self, prompt, batch):
|
||||
"""使用重试机制调用OpenAI兼容的Gemini代理"""
|
||||
try:
|
||||
return await self._generate_with_openai_api(prompt, batch)
|
||||
except Exception as e:
|
||||
logger.warning(f"OpenAI兼容Gemini代理请求异常: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _generate_with_openai_api(self, prompt, batch):
|
||||
"""使用OpenAI兼容接口生成内容"""
|
||||
# 将PIL图片转换为base64编码
|
||||
image_contents = []
|
||||
for img in batch:
|
||||
# 将PIL图片转换为字节流
|
||||
img_buffer = io.BytesIO()
|
||||
img.save(img_buffer, format='JPEG', quality=85)
|
||||
img_bytes = img_buffer.getvalue()
|
||||
|
||||
# 转换为base64
|
||||
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
|
||||
image_contents.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{img_base64}"
|
||||
}
|
||||
})
|
||||
|
||||
# 构建OpenAI格式的消息
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
*image_contents
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# 调用OpenAI兼容接口
|
||||
response = await asyncio.to_thread(
|
||||
self.client.chat.completions.create,
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
max_tokens=4000,
|
||||
temperature=1.0
|
||||
)
|
||||
|
||||
# 创建兼容的响应对象
|
||||
class CompatibleResponse:
|
||||
def __init__(self, text):
|
||||
self.text = text
|
||||
|
||||
return CompatibleResponse(response.choices[0].message.content)
|
||||
|
||||
async def analyze_images(self,
|
||||
images: List[Union[str, Path, PIL.Image.Image]],
|
||||
prompt: str,
|
||||
batch_size: int = 10) -> List[str]:
|
||||
"""
|
||||
分析图片并返回结果
|
||||
|
||||
Args:
|
||||
images: 图片路径列表或PIL图片对象列表
|
||||
prompt: 分析提示词
|
||||
batch_size: 批处理大小
|
||||
|
||||
Returns:
|
||||
分析结果列表
|
||||
"""
|
||||
logger.info(f"开始分析 {len(images)} 张图片,使用OpenAI兼容Gemini代理")
|
||||
|
||||
# 加载图片
|
||||
loaded_images = []
|
||||
for img in images:
|
||||
if isinstance(img, (str, Path)):
|
||||
try:
|
||||
pil_img = PIL.Image.open(img)
|
||||
# 调整图片大小以优化性能
|
||||
if pil_img.size[0] > 1024 or pil_img.size[1] > 1024:
|
||||
pil_img.thumbnail((1024, 1024), PIL.Image.Resampling.LANCZOS)
|
||||
loaded_images.append(pil_img)
|
||||
except Exception as e:
|
||||
logger.error(f"加载图片失败 {img}: {str(e)}")
|
||||
continue
|
||||
elif isinstance(img, PIL.Image.Image):
|
||||
loaded_images.append(img)
|
||||
else:
|
||||
logger.warning(f"不支持的图片类型: {type(img)}")
|
||||
continue
|
||||
|
||||
if not loaded_images:
|
||||
raise ValueError("没有有效的图片可以分析")
|
||||
|
||||
# 分批处理
|
||||
results = []
|
||||
total_batches = (len(loaded_images) + batch_size - 1) // batch_size
|
||||
|
||||
for i in tqdm(range(0, len(loaded_images), batch_size),
|
||||
desc="分析图片批次", total=total_batches):
|
||||
batch = loaded_images[i:i + batch_size]
|
||||
|
||||
try:
|
||||
response = await self._generate_content_with_retry(prompt, batch)
|
||||
results.append(response.text)
|
||||
|
||||
# 添加延迟以避免API限流
|
||||
if i + batch_size < len(loaded_images):
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"分析批次 {i//batch_size + 1} 失败: {str(e)}")
|
||||
results.append(f"分析失败: {str(e)}")
|
||||
|
||||
logger.info(f"完成图片分析,共处理 {len(results)} 个批次")
|
||||
return results
|
||||
|
||||
def analyze_images_sync(self,
|
||||
images: List[Union[str, Path, PIL.Image.Image]],
|
||||
prompt: str,
|
||||
batch_size: int = 10) -> List[str]:
|
||||
"""
|
||||
同步版本的图片分析方法
|
||||
"""
|
||||
return asyncio.run(self.analyze_images(images, prompt, batch_size))
|
||||
Loading…
x
Reference in New Issue
Block a user