from abc import ABC, abstractmethod import re import logging from typing import List, Optional import openai from tenacity import ( retry, stop_after_attempt, wait_random_exponential, ) from entity.configs import EmbeddingConfig logger = logging.getLogger(__name__) class EmbeddingBase(ABC): def __init__(self, embedding_config: EmbeddingConfig): self.config = embedding_config @abstractmethod def get_embedding(self, text): ... def _preprocess_text(self, text: str) -> str: """Preprocess text to improve embedding quality.""" if not text: return "" # Remove extra whitespace text = re.sub(r'\s+', ' ', text.strip()) # Remove special characters and emoji text = re.sub(r'[^\w\s\u4e00-\u9fff]', ' ', text) # Clean up whitespace again text = re.sub(r'\s+', ' ', text.strip()) return text def _chunk_text(self, text: str, max_length: int = 500) -> List[str]: """Split long text into chunks to improve embedding quality.""" if len(text) <= max_length: return [text] # Split by sentence boundaries sentences = re.split(r'[\u3002\uff01\uff1f\uff1b\n]', text) chunks = [] current_chunk = "" for sentence in sentences: sentence = sentence.strip() if not sentence: continue if len(current_chunk + sentence) <= max_length: current_chunk += sentence + "\u3002" else: if current_chunk: chunks.append(current_chunk.strip()) current_chunk = sentence + "\u3002" if current_chunk: chunks.append(current_chunk.strip()) return chunks class EmbeddingFactory: @staticmethod def create_embedding(embedding_config: EmbeddingConfig) -> EmbeddingBase: model = embedding_config.provider if model == 'openai': return OpenAIEmbedding(embedding_config) elif model == 'local': return LocalEmbedding(embedding_config) else: raise ValueError(f"Unsupported embedding model: {model}") class OpenAIEmbedding(EmbeddingBase): def __init__(self, embedding_config: EmbeddingConfig): super().__init__(embedding_config) self.base_url = embedding_config.base_url self.api_key = embedding_config.api_key self.model_name = embedding_config.model or "text-embedding-3-small" # Default model self.max_length = embedding_config.params.get('max_length', 8191) self.use_chunking = embedding_config.params.get('use_chunking', False) self.chunk_strategy = embedding_config.params.get('chunk_strategy', 'average') self._fallback_dim = 1536 # Default; updated after first successful call if self.base_url: self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url) else: self.client = openai.OpenAI(api_key=self.api_key) @retry(wait=wait_random_exponential(min=2, max=5), stop=stop_after_attempt(10)) def get_embedding(self, text): # Preprocess the text processed_text = self._preprocess_text(text) if not processed_text: logger.warning("Empty text after preprocessing") return [0.0] * self._fallback_dim # Handle long text via chunking if self.use_chunking and len(processed_text) > self.max_length: return self._get_chunked_embedding(processed_text) # Truncate text truncated_text = processed_text[:self.max_length] try: response = self.client.embeddings.create( input=truncated_text, model=self.model_name, encoding_format="float" ) embedding = response.data[0].embedding self._fallback_dim = len(embedding) return embedding except Exception as e: logger.error(f"Error getting embedding: {e}") return [0.0] * self._fallback_dim def _get_chunked_embedding(self, text: str) -> List[float]: """Chunk long text, embed each chunk, then aggregate.""" chunks = self._chunk_text(text, self.max_length // 2) # Halve the chunk length if not chunks: return [0.0] * self._fallback_dim chunk_embeddings = [] for chunk in chunks: try: response = self.client.embeddings.create( input=chunk, model=self.model_name, encoding_format="float" ) chunk_embeddings.append(response.data[0].embedding) except Exception as e: logger.warning(f"Error getting chunk embedding: {e}") continue if not chunk_embeddings: return [0.0] * self._fallback_dim # Aggregation strategy if self.chunk_strategy == 'average': # Mean aggregation return [sum(chunk[i] for chunk in chunk_embeddings) / len(chunk_embeddings) for i in range(len(chunk_embeddings[0]))] elif self.chunk_strategy == 'weighted': # Weighted aggregation (earlier chunks weigh more) weights = [1.0 / (i + 1) for i in range(len(chunk_embeddings))] total_weight = sum(weights) return [sum(chunk[i] * weights[j] for j, chunk in enumerate(chunk_embeddings)) / total_weight for i in range(len(chunk_embeddings[0]))] else: # Default to the first chunk return chunk_embeddings[0] class LocalEmbedding(EmbeddingBase): def __init__(self, embedding_config: EmbeddingConfig): super().__init__(embedding_config) self.model_path = embedding_config.params.get('model_path') self.device = embedding_config.params.get('device', 'cpu') self._fallback_dim = 768 # Default; updated after first successful call if not self.model_path: raise ValueError("LocalEmbedding requires model_path parameter") # Load the local embedding model (e.g., sentence-transformers) try: from sentence_transformers import SentenceTransformer self.model = SentenceTransformer(self.model_path, device=self.device) except ImportError: raise ImportError("sentence-transformers is required for LocalEmbedding") def get_embedding(self, text): # Preprocess text before encoding processed_text = self._preprocess_text(text) if not processed_text: return [0.0] * self._fallback_dim try: embedding = self.model.encode(processed_text, convert_to_tensor=False) result = embedding.tolist() self._fallback_dim = len(result) return result except Exception as e: logger.error(f"Error getting local embedding: {e}") return [0.0] * self._fallback_dim