From 6a898ce2da1d91bf5eac64c5d72b0dc1f6a795fc Mon Sep 17 00:00:00 2001 From: laansdole Date: Sun, 15 Mar 2026 22:50:40 +0700 Subject: [PATCH] fix: stored memory embeddings had mixed dim --- runtime/node/agent/memory/embedding.py | 19 ++- runtime/node/agent/memory/file_memory.py | 8 + runtime/node/agent/memory/simple_memory.py | 11 ++ tests/test_memory_embedding_consistency.py | 168 +++++++++++++++++++++ 4 files changed, 199 insertions(+), 7 deletions(-) create mode 100644 tests/test_memory_embedding_consistency.py diff --git a/runtime/node/agent/memory/embedding.py b/runtime/node/agent/memory/embedding.py index 02bcf0bc..7f27f1a1 100755 --- a/runtime/node/agent/memory/embedding.py +++ b/runtime/node/agent/memory/embedding.py @@ -86,6 +86,7 @@ class OpenAIEmbedding(EmbeddingBase): self.max_length = embedding_config.params.get('max_length', 8191) self.use_chunking = embedding_config.params.get('use_chunking', False) self.chunk_strategy = embedding_config.params.get('chunk_strategy', 'average') + self._fallback_dim = 1536 # Default; updated after first successful call if self.base_url: self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url) @@ -99,7 +100,7 @@ class OpenAIEmbedding(EmbeddingBase): if not processed_text: logger.warning("Empty text after preprocessing") - return [0.0] * 1536 # Return a zero vector + return [0.0] * self._fallback_dim # Handle long text via chunking if self.use_chunking and len(processed_text) > self.max_length: @@ -115,17 +116,18 @@ class OpenAIEmbedding(EmbeddingBase): encoding_format="float" ) embedding = response.data[0].embedding + self._fallback_dim = len(embedding) return embedding except Exception as e: logger.error(f"Error getting embedding: {e}") - return [0.0] * 1536 # Return zero vector as fallback + return [0.0] * self._fallback_dim def _get_chunked_embedding(self, text: str) -> List[float]: """Chunk long text, embed each chunk, then aggregate.""" chunks = self._chunk_text(text, self.max_length // 2) # Halve the chunk length if not chunks: - return [0.0] * 1536 + return [0.0] * self._fallback_dim chunk_embeddings = [] for chunk in chunks: @@ -141,7 +143,7 @@ class OpenAIEmbedding(EmbeddingBase): continue if not chunk_embeddings: - return [0.0] * 1536 + return [0.0] * self._fallback_dim # Aggregation strategy if self.chunk_strategy == 'average': @@ -163,6 +165,7 @@ class LocalEmbedding(EmbeddingBase): super().__init__(embedding_config) self.model_path = embedding_config.params.get('model_path') self.device = embedding_config.params.get('device', 'cpu') + self._fallback_dim = 768 # Default; updated after first successful call if not self.model_path: raise ValueError("LocalEmbedding requires model_path parameter") @@ -179,11 +182,13 @@ class LocalEmbedding(EmbeddingBase): processed_text = self._preprocess_text(text) if not processed_text: - return [0.0] * 768 # Return zero vector + return [0.0] * self._fallback_dim try: embedding = self.model.encode(processed_text, convert_to_tensor=False) - return embedding.tolist() + result = embedding.tolist() + self._fallback_dim = len(result) + return result except Exception as e: logger.error(f"Error getting local embedding: {e}") - return [0.0] * 768 # Return zero vector as fallback + return [0.0] * self._fallback_dim diff --git a/runtime/node/agent/memory/file_memory.py b/runtime/node/agent/memory/file_memory.py index 1a020f30..45ba3138 100755 --- a/runtime/node/agent/memory/file_memory.py +++ b/runtime/node/agent/memory/file_memory.py @@ -123,11 +123,19 @@ class FileMemory(MemoryBase): query_embedding = query_embedding.reshape(1, -1) faiss.normalize_L2(query_embedding) + expected_dim = query_embedding.shape[1] + # Collect embeddings from memory items memory_embeddings = [] valid_items = [] for item in self.contents: if item.embedding is not None: + if len(item.embedding) != expected_dim: + logger.warning( + "Skipping memory item %s: embedding dim %d != expected %d", + item.id, len(item.embedding), expected_dim, + ) + continue memory_embeddings.append(item.embedding) valid_items.append(item) diff --git a/runtime/node/agent/memory/simple_memory.py b/runtime/node/agent/memory/simple_memory.py index abf02c53..363e4359 100755 --- a/runtime/node/agent/memory/simple_memory.py +++ b/runtime/node/agent/memory/simple_memory.py @@ -1,5 +1,6 @@ import hashlib import json +import logging import os import re import time @@ -16,6 +17,8 @@ from runtime.node.agent.memory.memory_base import ( import faiss import numpy as np +logger = logging.getLogger(__name__) + class SimpleMemory(MemoryBase): def __init__(self, store: MemoryStoreConfig): config = store.as_config(SimpleMemoryConfig) @@ -107,10 +110,18 @@ class SimpleMemory(MemoryBase): inputs_embedding = inputs_embedding.reshape(1, -1) faiss.normalize_L2(inputs_embedding) + expected_dim = inputs_embedding.shape[1] + memory_embeddings = [] valid_items = [] for item in self.contents: if item.embedding is not None: + if len(item.embedding) != expected_dim: + logger.warning( + "Skipping memory item %s: embedding dim %d != expected %d", + item.id, len(item.embedding), expected_dim, + ) + continue memory_embeddings.append(item.embedding) valid_items.append(item) diff --git a/tests/test_memory_embedding_consistency.py b/tests/test_memory_embedding_consistency.py new file mode 100644 index 00000000..2b7c7a91 --- /dev/null +++ b/tests/test_memory_embedding_consistency.py @@ -0,0 +1,168 @@ +"""Tests for memory embedding dimension consistency.""" +from unittest.mock import MagicMock, patch +from runtime.node.agent.memory.memory_base import MemoryContentSnapshot, MemoryItem +from runtime.node.agent.memory.simple_memory import SimpleMemory + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_store(memory_path=None): + """Build a minimal MemoryStoreConfig mock for SimpleMemory.""" + simple_cfg = MagicMock() + simple_cfg.memory_path = memory_path + simple_cfg.embedding = None # We'll set embedding manually + + store = MagicMock() + store.name = "test_store" + store.as_config.return_value = simple_cfg + return store + + +def _make_embedding(dim: int): + """Create a mock EmbeddingBase that produces vectors of the given dimension.""" + emb = MagicMock() + emb.get_embedding.return_value = [0.1] * dim + return emb + + +def _make_memory_item(item_id: str, dim: int): + """Create a MemoryItem with an embedding of the specified dimension.""" + return MemoryItem( + id=item_id, + content_summary=f"content for {item_id}", + metadata={}, + embedding=[float(i) for i in range(dim)], + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestSimpleMemoryRetrieveMixedDimensions: + """Task 2.1: verify retrieve() handles mixed-dimension embeddings.""" + + def test_mixed_dimensions_does_not_crash(self): + """Retrieve with mixed-dimensional embeddings MUST not raise.""" + store = _make_store() + memory = SimpleMemory(store) + memory.embedding = _make_embedding(dim=768) + + # 3 items with correct dim, 2 with wrong dim + memory.contents = [ + _make_memory_item("ok_1", 768), + _make_memory_item("bad_1", 1536), + _make_memory_item("ok_2", 768), + _make_memory_item("bad_2", 256), + _make_memory_item("ok_3", 768), + ] + + query = MemoryContentSnapshot(text="test query") + # Should NOT raise ValueError / numpy error + results = memory.retrieve( + agent_role="tester", + query=query, + top_k=5, + similarity_threshold=-1.0, + ) + # Only the 3 correct-dimension items should be candidates + assert len(results) <= 3 + + def test_all_same_dimension_returns_results(self): + """When all embeddings share the correct dimension, all are candidates.""" + store = _make_store() + memory = SimpleMemory(store) + memory.embedding = _make_embedding(dim=768) + + memory.contents = [ + _make_memory_item("a", 768), + _make_memory_item("b", 768), + ] + + query = MemoryContentSnapshot(text="test query") + results = memory.retrieve( + agent_role="tester", + query=query, + top_k=5, + similarity_threshold=-1.0, + ) + assert len(results) == 2 + + def test_all_wrong_dimension_returns_empty(self): + """When every stored embedding has a wrong dimension, return empty.""" + store = _make_store() + memory = SimpleMemory(store) + memory.embedding = _make_embedding(dim=768) + + memory.contents = [ + _make_memory_item("x", 1536), + _make_memory_item("y", 1536), + ] + + query = MemoryContentSnapshot(text="test query") + results = memory.retrieve( + agent_role="tester", + query=query, + top_k=5, + similarity_threshold=-1.0, + ) + assert results == [] + + +class TestOpenAIEmbeddingDynamicFallback: + """Task 2.2: verify dynamic fallback dimension caching.""" + + def test_fallback_uses_model_dimension_after_success(self): + """After a successful call the fallback dimension MUST match the model.""" + from runtime.node.agent.memory.embedding import OpenAIEmbedding + + cfg = MagicMock() + cfg.base_url = "http://localhost:11434/v1" + cfg.api_key = "test" + cfg.model = "test-model" + cfg.params = {} + + emb = OpenAIEmbedding(cfg) + assert emb._fallback_dim == 1536 # default before any call + + # Simulate a successful 768-dim response + mock_data = MagicMock() + mock_data.embedding = [0.1] * 768 + mock_response = MagicMock() + mock_response.data = [mock_data] + + with patch.object(emb.client.embeddings, "create", return_value=mock_response): + result = emb.get_embedding("hello world") + + assert len(result) == 768 + assert emb._fallback_dim == 768 # updated after success + + def test_fallback_zero_vector_matches_cached_dim(self): + """After caching dim, fallback zero-vectors MUST use that dim.""" + from runtime.node.agent.memory.embedding import OpenAIEmbedding + + cfg = MagicMock() + cfg.base_url = "http://localhost:11434/v1" + cfg.api_key = "test" + cfg.model = "test-model" + cfg.params = {} + + emb = OpenAIEmbedding(cfg) + + # Simulate successful 512-dim call + mock_data = MagicMock() + mock_data.embedding = [0.1] * 512 + mock_response = MagicMock() + mock_response.data = [mock_data] + + with patch.object(emb.client.embeddings, "create", return_value=mock_response): + emb.get_embedding("first call") + + # Now simulate a failure — fallback should be 512-dim + with patch.object(emb.client.embeddings, "create", side_effect=Exception("API down")): + fallback = emb.get_embedding("failing call") + + assert len(fallback) == 512 + assert all(v == 0.0 for v in fallback)