mirror of
https://github.com/linyqh/NarratoAI.git
synced 2025-12-10 18:02:51 +00:00
更新generate_narration_script.py、base.py和generate_short_summary.py文件,重构解说文案生成和视觉分析器的实现,优先使用新的LLM服务架构。添加回退机制以确保兼容性,增强系统的稳定性和用户体验。
169 lines
6.1 KiB
Python
169 lines
6.1 KiB
Python
"""
|
||
OpenAI API提供商实现
|
||
|
||
使用OpenAI API进行文本生成,也支持OpenAI兼容的其他服务
|
||
"""
|
||
|
||
import asyncio
|
||
from typing import List, Dict, Any, Optional
|
||
from openai import OpenAI, BadRequestError
|
||
from loguru import logger
|
||
|
||
from ..base import TextModelProvider
|
||
from ..exceptions import APICallError, RateLimitError, AuthenticationError
|
||
|
||
|
||
class OpenAITextProvider(TextModelProvider):
|
||
"""OpenAI文本生成提供商"""
|
||
|
||
@property
|
||
def provider_name(self) -> str:
|
||
return "openai"
|
||
|
||
@property
|
||
def supported_models(self) -> List[str]:
|
||
return [
|
||
"gpt-4o",
|
||
"gpt-4o-mini",
|
||
"gpt-4-turbo",
|
||
"gpt-4",
|
||
"gpt-3.5-turbo",
|
||
"gpt-3.5-turbo-16k",
|
||
# 支持其他OpenAI兼容模型
|
||
"deepseek-chat",
|
||
"deepseek-reasoner",
|
||
"qwen-plus",
|
||
"qwen-turbo",
|
||
"moonshot-v1-8k",
|
||
"moonshot-v1-32k",
|
||
"moonshot-v1-128k"
|
||
]
|
||
|
||
def _initialize(self):
|
||
"""初始化OpenAI客户端"""
|
||
if not self.base_url:
|
||
self.base_url = "https://api.openai.com/v1"
|
||
|
||
self.client = OpenAI(
|
||
api_key=self.api_key,
|
||
base_url=self.base_url
|
||
)
|
||
|
||
async def generate_text(self,
|
||
prompt: str,
|
||
system_prompt: Optional[str] = None,
|
||
temperature: float = 1.0,
|
||
max_tokens: Optional[int] = None,
|
||
response_format: Optional[str] = None,
|
||
**kwargs) -> str:
|
||
"""
|
||
使用OpenAI API生成文本
|
||
|
||
Args:
|
||
prompt: 用户提示词
|
||
system_prompt: 系统提示词
|
||
temperature: 生成温度
|
||
max_tokens: 最大token数
|
||
response_format: 响应格式 ('json' 或 None)
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
生成的文本内容
|
||
"""
|
||
# 构建消息列表
|
||
messages = self._build_messages(prompt, system_prompt)
|
||
|
||
# 构建请求参数
|
||
request_params = {
|
||
"model": self.model_name,
|
||
"messages": messages,
|
||
"temperature": temperature
|
||
}
|
||
|
||
if max_tokens:
|
||
request_params["max_tokens"] = max_tokens
|
||
|
||
# 处理JSON格式输出
|
||
if response_format == "json":
|
||
# 检查模型是否支持response_format
|
||
if self._supports_response_format():
|
||
request_params["response_format"] = {"type": "json_object"}
|
||
else:
|
||
# 对于不支持response_format的模型,在提示词中添加约束
|
||
messages[-1]["content"] += "\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。"
|
||
|
||
try:
|
||
# 发送API请求
|
||
response = await asyncio.to_thread(
|
||
self.client.chat.completions.create,
|
||
**request_params
|
||
)
|
||
|
||
# 提取生成的内容
|
||
if response.choices and len(response.choices) > 0:
|
||
content = response.choices[0].message.content
|
||
|
||
# 对于不支持response_format的模型,清理输出
|
||
if response_format == "json" and not self._supports_response_format():
|
||
content = self._clean_json_output(content)
|
||
|
||
logger.debug(f"OpenAI API调用成功,消耗tokens: {response.usage.total_tokens if response.usage else 'N/A'}")
|
||
return content
|
||
else:
|
||
raise APICallError("OpenAI API返回空响应")
|
||
|
||
except BadRequestError as e:
|
||
# 处理不支持response_format的情况
|
||
if "response_format" in str(e) and response_format == "json":
|
||
logger.warning(f"模型 {self.model_name} 不支持response_format,重试不带格式约束的请求")
|
||
request_params.pop("response_format", None)
|
||
messages[-1]["content"] += "\n\n请确保输出严格的JSON格式,不要包含任何其他文字或标记。"
|
||
|
||
response = await asyncio.to_thread(
|
||
self.client.chat.completions.create,
|
||
**request_params
|
||
)
|
||
|
||
if response.choices and len(response.choices) > 0:
|
||
content = response.choices[0].message.content
|
||
content = self._clean_json_output(content)
|
||
return content
|
||
else:
|
||
raise APICallError("OpenAI API返回空响应")
|
||
else:
|
||
raise APICallError(f"OpenAI API请求失败: {str(e)}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"OpenAI API调用失败: {str(e)}")
|
||
raise APICallError(f"OpenAI API调用失败: {str(e)}")
|
||
|
||
def _supports_response_format(self) -> bool:
|
||
"""检查模型是否支持response_format参数"""
|
||
# 已知不支持response_format的模型
|
||
unsupported_models = [
|
||
"deepseek-reasoner",
|
||
"deepseek-r1"
|
||
]
|
||
|
||
return not any(unsupported in self.model_name.lower() for unsupported in unsupported_models)
|
||
|
||
def _clean_json_output(self, output: str) -> str:
|
||
"""清理JSON输出,移除markdown标记等"""
|
||
import re
|
||
|
||
# 移除可能的markdown代码块标记
|
||
output = re.sub(r'^```json\s*', '', output, flags=re.MULTILINE)
|
||
output = re.sub(r'^```\s*$', '', output, flags=re.MULTILINE)
|
||
output = re.sub(r'^```.*$', '', output, flags=re.MULTILINE)
|
||
|
||
# 移除前后空白字符
|
||
output = output.strip()
|
||
|
||
return output
|
||
|
||
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""执行API调用 - 由于使用OpenAI SDK,这个方法主要用于兼容基类"""
|
||
# 这个方法在OpenAI提供商中不直接使用,因为我们使用OpenAI SDK
|
||
# 但为了兼容基类接口,保留此方法
|
||
pass
|