2026-03-16 16:27:37 +07:00

195 lines
7.1 KiB
Python
Executable File

from abc import ABC, abstractmethod
import re
import logging
from typing import List, Optional
import openai
from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential,
)
from entity.configs import EmbeddingConfig
logger = logging.getLogger(__name__)
class EmbeddingBase(ABC):
def __init__(self, embedding_config: EmbeddingConfig):
self.config = embedding_config
@abstractmethod
def get_embedding(self, text):
...
def _preprocess_text(self, text: str) -> str:
"""Preprocess text to improve embedding quality."""
if not text:
return ""
# Remove extra whitespace
text = re.sub(r'\s+', ' ', text.strip())
# Remove special characters and emoji
text = re.sub(r'[^\w\s\u4e00-\u9fff]', ' ', text)
# Clean up whitespace again
text = re.sub(r'\s+', ' ', text.strip())
return text
def _chunk_text(self, text: str, max_length: int = 500) -> List[str]:
"""Split long text into chunks to improve embedding quality."""
if len(text) <= max_length:
return [text]
# Split by sentence boundaries
sentences = re.split(r'[\u3002\uff01\uff1f\uff1b\n]', text)
chunks = []
current_chunk = ""
for sentence in sentences:
sentence = sentence.strip()
if not sentence:
continue
if len(current_chunk + sentence) <= max_length:
current_chunk += sentence + "\u3002"
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence + "\u3002"
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
class EmbeddingFactory:
@staticmethod
def create_embedding(embedding_config: EmbeddingConfig) -> EmbeddingBase:
model = embedding_config.provider
if model == 'openai':
return OpenAIEmbedding(embedding_config)
elif model == 'local':
return LocalEmbedding(embedding_config)
else:
raise ValueError(f"Unsupported embedding model: {model}")
class OpenAIEmbedding(EmbeddingBase):
def __init__(self, embedding_config: EmbeddingConfig):
super().__init__(embedding_config)
self.base_url = embedding_config.base_url
self.api_key = embedding_config.api_key
self.model_name = embedding_config.model or "text-embedding-3-small" # Default model
self.max_length = embedding_config.params.get('max_length', 8191)
self.use_chunking = embedding_config.params.get('use_chunking', False)
self.chunk_strategy = embedding_config.params.get('chunk_strategy', 'average')
self._fallback_dim = 1536 # Default; updated after first successful call
if self.base_url:
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
else:
self.client = openai.OpenAI(api_key=self.api_key)
@retry(wait=wait_random_exponential(min=2, max=5), stop=stop_after_attempt(10))
def get_embedding(self, text):
# Preprocess the text
processed_text = self._preprocess_text(text)
if not processed_text:
logger.warning("Empty text after preprocessing")
return [0.0] * self._fallback_dim
# Handle long text via chunking
if self.use_chunking and len(processed_text) > self.max_length:
return self._get_chunked_embedding(processed_text)
# Truncate text
truncated_text = processed_text[:self.max_length]
try:
response = self.client.embeddings.create(
input=truncated_text,
model=self.model_name,
encoding_format="float"
)
embedding = response.data[0].embedding
self._fallback_dim = len(embedding)
return embedding
except Exception as e:
logger.error(f"Error getting embedding: {e}")
return [0.0] * self._fallback_dim
def _get_chunked_embedding(self, text: str) -> List[float]:
"""Chunk long text, embed each chunk, then aggregate."""
chunks = self._chunk_text(text, self.max_length // 2) # Halve the chunk length
if not chunks:
return [0.0] * self._fallback_dim
chunk_embeddings = []
for chunk in chunks:
try:
response = self.client.embeddings.create(
input=chunk,
model=self.model_name,
encoding_format="float"
)
chunk_embeddings.append(response.data[0].embedding)
except Exception as e:
logger.warning(f"Error getting chunk embedding: {e}")
continue
if not chunk_embeddings:
return [0.0] * self._fallback_dim
# Aggregation strategy
if self.chunk_strategy == 'average':
# Mean aggregation
return [sum(chunk[i] for chunk in chunk_embeddings) / len(chunk_embeddings)
for i in range(len(chunk_embeddings[0]))]
elif self.chunk_strategy == 'weighted':
# Weighted aggregation (earlier chunks weigh more)
weights = [1.0 / (i + 1) for i in range(len(chunk_embeddings))]
total_weight = sum(weights)
return [sum(chunk[i] * weights[j] for j, chunk in enumerate(chunk_embeddings)) / total_weight
for i in range(len(chunk_embeddings[0]))]
else:
# Default to the first chunk
return chunk_embeddings[0]
class LocalEmbedding(EmbeddingBase):
def __init__(self, embedding_config: EmbeddingConfig):
super().__init__(embedding_config)
self.model_path = embedding_config.params.get('model_path')
self.device = embedding_config.params.get('device', 'cpu')
self._fallback_dim = 768 # Default; updated after first successful call
if not self.model_path:
raise ValueError("LocalEmbedding requires model_path parameter")
# Load the local embedding model (e.g., sentence-transformers)
try:
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(self.model_path, device=self.device)
except ImportError:
raise ImportError("sentence-transformers is required for LocalEmbedding")
def get_embedding(self, text):
# Preprocess text before encoding
processed_text = self._preprocess_text(text)
if not processed_text:
return [0.0] * self._fallback_dim
try:
embedding = self.model.encode(processed_text, convert_to_tensor=False)
result = embedding.tolist()
self._fallback_dim = len(result)
return result
except Exception as e:
logger.error(f"Error getting local embedding: {e}")
return [0.0] * self._fallback_dim