mirror of
https://github.com/OpenBMB/ChatDev.git
synced 2026-04-25 11:18:06 +00:00
486 lines
16 KiB
Python
Executable File
486 lines
16 KiB
Python
Executable File
"""
|
||
FileMemory: Memory system for vectorizing and retrieving file contents
|
||
"""
|
||
import json
|
||
import os
|
||
import hashlib
|
||
import logging
|
||
from pathlib import Path
|
||
from typing import List, Dict, Any
|
||
import time
|
||
|
||
import faiss
|
||
import numpy as np
|
||
|
||
from runtime.node.agent.memory.memory_base import (
|
||
MemoryBase,
|
||
MemoryContentSnapshot,
|
||
MemoryItem,
|
||
MemoryWritePayload,
|
||
)
|
||
from entity.configs import MemoryStoreConfig, FileSourceConfig
|
||
from entity.configs.node.memory import FileMemoryConfig
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class FileMemory(MemoryBase):
|
||
"""
|
||
File-based memory system that indexes and retrieves content from files/directories.
|
||
Supports multiple file types, chunking strategies, and incremental updates.
|
||
"""
|
||
|
||
def __init__(self, store: MemoryStoreConfig):
|
||
config = store.as_config(FileMemoryConfig)
|
||
if not config:
|
||
raise ValueError("FileMemory requires a file memory store configuration")
|
||
super().__init__(store)
|
||
|
||
if not config.file_sources:
|
||
raise ValueError("FileMemory requires at least one file_source in configuration")
|
||
|
||
self.file_config = config
|
||
self.file_sources: List[FileSourceConfig] = config.file_sources
|
||
self.index_path = self.file_config.index_path # Path to store the index
|
||
|
||
# Chunking configuration
|
||
self.chunk_size = 500 # Characters per chunk
|
||
self.chunk_overlap = 50 # Overlapping characters between chunks
|
||
|
||
# File metadata cache {file_path: {hash, chunks_count, ...}}
|
||
self.file_metadata: Dict[str, Dict[str, Any]] = {}
|
||
|
||
def load(self) -> None:
|
||
"""
|
||
Load existing index or build new one from file sources.
|
||
Validates index integrity and performs incremental updates if needed.
|
||
"""
|
||
if self.index_path and os.path.exists(self.index_path):
|
||
logger.info(f"Loading existing index from {self.index_path}")
|
||
self._load_from_file()
|
||
|
||
# Validate and update if files changed
|
||
if self._validate_and_update_index():
|
||
logger.info("Index updated due to file changes")
|
||
self.save()
|
||
else:
|
||
logger.info("Building new index from file sources")
|
||
self._build_index_from_sources()
|
||
if self.index_path:
|
||
self.save()
|
||
|
||
def save(self) -> None:
|
||
"""Persist the memory index to disk"""
|
||
if not self.index_path:
|
||
logger.warning("No index_path specified, skipping save")
|
||
return
|
||
|
||
# Ensure directory exists
|
||
os.makedirs(os.path.dirname(self.index_path), exist_ok=True)
|
||
|
||
# Prepare data for serialization
|
||
data = {
|
||
"file_metadata": self.file_metadata,
|
||
"contents": [item.to_dict() for item in self.contents],
|
||
"config": {
|
||
"chunk_size": self.chunk_size,
|
||
"chunk_overlap": self.chunk_overlap,
|
||
}
|
||
}
|
||
|
||
# Save to JSON
|
||
with open(self.index_path, 'w', encoding='utf-8') as f:
|
||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||
|
||
logger.info(f"Index saved to {self.index_path} ({len(self.contents)} chunks)")
|
||
|
||
def retrieve(
|
||
self,
|
||
agent_role: str,
|
||
query: MemoryContentSnapshot,
|
||
top_k: int,
|
||
similarity_threshold: float,
|
||
) -> List[MemoryItem]:
|
||
"""
|
||
Retrieve relevant file chunks based on query.
|
||
|
||
Args:
|
||
agent_role: Agent role (not used in file memory)
|
||
inputs: Query text
|
||
top_k: Number of results to return
|
||
similarity_threshold: Minimum similarity score
|
||
|
||
Returns:
|
||
List of MemoryItem with file chunks
|
||
"""
|
||
if self.count_memories() == 0:
|
||
return []
|
||
|
||
# Generate query embedding
|
||
query_embedding = self.embedding.get_embedding(query.text)
|
||
if isinstance(query_embedding, list):
|
||
query_embedding = np.array(query_embedding, dtype=np.float32)
|
||
query_embedding = query_embedding.reshape(1, -1)
|
||
faiss.normalize_L2(query_embedding)
|
||
|
||
expected_dim = query_embedding.shape[1]
|
||
|
||
# Collect embeddings from memory items
|
||
memory_embeddings = []
|
||
valid_items = []
|
||
for item in self.contents:
|
||
if item.embedding is not None:
|
||
if len(item.embedding) != expected_dim:
|
||
logger.warning(
|
||
"Skipping memory item %s: embedding dim %d != expected %d",
|
||
item.id, len(item.embedding), expected_dim,
|
||
)
|
||
continue
|
||
memory_embeddings.append(item.embedding)
|
||
valid_items.append(item)
|
||
|
||
if not memory_embeddings:
|
||
return []
|
||
|
||
memory_embeddings = np.array(memory_embeddings, dtype=np.float32)
|
||
|
||
# Build FAISS index and search
|
||
index = faiss.IndexFlatIP(memory_embeddings.shape[1])
|
||
index.add(memory_embeddings)
|
||
|
||
similarities, indices = index.search(query_embedding, min(top_k, len(valid_items)))
|
||
|
||
# Filter by threshold and return results
|
||
results = []
|
||
for i in range(len(indices[0])):
|
||
idx = indices[0][i]
|
||
similarity = similarities[0][i]
|
||
|
||
if idx != -1 and similarity >= similarity_threshold:
|
||
results.append(valid_items[idx])
|
||
|
||
return results
|
||
|
||
def update(self, payload: MemoryWritePayload) -> None:
|
||
"""
|
||
FileMemory is read-only, updates are not supported.
|
||
This method is a no-op to maintain interface compatibility.
|
||
"""
|
||
logger.debug("FileMemory.update() called but FileMemory is read-only")
|
||
pass
|
||
|
||
# ========== Private Helper Methods ==========
|
||
|
||
def _load_from_file(self) -> None:
|
||
"""Load index from JSON file"""
|
||
try:
|
||
with open(self.index_path, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
self.file_metadata = data.get("file_metadata", {})
|
||
raw_contents = data.get("contents", [])
|
||
contents: List[MemoryItem] = []
|
||
for raw in raw_contents:
|
||
try:
|
||
contents.append(MemoryItem.from_dict(raw))
|
||
except Exception:
|
||
continue
|
||
self.contents = contents
|
||
|
||
# Load config if present
|
||
config = data.get("config", {})
|
||
self.chunk_size = config.get("chunk_size", self.chunk_size)
|
||
self.chunk_overlap = config.get("chunk_overlap", self.chunk_overlap)
|
||
|
||
logger.info(f"Loaded {len(self.contents)} chunks from index")
|
||
except Exception as e:
|
||
logger.error(f"Error loading index: {e}")
|
||
self.file_metadata = {}
|
||
self.contents = []
|
||
|
||
def _build_index_from_sources(self) -> None:
|
||
"""Build index by scanning all file sources"""
|
||
all_chunks = []
|
||
|
||
for source in self.file_sources:
|
||
logger.info(f"Scanning source: {source.source_path}")
|
||
files = self._scan_files(source)
|
||
logger.info(f"Found {len(files)} files in {source.source_path}")
|
||
|
||
for file_path in files:
|
||
chunks = self._read_and_chunk_file(file_path, source.encoding)
|
||
all_chunks.extend(chunks)
|
||
|
||
logger.info(f"Total chunks to index: {len(all_chunks)}")
|
||
|
||
# Generate embeddings for all chunks
|
||
self.contents = self._build_embeddings(all_chunks)
|
||
|
||
logger.info(f"Index built with {len(self.contents)} chunks")
|
||
|
||
def _validate_and_update_index(self) -> bool:
|
||
"""
|
||
Validate index integrity and update if files changed.
|
||
|
||
Returns:
|
||
True if index was updated, False otherwise
|
||
"""
|
||
updated = False
|
||
current_files = set()
|
||
|
||
# Scan current files
|
||
for source in self.file_sources:
|
||
files = self._scan_files(source)
|
||
current_files.update(files)
|
||
|
||
# Check for deleted files
|
||
indexed_files = set(self.file_metadata.keys())
|
||
deleted_files = indexed_files - current_files
|
||
|
||
if deleted_files:
|
||
logger.info(f"Removing {len(deleted_files)} deleted files from index")
|
||
self._remove_files_from_index(deleted_files)
|
||
updated = True
|
||
|
||
# Check for new or modified files
|
||
for source in self.file_sources:
|
||
files = self._scan_files(source)
|
||
|
||
for file_path in files:
|
||
file_hash = self._compute_file_hash(file_path)
|
||
|
||
# New file
|
||
if file_path not in self.file_metadata:
|
||
logger.info(f"Indexing new file: {file_path}")
|
||
self._index_file(file_path, source.encoding)
|
||
updated = True
|
||
|
||
# Modified file
|
||
elif self.file_metadata[file_path].get("hash") != file_hash:
|
||
logger.info(f"Re-indexing modified file: {file_path}")
|
||
self._remove_files_from_index([file_path])
|
||
self._index_file(file_path, source.encoding)
|
||
updated = True
|
||
|
||
return updated
|
||
|
||
def _scan_files(self, source: FileSourceConfig) -> List[str]:
|
||
"""
|
||
Scan file path and return list of matching files.
|
||
|
||
Args:
|
||
source: FileSourceConfig with path and filters
|
||
|
||
Returns:
|
||
List of absolute file paths
|
||
"""
|
||
path = Path(source.source_path).expanduser().resolve()
|
||
|
||
# Single file
|
||
if path.is_file():
|
||
if self._matches_file_types(path, source.file_types):
|
||
return [str(path)]
|
||
return []
|
||
|
||
# Directory
|
||
if not path.is_dir():
|
||
logger.warning(f"Path does not exist: {source.source_path}")
|
||
return []
|
||
|
||
files = []
|
||
|
||
if source.recursive:
|
||
# Recursive scan
|
||
for file_path in path.rglob("*"):
|
||
if file_path.is_file() and self._matches_file_types(file_path, source.file_types):
|
||
files.append(str(file_path))
|
||
else:
|
||
# Non-recursive scan
|
||
for file_path in path.glob("*"):
|
||
if file_path.is_file() and self._matches_file_types(file_path, source.file_types):
|
||
files.append(str(file_path))
|
||
|
||
return files
|
||
|
||
def _matches_file_types(self, file_path: Path, file_types: List[str]) -> bool:
|
||
"""Check if file matches the file type filter"""
|
||
if file_types is None:
|
||
return True
|
||
return file_path.suffix in file_types
|
||
|
||
def _read_and_chunk_file(self, file_path: str, encoding: str = "utf-8") -> List[Dict]:
|
||
"""
|
||
Read file and split into chunks.
|
||
|
||
Args:
|
||
file_path: Path to file
|
||
encoding: File encoding
|
||
|
||
Returns:
|
||
List of chunk dictionaries with content and metadata
|
||
"""
|
||
try:
|
||
with open(file_path, 'r', encoding=encoding, errors='ignore') as f:
|
||
content = f.read()
|
||
except Exception as e:
|
||
logger.error(f"Error reading file {file_path}: {e}")
|
||
return []
|
||
|
||
if not content.strip():
|
||
return []
|
||
|
||
# Compute file hash
|
||
file_hash = self._compute_file_hash(file_path)
|
||
file_size = os.path.getsize(file_path)
|
||
|
||
# Chunk the content
|
||
chunks = self._chunk_text(content)
|
||
|
||
# Build chunk metadata
|
||
chunk_dicts = []
|
||
for i, chunk_text in enumerate(chunks):
|
||
chunk_dicts.append({
|
||
"content": chunk_text,
|
||
"metadata": {
|
||
"source_type": "file",
|
||
"file_path": file_path,
|
||
"file_name": os.path.basename(file_path),
|
||
"file_hash": file_hash,
|
||
"file_size": file_size,
|
||
"chunk_index": i,
|
||
"total_chunks": len(chunks),
|
||
"encoding": encoding,
|
||
}
|
||
})
|
||
|
||
# Update file metadata cache
|
||
self.file_metadata[file_path] = {
|
||
"hash": file_hash,
|
||
"size": file_size,
|
||
"chunks_count": len(chunks),
|
||
"indexed_at": time.time(),
|
||
}
|
||
|
||
return chunk_dicts
|
||
|
||
def _chunk_text(self, text: str) -> List[str]:
|
||
"""
|
||
Split text into chunks with overlap.
|
||
|
||
Args:
|
||
text: Input text
|
||
|
||
Returns:
|
||
List of text chunks
|
||
"""
|
||
if len(text) <= self.chunk_size:
|
||
return [text]
|
||
|
||
chunks = []
|
||
start = 0
|
||
|
||
while start < len(text):
|
||
end = start + self.chunk_size
|
||
chunk = text[start:end]
|
||
|
||
# Try to break at sentence boundary
|
||
if end < len(text):
|
||
# Look for sentence endings
|
||
last_sentence = max(
|
||
chunk.rfind('。'),
|
||
chunk.rfind('!'),
|
||
chunk.rfind('?'),
|
||
chunk.rfind('.'),
|
||
chunk.rfind('!'),
|
||
chunk.rfind('?'),
|
||
chunk.rfind('\n')
|
||
)
|
||
|
||
if last_sentence > self.chunk_size // 2: # Don't break too early
|
||
chunk = chunk[:last_sentence + 1]
|
||
end = start + last_sentence + 1
|
||
|
||
chunks.append(chunk.strip())
|
||
|
||
# Move start with overlap
|
||
start = end - self.chunk_overlap
|
||
|
||
if start >= len(text):
|
||
break
|
||
|
||
return [c for c in chunks if c] # Filter empty chunks
|
||
|
||
def _build_embeddings(self, chunks: List[Dict]) -> List[MemoryItem]:
|
||
"""
|
||
Generate embeddings for chunks and create MemoryItems.
|
||
|
||
Args:
|
||
chunks: List of chunk dictionaries
|
||
|
||
Returns:
|
||
List of MemoryItem objects
|
||
"""
|
||
memory_items = []
|
||
|
||
for chunk_dict in chunks:
|
||
content = chunk_dict["content"]
|
||
metadata = chunk_dict["metadata"]
|
||
|
||
# Generate embedding
|
||
try:
|
||
embedding = self.embedding.get_embedding(content)
|
||
if isinstance(embedding, list):
|
||
embedding = np.array(embedding, dtype=np.float32).reshape(1, -1)
|
||
faiss.normalize_L2(embedding)
|
||
embedding_list = embedding.tolist()[0]
|
||
except Exception as e:
|
||
logger.error(f"Error generating embedding for chunk: {e}")
|
||
continue
|
||
|
||
# Create MemoryItem
|
||
item_id = f"{metadata['file_hash']}_{metadata['chunk_index']}"
|
||
memory_item = MemoryItem(
|
||
id=item_id,
|
||
content_summary=content,
|
||
metadata=metadata,
|
||
embedding=embedding_list,
|
||
timestamp=time.time(),
|
||
)
|
||
|
||
memory_items.append(memory_item)
|
||
|
||
return memory_items
|
||
|
||
def _compute_file_hash(self, file_path: str) -> str:
|
||
"""Compute MD5 hash of file"""
|
||
hash_md5 = hashlib.md5()
|
||
try:
|
||
with open(file_path, "rb") as f:
|
||
for chunk in iter(lambda: f.read(4096), b""):
|
||
hash_md5.update(chunk)
|
||
return hash_md5.hexdigest()[:16]
|
||
except Exception as e:
|
||
logger.error(f"Error computing hash for {file_path}: {e}")
|
||
return "error"
|
||
|
||
def _index_file(self, file_path: str, encoding: str = "utf-8") -> None:
|
||
"""Index a single file (helper for incremental updates)"""
|
||
chunks = self._read_and_chunk_file(file_path, encoding)
|
||
if chunks:
|
||
new_items = self._build_embeddings(chunks)
|
||
self.contents.extend(new_items)
|
||
|
||
def _remove_files_from_index(self, file_paths: List[str]) -> None:
|
||
"""Remove chunks from deleted files"""
|
||
file_paths_set = set(file_paths)
|
||
|
||
# Filter out chunks from deleted files
|
||
self.contents = [
|
||
item for item in self.contents
|
||
if item.metadata.get("file_path") not in file_paths_set
|
||
]
|
||
|
||
# Remove from metadata
|
||
for file_path in file_paths:
|
||
self.file_metadata.pop(file_path, None)
|