mirror of
https://github.com/OpenBMB/ChatDev.git
synced 2026-04-25 11:18:06 +00:00
195 lines
7.1 KiB
Python
Executable File
195 lines
7.1 KiB
Python
Executable File
from abc import ABC, abstractmethod
|
|
import re
|
|
import logging
|
|
from typing import List, Optional
|
|
|
|
import openai
|
|
from tenacity import (
|
|
retry,
|
|
stop_after_attempt,
|
|
wait_random_exponential,
|
|
)
|
|
|
|
from entity.configs import EmbeddingConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class EmbeddingBase(ABC):
|
|
def __init__(self, embedding_config: EmbeddingConfig):
|
|
self.config = embedding_config
|
|
|
|
@abstractmethod
|
|
def get_embedding(self, text):
|
|
...
|
|
|
|
def _preprocess_text(self, text: str) -> str:
|
|
"""Preprocess text to improve embedding quality."""
|
|
if not text:
|
|
return ""
|
|
|
|
# Remove extra whitespace
|
|
text = re.sub(r'\s+', ' ', text.strip())
|
|
|
|
# Remove special characters and emoji
|
|
text = re.sub(r'[^\w\s\u4e00-\u9fff]', ' ', text)
|
|
|
|
# Clean up whitespace again
|
|
text = re.sub(r'\s+', ' ', text.strip())
|
|
|
|
return text
|
|
|
|
def _chunk_text(self, text: str, max_length: int = 500) -> List[str]:
|
|
"""Split long text into chunks to improve embedding quality."""
|
|
if len(text) <= max_length:
|
|
return [text]
|
|
|
|
# Split by sentence boundaries
|
|
sentences = re.split(r'[\u3002\uff01\uff1f\uff1b\n]', text)
|
|
chunks = []
|
|
current_chunk = ""
|
|
|
|
for sentence in sentences:
|
|
sentence = sentence.strip()
|
|
if not sentence:
|
|
continue
|
|
|
|
if len(current_chunk + sentence) <= max_length:
|
|
current_chunk += sentence + "\u3002"
|
|
else:
|
|
if current_chunk:
|
|
chunks.append(current_chunk.strip())
|
|
current_chunk = sentence + "\u3002"
|
|
|
|
if current_chunk:
|
|
chunks.append(current_chunk.strip())
|
|
|
|
return chunks
|
|
|
|
class EmbeddingFactory:
|
|
@staticmethod
|
|
def create_embedding(embedding_config: EmbeddingConfig) -> EmbeddingBase:
|
|
model = embedding_config.provider
|
|
if model == 'openai':
|
|
return OpenAIEmbedding(embedding_config)
|
|
elif model == 'local':
|
|
return LocalEmbedding(embedding_config)
|
|
else:
|
|
raise ValueError(f"Unsupported embedding model: {model}")
|
|
|
|
class OpenAIEmbedding(EmbeddingBase):
|
|
def __init__(self, embedding_config: EmbeddingConfig):
|
|
super().__init__(embedding_config)
|
|
self.base_url = embedding_config.base_url
|
|
self.api_key = embedding_config.api_key
|
|
self.model_name = embedding_config.model or "text-embedding-3-small" # Default model
|
|
self.max_length = embedding_config.params.get('max_length', 8191)
|
|
self.use_chunking = embedding_config.params.get('use_chunking', False)
|
|
self.chunk_strategy = embedding_config.params.get('chunk_strategy', 'average')
|
|
self._fallback_dim = 1536 # Default; updated after first successful call
|
|
|
|
if self.base_url:
|
|
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
|
|
else:
|
|
self.client = openai.OpenAI(api_key=self.api_key)
|
|
|
|
@retry(wait=wait_random_exponential(min=2, max=5), stop=stop_after_attempt(10))
|
|
def get_embedding(self, text):
|
|
# Preprocess the text
|
|
processed_text = self._preprocess_text(text)
|
|
|
|
if not processed_text:
|
|
logger.warning("Empty text after preprocessing")
|
|
return [0.0] * self._fallback_dim
|
|
|
|
# Handle long text via chunking
|
|
if self.use_chunking and len(processed_text) > self.max_length:
|
|
return self._get_chunked_embedding(processed_text)
|
|
|
|
# Truncate text
|
|
truncated_text = processed_text[:self.max_length]
|
|
|
|
try:
|
|
response = self.client.embeddings.create(
|
|
input=truncated_text,
|
|
model=self.model_name,
|
|
encoding_format="float"
|
|
)
|
|
embedding = response.data[0].embedding
|
|
self._fallback_dim = len(embedding)
|
|
return embedding
|
|
except Exception as e:
|
|
logger.error(f"Error getting embedding: {e}")
|
|
return [0.0] * self._fallback_dim
|
|
|
|
def _get_chunked_embedding(self, text: str) -> List[float]:
|
|
"""Chunk long text, embed each chunk, then aggregate."""
|
|
chunks = self._chunk_text(text, self.max_length // 2) # Halve the chunk length
|
|
|
|
if not chunks:
|
|
return [0.0] * self._fallback_dim
|
|
|
|
chunk_embeddings = []
|
|
for chunk in chunks:
|
|
try:
|
|
response = self.client.embeddings.create(
|
|
input=chunk,
|
|
model=self.model_name,
|
|
encoding_format="float"
|
|
)
|
|
chunk_embeddings.append(response.data[0].embedding)
|
|
except Exception as e:
|
|
logger.warning(f"Error getting chunk embedding: {e}")
|
|
continue
|
|
|
|
if not chunk_embeddings:
|
|
return [0.0] * self._fallback_dim
|
|
|
|
# Aggregation strategy
|
|
if self.chunk_strategy == 'average':
|
|
# Mean aggregation
|
|
return [sum(chunk[i] for chunk in chunk_embeddings) / len(chunk_embeddings)
|
|
for i in range(len(chunk_embeddings[0]))]
|
|
elif self.chunk_strategy == 'weighted':
|
|
# Weighted aggregation (earlier chunks weigh more)
|
|
weights = [1.0 / (i + 1) for i in range(len(chunk_embeddings))]
|
|
total_weight = sum(weights)
|
|
return [sum(chunk[i] * weights[j] for j, chunk in enumerate(chunk_embeddings)) / total_weight
|
|
for i in range(len(chunk_embeddings[0]))]
|
|
else:
|
|
# Default to the first chunk
|
|
return chunk_embeddings[0]
|
|
|
|
class LocalEmbedding(EmbeddingBase):
|
|
def __init__(self, embedding_config: EmbeddingConfig):
|
|
super().__init__(embedding_config)
|
|
self.model_path = embedding_config.params.get('model_path')
|
|
self.device = embedding_config.params.get('device', 'cpu')
|
|
self._fallback_dim = 768 # Default; updated after first successful call
|
|
|
|
if not self.model_path:
|
|
raise ValueError("LocalEmbedding requires model_path parameter")
|
|
|
|
# Load the local embedding model (e.g., sentence-transformers)
|
|
try:
|
|
from sentence_transformers import SentenceTransformer
|
|
self.model = SentenceTransformer(self.model_path, device=self.device)
|
|
except ImportError:
|
|
raise ImportError("sentence-transformers is required for LocalEmbedding")
|
|
|
|
def get_embedding(self, text):
|
|
# Preprocess text before encoding
|
|
processed_text = self._preprocess_text(text)
|
|
|
|
if not processed_text:
|
|
return [0.0] * self._fallback_dim
|
|
|
|
try:
|
|
embedding = self.model.encode(processed_text, convert_to_tensor=False)
|
|
result = embedding.tolist()
|
|
self._fallback_dim = len(result)
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Error getting local embedding: {e}")
|
|
return [0.0] * self._fallback_dim
|