viccy e6e39d2dcd feat(short-drama): 完整实现短剧解说剪辑全流程并新增LLM流式生成支持
- 新增短剧解说全流程四类提示词模板:解说文案生成、片段规划、文案画面匹配、脚本修复
- 重构原有脚本生成提示词至v2.1,改为基于上游规划片段生成合规解说脚本
- 为LLM基础服务层新增流式文本生成接口,完善OpenAI兼容提供商的流式实现,支持流式回调与推理内容提取
- 重构OpenAI兼容文本提供商的生成逻辑,提取公共参数构建方法
- 新增多语言国际化文案,覆盖解说语言、短剧类型、原片占比等配置项与交互提示
- 新增多套单元测试,覆盖脚本校验、适配器流程、工具函数等模块
- 封装SubtitleAnalyzerAdapter,统一短剧解说脚本生成的整套业务接口
- 新增前端交互所需的解说文案审核相关提示文案
2026-06-07 17:10:48 +08:00

213 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
大模型服务提供商基类定义
定义了统一的大模型服务接口,包括视觉模型和文本生成模型的抽象基类
"""
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Union
from pathlib import Path
import PIL.Image
from loguru import logger
from .exceptions import LLMServiceError, ConfigurationError
class BaseLLMProvider(ABC):
"""大模型服务提供商基类"""
def __init__(self,
api_key: str,
model_name: str,
base_url: Optional[str] = None,
**kwargs):
"""
初始化大模型服务提供商
Args:
api_key: API密钥
model_name: 模型名称
base_url: API基础URL
**kwargs: 其他配置参数
"""
self.api_key = api_key
self.model_name = model_name
self.base_url = base_url
self.config = kwargs
# 验证必要配置
self._validate_config()
# 初始化提供商特定设置
self._initialize()
@property
@abstractmethod
def provider_name(self) -> str:
"""供应商名称"""
pass
@property
@abstractmethod
def supported_models(self) -> List[str]:
"""支持的模型列表"""
pass
def _validate_config(self):
"""验证配置参数"""
if not self.api_key:
raise ConfigurationError("API密钥不能为空", "api_key")
if not self.model_name:
raise ConfigurationError("模型名称不能为空", "model_name")
# 检查模型支持情况
self._validate_model_support()
def _validate_model_support(self):
"""验证模型支持情况(宽松模式,仅记录警告)"""
from loguru import logger
# OpenAI 兼容网关的模型数量较多,运行时由远端完成最终校验
if self.model_name not in self.supported_models:
logger.warning(
f"模型 {self.model_name} 未在供应商 {self.provider_name} 的预定义支持列表中。"
f"支持的模型列表: {self.supported_models}"
)
def _initialize(self):
"""初始化提供商特定设置,子类可重写"""
pass
@abstractmethod
async def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""执行API调用子类必须实现"""
pass
def _handle_api_error(self, status_code: int, response_text: str) -> LLMServiceError:
"""处理API错误返回适当的异常"""
from .exceptions import APICallError, RateLimitError, AuthenticationError
if status_code == 401:
return AuthenticationError()
elif status_code == 429:
return RateLimitError()
elif status_code in [502, 503, 504]:
return APICallError(f"服务器错误 HTTP {status_code}", status_code, response_text)
elif status_code == 524:
return APICallError(f"服务器处理超时 HTTP {status_code}", status_code, response_text)
else:
return APICallError(f"HTTP {status_code}", status_code, response_text)
class VisionModelProvider(BaseLLMProvider):
"""视觉模型提供商基类"""
@abstractmethod
async def analyze_images(self,
images: List[Union[str, Path, PIL.Image.Image]],
prompt: str,
batch_size: int = 10,
max_concurrency: int = 1,
**kwargs) -> List[str]:
"""
分析图片并返回结果
Args:
images: 图片路径列表或PIL图片对象列表
prompt: 分析提示词
batch_size: 批处理大小
max_concurrency: 最大并发批次数(实现支持时生效)
**kwargs: 其他参数
Returns:
分析结果列表
"""
pass
def _prepare_images(self, images: List[Union[str, Path, PIL.Image.Image]]) -> List[PIL.Image.Image]:
"""预处理图片统一转换为PIL.Image对象"""
processed_images = []
for img in images:
try:
if isinstance(img, (str, Path)):
pil_img = PIL.Image.open(img)
elif isinstance(img, PIL.Image.Image):
pil_img = img
else:
logger.warning(f"不支持的图片类型: {type(img)}")
continue
# 调整图片大小以优化性能
if pil_img.size[0] > 1024 or pil_img.size[1] > 1024:
pil_img.thumbnail((1024, 1024), PIL.Image.Resampling.LANCZOS)
processed_images.append(pil_img)
except Exception as e:
logger.error(f"加载图片失败 {img}: {str(e)}")
continue
return processed_images
class TextModelProvider(BaseLLMProvider):
"""文本生成模型提供商基类"""
@abstractmethod
async def generate_text(self,
prompt: str,
system_prompt: Optional[str] = None,
temperature: float = 1.0,
max_tokens: Optional[int] = None,
response_format: Optional[str] = None,
**kwargs) -> str:
"""
生成文本内容
Args:
prompt: 用户提示词
system_prompt: 系统提示词
temperature: 生成温度
max_tokens: 最大token数
response_format: 响应格式 ('json' 或 None)
**kwargs: 其他参数
Returns:
生成的文本内容
"""
pass
async def generate_text_stream(self,
prompt: str,
system_prompt: Optional[str] = None,
temperature: float = 1.0,
max_tokens: Optional[int] = None,
response_format: Optional[str] = None,
on_chunk=None,
**kwargs) -> str:
"""生成文本内容并尽可能回调流式片段;默认退化为一次性输出。"""
result = await self.generate_text(
prompt=prompt,
system_prompt=system_prompt,
temperature=temperature,
max_tokens=max_tokens,
response_format=response_format,
**kwargs,
)
if on_chunk:
on_chunk({"type": "content", "text": result})
return result
def _build_messages(self, prompt: str, system_prompt: Optional[str] = None) -> List[Dict[str, str]]:
"""构建消息列表"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
return messages