Merge pull request #582 from LaansDole/main

Fix: hardcoded embedding dimension
This commit is contained in:
Shu Yao 2026-03-17 10:43:17 +08:00 committed by GitHub
commit fea709142a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 220 additions and 8 deletions

View File

@ -41,6 +41,23 @@ validate-yamls: ## Validate all YAML configuration files
.PHONY: help .PHONY: help
help: ## Display this help message help: ## Display this help message
@python -c "import re; \ @uv run python -c "import re; \
p=r'$(firstword $(MAKEFILE_LIST))'.strip(); \ p=r'$(firstword $(MAKEFILE_LIST))'.strip(); \
[print(f'{m[0]:<20} {m[1]}') for m in re.findall(r'^([a-zA-Z_-]+):.*?## (.*)$$', open(p, encoding='utf-8').read(), re.M)]" | sort [print(f'{m[0]:<20} {m[1]}') for m in re.findall(r'^([a-zA-Z_-]+):.*?## (.*)$$', open(p, encoding='utf-8').read(), re.M)]" | sort
# ==============================================================================
# Quality Checks
# ==============================================================================
.PHONY: check-backend
check-backend: ## Run backend quality checks (tests + linting)
@$(MAKE) backend-tests
@$(MAKE) backend-lint
.PHONY: backend-tests
backend-tests: ## Run backend tests
@uv run pytest -v
.PHONY: backend-lint
backend-lint: ## Run backend linting
@uvx ruff check .

View File

