ChatDev/runtime/node/agent/memory/simple_memory.py
2026-03-16 16:27:37 +07:00

295 lines
11 KiB
Python
Executable File

import hashlib
import json
import logging
import os
import re
import time
from typing import List
from entity.configs import MemoryStoreConfig
from entity.configs.node.memory import SimpleMemoryConfig
from runtime.node.agent.memory.memory_base import (
MemoryBase,
MemoryContentSnapshot,
MemoryItem,
MemoryWritePayload,
)
import faiss
import numpy as np
logger = logging.getLogger(__name__)
class SimpleMemory(MemoryBase):
def __init__(self, store: MemoryStoreConfig):
config = store.as_config(SimpleMemoryConfig)
if not config:
raise ValueError("SimpleMemory requires a simple memory store configuration")
super().__init__(store)
self.config = config
# Optimized prompt templates for clarity
self.retrieve_prompt = "Query: {input}"
self.update_prompt = "Input: {input}\nOutput: {output}"
self.memory_path = self.config.memory_path # auto
# Content extraction configuration
self.max_content_length = 500 # Maximum content length
self.min_content_length = 20 # Minimum content length
def _extract_key_content(self, content: str) -> str:
"""Extract key content while stripping redundant text."""
# Remove redundant whitespace
content = re.sub(r'\s+', ' ', content.strip())
# Skip heavy processing for short snippets
if len(content) <= 100:
return content
# Remove common templated instructions
content = re.sub(r'(?:Agent|Model) Role:.*?\n\n', '', content)
content = re.sub("(?:You are|\u4f60\u662f\u4e00\u4f4d).*?(?:,|\uff0c)", '', content)
content = re.sub("(?:User will input|\u7528\u6237\u4f1a\u8f93\u5165).*?(?:,|\uff0c)", '', content)
content = re.sub("(?:You need to|\u4f60\u9700\u8981).*?(?:,|\uff0c)", '', content)
# Extract key sentences while skipping very short ones
sentences = re.split(r'[\u3002\uff01\uff1f\uff1b\n]', content)
key_sentences = [s.strip() for s in sentences if len(s.strip()) >= self.min_content_length]
# Fallback to original content when no sentence survives
if not key_sentences:
return content[:self.max_content_length]
# Recombine and limit the number of sentences (max 3)
extracted_content = '\u3002'.join(key_sentences[:3])
if len(extracted_content) > self.max_content_length:
extracted_content = extracted_content[:self.max_content_length] + "..."
return extracted_content.strip()
def _generate_content_hash(self, content: str) -> str:
"""Generate a content hash used for deduplication."""
return hashlib.md5(content.encode('utf-8')).hexdigest()[:8]
def load(self) -> None:
if self.memory_path and os.path.exists(self.memory_path) and self.memory_path.endswith(".json"):
try:
with open(self.memory_path) as file:
raw_data = json.load(file)
contents = []
for raw in raw_data:
try:
contents.append(MemoryItem.from_dict(raw))
except Exception:
continue
self.contents = contents
except Exception:
self.contents = []
def save(self) -> None:
if self.memory_path and self.memory_path.endswith(".json"):
os.makedirs(os.path.dirname(self.memory_path), exist_ok=True)
with open(self.memory_path, "w") as file:
json.dump([item.to_dict() for item in self.contents], file, indent=2, ensure_ascii=False)
def retrieve(
self,
agent_role: str,
query: MemoryContentSnapshot,
top_k: int,
similarity_threshold: float,
) -> List[MemoryItem]:
if self.count_memories() == 0 or not self.embedding:
return []
# Build an optimized query for retrieval
query_text = self.retrieve_prompt.format(input=query.text)
query_text = self._extract_key_content(query_text)
inputs_embedding = self.embedding.get_embedding(query_text)
if isinstance(inputs_embedding, list):
inputs_embedding = np.array(inputs_embedding, dtype=np.float32)
inputs_embedding = inputs_embedding.reshape(1, -1)
faiss.normalize_L2(inputs_embedding)
expected_dim = inputs_embedding.shape[1]
memory_embeddings = []
valid_items = []
for item in self.contents:
if item.embedding is not None:
if len(item.embedding) != expected_dim:
logger.warning(
"Skipping memory item %s: embedding dim %d != expected %d",
item.id, len(item.embedding), expected_dim,
)
continue
memory_embeddings.append(item.embedding)
valid_items.append(item)
if not memory_embeddings:
return []
memory_embeddings = np.array(memory_embeddings, dtype=np.float32)
# Use an efficient inner-product index
index = faiss.IndexFlatIP(memory_embeddings.shape[1])
index.add(memory_embeddings)
# Retrieve extra candidates for reranking
retrieval_k = min(top_k * 3, len(valid_items))
similarities, indices = index.search(inputs_embedding, retrieval_k)
# Filter and rerank the candidates
candidates = []
for i in range(len(indices[0])):
idx = indices[0][i]
similarity = similarities[0][i]
if idx != -1 and similarity >= similarity_threshold:
item = valid_items[idx]
# Calculate an auxiliary semantic similarity score
semantic_score = self._calculate_semantic_similarity(query_text, item.content_summary)
# Combine similarity metrics
combined_score = 0.7 * similarity + 0.3 * semantic_score
candidates.append((item, combined_score))
# Sort by the combined score and return the top_k items
candidates.sort(key=lambda x: x[1], reverse=True)
results = [item for item, score in candidates[:top_k]]
return results
def _calculate_semantic_similarity(self, query: str, content: str) -> float:
"""Compute a semantic similarity value."""
# Enhanced semantic similarity computation
query_lower = query.lower()
content_lower = content.lower()
# 1. Token overlap (Jaccard similarity)
query_words = set(query_lower.split())
content_words = set(content_lower.split())
if not query_words or not content_words:
jaccard_sim = 0.0
else:
intersection = query_words & content_words
union = query_words | content_words
jaccard_sim = len(intersection) / len(union) if union else 0.0
# 2. Longest common subsequence similarity
lcs_sim = self._calculate_lcs_similarity(query_lower, content_lower)
# 3. Keyword match score
keyword_sim = self._calculate_keyword_similarity(query_lower, content_lower)
# 4. Length penalty factor (avoid overly short/long matches)
length_factor = self._calculate_length_factor(query_lower, content_lower)
# Weighted final score
final_score = (0.4 * jaccard_sim +
0.3 * lcs_sim +
0.2 * keyword_sim +
0.1 * length_factor)
return min(final_score, 1.0)
def _calculate_lcs_similarity(self, s1: str, s2: str) -> float:
"""Compute longest common subsequence similarity."""
m, n = len(s1), len(s2)
dp = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(1, m + 1):
for j in range(1, n + 1):
if s1[i-1] == s2[j-1]:
dp[i][j] = dp[i-1][j-1] + 1
else:
dp[i][j] = max(dp[i-1][j], dp[i][j-1])
lcs_length = dp[m][n]
return lcs_length / max(len(s1), len(s2)) if max(len(s1), len(s2)) > 0 else 0.0
def _calculate_keyword_similarity(self, query: str, content: str) -> float:
"""Compute keyword match similarity."""
# Extract potential keywords (length >= 2)
query_keywords = set(word for word in query.split() if len(word) >= 2)
content_keywords = set(word for word in content.split() if len(word) >= 2)
if not query_keywords:
return 0.0
matches = query_keywords & content_keywords
return len(matches) / len(query_keywords)
def _calculate_length_factor(self, query: str, content: str) -> float:
"""Penalize matches that deviate too much in length."""
query_len = len(query)
content_len = len(content)
if content_len == 0:
return 0.0
# Ideal length ratio range
ideal_ratio_min = 0.5
ideal_ratio_max = 2.0
ratio = content_len / query_len
if ideal_ratio_min <= ratio <= ideal_ratio_max:
return 1.0
elif ratio < ideal_ratio_min:
return ratio / ideal_ratio_min
else:
return max(0.1, ideal_ratio_max / ratio)
def update(self, payload: MemoryWritePayload) -> None:
if not self.embedding:
return
snapshot = payload.output_snapshot
if not snapshot or not snapshot.text.strip():
return
raw_content = self.update_prompt.format(
input=payload.inputs_text,
output=snapshot.text,
)
extracted_content = self._extract_key_content(raw_content)
if len(extracted_content) < self.min_content_length:
return
content_hash = self._generate_content_hash(extracted_content)
for existing_item in self.contents:
existing_hash = self._generate_content_hash(existing_item.content_summary)
if existing_hash == content_hash:
return
embedding_vector = self.embedding.get_embedding(extracted_content)
if isinstance(embedding_vector, list):
embedding_vector = np.array(embedding_vector, dtype=np.float32)
if embedding_vector is None:
return
embedding_array = np.array(embedding_vector, dtype=np.float32).reshape(1, -1)
faiss.normalize_L2(embedding_array)
metadata = {
"agent_role": payload.agent_role,
"input_preview": (payload.inputs_text or "")[:200],
"content_length": len(extracted_content),
"attachments": snapshot.attachment_overview(),
}
memory_item = MemoryItem(
id=f"{content_hash}_{int(time.time())}",
content_summary=extracted_content,
metadata=metadata,
embedding=embedding_array.tolist()[0],
input_snapshot=payload.input_snapshot,
output_snapshot=snapshot,
)
self.contents.append(memory_item)
max_memories = 1000
if len(self.contents) > max_memories:
self.contents = self.contents[-max_memories:]