ChatDev/runtime/node/agent/memory/file_memory.py
2026-03-16 16:27:37 +07:00

486 lines
16 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
FileMemory: Memory system for vectorizing and retrieving file contents
"""
import json
import os
import hashlib
import logging
from pathlib import Path
from typing import List, Dict, Any
import time
import faiss
import numpy as np
from runtime.node.agent.memory.memory_base import (
MemoryBase,
MemoryContentSnapshot,
MemoryItem,
MemoryWritePayload,
)
from entity.configs import MemoryStoreConfig, FileSourceConfig
from entity.configs.node.memory import FileMemoryConfig
logger = logging.getLogger(__name__)
class FileMemory(MemoryBase):
"""
File-based memory system that indexes and retrieves content from files/directories.
Supports multiple file types, chunking strategies, and incremental updates.
"""
def __init__(self, store: MemoryStoreConfig):
config = store.as_config(FileMemoryConfig)
if not config:
raise ValueError("FileMemory requires a file memory store configuration")
super().__init__(store)
if not config.file_sources:
raise ValueError("FileMemory requires at least one file_source in configuration")
self.file_config = config
self.file_sources: List[FileSourceConfig] = config.file_sources
self.index_path = self.file_config.index_path # Path to store the index
# Chunking configuration
self.chunk_size = 500 # Characters per chunk
self.chunk_overlap = 50 # Overlapping characters between chunks
# File metadata cache {file_path: {hash, chunks_count, ...}}
self.file_metadata: Dict[str, Dict[str, Any]] = {}
def load(self) -> None:
"""
Load existing index or build new one from file sources.
Validates index integrity and performs incremental updates if needed.
"""
if self.index_path and os.path.exists(self.index_path):
logger.info(f"Loading existing index from {self.index_path}")
self._load_from_file()
# Validate and update if files changed
if self._validate_and_update_index():
logger.info("Index updated due to file changes")
self.save()
else:
logger.info("Building new index from file sources")
self._build_index_from_sources()
if self.index_path:
self.save()
def save(self) -> None:
"""Persist the memory index to disk"""
if not self.index_path:
logger.warning("No index_path specified, skipping save")
return
# Ensure directory exists
os.makedirs(os.path.dirname(self.index_path), exist_ok=True)
# Prepare data for serialization
data = {
"file_metadata": self.file_metadata,
"contents": [item.to_dict() for item in self.contents],
"config": {
"chunk_size": self.chunk_size,
"chunk_overlap": self.chunk_overlap,
}
}
# Save to JSON
with open(self.index_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
logger.info(f"Index saved to {self.index_path} ({len(self.contents)} chunks)")
def retrieve(
self,
agent_role: str,
query: MemoryContentSnapshot,
top_k: int,
similarity_threshold: float,
) -> List[MemoryItem]:
"""
Retrieve relevant file chunks based on query.
Args:
agent_role: Agent role (not used in file memory)
inputs: Query text
top_k: Number of results to return
similarity_threshold: Minimum similarity score
Returns:
List of MemoryItem with file chunks
"""
if self.count_memories() == 0:
return []
# Generate query embedding
query_embedding = self.embedding.get_embedding(query.text)
if isinstance(query_embedding, list):
query_embedding = np.array(query_embedding, dtype=np.float32)
query_embedding = query_embedding.reshape(1, -1)
faiss.normalize_L2(query_embedding)
expected_dim = query_embedding.shape[1]
# Collect embeddings from memory items
memory_embeddings = []
valid_items = []
for item in self.contents:
if item.embedding is not None:
if len(item.embedding) != expected_dim:
logger.warning(
"Skipping memory item %s: embedding dim %d != expected %d",
item.id, len(item.embedding), expected_dim,
)
continue
memory_embeddings.append(item.embedding)
valid_items.append(item)
if not memory_embeddings:
return []
memory_embeddings = np.array(memory_embeddings, dtype=np.float32)
# Build FAISS index and search
index = faiss.IndexFlatIP(memory_embeddings.shape[1])
index.add(memory_embeddings)
similarities, indices = index.search(query_embedding, min(top_k, len(valid_items)))
# Filter by threshold and return results
results = []
for i in range(len(indices[0])):
idx = indices[0][i]
similarity = similarities[0][i]
if idx != -1 and similarity >= similarity_threshold:
results.append(valid_items[idx])
return results
def update(self, payload: MemoryWritePayload) -> None:
"""
FileMemory is read-only, updates are not supported.
This method is a no-op to maintain interface compatibility.
"""
logger.debug("FileMemory.update() called but FileMemory is read-only")
pass
# ========== Private Helper Methods ==========
def _load_from_file(self) -> None:
"""Load index from JSON file"""
try:
with open(self.index_path, 'r', encoding='utf-8') as f:
data = json.load(f)
self.file_metadata = data.get("file_metadata", {})
raw_contents = data.get("contents", [])
contents: List[MemoryItem] = []
for raw in raw_contents:
try:
contents.append(MemoryItem.from_dict(raw))
except Exception:
continue
self.contents = contents
# Load config if present
config = data.get("config", {})
self.chunk_size = config.get("chunk_size", self.chunk_size)
self.chunk_overlap = config.get("chunk_overlap", self.chunk_overlap)
logger.info(f"Loaded {len(self.contents)} chunks from index")
except Exception as e:
logger.error(f"Error loading index: {e}")
self.file_metadata = {}
self.contents = []
def _build_index_from_sources(self) -> None:
"""Build index by scanning all file sources"""
all_chunks = []
for source in self.file_sources:
logger.info(f"Scanning source: {source.source_path}")
files = self._scan_files(source)
logger.info(f"Found {len(files)} files in {source.source_path}")
for file_path in files:
chunks = self._read_and_chunk_file(file_path, source.encoding)
all_chunks.extend(chunks)
logger.info(f"Total chunks to index: {len(all_chunks)}")
# Generate embeddings for all chunks
self.contents = self._build_embeddings(all_chunks)
logger.info(f"Index built with {len(self.contents)} chunks")
def _validate_and_update_index(self) -> bool:
"""
Validate index integrity and update if files changed.
Returns:
True if index was updated, False otherwise
"""
updated = False
current_files = set()
# Scan current files
for source in self.file_sources:
files = self._scan_files(source)
current_files.update(files)
# Check for deleted files
indexed_files = set(self.file_metadata.keys())
deleted_files = indexed_files - current_files
if deleted_files:
logger.info(f"Removing {len(deleted_files)} deleted files from index")
self._remove_files_from_index(deleted_files)
updated = True
# Check for new or modified files
for source in self.file_sources:
files = self._scan_files(source)
for file_path in files:
file_hash = self._compute_file_hash(file_path)
# New file
if file_path not in self.file_metadata:
logger.info(f"Indexing new file: {file_path}")
self._index_file(file_path, source.encoding)
updated = True
# Modified file
elif self.file_metadata[file_path].get("hash") != file_hash:
logger.info(f"Re-indexing modified file: {file_path}")
self._remove_files_from_index([file_path])
self._index_file(file_path, source.encoding)
updated = True
return updated
def _scan_files(self, source: FileSourceConfig) -> List[str]:
"""
Scan file path and return list of matching files.
Args:
source: FileSourceConfig with path and filters
Returns:
List of absolute file paths
"""
path = Path(source.source_path).expanduser().resolve()
# Single file
if path.is_file():
if self._matches_file_types(path, source.file_types):
return [str(path)]
return []
# Directory
if not path.is_dir():
logger.warning(f"Path does not exist: {source.source_path}")
return []
files = []
if source.recursive:
# Recursive scan
for file_path in path.rglob("*"):
if file_path.is_file() and self._matches_file_types(file_path, source.file_types):
files.append(str(file_path))
else:
# Non-recursive scan
for file_path in path.glob("*"):
if file_path.is_file() and self._matches_file_types(file_path, source.file_types):
files.append(str(file_path))
return files
def _matches_file_types(self, file_path: Path, file_types: List[str]) -> bool:
"""Check if file matches the file type filter"""
if file_types is None:
return True
return file_path.suffix in file_types
def _read_and_chunk_file(self, file_path: str, encoding: str = "utf-8") -> List[Dict]:
"""
Read file and split into chunks.
Args:
file_path: Path to file
encoding: File encoding
Returns:
List of chunk dictionaries with content and metadata
"""
try:
with open(file_path, 'r', encoding=encoding, errors='ignore') as f:
content = f.read()
except Exception as e:
logger.error(f"Error reading file {file_path}: {e}")
return []
if not content.strip():
return []
# Compute file hash
file_hash = self._compute_file_hash(file_path)
file_size = os.path.getsize(file_path)
# Chunk the content
chunks = self._chunk_text(content)
# Build chunk metadata
chunk_dicts = []
for i, chunk_text in enumerate(chunks):
chunk_dicts.append({
"content": chunk_text,
"metadata": {
"source_type": "file",
"file_path": file_path,
"file_name": os.path.basename(file_path),
"file_hash": file_hash,
"file_size": file_size,
"chunk_index": i,
"total_chunks": len(chunks),
"encoding": encoding,
}
})
# Update file metadata cache
self.file_metadata[file_path] = {
"hash": file_hash,
"size": file_size,
"chunks_count": len(chunks),
"indexed_at": time.time(),
}
return chunk_dicts
def _chunk_text(self, text: str) -> List[str]:
"""
Split text into chunks with overlap.
Args:
text: Input text
Returns:
List of text chunks
"""
if len(text) <= self.chunk_size:
return [text]
chunks = []
start = 0
while start < len(text):
end = start + self.chunk_size
chunk = text[start:end]
# Try to break at sentence boundary
if end < len(text):
# Look for sentence endings
last_sentence = max(
chunk.rfind(''),
chunk.rfind(''),
chunk.rfind(''),
chunk.rfind('.'),
chunk.rfind('!'),
chunk.rfind('?'),
chunk.rfind('\n')
)
if last_sentence > self.chunk_size // 2: # Don't break too early
chunk = chunk[:last_sentence + 1]
end = start + last_sentence + 1
chunks.append(chunk.strip())
# Move start with overlap
start = end - self.chunk_overlap
if start >= len(text):
break
return [c for c in chunks if c] # Filter empty chunks
def _build_embeddings(self, chunks: List[Dict]) -> List[MemoryItem]:
"""
Generate embeddings for chunks and create MemoryItems.
Args:
chunks: List of chunk dictionaries
Returns:
List of MemoryItem objects
"""
memory_items = []
for chunk_dict in chunks:
content = chunk_dict["content"]
metadata = chunk_dict["metadata"]
# Generate embedding
try:
embedding = self.embedding.get_embedding(content)
if isinstance(embedding, list):
embedding = np.array(embedding, dtype=np.float32).reshape(1, -1)
faiss.normalize_L2(embedding)
embedding_list = embedding.tolist()[0]
except Exception as e:
logger.error(f"Error generating embedding for chunk: {e}")
continue
# Create MemoryItem
item_id = f"{metadata['file_hash']}_{metadata['chunk_index']}"
memory_item = MemoryItem(
id=item_id,
content_summary=content,
metadata=metadata,
embedding=embedding_list,
timestamp=time.time(),
)
memory_items.append(memory_item)
return memory_items
def _compute_file_hash(self, file_path: str) -> str:
"""Compute MD5 hash of file"""
hash_md5 = hashlib.md5()
try:
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()[:16]
except Exception as e:
logger.error(f"Error computing hash for {file_path}: {e}")
return "error"
def _index_file(self, file_path: str, encoding: str = "utf-8") -> None:
"""Index a single file (helper for incremental updates)"""
chunks = self._read_and_chunk_file(file_path, encoding)
if chunks:
new_items = self._build_embeddings(chunks)
self.contents.extend(new_items)
def _remove_files_from_index(self, file_paths: List[str]) -> None:
"""Remove chunks from deleted files"""
file_paths_set = set(file_paths)
# Filter out chunks from deleted files
self.contents = [
item for item in self.contents
if item.metadata.get("file_path") not in file_paths_set
]
# Remove from metadata
for file_path in file_paths:
self.file_metadata.pop(file_path, None)