@ -45,5 +45,19 @@ dependencies = [
requires = ["hatchling"] requires = ["hatchling"]
build-backend = "hatchling.build" build-backend = "hatchling.build"
[tool.pytest.ini_options]
pythonpath = ["."]
testpaths = ["tests"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
addopts = "-v --tb=short"
filterwarnings = [
# Upstream SWIG issue in faiss-cpu on Python 3.12; awaiting SWIG 4.4 fix.
"ignore:builtin type Swig.*:DeprecationWarning",
"ignore:builtin type swigvarlink.*:DeprecationWarning",
]
[tool.uv] [tool.uv]
package = false package = false

View File

@ -86,6 +86,7 @@ class OpenAIEmbedding(EmbeddingBase):
self.max_length = embedding_config.params.get('max_length', 8191) self.max_length = embedding_config.params.get('max_length', 8191)
self.use_chunking = embedding_config.params.get('use_chunking', False) self.use_chunking = embedding_config.params.get('use_chunking', False)
self.chunk_strategy = embedding_config.params.get('chunk_strategy', 'average') self.chunk_strategy = embedding_config.params.get('chunk_strategy', 'average')
self._fallback_dim = 1536 # Default; updated after first successful call
if self.base_url: if self.base_url:
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url) self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
@ -99,7 +100,7 @@ class OpenAIEmbedding(EmbeddingBase):
if not processed_text: if not processed_text:
logger.warning("Empty text after preprocessing") logger.warning("Empty text after preprocessing")
return [0.0] * 1536 # Return a zero vector return [0.0] * self._fallback_dim
# Handle long text via chunking # Handle long text via chunking
if self.use_chunking and len(processed_text) > self.max_length: if self.use_chunking and len(processed_text) > self.max_length:
@ -115,17 +116,18 @@ class OpenAIEmbedding(EmbeddingBase):
encoding_format="float" encoding_format="float"
) )
embedding = response.data[0].embedding embedding = response.data[0].embedding
self._fallback_dim = len(embedding)
return embedding return embedding
except Exception as e: except Exception as e:
logger.error(f"Error getting embedding: {e}") logger.error(f"Error getting embedding: {e}")
return [0.0] * 1536 # Return zero vector as fallback return [0.0] * self._fallback_dim
def _get_chunked_embedding(self, text: str) -> List[float]: def _get_chunked_embedding(self, text: str) -> List[float]:
"""Chunk long text, embed each chunk, then aggregate.""" """Chunk long text, embed each chunk, then aggregate."""
chunks = self._chunk_text(text, self.max_length // 2) # Halve the chunk length chunks = self._chunk_text(text, self.max_length // 2) # Halve the chunk length
if not chunks: if not chunks:
return [0.0] * 1536 return [0.0] * self._fallback_dim
chunk_embeddings = [] chunk_embeddings = []
for chunk in chunks: for chunk in chunks:
@ -141,7 +143,7 @@ class OpenAIEmbedding(EmbeddingBase):
continue continue
if not chunk_embeddings: if not chunk_embeddings:
return [0.0] * 1536 return [0.0] * self._fallback_dim
# Aggregation strategy # Aggregation strategy
if self.chunk_strategy == 'average': if self.chunk_strategy == 'average':
@ -163,6 +165,7 @@ class LocalEmbedding(EmbeddingBase):
super().__init__(embedding_config) super().__init__(embedding_config)
self.model_path = embedding_config.params.get('model_path') self.model_path = embedding_config.params.get('model_path')
self.device = embedding_config.params.get('device', 'cpu') self.device = embedding_config.params.get('device', 'cpu')
self._fallback_dim = 768 # Default; updated after first successful call
if not self.model_path: if not self.model_path:
raise ValueError("LocalEmbedding requires model_path parameter") raise ValueError("LocalEmbedding requires model_path parameter")
@ -179,11 +182,13 @@ class LocalEmbedding(EmbeddingBase):
processed_text = self._preprocess_text(text) processed_text = self._preprocess_text(text)
if not processed_text: if not processed_text:
return [0.0] * 768 # Return zero vector return [0.0] * self._fallback_dim
try: try:
embedding = self.model.encode(processed_text, convert_to_tensor=False) embedding = self.model.encode(processed_text, convert_to_tensor=False)
return embedding.tolist() result = embedding.tolist()
self._fallback_dim = len(result)
return result
except Exception as e: except Exception as e:
logger.error(f"Error getting local embedding: {e}") logger.error(f"Error getting local embedding: {e}")
return [0.0] * 768 # Return zero vector as fallback return [0.0] * self._fallback_dim

View File

@ -123,11 +123,19 @@ class FileMemory(MemoryBase):
query_embedding = query_embedding.reshape(1, -1) query_embedding = query_embedding.reshape(1, -1)
faiss.normalize_L2(query_embedding) faiss.normalize_L2(query_embedding)
expected_dim = query_embedding.shape[1]
# Collect embeddings from memory items # Collect embeddings from memory items
memory_embeddings = [] memory_embeddings = []
valid_items = [] valid_items = []
for item in self.contents: for item in self.contents:
if item.embedding is not None: if item.embedding is not None:
if len(item.embedding) != expected_dim:
logger.warning(
"Skipping memory item %s: embedding dim %d != expected %d",
item.id, len(item.embedding), expected_dim,
)
continue
memory_embeddings.append(item.embedding) memory_embeddings.append(item.embedding)
valid_items.append(item) valid_items.append(item)

View File

@ -1,5 +1,6 @@
import hashlib import hashlib
import json import json
import logging
import os import os
import re import re
import time import time
@ -16,6 +17,8 @@ from runtime.node.agent.memory.memory_base import (
import faiss import faiss
import numpy as np import numpy as np
logger = logging.getLogger(__name__)
class SimpleMemory(MemoryBase): class SimpleMemory(MemoryBase):
def __init__(self, store: MemoryStoreConfig): def __init__(self, store: MemoryStoreConfig):
config = store.as_config(SimpleMemoryConfig) config = store.as_config(SimpleMemoryConfig)
@ -107,10 +110,18 @@ class SimpleMemory(MemoryBase):
inputs_embedding = inputs_embedding.reshape(1, -1) inputs_embedding = inputs_embedding.reshape(1, -1)
faiss.normalize_L2(inputs_embedding) faiss.normalize_L2(inputs_embedding)
expected_dim = inputs_embedding.shape[1]
memory_embeddings = [] memory_embeddings = []
valid_items = [] valid_items = []
for item in self.contents: for item in self.contents:
if item.embedding is not None: if item.embedding is not None:
if len(item.embedding) != expected_dim:
logger.warning(
"Skipping memory item %s: embedding dim %d != expected %d",
item.id, len(item.embedding), expected_dim,
)
continue
memory_embeddings.append(item.embedding) memory_embeddings.append(item.embedding)
valid_items.append(item) valid_items.append(item)

View File

@ -0,0 +1,157 @@
"""Tests for memory embedding dimension consistency."""
from unittest.mock import MagicMock, patch
from runtime.node.agent.memory.memory_base import MemoryContentSnapshot, MemoryItem
from runtime.node.agent.memory.simple_memory import SimpleMemory
def _make_store(memory_path=None):
"""Build a minimal MemoryStoreConfig mock for SimpleMemory."""
simple_cfg = MagicMock()
simple_cfg.memory_path = memory_path
simple_cfg.embedding = None # We'll set embedding manually
store = MagicMock()
store.name = "test_store"
store.as_config.return_value = simple_cfg
return store
def _make_embedding(dim: int):
"""Create a mock EmbeddingBase that produces vectors of the given dimension."""
emb = MagicMock()
emb.get_embedding.return_value = [0.1] * dim
return emb
def _make_memory_item(item_id: str, dim: int):
"""Create a MemoryItem with an embedding of the specified dimension."""
return MemoryItem(
id=item_id,
content_summary=f"content for {item_id}",
metadata={},
embedding=[float(i) for i in range(dim)],
)
class TestSimpleMemoryRetrieveMixedDimensions:
def test_mixed_dimensions_does_not_crash(self):
"""Retrieve with mixed-dimensional embeddings MUST not raise."""
store = _make_store()
memory = SimpleMemory(store)
memory.embedding = _make_embedding(dim=768)
# 3 items with correct dim, 2 with wrong dim
memory.contents = [
_make_memory_item("ok_1", 768),
_make_memory_item("bad_1", 1536),
_make_memory_item("ok_2", 768),
_make_memory_item("bad_2", 256),
_make_memory_item("ok_3", 768),
]
query = MemoryContentSnapshot(text="test query")
# Should NOT raise ValueError / numpy error
results = memory.retrieve(
agent_role="tester",
query=query,
top_k=5,
similarity_threshold=-1.0,
)
# Only the 3 correct-dimension items should be candidates
assert len(results) <= 3
def test_all_same_dimension_returns_results(self):
"""When all embeddings share the correct dimension, all are candidates."""
store = _make_store()
memory = SimpleMemory(store)
memory.embedding = _make_embedding(dim=768)
memory.contents = [
_make_memory_item("a", 768),
_make_memory_item("b", 768),
]
query = MemoryContentSnapshot(text="test query")
results = memory.retrieve(
agent_role="tester",
query=query,
top_k=5,
similarity_threshold=-1.0,
)
assert len(results) == 2
def test_all_wrong_dimension_returns_empty(self):
"""When every stored embedding has a wrong dimension, return empty."""
store = _make_store()
memory = SimpleMemory(store)
memory.embedding = _make_embedding(dim=768)
memory.contents = [
_make_memory_item("x", 1536),
_make_memory_item("y", 1536),
]
query = MemoryContentSnapshot(text="test query")
results = memory.retrieve(
agent_role="tester",
query=query,
top_k=5,
similarity_threshold=-1.0,
)
assert results == []
class TestOpenAIEmbeddingDynamicFallback:
def test_fallback_uses_model_dimension_after_success(self):
"""After a successful call the fallback dimension MUST match the model."""
from runtime.node.agent.memory.embedding import OpenAIEmbedding
cfg = MagicMock()
cfg.base_url = "http://localhost:11434/v1"
cfg.api_key = "test"
cfg.model = "test-model"
cfg.params = {}
emb = OpenAIEmbedding(cfg)
assert emb._fallback_dim == 1536 # default before any call
# Simulate a successful 768-dim response
mock_data = MagicMock()
mock_data.embedding = [0.1] * 768
mock_response = MagicMock()
mock_response.data = [mock_data]
with patch.object(emb.client.embeddings, "create", return_value=mock_response):
result = emb.get_embedding("hello world")
assert len(result) == 768
assert emb._fallback_dim == 768 # updated after success
def test_fallback_zero_vector_matches_cached_dim(self):
"""After caching dim, fallback zero-vectors MUST use that dim."""
from runtime.node.agent.memory.embedding import OpenAIEmbedding
cfg = MagicMock()
cfg.base_url = "http://localhost:11434/v1"
cfg.api_key = "test"
cfg.model = "test-model"
cfg.params = {}
emb = OpenAIEmbedding(cfg)
# Simulate successful 512-dim call
mock_data = MagicMock()
mock_data.embedding = [0.1] * 512
mock_response = MagicMock()
mock_response.data = [mock_data]
with patch.object(emb.client.embeddings, "create", return_value=mock_response):
emb.get_embedding("first call")
# Now simulate a failure — fallback should be 512-dim
with patch.object(emb.client.embeddings, "create", side_effect=Exception("API down")):
fallback = emb.get_embedding("failing call")
assert len(fallback) == 512
assert all(v == 0.0 for v in fallback)