mirror of
https://github.com/OpenBMB/ChatDev.git
synced 2026-04-25 11:18:06 +00:00
Merge branch 'main' into main
This commit is contained in:
commit
036fb6efe1
19
Makefile
19
Makefile
@ -41,6 +41,23 @@ validate-yamls: ## Validate all YAML configuration files
|
|||||||
|
|
||||||
.PHONY: help
|
.PHONY: help
|
||||||
help: ## Display this help message
|
help: ## Display this help message
|
||||||
@python -c "import re; \
|
@uv run python -c "import re; \
|
||||||
p=r'$(firstword $(MAKEFILE_LIST))'.strip(); \
|
p=r'$(firstword $(MAKEFILE_LIST))'.strip(); \
|
||||||
[print(f'{m[0]:<20} {m[1]}') for m in re.findall(r'^([a-zA-Z_-]+):.*?## (.*)$$', open(p, encoding='utf-8').read(), re.M)]" | sort
|
[print(f'{m[0]:<20} {m[1]}') for m in re.findall(r'^([a-zA-Z_-]+):.*?## (.*)$$', open(p, encoding='utf-8').read(), re.M)]" | sort
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Quality Checks
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
.PHONY: check-backend
|
||||||
|
check-backend: ## Run backend quality checks (tests + linting)
|
||||||
|
@$(MAKE) backend-tests
|
||||||
|
@$(MAKE) backend-lint
|
||||||
|
|
||||||
|
.PHONY: backend-tests
|
||||||
|
backend-tests: ## Run backend tests
|
||||||
|
@uv run pytest -v
|
||||||
|
|
||||||
|
.PHONY: backend-lint
|
||||||
|
backend-lint: ## Run backend linting
|
||||||
|
@uvx ruff check .
|
||||||
|
|||||||
@ -2390,6 +2390,7 @@ watch(
|
|||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
pointer-events: auto;
|
pointer-events: auto;
|
||||||
z-index: auto;
|
z-index: auto;
|
||||||
|
min-height: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
.chat-panel-fullscreen .chat-panel-content {
|
.chat-panel-fullscreen .chat-panel-content {
|
||||||
@ -2408,6 +2409,7 @@ watch(
|
|||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
gap: 12px;
|
gap: 12px;
|
||||||
min-width: 0;
|
min-width: 0;
|
||||||
|
min-height: 0;
|
||||||
pointer-events: auto;
|
pointer-events: auto;
|
||||||
background: rgba(26, 26, 26, 0.92);
|
background: rgba(26, 26, 26, 0.92);
|
||||||
border: 1px solid rgba(255, 255, 255, 0.1);
|
border: 1px solid rgba(255, 255, 255, 0.1);
|
||||||
@ -2463,6 +2465,7 @@ watch(
|
|||||||
overflow: hidden;
|
overflow: hidden;
|
||||||
display: flex;
|
display: flex;
|
||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
|
min-height: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
.chat-messages::-webkit-scrollbar {
|
.chat-messages::-webkit-scrollbar {
|
||||||
|
|||||||
@ -45,5 +45,19 @@ dependencies = [
|
|||||||
requires = ["hatchling"]
|
requires = ["hatchling"]
|
||||||
build-backend = "hatchling.build"
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
pythonpath = ["."]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
python_files = ["test_*.py"]
|
||||||
|
python_classes = ["Test*"]
|
||||||
|
python_functions = ["test_*"]
|
||||||
|
addopts = "-v --tb=short"
|
||||||
|
filterwarnings = [
|
||||||
|
# Upstream SWIG issue in faiss-cpu on Python 3.12; awaiting SWIG 4.4 fix.
|
||||||
|
"ignore:builtin type Swig.*:DeprecationWarning",
|
||||||
|
"ignore:builtin type swigvarlink.*:DeprecationWarning",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
[tool.uv]
|
[tool.uv]
|
||||||
package = false
|
package = false
|
||||||
|
|||||||
@ -86,6 +86,7 @@ class OpenAIEmbedding(EmbeddingBase):
|
|||||||
self.max_length = embedding_config.params.get('max_length', 8191)
|
self.max_length = embedding_config.params.get('max_length', 8191)
|
||||||
self.use_chunking = embedding_config.params.get('use_chunking', False)
|
self.use_chunking = embedding_config.params.get('use_chunking', False)
|
||||||
self.chunk_strategy = embedding_config.params.get('chunk_strategy', 'average')
|
self.chunk_strategy = embedding_config.params.get('chunk_strategy', 'average')
|
||||||
|
self._fallback_dim = 1536 # Default; updated after first successful call
|
||||||
|
|
||||||
if self.base_url:
|
if self.base_url:
|
||||||
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
|
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||||
@ -99,7 +100,7 @@ class OpenAIEmbedding(EmbeddingBase):
|
|||||||
|
|
||||||
if not processed_text:
|
if not processed_text:
|
||||||
logger.warning("Empty text after preprocessing")
|
logger.warning("Empty text after preprocessing")
|
||||||
return [0.0] * 1536 # Return a zero vector
|
return [0.0] * self._fallback_dim
|
||||||
|
|
||||||
# Handle long text via chunking
|
# Handle long text via chunking
|
||||||
if self.use_chunking and len(processed_text) > self.max_length:
|
if self.use_chunking and len(processed_text) > self.max_length:
|
||||||
@ -115,17 +116,18 @@ class OpenAIEmbedding(EmbeddingBase):
|
|||||||
encoding_format="float"
|
encoding_format="float"
|
||||||
)
|
)
|
||||||
embedding = response.data[0].embedding
|
embedding = response.data[0].embedding
|
||||||
|
self._fallback_dim = len(embedding)
|
||||||
return embedding
|
return embedding
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting embedding: {e}")
|
logger.error(f"Error getting embedding: {e}")
|
||||||
return [0.0] * 1536 # Return zero vector as fallback
|
return [0.0] * self._fallback_dim
|
||||||
|
|
||||||
def _get_chunked_embedding(self, text: str) -> List[float]:
|
def _get_chunked_embedding(self, text: str) -> List[float]:
|
||||||
"""Chunk long text, embed each chunk, then aggregate."""
|
"""Chunk long text, embed each chunk, then aggregate."""
|
||||||
chunks = self._chunk_text(text, self.max_length // 2) # Halve the chunk length
|
chunks = self._chunk_text(text, self.max_length // 2) # Halve the chunk length
|
||||||
|
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return [0.0] * 1536
|
return [0.0] * self._fallback_dim
|
||||||
|
|
||||||
chunk_embeddings = []
|
chunk_embeddings = []
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
@ -141,7 +143,7 @@ class OpenAIEmbedding(EmbeddingBase):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if not chunk_embeddings:
|
if not chunk_embeddings:
|
||||||
return [0.0] * 1536
|
return [0.0] * self._fallback_dim
|
||||||
|
|
||||||
# Aggregation strategy
|
# Aggregation strategy
|
||||||
if self.chunk_strategy == 'average':
|
if self.chunk_strategy == 'average':
|
||||||
@ -163,6 +165,7 @@ class LocalEmbedding(EmbeddingBase):
|
|||||||
super().__init__(embedding_config)
|
super().__init__(embedding_config)
|
||||||
self.model_path = embedding_config.params.get('model_path')
|
self.model_path = embedding_config.params.get('model_path')
|
||||||
self.device = embedding_config.params.get('device', 'cpu')
|
self.device = embedding_config.params.get('device', 'cpu')
|
||||||
|
self._fallback_dim = 768 # Default; updated after first successful call
|
||||||
|
|
||||||
if not self.model_path:
|
if not self.model_path:
|
||||||
raise ValueError("LocalEmbedding requires model_path parameter")
|
raise ValueError("LocalEmbedding requires model_path parameter")
|
||||||
@ -179,11 +182,13 @@ class LocalEmbedding(EmbeddingBase):
|
|||||||
processed_text = self._preprocess_text(text)
|
processed_text = self._preprocess_text(text)
|
||||||
|
|
||||||
if not processed_text:
|
if not processed_text:
|
||||||
return [0.0] * 768 # Return zero vector
|
return [0.0] * self._fallback_dim
|
||||||
|
|
||||||
try:
|
try:
|
||||||
embedding = self.model.encode(processed_text, convert_to_tensor=False)
|
embedding = self.model.encode(processed_text, convert_to_tensor=False)
|
||||||
return embedding.tolist()
|
result = embedding.tolist()
|
||||||
|
self._fallback_dim = len(result)
|
||||||
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting local embedding: {e}")
|
logger.error(f"Error getting local embedding: {e}")
|
||||||
return [0.0] * 768 # Return zero vector as fallback
|
return [0.0] * self._fallback_dim
|
||||||
|
|||||||
@ -123,11 +123,19 @@ class FileMemory(MemoryBase):
|
|||||||
query_embedding = query_embedding.reshape(1, -1)
|
query_embedding = query_embedding.reshape(1, -1)
|
||||||
faiss.normalize_L2(query_embedding)
|
faiss.normalize_L2(query_embedding)
|
||||||
|
|
||||||
|
expected_dim = query_embedding.shape[1]
|
||||||
|
|
||||||
# Collect embeddings from memory items
|
# Collect embeddings from memory items
|
||||||
memory_embeddings = []
|
memory_embeddings = []
|
||||||
valid_items = []
|
valid_items = []
|
||||||
for item in self.contents:
|
for item in self.contents:
|
||||||
if item.embedding is not None:
|
if item.embedding is not None:
|
||||||
|
if len(item.embedding) != expected_dim:
|
||||||
|
logger.warning(
|
||||||
|
"Skipping memory item %s: embedding dim %d != expected %d",
|
||||||
|
item.id, len(item.embedding), expected_dim,
|
||||||
|
)
|
||||||
|
continue
|
||||||
memory_embeddings.append(item.embedding)
|
memory_embeddings.append(item.embedding)
|
||||||
valid_items.append(item)
|
valid_items.append(item)
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
@ -16,6 +17,8 @@ from runtime.node.agent.memory.memory_base import (
|
|||||||
import faiss
|
import faiss
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class SimpleMemory(MemoryBase):
|
class SimpleMemory(MemoryBase):
|
||||||
def __init__(self, store: MemoryStoreConfig):
|
def __init__(self, store: MemoryStoreConfig):
|
||||||
config = store.as_config(SimpleMemoryConfig)
|
config = store.as_config(SimpleMemoryConfig)
|
||||||
@ -107,10 +110,18 @@ class SimpleMemory(MemoryBase):
|
|||||||
inputs_embedding = inputs_embedding.reshape(1, -1)
|
inputs_embedding = inputs_embedding.reshape(1, -1)
|
||||||
faiss.normalize_L2(inputs_embedding)
|
faiss.normalize_L2(inputs_embedding)
|
||||||
|
|
||||||
|
expected_dim = inputs_embedding.shape[1]
|
||||||
|
|
||||||
memory_embeddings = []
|
memory_embeddings = []
|
||||||
valid_items = []
|
valid_items = []
|
||||||
for item in self.contents:
|
for item in self.contents:
|
||||||
if item.embedding is not None:
|
if item.embedding is not None:
|
||||||
|
if len(item.embedding) != expected_dim:
|
||||||
|
logger.warning(
|
||||||
|
"Skipping memory item %s: embedding dim %d != expected %d",
|
||||||
|
item.id, len(item.embedding), expected_dim,
|
||||||
|
)
|
||||||
|
continue
|
||||||
memory_embeddings.append(item.embedding)
|
memory_embeddings.append(item.embedding)
|
||||||
valid_items.append(item)
|
valid_items.append(item)
|
||||||
|
|
||||||
|
|||||||
157
tests/test_memory_embedding_consistency.py
Normal file
157
tests/test_memory_embedding_consistency.py
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
"""Tests for memory embedding dimension consistency."""
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from runtime.node.agent.memory.memory_base import MemoryContentSnapshot, MemoryItem
|
||||||
|
from runtime.node.agent.memory.simple_memory import SimpleMemory
|
||||||
|
|
||||||
|
def _make_store(memory_path=None):
|
||||||
|
"""Build a minimal MemoryStoreConfig mock for SimpleMemory."""
|
||||||
|
simple_cfg = MagicMock()
|
||||||
|
simple_cfg.memory_path = memory_path
|
||||||
|
simple_cfg.embedding = None # We'll set embedding manually
|
||||||
|
|
||||||
|
store = MagicMock()
|
||||||
|
store.name = "test_store"
|
||||||
|
store.as_config.return_value = simple_cfg
|
||||||
|
return store
|
||||||
|
|
||||||
|
|
||||||
|
def _make_embedding(dim: int):
|
||||||
|
"""Create a mock EmbeddingBase that produces vectors of the given dimension."""
|
||||||
|
emb = MagicMock()
|
||||||
|
emb.get_embedding.return_value = [0.1] * dim
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
def _make_memory_item(item_id: str, dim: int):
|
||||||
|
"""Create a MemoryItem with an embedding of the specified dimension."""
|
||||||
|
return MemoryItem(
|
||||||
|
id=item_id,
|
||||||
|
content_summary=f"content for {item_id}",
|
||||||
|
metadata={},
|
||||||
|
embedding=[float(i) for i in range(dim)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSimpleMemoryRetrieveMixedDimensions:
|
||||||
|
|
||||||
|
def test_mixed_dimensions_does_not_crash(self):
|
||||||
|
"""Retrieve with mixed-dimensional embeddings MUST not raise."""
|
||||||
|
store = _make_store()
|
||||||
|
memory = SimpleMemory(store)
|
||||||
|
memory.embedding = _make_embedding(dim=768)
|
||||||
|
|
||||||
|
# 3 items with correct dim, 2 with wrong dim
|
||||||
|
memory.contents = [
|
||||||
|
_make_memory_item("ok_1", 768),
|
||||||
|
_make_memory_item("bad_1", 1536),
|
||||||
|
_make_memory_item("ok_2", 768),
|
||||||
|
_make_memory_item("bad_2", 256),
|
||||||
|
_make_memory_item("ok_3", 768),
|
||||||
|
]
|
||||||
|
|
||||||
|
query = MemoryContentSnapshot(text="test query")
|
||||||
|
# Should NOT raise ValueError / numpy error
|
||||||
|
results = memory.retrieve(
|
||||||
|
agent_role="tester",
|
||||||
|
query=query,
|
||||||
|
top_k=5,
|
||||||
|
similarity_threshold=-1.0,
|
||||||
|
)
|
||||||
|
# Only the 3 correct-dimension items should be candidates
|
||||||
|
assert len(results) <= 3
|
||||||
|
|
||||||
|
def test_all_same_dimension_returns_results(self):
|
||||||
|
"""When all embeddings share the correct dimension, all are candidates."""
|
||||||
|
store = _make_store()
|
||||||
|
memory = SimpleMemory(store)
|
||||||
|
memory.embedding = _make_embedding(dim=768)
|
||||||
|
|
||||||
|
memory.contents = [
|
||||||
|
_make_memory_item("a", 768),
|
||||||
|
_make_memory_item("b", 768),
|
||||||
|
]
|
||||||
|
|
||||||
|
query = MemoryContentSnapshot(text="test query")
|
||||||
|
results = memory.retrieve(
|
||||||
|
agent_role="tester",
|
||||||
|
query=query,
|
||||||
|
top_k=5,
|
||||||
|
similarity_threshold=-1.0,
|
||||||
|
)
|
||||||
|
assert len(results) == 2
|
||||||
|
|
||||||
|
def test_all_wrong_dimension_returns_empty(self):
|
||||||
|
"""When every stored embedding has a wrong dimension, return empty."""
|
||||||
|
store = _make_store()
|
||||||
|
memory = SimpleMemory(store)
|
||||||
|
memory.embedding = _make_embedding(dim=768)
|
||||||
|
|
||||||
|
memory.contents = [
|
||||||
|
_make_memory_item("x", 1536),
|
||||||
|
_make_memory_item("y", 1536),
|
||||||
|
]
|
||||||
|
|
||||||
|
query = MemoryContentSnapshot(text="test query")
|
||||||
|
results = memory.retrieve(
|
||||||
|
agent_role="tester",
|
||||||
|
query=query,
|
||||||
|
top_k=5,
|
||||||
|
similarity_threshold=-1.0,
|
||||||
|
)
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIEmbeddingDynamicFallback:
|
||||||
|
|
||||||
|
def test_fallback_uses_model_dimension_after_success(self):
|
||||||
|
"""After a successful call the fallback dimension MUST match the model."""
|
||||||
|
from runtime.node.agent.memory.embedding import OpenAIEmbedding
|
||||||
|
|
||||||
|
cfg = MagicMock()
|
||||||
|
cfg.base_url = "http://localhost:11434/v1"
|
||||||
|
cfg.api_key = "test"
|
||||||
|
cfg.model = "test-model"
|
||||||
|
cfg.params = {}
|
||||||
|
|
||||||
|
emb = OpenAIEmbedding(cfg)
|
||||||
|
assert emb._fallback_dim == 1536 # default before any call
|
||||||
|
|
||||||
|
# Simulate a successful 768-dim response
|
||||||
|
mock_data = MagicMock()
|
||||||
|
mock_data.embedding = [0.1] * 768
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.data = [mock_data]
|
||||||
|
|
||||||
|
with patch.object(emb.client.embeddings, "create", return_value=mock_response):
|
||||||
|
result = emb.get_embedding("hello world")
|
||||||
|
|
||||||
|
assert len(result) == 768
|
||||||
|
assert emb._fallback_dim == 768 # updated after success
|
||||||
|
|
||||||
|
def test_fallback_zero_vector_matches_cached_dim(self):
|
||||||
|
"""After caching dim, fallback zero-vectors MUST use that dim."""
|
||||||
|
from runtime.node.agent.memory.embedding import OpenAIEmbedding
|
||||||
|
|
||||||
|
cfg = MagicMock()
|
||||||
|
cfg.base_url = "http://localhost:11434/v1"
|
||||||
|
cfg.api_key = "test"
|
||||||
|
cfg.model = "test-model"
|
||||||
|
cfg.params = {}
|
||||||
|
|
||||||
|
emb = OpenAIEmbedding(cfg)
|
||||||
|
|
||||||
|
# Simulate successful 512-dim call
|
||||||
|
mock_data = MagicMock()
|
||||||
|
mock_data.embedding = [0.1] * 512
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.data = [mock_data]
|
||||||
|
|
||||||
|
with patch.object(emb.client.embeddings, "create", return_value=mock_response):
|
||||||
|
emb.get_embedding("first call")
|
||||||
|
|
||||||
|
# Now simulate a failure — fallback should be 512-dim
|
||||||
|
with patch.object(emb.client.embeddings, "create", side_effect=Exception("API down")):
|
||||||
|
fallback = emb.get_embedding("failing call")
|
||||||
|
|
||||||
|
assert len(fallback) == 512
|
||||||
|
assert all(v == 0.0 for v in fallback)
|
||||||
Loading…
x
Reference in New Issue
Block a user