mirror of
https://github.com/OpenBMB/ChatDev.git
synced 2026-04-25 11:18:06 +00:00
Merge branch 'main' into main
This commit is contained in:
commit
036fb6efe1
19
Makefile
19
Makefile
@ -41,6 +41,23 @@ validate-yamls: ## Validate all YAML configuration files
|
||||
|
||||
.PHONY: help
|
||||
help: ## Display this help message
|
||||
@python -c "import re; \
|
||||
@uv run python -c "import re; \
|
||||
p=r'$(firstword $(MAKEFILE_LIST))'.strip(); \
|
||||
[print(f'{m[0]:<20} {m[1]}') for m in re.findall(r'^([a-zA-Z_-]+):.*?## (.*)$$', open(p, encoding='utf-8').read(), re.M)]" | sort
|
||||
|
||||
# ==============================================================================
|
||||
# Quality Checks
|
||||
# ==============================================================================
|
||||
|
||||
.PHONY: check-backend
|
||||
check-backend: ## Run backend quality checks (tests + linting)
|
||||
@$(MAKE) backend-tests
|
||||
@$(MAKE) backend-lint
|
||||
|
||||
.PHONY: backend-tests
|
||||
backend-tests: ## Run backend tests
|
||||
@uv run pytest -v
|
||||
|
||||
.PHONY: backend-lint
|
||||
backend-lint: ## Run backend linting
|
||||
@uvx ruff check .
|
||||
|
||||
@ -2390,6 +2390,7 @@ watch(
|
||||
flex-direction: column;
|
||||
pointer-events: auto;
|
||||
z-index: auto;
|
||||
min-height: 0;
|
||||
}
|
||||
|
||||
.chat-panel-fullscreen .chat-panel-content {
|
||||
@ -2408,6 +2409,7 @@ watch(
|
||||
flex-direction: column;
|
||||
gap: 12px;
|
||||
min-width: 0;
|
||||
min-height: 0;
|
||||
pointer-events: auto;
|
||||
background: rgba(26, 26, 26, 0.92);
|
||||
border: 1px solid rgba(255, 255, 255, 0.1);
|
||||
@ -2463,6 +2465,7 @@ watch(
|
||||
overflow: hidden;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
min-height: 0;
|
||||
}
|
||||
|
||||
.chat-messages::-webkit-scrollbar {
|
||||
|
||||
@ -45,5 +45,19 @@ dependencies = [
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
pythonpath = ["."]
|
||||
testpaths = ["tests"]
|
||||
python_files = ["test_*.py"]
|
||||
python_classes = ["Test*"]
|
||||
python_functions = ["test_*"]
|
||||
addopts = "-v --tb=short"
|
||||
filterwarnings = [
|
||||
# Upstream SWIG issue in faiss-cpu on Python 3.12; awaiting SWIG 4.4 fix.
|
||||
"ignore:builtin type Swig.*:DeprecationWarning",
|
||||
"ignore:builtin type swigvarlink.*:DeprecationWarning",
|
||||
]
|
||||
|
||||
|
||||
[tool.uv]
|
||||
package = false
|
||||
|
||||
@ -86,6 +86,7 @@ class OpenAIEmbedding(EmbeddingBase):
|
||||
self.max_length = embedding_config.params.get('max_length', 8191)
|
||||
self.use_chunking = embedding_config.params.get('use_chunking', False)
|
||||
self.chunk_strategy = embedding_config.params.get('chunk_strategy', 'average')
|
||||
self._fallback_dim = 1536 # Default; updated after first successful call
|
||||
|
||||
if self.base_url:
|
||||
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
@ -99,7 +100,7 @@ class OpenAIEmbedding(EmbeddingBase):
|
||||
|
||||
if not processed_text:
|
||||
logger.warning("Empty text after preprocessing")
|
||||
return [0.0] * 1536 # Return a zero vector
|
||||
return [0.0] * self._fallback_dim
|
||||
|
||||
# Handle long text via chunking
|
||||
if self.use_chunking and len(processed_text) > self.max_length:
|
||||
@ -115,17 +116,18 @@ class OpenAIEmbedding(EmbeddingBase):
|
||||
encoding_format="float"
|
||||
)
|
||||
embedding = response.data[0].embedding
|
||||
self._fallback_dim = len(embedding)
|
||||
return embedding
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting embedding: {e}")
|
||||
return [0.0] * 1536 # Return zero vector as fallback
|
||||
return [0.0] * self._fallback_dim
|
||||
|
||||
def _get_chunked_embedding(self, text: str) -> List[float]:
|
||||
"""Chunk long text, embed each chunk, then aggregate."""
|
||||
chunks = self._chunk_text(text, self.max_length // 2) # Halve the chunk length
|
||||
|
||||
if not chunks:
|
||||
return [0.0] * 1536
|
||||
return [0.0] * self._fallback_dim
|
||||
|
||||
chunk_embeddings = []
|
||||
for chunk in chunks:
|
||||
@ -141,7 +143,7 @@ class OpenAIEmbedding(EmbeddingBase):
|
||||
continue
|
||||
|
||||
if not chunk_embeddings:
|
||||
return [0.0] * 1536
|
||||
return [0.0] * self._fallback_dim
|
||||
|
||||
# Aggregation strategy
|
||||
if self.chunk_strategy == 'average':
|
||||
@ -163,6 +165,7 @@ class LocalEmbedding(EmbeddingBase):
|
||||
super().__init__(embedding_config)
|
||||
self.model_path = embedding_config.params.get('model_path')
|
||||
self.device = embedding_config.params.get('device', 'cpu')
|
||||
self._fallback_dim = 768 # Default; updated after first successful call
|
||||
|
||||
if not self.model_path:
|
||||
raise ValueError("LocalEmbedding requires model_path parameter")
|
||||
@ -179,11 +182,13 @@ class LocalEmbedding(EmbeddingBase):
|
||||
processed_text = self._preprocess_text(text)
|
||||
|
||||
if not processed_text:
|
||||
return [0.0] * 768 # Return zero vector
|
||||
return [0.0] * self._fallback_dim
|
||||
|
||||
try:
|
||||
embedding = self.model.encode(processed_text, convert_to_tensor=False)
|
||||
return embedding.tolist()
|
||||
result = embedding.tolist()
|
||||
self._fallback_dim = len(result)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting local embedding: {e}")
|
||||
return [0.0] * 768 # Return zero vector as fallback
|
||||
return [0.0] * self._fallback_dim
|
||||
|
||||
@ -123,11 +123,19 @@ class FileMemory(MemoryBase):
|
||||
query_embedding = query_embedding.reshape(1, -1)
|
||||
faiss.normalize_L2(query_embedding)
|
||||
|
||||
expected_dim = query_embedding.shape[1]
|
||||
|
||||
# Collect embeddings from memory items
|
||||
memory_embeddings = []
|
||||
valid_items = []
|
||||
for item in self.contents:
|
||||
if item.embedding is not None:
|
||||
if len(item.embedding) != expected_dim:
|
||||
logger.warning(
|
||||
"Skipping memory item %s: embedding dim %d != expected %d",
|
||||
item.id, len(item.embedding), expected_dim,
|
||||
)
|
||||
continue
|
||||
memory_embeddings.append(item.embedding)
|
||||
valid_items.append(item)
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
@ -16,6 +17,8 @@ from runtime.node.agent.memory.memory_base import (
|
||||
import faiss
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SimpleMemory(MemoryBase):
|
||||
def __init__(self, store: MemoryStoreConfig):
|
||||
config = store.as_config(SimpleMemoryConfig)
|
||||
@ -107,10 +110,18 @@ class SimpleMemory(MemoryBase):
|
||||
inputs_embedding = inputs_embedding.reshape(1, -1)
|
||||
faiss.normalize_L2(inputs_embedding)
|
||||
|
||||
expected_dim = inputs_embedding.shape[1]
|
||||
|
||||
memory_embeddings = []
|
||||
valid_items = []
|
||||
for item in self.contents:
|
||||
if item.embedding is not None:
|
||||
if len(item.embedding) != expected_dim:
|
||||
logger.warning(
|
||||
"Skipping memory item %s: embedding dim %d != expected %d",
|
||||
item.id, len(item.embedding), expected_dim,
|
||||
)
|
||||
continue
|
||||
memory_embeddings.append(item.embedding)
|
||||
valid_items.append(item)
|
||||
|
||||
|
||||
157
tests/test_memory_embedding_consistency.py
Normal file
157
tests/test_memory_embedding_consistency.py
Normal file
@ -0,0 +1,157 @@
|
||||
"""Tests for memory embedding dimension consistency."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
from runtime.node.agent.memory.memory_base import MemoryContentSnapshot, MemoryItem
|
||||
from runtime.node.agent.memory.simple_memory import SimpleMemory
|
||||
|
||||
def _make_store(memory_path=None):
|
||||
"""Build a minimal MemoryStoreConfig mock for SimpleMemory."""
|
||||
simple_cfg = MagicMock()
|
||||
simple_cfg.memory_path = memory_path
|
||||
simple_cfg.embedding = None # We'll set embedding manually
|
||||
|
||||
store = MagicMock()
|
||||
store.name = "test_store"
|
||||
store.as_config.return_value = simple_cfg
|
||||
return store
|
||||
|
||||
|
||||
def _make_embedding(dim: int):
|
||||
"""Create a mock EmbeddingBase that produces vectors of the given dimension."""
|
||||
emb = MagicMock()
|
||||
emb.get_embedding.return_value = [0.1] * dim
|
||||
return emb
|
||||
|
||||
|
||||
def _make_memory_item(item_id: str, dim: int):
|
||||
"""Create a MemoryItem with an embedding of the specified dimension."""
|
||||
return MemoryItem(
|
||||
id=item_id,
|
||||
content_summary=f"content for {item_id}",
|
||||
metadata={},
|
||||
embedding=[float(i) for i in range(dim)],
|
||||
)
|
||||
|
||||
|
||||
class TestSimpleMemoryRetrieveMixedDimensions:
|
||||
|
||||
def test_mixed_dimensions_does_not_crash(self):
|
||||
"""Retrieve with mixed-dimensional embeddings MUST not raise."""
|
||||
store = _make_store()
|
||||
memory = SimpleMemory(store)
|
||||
memory.embedding = _make_embedding(dim=768)
|
||||
|
||||
# 3 items with correct dim, 2 with wrong dim
|
||||
memory.contents = [
|
||||
_make_memory_item("ok_1", 768),
|
||||
_make_memory_item("bad_1", 1536),
|
||||
_make_memory_item("ok_2", 768),
|
||||
_make_memory_item("bad_2", 256),
|
||||
_make_memory_item("ok_3", 768),
|
||||
]
|
||||
|
||||
query = MemoryContentSnapshot(text="test query")
|
||||
# Should NOT raise ValueError / numpy error
|
||||
results = memory.retrieve(
|
||||
agent_role="tester",
|
||||
query=query,
|
||||
top_k=5,
|
||||
similarity_threshold=-1.0,
|
||||
)
|
||||
# Only the 3 correct-dimension items should be candidates
|
||||
assert len(results) <= 3
|
||||
|
||||
def test_all_same_dimension_returns_results(self):
|
||||
"""When all embeddings share the correct dimension, all are candidates."""
|
||||
store = _make_store()
|
||||
memory = SimpleMemory(store)
|
||||
memory.embedding = _make_embedding(dim=768)
|
||||
|
||||
memory.contents = [
|
||||
_make_memory_item("a", 768),
|
||||
_make_memory_item("b", 768),
|
||||
]
|
||||
|
||||
query = MemoryContentSnapshot(text="test query")
|
||||
results = memory.retrieve(
|
||||
agent_role="tester",
|
||||
query=query,
|
||||
top_k=5,
|
||||
similarity_threshold=-1.0,
|
||||
)
|
||||
assert len(results) == 2
|
||||
|
||||
def test_all_wrong_dimension_returns_empty(self):
|
||||
"""When every stored embedding has a wrong dimension, return empty."""
|
||||
store = _make_store()
|
||||
memory = SimpleMemory(store)
|
||||
memory.embedding = _make_embedding(dim=768)
|
||||
|
||||
memory.contents = [
|
||||
_make_memory_item("x", 1536),
|
||||
_make_memory_item("y", 1536),
|
||||
]
|
||||
|
||||
query = MemoryContentSnapshot(text="test query")
|
||||
results = memory.retrieve(
|
||||
agent_role="tester",
|
||||
query=query,
|
||||
top_k=5,
|
||||
similarity_threshold=-1.0,
|
||||
)
|
||||
assert results == []
|
||||
|
||||
|
||||
class TestOpenAIEmbeddingDynamicFallback:
|
||||
|
||||
def test_fallback_uses_model_dimension_after_success(self):
|
||||
"""After a successful call the fallback dimension MUST match the model."""
|
||||
from runtime.node.agent.memory.embedding import OpenAIEmbedding
|
||||
|
||||
cfg = MagicMock()
|
||||
cfg.base_url = "http://localhost:11434/v1"
|
||||
cfg.api_key = "test"
|
||||
cfg.model = "test-model"
|
||||
cfg.params = {}
|
||||
|
||||
emb = OpenAIEmbedding(cfg)
|
||||
assert emb._fallback_dim == 1536 # default before any call
|
||||
|
||||
# Simulate a successful 768-dim response
|
||||
mock_data = MagicMock()
|
||||
mock_data.embedding = [0.1] * 768
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [mock_data]
|
||||
|
||||
with patch.object(emb.client.embeddings, "create", return_value=mock_response):
|
||||
result = emb.get_embedding("hello world")
|
||||
|
||||
assert len(result) == 768
|
||||
assert emb._fallback_dim == 768 # updated after success
|
||||
|
||||
def test_fallback_zero_vector_matches_cached_dim(self):
|
||||
"""After caching dim, fallback zero-vectors MUST use that dim."""
|
||||
from runtime.node.agent.memory.embedding import OpenAIEmbedding
|
||||
|
||||
cfg = MagicMock()
|
||||
cfg.base_url = "http://localhost:11434/v1"
|
||||
cfg.api_key = "test"
|
||||
cfg.model = "test-model"
|
||||
cfg.params = {}
|
||||
|
||||
emb = OpenAIEmbedding(cfg)
|
||||
|
||||
# Simulate successful 512-dim call
|
||||
mock_data = MagicMock()
|
||||
mock_data.embedding = [0.1] * 512
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [mock_data]
|
||||
|
||||
with patch.object(emb.client.embeddings, "create", return_value=mock_response):
|
||||
emb.get_embedding("first call")
|
||||
|
||||
# Now simulate a failure — fallback should be 512-dim
|
||||
with patch.object(emb.client.embeddings, "create", side_effect=Exception("API down")):
|
||||
fallback = emb.get_embedding("failing call")
|
||||
|
||||
assert len(fallback) == 512
|
||||
assert all(v == 0.0 for v in fallback)
|
||||
Loading…
x
Reference in New Issue
Block a user