ChatDev/tests/test_mem0_memory.py

385 lines
14 KiB
Python

"""Tests for Mem0 memory store implementation."""
from unittest.mock import MagicMock, patch
import pytest
from entity.configs.node.memory import Mem0MemoryConfig
from runtime.node.agent.memory.memory_base import (
MemoryContentSnapshot,
MemoryItem,
MemoryWritePayload,
)
def _make_store(user_id=None, agent_id=None, api_key="test-key"):
"""Build a minimal MemoryStoreConfig mock for Mem0Memory."""
mem0_cfg = MagicMock(spec=Mem0MemoryConfig)
mem0_cfg.api_key = api_key
mem0_cfg.org_id = None
mem0_cfg.project_id = None
mem0_cfg.user_id = user_id
mem0_cfg.agent_id = agent_id
store = MagicMock()
store.name = "test_mem0"
# Return correct config type based on the requested class
def _as_config_side_effect(expected_type, **kwargs):
if expected_type is Mem0MemoryConfig:
return mem0_cfg
return None
store.as_config.side_effect = _as_config_side_effect
return store
def _make_mem0_memory(user_id=None, agent_id=None):
"""Create a Mem0Memory with a mocked client."""
with patch("runtime.node.agent.memory.mem0_memory._get_mem0_client") as mock_get:
mock_client = MagicMock()
mock_get.return_value = mock_client
from runtime.node.agent.memory.mem0_memory import Mem0Memory
store = _make_store(user_id=user_id, agent_id=agent_id)
memory = Mem0Memory(store)
return memory, mock_client
class TestMem0MemoryRetrieve:
def test_retrieve_with_agent_id(self):
"""Retrieve passes agent_id in filters dict to SDK search."""
memory, client = _make_mem0_memory(agent_id="agent-1")
client.search.return_value = {
"memories": [
{"id": "m1", "memory": "test fact", "score": 0.95},
]
}
query = MemoryContentSnapshot(text="what do you know?")
results = memory.retrieve("writer", query, top_k=5, similarity_threshold=-1.0)
client.search.assert_called_once()
call_kwargs = client.search.call_args[1]
assert call_kwargs["filters"] == {"agent_id": "agent-1"}
assert len(results) == 1
assert results[0].content_summary == "test fact"
assert results[0].metadata["source"] == "mem0"
def test_retrieve_with_user_id(self):
"""Retrieve passes user_id in filters dict to SDK search."""
memory, client = _make_mem0_memory(user_id="user-1")
client.search.return_value = {
"memories": [
{"id": "m1", "memory": "user pref", "score": 0.9},
]
}
query = MemoryContentSnapshot(text="preferences")
results = memory.retrieve("assistant", query, top_k=3, similarity_threshold=-1.0)
call_kwargs = client.search.call_args[1]
assert call_kwargs["filters"] == {"user_id": "user-1"}
assert len(results) == 1
def test_retrieve_with_both_ids_uses_or_filter(self):
"""When both user_id and agent_id are set, an OR filter is used."""
memory, client = _make_mem0_memory(user_id="user-1", agent_id="agent-1")
client.search.return_value = {
"memories": [
{"id": "u1", "memory": "user fact", "score": 0.8},
{"id": "a1", "memory": "agent fact", "score": 0.9},
]
}
query = MemoryContentSnapshot(text="test")
results = memory.retrieve("writer", query, top_k=5, similarity_threshold=-1.0)
client.search.assert_called_once()
call_kwargs = client.search.call_args[1]
assert call_kwargs["filters"] == {
"OR": [
{"user_id": "user-1"},
{"agent_id": "agent-1"},
]
}
assert len(results) == 2
def test_retrieve_fallback_uses_agent_role(self):
"""When no IDs configured, fall back to agent_role as agent_id in filters."""
memory, client = _make_mem0_memory()
client.search.return_value = {"memories": []}
query = MemoryContentSnapshot(text="test")
memory.retrieve("coder", query, top_k=3, similarity_threshold=-1.0)
call_kwargs = client.search.call_args[1]
assert call_kwargs["filters"] == {"agent_id": "coder"}
def test_retrieve_empty_query_returns_empty(self):
"""Empty query text returns empty without calling API."""
memory, client = _make_mem0_memory(agent_id="a1")
query = MemoryContentSnapshot(text=" ")
results = memory.retrieve("writer", query, top_k=3, similarity_threshold=-1.0)
assert results == []
client.search.assert_not_called()
def test_retrieve_api_error_returns_empty(self):
"""API errors are caught and return empty list."""
memory, client = _make_mem0_memory(agent_id="a1")
client.search.side_effect = Exception("API down")
query = MemoryContentSnapshot(text="test")
results = memory.retrieve("writer", query, top_k=3, similarity_threshold=-1.0)
assert results == []
def test_retrieve_respects_top_k(self):
"""top_k is passed to Mem0 search."""
memory, client = _make_mem0_memory(agent_id="a1")
client.search.return_value = {"memories": []}
query = MemoryContentSnapshot(text="test")
memory.retrieve("writer", query, top_k=7, similarity_threshold=-1.0)
call_kwargs = client.search.call_args[1]
assert call_kwargs["top_k"] == 7
def test_retrieve_passes_threshold_when_non_negative(self):
"""Non-negative similarity_threshold is forwarded to Mem0."""
memory, client = _make_mem0_memory(agent_id="a1")
client.search.return_value = {"memories": []}
query = MemoryContentSnapshot(text="test")
memory.retrieve("writer", query, top_k=3, similarity_threshold=0.5)
call_kwargs = client.search.call_args[1]
assert call_kwargs["threshold"] == 0.5
def test_retrieve_passes_zero_threshold(self):
"""A threshold of 0.0 is a valid value and should be sent."""
memory, client = _make_mem0_memory(agent_id="a1")
client.search.return_value = {"memories": []}
query = MemoryContentSnapshot(text="test")
memory.retrieve("writer", query, top_k=3, similarity_threshold=0.0)
call_kwargs = client.search.call_args[1]
assert call_kwargs["threshold"] == 0.0
def test_retrieve_skips_threshold_when_negative(self):
"""Negative similarity_threshold is not sent to Mem0."""
memory, client = _make_mem0_memory(agent_id="a1")
client.search.return_value = {"memories": []}
query = MemoryContentSnapshot(text="test")
memory.retrieve("writer", query, top_k=3, similarity_threshold=-1.0)
call_kwargs = client.search.call_args[1]
assert "threshold" not in call_kwargs
def test_retrieve_handles_legacy_results_key(self):
"""Handles SDK response with 'results' key (older SDK versions)."""
memory, client = _make_mem0_memory(agent_id="a1")
client.search.return_value = {
"results": [
{"id": "m1", "memory": "legacy format", "score": 0.8},
]
}
query = MemoryContentSnapshot(text="test")
results = memory.retrieve("writer", query, top_k=3, similarity_threshold=-1.0)
assert len(results) == 1
assert results[0].content_summary == "legacy format"
class TestMem0MemoryUpdate:
def test_update_with_agent_id_uses_assistant_role(self):
"""Agent-scoped update sends role=assistant messages with agent_id."""
memory, client = _make_mem0_memory(agent_id="agent-1")
client.add.return_value = [{"id": "new", "event": "ADD"}]
payload = MemoryWritePayload(
agent_role="writer",
inputs_text="Write about AI",
input_snapshot=MemoryContentSnapshot(text="Write about AI"),
output_snapshot=MemoryContentSnapshot(text="AI is transformative..."),
)
memory.update(payload)
client.add.assert_called_once()
call_kwargs = client.add.call_args[1]
assert call_kwargs["agent_id"] == "agent-1"
assert "user_id" not in call_kwargs
messages = call_kwargs["messages"]
assert messages[0]["role"] == "user"
assert messages[1]["role"] == "assistant"
def test_update_with_user_id(self):
"""User-scoped update uses user_id, not agent_id."""
memory, client = _make_mem0_memory(user_id="user-1")
client.add.return_value = []
payload = MemoryWritePayload(
agent_role="writer",
inputs_text="I prefer Python",
input_snapshot=None,
output_snapshot=MemoryContentSnapshot(text="Noted your preference"),
)
memory.update(payload)
call_kwargs = client.add.call_args[1]
assert call_kwargs["user_id"] == "user-1"
assert "agent_id" not in call_kwargs
def test_update_fallback_uses_agent_role(self):
"""When no IDs configured, uses agent_role as agent_id."""
memory, client = _make_mem0_memory()
client.add.return_value = []
payload = MemoryWritePayload(
agent_role="coder",
inputs_text="test input",
input_snapshot=None,
output_snapshot=MemoryContentSnapshot(text="test output"),
)
memory.update(payload)
call_kwargs = client.add.call_args[1]
assert call_kwargs["agent_id"] == "coder"
def test_update_with_both_ids_prefers_agent_id(self):
"""When both user_id and agent_id configured, agent_id takes precedence for writes."""
memory, client = _make_mem0_memory(user_id="user-1", agent_id="agent-1")
client.add.return_value = []
payload = MemoryWritePayload(
agent_role="writer",
inputs_text="input",
input_snapshot=None,
output_snapshot=MemoryContentSnapshot(text="output"),
)
memory.update(payload)
call_kwargs = client.add.call_args[1]
assert call_kwargs["agent_id"] == "agent-1"
assert "user_id" not in call_kwargs
def test_update_empty_output_is_noop(self):
"""Empty output snapshot skips API call."""
memory, client = _make_mem0_memory(agent_id="a1")
payload = MemoryWritePayload(
agent_role="writer",
inputs_text="",
input_snapshot=None,
output_snapshot=MemoryContentSnapshot(text=" "),
)
memory.update(payload)
client.add.assert_not_called()
def test_update_no_snapshot_is_noop(self):
"""No snapshot at all skips API call."""
memory, client = _make_mem0_memory(agent_id="a1")
payload = MemoryWritePayload(
agent_role="writer",
inputs_text="test",
input_snapshot=None,
output_snapshot=None,
)
memory.update(payload)
client.add.assert_not_called()
def test_update_api_error_does_not_raise(self):
"""API errors are logged but do not propagate."""
memory, client = _make_mem0_memory(agent_id="a1")
client.add.side_effect = Exception("API error")
payload = MemoryWritePayload(
agent_role="writer",
inputs_text="test",
input_snapshot=None,
output_snapshot=MemoryContentSnapshot(text="output"),
)
# Should not raise
memory.update(payload)
class TestMem0MemoryLoadSave:
def test_load_is_noop(self):
"""load() does nothing for cloud-managed store."""
memory, _ = _make_mem0_memory(agent_id="a1")
memory.load() # Should not raise
def test_save_is_noop(self):
"""save() does nothing for cloud-managed store."""
memory, _ = _make_mem0_memory(agent_id="a1")
memory.save() # Should not raise
class TestMem0MemoryConfig:
def test_config_from_dict(self):
"""Config parses from dict correctly."""
data = {
"api_key": "test-key",
"user_id": "u1",
"org_id": "org-1",
}
config = Mem0MemoryConfig.from_dict(data, path="test")
assert config.api_key == "test-key"
assert config.user_id == "u1"
assert config.org_id == "org-1"
assert config.agent_id is None
assert config.project_id is None
def test_config_field_specs_exist(self):
"""FIELD_SPECS are defined for UI generation."""
specs = Mem0MemoryConfig.field_specs()
assert "api_key" in specs
assert "user_id" in specs
assert "agent_id" in specs
assert specs["api_key"].required is True
def test_config_requires_api_key(self):
"""Config raises ConfigError when api_key is missing."""
from entity.configs.base import ConfigError
data = {"agent_id": "a1"}
with pytest.raises(ConfigError):
Mem0MemoryConfig.from_dict(data, path="test")
class TestMem0MemoryConstructor:
def test_raises_on_wrong_config_type(self):
"""Mem0Memory raises ValueError when store has wrong config type."""
from runtime.node.agent.memory.mem0_memory import Mem0Memory
store = MagicMock()
store.name = "bad_store"
store.as_config.return_value = None # Wrong config type
with pytest.raises(ValueError, match="Mem0 memory store configuration"):
Mem0Memory(store)
def test_import_error_when_mem0ai_missing(self):
"""Helpful ImportError when mem0ai is not installed."""
from runtime.node.agent.memory.mem0_memory import _get_mem0_client
mem0_cfg = MagicMock(spec=Mem0MemoryConfig)
mem0_cfg.api_key = "test"
mem0_cfg.org_id = None
mem0_cfg.project_id = None
with patch.dict("sys.modules", {"mem0": None}):
with pytest.raises(ImportError, match="pip install mem0ai"):
_get_mem0_client(mem0_cfg)