mirror of
https://github.com/OpenBMB/ChatDev.git
synced 2026-04-25 11:18:06 +00:00
feat: add Mem0 memory integration with config, implementation, docs, tests, and dependency
This commit is contained in:
parent
cb75e0692c
commit
adc00f4faf
@ -32,12 +32,23 @@ memory:
|
|||||||
model: text-embedding-3-small
|
model: text-embedding-3-small
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Mem0 Memory Config
|
||||||
|
```yaml
|
||||||
|
memory:
|
||||||
|
- name: agent_memory
|
||||||
|
type: mem0
|
||||||
|
config:
|
||||||
|
api_key: ${MEM0_API_KEY}
|
||||||
|
agent_id: my-agent
|
||||||
|
```
|
||||||
|
|
||||||
## 3. Built-in Store Comparison
|
## 3. Built-in Store Comparison
|
||||||
| Type | Path | Highlights | Best for |
|
| Type | Path | Highlights | Best for |
|
||||||
| --- | --- | --- | --- |
|
| --- | --- | --- | --- |
|
||||||
| `simple` | `node/agent/memory/simple_memory.py` | Optional disk persistence (JSON) after runs; FAISS + semantic rerank; read/write capable. | Small conversation history, prototypes. |
|
| `simple` | `node/agent/memory/simple_memory.py` | Optional disk persistence (JSON) after runs; FAISS + semantic rerank; read/write capable. | Small conversation history, prototypes. |
|
||||||
| `file` | `node/agent/memory/file_memory.py` | Chunks files/dirs into a vector index, read-only, auto rebuilds when files change. | Knowledge bases, doc QA. |
|
| `file` | `node/agent/memory/file_memory.py` | Chunks files/dirs into a vector index, read-only, auto rebuilds when files change. | Knowledge bases, doc QA. |
|
||||||
| `blackboard` | `node/agent/memory/blackboard_memory.py` | Lightweight append-only log trimmed by time/count; no vector search. | Broadcast boards, pipeline debugging. |
|
| `blackboard` | `node/agent/memory/blackboard_memory.py` | Lightweight append-only log trimmed by time/count; no vector search. | Broadcast boards, pipeline debugging. |
|
||||||
|
| `mem0` | `node/agent/memory/mem0_memory.py` | Cloud-managed by Mem0; semantic search + graph relationships; no local embeddings or persistence needed. Requires `mem0ai` package. | Production memory, cross-session persistence, multi-agent memory sharing. |
|
||||||
|
|
||||||
All stores register through `register_memory_store()` so summaries show up in UI via `MemoryStoreConfig.field_specs()`.
|
All stores register through `register_memory_store()` so summaries show up in UI via `MemoryStoreConfig.field_specs()`.
|
||||||
|
|
||||||
@ -98,6 +109,14 @@ This schema lets multimodal outputs flow into Memory/Thinking modules without ex
|
|||||||
- **Retrieval** – Returns the latest `top_k` entries ordered by time.
|
- **Retrieval** – Returns the latest `top_k` entries ordered by time.
|
||||||
- **Write** – `update()` appends the latest snapshot (input/output blocks, attachments, previews). No embeddings are generated, so retrieval is purely recency-based.
|
- **Write** – `update()` appends the latest snapshot (input/output blocks, attachments, previews). No embeddings are generated, so retrieval is purely recency-based.
|
||||||
|
|
||||||
|
### 5.4 Mem0Memory
|
||||||
|
- **Config** – Requires `api_key` (from [app.mem0.ai](https://app.mem0.ai)). Optional `user_id`, `agent_id`, `org_id`, `project_id` for scoping.
|
||||||
|
- **Important**: `user_id` and `agent_id` are mutually exclusive in Mem0 API calls. If both are configured, two separate searches are made and results merged. For writes, `agent_id` takes precedence. Agent-generated content is stored with `role: "assistant"`.
|
||||||
|
- **Retrieval** – Uses Mem0's server-side semantic search. Supports `top_k` and `similarity_threshold` via `MemoryAttachmentConfig`.
|
||||||
|
- **Write** – `update()` sends conversation messages to Mem0 via the SDK. Agent outputs use `role: "assistant"`, user inputs use `role: "user"`.
|
||||||
|
- **Persistence** – Fully cloud-managed. `load()` and `save()` are no-ops. Memories persist across runs and sessions automatically.
|
||||||
|
- **Dependencies** – Requires `mem0ai` package (`pip install mem0ai`).
|
||||||
|
|
||||||
## 6. EmbeddingConfig Notes
|
## 6. EmbeddingConfig Notes
|
||||||
- Fields: `provider`, `model`, `api_key`, `base_url`, `params`.
|
- Fields: `provider`, `model`, `api_key`, `base_url`, `params`.
|
||||||
- `provider=openai` uses the official client; override `base_url` for compatibility layers.
|
- `provider=openai` uses the official client; override `base_url` for compatibility layers.
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from .node.memory import (
|
|||||||
EmbeddingConfig,
|
EmbeddingConfig,
|
||||||
FileMemoryConfig,
|
FileMemoryConfig,
|
||||||
FileSourceConfig,
|
FileSourceConfig,
|
||||||
|
Mem0MemoryConfig,
|
||||||
MemoryAttachmentConfig,
|
MemoryAttachmentConfig,
|
||||||
MemoryStoreConfig,
|
MemoryStoreConfig,
|
||||||
SimpleMemoryConfig,
|
SimpleMemoryConfig,
|
||||||
@ -43,6 +44,7 @@ __all__ = [
|
|||||||
"FunctionToolConfig",
|
"FunctionToolConfig",
|
||||||
"GraphDefinition",
|
"GraphDefinition",
|
||||||
"HumanConfig",
|
"HumanConfig",
|
||||||
|
"Mem0MemoryConfig",
|
||||||
"MemoryAttachmentConfig",
|
"MemoryAttachmentConfig",
|
||||||
"MemoryStoreConfig",
|
"MemoryStoreConfig",
|
||||||
"McpLocalConfig",
|
"McpLocalConfig",
|
||||||
|
|||||||
@ -279,6 +279,75 @@ class BlackboardMemoryConfig(BaseConfig):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Mem0MemoryConfig(BaseConfig):
|
||||||
|
"""Configuration for Mem0 managed memory service."""
|
||||||
|
|
||||||
|
api_key: str = ""
|
||||||
|
org_id: str | None = None
|
||||||
|
project_id: str | None = None
|
||||||
|
user_id: str | None = None
|
||||||
|
agent_id: str | None = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Mapping[str, Any], *, path: str) -> "Mem0MemoryConfig":
|
||||||
|
mapping = require_mapping(data, path)
|
||||||
|
api_key = require_str(mapping, "api_key", path)
|
||||||
|
org_id = optional_str(mapping, "org_id", path)
|
||||||
|
project_id = optional_str(mapping, "project_id", path)
|
||||||
|
user_id = optional_str(mapping, "user_id", path)
|
||||||
|
agent_id = optional_str(mapping, "agent_id", path)
|
||||||
|
return cls(
|
||||||
|
api_key=api_key,
|
||||||
|
org_id=org_id,
|
||||||
|
project_id=project_id,
|
||||||
|
user_id=user_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
path=path,
|
||||||
|
)
|
||||||
|
|
||||||
|
FIELD_SPECS = {
|
||||||
|
"api_key": ConfigFieldSpec(
|
||||||
|
name="api_key",
|
||||||
|
display_name="Mem0 API Key",
|
||||||
|
type_hint="str",
|
||||||
|
required=True,
|
||||||
|
description="Mem0 API key (get one from app.mem0.ai)",
|
||||||
|
default="${MEM0_API_KEY}",
|
||||||
|
),
|
||||||
|
"org_id": ConfigFieldSpec(
|
||||||
|
name="org_id",
|
||||||
|
display_name="Organization ID",
|
||||||
|
type_hint="str",
|
||||||
|
required=False,
|
||||||
|
description="Mem0 organization ID for scoping",
|
||||||
|
advance=True,
|
||||||
|
),
|
||||||
|
"project_id": ConfigFieldSpec(
|
||||||
|
name="project_id",
|
||||||
|
display_name="Project ID",
|
||||||
|
type_hint="str",
|
||||||
|
required=False,
|
||||||
|
description="Mem0 project ID for scoping",
|
||||||
|
advance=True,
|
||||||
|
),
|
||||||
|
"user_id": ConfigFieldSpec(
|
||||||
|
name="user_id",
|
||||||
|
display_name="User ID",
|
||||||
|
type_hint="str",
|
||||||
|
required=False,
|
||||||
|
description="User ID for user-scoped memories. Mutually exclusive with agent_id in API calls.",
|
||||||
|
),
|
||||||
|
"agent_id": ConfigFieldSpec(
|
||||||
|
name="agent_id",
|
||||||
|
display_name="Agent ID",
|
||||||
|
type_hint="str",
|
||||||
|
required=False,
|
||||||
|
description="Agent ID for agent-scoped memories. Mutually exclusive with user_id in API calls.",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MemoryStoreConfig(BaseConfig):
|
class MemoryStoreConfig(BaseConfig):
|
||||||
name: str
|
name: str
|
||||||
|
|||||||
@ -39,6 +39,7 @@ dependencies = [
|
|||||||
"filelock>=3.20.1",
|
"filelock>=3.20.1",
|
||||||
"markdown>=3.10",
|
"markdown>=3.10",
|
||||||
"xhtml2pdf>=0.2.17",
|
"xhtml2pdf>=0.2.17",
|
||||||
|
"mem0ai>=1.0.9",
|
||||||
]
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
from entity.configs.node.memory import (
|
from entity.configs.node.memory import (
|
||||||
BlackboardMemoryConfig,
|
BlackboardMemoryConfig,
|
||||||
FileMemoryConfig,
|
FileMemoryConfig,
|
||||||
|
Mem0MemoryConfig,
|
||||||
SimpleMemoryConfig,
|
SimpleMemoryConfig,
|
||||||
MemoryStoreConfig,
|
MemoryStoreConfig,
|
||||||
)
|
)
|
||||||
@ -34,6 +35,19 @@ register_memory_store(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_mem0_memory(store):
|
||||||
|
from runtime.node.agent.memory.mem0_memory import Mem0Memory
|
||||||
|
return Mem0Memory(store)
|
||||||
|
|
||||||
|
|
||||||
|
register_memory_store(
|
||||||
|
"mem0",
|
||||||
|
config_cls=Mem0MemoryConfig,
|
||||||
|
factory=_create_mem0_memory,
|
||||||
|
summary="Mem0 managed memory with semantic search and graph relationships",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MemoryFactory:
|
class MemoryFactory:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_memory(store: MemoryStoreConfig) -> MemoryBase:
|
def create_memory(store: MemoryStoreConfig) -> MemoryBase:
|
||||||
|
|||||||
203
runtime/node/agent/memory/mem0_memory.py
Normal file
203
runtime/node/agent/memory/mem0_memory.py
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
"""Mem0 managed memory store implementation."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from entity.configs import MemoryStoreConfig
|
||||||
|
from entity.configs.node.memory import Mem0MemoryConfig
|
||||||
|
from runtime.node.agent.memory.memory_base import (
|
||||||
|
MemoryBase,
|
||||||
|
MemoryContentSnapshot,
|
||||||
|
MemoryItem,
|
||||||
|
MemoryWritePayload,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_mem0_client(config: Mem0MemoryConfig):
|
||||||
|
"""Lazy-import mem0ai and create a MemoryClient."""
|
||||||
|
try:
|
||||||
|
from mem0 import MemoryClient
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"mem0ai is required for Mem0Memory. Install it with: pip install mem0ai"
|
||||||
|
)
|
||||||
|
|
||||||
|
client_kwargs: Dict[str, Any] = {}
|
||||||
|
if config.api_key:
|
||||||
|
client_kwargs["api_key"] = config.api_key
|
||||||
|
if config.org_id:
|
||||||
|
client_kwargs["org_id"] = config.org_id
|
||||||
|
if config.project_id:
|
||||||
|
client_kwargs["project_id"] = config.project_id
|
||||||
|
|
||||||
|
return MemoryClient(**client_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Mem0Memory(MemoryBase):
|
||||||
|
"""Memory store backed by Mem0's managed cloud service.
|
||||||
|
|
||||||
|
Mem0 handles embeddings, storage, and semantic search server-side.
|
||||||
|
No local persistence or embedding computation is needed.
|
||||||
|
|
||||||
|
Important API constraints:
|
||||||
|
- Agent memories use role="assistant" + agent_id
|
||||||
|
- user_id and agent_id are stored as separate records in Mem0;
|
||||||
|
if both are configured, an OR filter is used to search across both scopes.
|
||||||
|
- search() uses filters dict; add() uses top-level kwargs.
|
||||||
|
- SDK returns {"memories": [...]} from search.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, store: MemoryStoreConfig):
|
||||||
|
config = store.as_config(Mem0MemoryConfig)
|
||||||
|
if not config:
|
||||||
|
raise ValueError("Mem0Memory requires a Mem0 memory store configuration")
|
||||||
|
super().__init__(store)
|
||||||
|
self.config = config
|
||||||
|
self.client = _get_mem0_client(config)
|
||||||
|
self.user_id = config.user_id
|
||||||
|
self.agent_id = config.agent_id
|
||||||
|
|
||||||
|
# -------- Persistence (no-ops for cloud-managed store) --------
|
||||||
|
|
||||||
|
def load(self) -> None:
|
||||||
|
"""No-op: Mem0 manages persistence server-side."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save(self) -> None:
|
||||||
|
"""No-op: Mem0 manages persistence server-side."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# -------- Retrieval --------
|
||||||
|
|
||||||
|
def _build_search_filters(self, agent_role: str) -> Dict[str, Any]:
|
||||||
|
"""Build the filters dict for Mem0 search.
|
||||||
|
|
||||||
|
Mem0 search requires a filters dict for entity scoping.
|
||||||
|
user_id and agent_id are stored as separate records, so
|
||||||
|
when both are configured we use an OR filter to match either.
|
||||||
|
"""
|
||||||
|
if self.user_id and self.agent_id:
|
||||||
|
return {
|
||||||
|
"OR": [
|
||||||
|
{"user_id": self.user_id},
|
||||||
|
{"agent_id": self.agent_id},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
elif self.user_id:
|
||||||
|
return {"user_id": self.user_id}
|
||||||
|
elif self.agent_id:
|
||||||
|
return {"agent_id": self.agent_id}
|
||||||
|
else:
|
||||||
|
# Fallback: use agent_role as agent_id
|
||||||
|
return {"agent_id": agent_role}
|
||||||
|
|
||||||
|
def retrieve(
|
||||||
|
self,
|
||||||
|
agent_role: str,
|
||||||
|
query: MemoryContentSnapshot,
|
||||||
|
top_k: int,
|
||||||
|
similarity_threshold: float,
|
||||||
|
) -> List[MemoryItem]:
|
||||||
|
"""Search Mem0 for relevant memories.
|
||||||
|
|
||||||
|
Uses the filters dict to scope by user_id, agent_id, or both
|
||||||
|
(via OR filter). The SDK returns {"memories": [...]}.
|
||||||
|
"""
|
||||||
|
if not query.text.strip():
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
filters = self._build_search_filters(agent_role)
|
||||||
|
search_kwargs: Dict[str, Any] = {
|
||||||
|
"query": query.text,
|
||||||
|
"top_k": top_k,
|
||||||
|
"filters": filters,
|
||||||
|
}
|
||||||
|
if similarity_threshold >= 0:
|
||||||
|
search_kwargs["threshold"] = similarity_threshold
|
||||||
|
|
||||||
|
response = self.client.search(**search_kwargs)
|
||||||
|
|
||||||
|
# SDK returns {"memories": [...]} — extract the list
|
||||||
|
if isinstance(response, dict):
|
||||||
|
raw_results = response.get("memories", response.get("results", []))
|
||||||
|
else:
|
||||||
|
raw_results = response
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Mem0 search failed: %s", e)
|
||||||
|
return []
|
||||||
|
|
||||||
|
items: List[MemoryItem] = []
|
||||||
|
for entry in raw_results:
|
||||||
|
item = MemoryItem(
|
||||||
|
id=entry.get("id", f"mem0_{uuid.uuid4().hex}"),
|
||||||
|
content_summary=entry.get("memory", ""),
|
||||||
|
metadata={
|
||||||
|
"agent_role": agent_role,
|
||||||
|
"score": entry.get("score"),
|
||||||
|
"categories": entry.get("categories", []),
|
||||||
|
"source": "mem0",
|
||||||
|
},
|
||||||
|
timestamp=time.time(),
|
||||||
|
)
|
||||||
|
items.append(item)
|
||||||
|
|
||||||
|
return items
|
||||||
|
|
||||||
|
# -------- Update --------
|
||||||
|
|
||||||
|
def update(self, payload: MemoryWritePayload) -> None:
|
||||||
|
"""Store a memory in Mem0.
|
||||||
|
|
||||||
|
Uses role="assistant" + agent_id for agent-generated memories,
|
||||||
|
and role="user" + user_id for user-scoped memories.
|
||||||
|
"""
|
||||||
|
snapshot = payload.output_snapshot or payload.input_snapshot
|
||||||
|
if not snapshot or not snapshot.text.strip():
|
||||||
|
return
|
||||||
|
|
||||||
|
messages = self._build_messages(payload)
|
||||||
|
if not messages:
|
||||||
|
return
|
||||||
|
|
||||||
|
add_kwargs: Dict[str, Any] = {"messages": messages}
|
||||||
|
|
||||||
|
# Determine scoping: agent_id takes precedence for agent-generated content
|
||||||
|
if self.agent_id:
|
||||||
|
add_kwargs["agent_id"] = self.agent_id
|
||||||
|
elif self.user_id:
|
||||||
|
add_kwargs["user_id"] = self.user_id
|
||||||
|
else:
|
||||||
|
# Default: use agent_role as agent_id
|
||||||
|
add_kwargs["agent_id"] = payload.agent_role
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.client.add(**add_kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Mem0 add failed: %s", e)
|
||||||
|
|
||||||
|
def _build_messages(self, payload: MemoryWritePayload) -> List[Dict[str, str]]:
|
||||||
|
"""Build Mem0-compatible message list from write payload.
|
||||||
|
|
||||||
|
Agent-generated content uses role="assistant".
|
||||||
|
User input uses role="user".
|
||||||
|
"""
|
||||||
|
messages: List[Dict[str, str]] = []
|
||||||
|
|
||||||
|
if payload.inputs_text and payload.inputs_text.strip():
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": payload.inputs_text.strip(),
|
||||||
|
})
|
||||||
|
|
||||||
|
if payload.output_snapshot and payload.output_snapshot.text.strip():
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": payload.output_snapshot.text.strip(),
|
||||||
|
})
|
||||||
|
|
||||||
|
return messages
|
||||||
384
tests/test_mem0_memory.py
Normal file
384
tests/test_mem0_memory.py
Normal file
@ -0,0 +1,384 @@
|
|||||||
|
"""Tests for Mem0 memory store implementation."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from entity.configs.node.memory import Mem0MemoryConfig
|
||||||
|
from runtime.node.agent.memory.memory_base import (
|
||||||
|
MemoryContentSnapshot,
|
||||||
|
MemoryItem,
|
||||||
|
MemoryWritePayload,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_store(user_id=None, agent_id=None, api_key="test-key"):
|
||||||
|
"""Build a minimal MemoryStoreConfig mock for Mem0Memory."""
|
||||||
|
mem0_cfg = MagicMock(spec=Mem0MemoryConfig)
|
||||||
|
mem0_cfg.api_key = api_key
|
||||||
|
mem0_cfg.org_id = None
|
||||||
|
mem0_cfg.project_id = None
|
||||||
|
mem0_cfg.user_id = user_id
|
||||||
|
mem0_cfg.agent_id = agent_id
|
||||||
|
|
||||||
|
store = MagicMock()
|
||||||
|
store.name = "test_mem0"
|
||||||
|
|
||||||
|
# Return correct config type based on the requested class
|
||||||
|
def _as_config_side_effect(expected_type, **kwargs):
|
||||||
|
if expected_type is Mem0MemoryConfig:
|
||||||
|
return mem0_cfg
|
||||||
|
return None
|
||||||
|
|
||||||
|
store.as_config.side_effect = _as_config_side_effect
|
||||||
|
return store
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mem0_memory(user_id=None, agent_id=None):
|
||||||
|
"""Create a Mem0Memory with a mocked client."""
|
||||||
|
with patch("runtime.node.agent.memory.mem0_memory._get_mem0_client") as mock_get:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_get.return_value = mock_client
|
||||||
|
from runtime.node.agent.memory.mem0_memory import Mem0Memory
|
||||||
|
store = _make_store(user_id=user_id, agent_id=agent_id)
|
||||||
|
memory = Mem0Memory(store)
|
||||||
|
return memory, mock_client
|
||||||
|
|
||||||
|
|
||||||
|
class TestMem0MemoryRetrieve:
|
||||||
|
|
||||||
|
def test_retrieve_with_agent_id(self):
|
||||||
|
"""Retrieve passes agent_id in filters dict to SDK search."""
|
||||||
|
memory, client = _make_mem0_memory(agent_id="agent-1")
|
||||||
|
client.search.return_value = {
|
||||||
|
"memories": [
|
||||||
|
{"id": "m1", "memory": "test fact", "score": 0.95},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
query = MemoryContentSnapshot(text="what do you know?")
|
||||||
|
results = memory.retrieve("writer", query, top_k=5, similarity_threshold=-1.0)
|
||||||
|
|
||||||
|
client.search.assert_called_once()
|
||||||
|
call_kwargs = client.search.call_args[1]
|
||||||
|
assert call_kwargs["filters"] == {"agent_id": "agent-1"}
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].content_summary == "test fact"
|
||||||
|
assert results[0].metadata["source"] == "mem0"
|
||||||
|
|
||||||
|
def test_retrieve_with_user_id(self):
|
||||||
|
"""Retrieve passes user_id in filters dict to SDK search."""
|
||||||
|
memory, client = _make_mem0_memory(user_id="user-1")
|
||||||
|
client.search.return_value = {
|
||||||
|
"memories": [
|
||||||
|
{"id": "m1", "memory": "user pref", "score": 0.9},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
query = MemoryContentSnapshot(text="preferences")
|
||||||
|
results = memory.retrieve("assistant", query, top_k=3, similarity_threshold=-1.0)
|
||||||
|
|
||||||
|
call_kwargs = client.search.call_args[1]
|
||||||
|
assert call_kwargs["filters"] == {"user_id": "user-1"}
|
||||||
|
assert len(results) == 1
|
||||||
|
|
||||||
|
def test_retrieve_with_both_ids_uses_or_filter(self):
|
||||||
|
"""When both user_id and agent_id are set, an OR filter is used."""
|
||||||
|
memory, client = _make_mem0_memory(user_id="user-1", agent_id="agent-1")
|
||||||
|
client.search.return_value = {
|
||||||
|
"memories": [
|
||||||
|
{"id": "u1", "memory": "user fact", "score": 0.8},
|
||||||
|
{"id": "a1", "memory": "agent fact", "score": 0.9},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
query = MemoryContentSnapshot(text="test")
|
||||||
|
results = memory.retrieve("writer", query, top_k=5, similarity_threshold=-1.0)
|
||||||
|
|
||||||
|
client.search.assert_called_once()
|
||||||
|
call_kwargs = client.search.call_args[1]
|
||||||
|
assert call_kwargs["filters"] == {
|
||||||
|
"OR": [
|
||||||
|
{"user_id": "user-1"},
|
||||||
|
{"agent_id": "agent-1"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
assert len(results) == 2
|
||||||
|
|
||||||
|
def test_retrieve_fallback_uses_agent_role(self):
|
||||||
|
"""When no IDs configured, fall back to agent_role as agent_id in filters."""
|
||||||
|
memory, client = _make_mem0_memory()
|
||||||
|
client.search.return_value = {"memories": []}
|
||||||
|
|
||||||
|
query = MemoryContentSnapshot(text="test")
|
||||||
|
memory.retrieve("coder", query, top_k=3, similarity_threshold=-1.0)
|
||||||
|
|
||||||
|
call_kwargs = client.search.call_args[1]
|
||||||
|
assert call_kwargs["filters"] == {"agent_id": "coder"}
|
||||||
|
|
||||||
|
def test_retrieve_empty_query_returns_empty(self):
|
||||||
|
"""Empty query text returns empty without calling API."""
|
||||||
|
memory, client = _make_mem0_memory(agent_id="a1")
|
||||||
|
|
||||||
|
query = MemoryContentSnapshot(text=" ")
|
||||||
|
results = memory.retrieve("writer", query, top_k=3, similarity_threshold=-1.0)
|
||||||
|
|
||||||
|
assert results == []
|
||||||
|
client.search.assert_not_called()
|
||||||
|
|
||||||
|
def test_retrieve_api_error_returns_empty(self):
|
||||||
|
"""API errors are caught and return empty list."""
|
||||||
|
memory, client = _make_mem0_memory(agent_id="a1")
|
||||||
|
client.search.side_effect = Exception("API down")
|
||||||
|
|
||||||
|
query = MemoryContentSnapshot(text="test")
|
||||||
|
results = memory.retrieve("writer", query, top_k=3, similarity_threshold=-1.0)
|
||||||
|
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
def test_retrieve_respects_top_k(self):
|
||||||
|
"""top_k is passed to Mem0 search."""
|
||||||
|
memory, client = _make_mem0_memory(agent_id="a1")
|
||||||
|
client.search.return_value = {"memories": []}
|
||||||
|
|
||||||
|
query = MemoryContentSnapshot(text="test")
|
||||||
|
memory.retrieve("writer", query, top_k=7, similarity_threshold=-1.0)
|
||||||
|
|
||||||
|
call_kwargs = client.search.call_args[1]
|
||||||
|
assert call_kwargs["top_k"] == 7
|
||||||
|
|
||||||
|
def test_retrieve_passes_threshold_when_non_negative(self):
|
||||||
|
"""Non-negative similarity_threshold is forwarded to Mem0."""
|
||||||
|
memory, client = _make_mem0_memory(agent_id="a1")
|
||||||
|
client.search.return_value = {"memories": []}
|
||||||
|
|
||||||
|
query = MemoryContentSnapshot(text="test")
|
||||||
|
memory.retrieve("writer", query, top_k=3, similarity_threshold=0.5)
|
||||||
|
|
||||||
|
call_kwargs = client.search.call_args[1]
|
||||||
|
assert call_kwargs["threshold"] == 0.5
|
||||||
|
|
||||||
|
def test_retrieve_passes_zero_threshold(self):
|
||||||
|
"""A threshold of 0.0 is a valid value and should be sent."""
|
||||||
|
memory, client = _make_mem0_memory(agent_id="a1")
|
||||||
|
client.search.return_value = {"memories": []}
|
||||||
|
|
||||||
|
query = MemoryContentSnapshot(text="test")
|
||||||
|
memory.retrieve("writer", query, top_k=3, similarity_threshold=0.0)
|
||||||
|
|
||||||
|
call_kwargs = client.search.call_args[1]
|
||||||
|
assert call_kwargs["threshold"] == 0.0
|
||||||
|
|
||||||
|
def test_retrieve_skips_threshold_when_negative(self):
|
||||||
|
"""Negative similarity_threshold is not sent to Mem0."""
|
||||||
|
memory, client = _make_mem0_memory(agent_id="a1")
|
||||||
|
client.search.return_value = {"memories": []}
|
||||||
|
|
||||||
|
query = MemoryContentSnapshot(text="test")
|
||||||
|
memory.retrieve("writer", query, top_k=3, similarity_threshold=-1.0)
|
||||||
|
|
||||||
|
call_kwargs = client.search.call_args[1]
|
||||||
|
assert "threshold" not in call_kwargs
|
||||||
|
|
||||||
|
def test_retrieve_handles_legacy_results_key(self):
|
||||||
|
"""Handles SDK response with 'results' key (older SDK versions)."""
|
||||||
|
memory, client = _make_mem0_memory(agent_id="a1")
|
||||||
|
client.search.return_value = {
|
||||||
|
"results": [
|
||||||
|
{"id": "m1", "memory": "legacy format", "score": 0.8},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
query = MemoryContentSnapshot(text="test")
|
||||||
|
results = memory.retrieve("writer", query, top_k=3, similarity_threshold=-1.0)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].content_summary == "legacy format"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMem0MemoryUpdate:
|
||||||
|
|
||||||
|
def test_update_with_agent_id_uses_assistant_role(self):
|
||||||
|
"""Agent-scoped update sends role=assistant messages with agent_id."""
|
||||||
|
memory, client = _make_mem0_memory(agent_id="agent-1")
|
||||||
|
client.add.return_value = [{"id": "new", "event": "ADD"}]
|
||||||
|
|
||||||
|
payload = MemoryWritePayload(
|
||||||
|
agent_role="writer",
|
||||||
|
inputs_text="Write about AI",
|
||||||
|
input_snapshot=MemoryContentSnapshot(text="Write about AI"),
|
||||||
|
output_snapshot=MemoryContentSnapshot(text="AI is transformative..."),
|
||||||
|
)
|
||||||
|
memory.update(payload)
|
||||||
|
|
||||||
|
client.add.assert_called_once()
|
||||||
|
call_kwargs = client.add.call_args[1]
|
||||||
|
assert call_kwargs["agent_id"] == "agent-1"
|
||||||
|
assert "user_id" not in call_kwargs
|
||||||
|
messages = call_kwargs["messages"]
|
||||||
|
assert messages[0]["role"] == "user"
|
||||||
|
assert messages[1]["role"] == "assistant"
|
||||||
|
|
||||||
|
def test_update_with_user_id(self):
|
||||||
|
"""User-scoped update uses user_id, not agent_id."""
|
||||||
|
memory, client = _make_mem0_memory(user_id="user-1")
|
||||||
|
client.add.return_value = []
|
||||||
|
|
||||||
|
payload = MemoryWritePayload(
|
||||||
|
agent_role="writer",
|
||||||
|
inputs_text="I prefer Python",
|
||||||
|
input_snapshot=None,
|
||||||
|
output_snapshot=MemoryContentSnapshot(text="Noted your preference"),
|
||||||
|
)
|
||||||
|
memory.update(payload)
|
||||||
|
|
||||||
|
call_kwargs = client.add.call_args[1]
|
||||||
|
assert call_kwargs["user_id"] == "user-1"
|
||||||
|
assert "agent_id" not in call_kwargs
|
||||||
|
|
||||||
|
def test_update_fallback_uses_agent_role(self):
|
||||||
|
"""When no IDs configured, uses agent_role as agent_id."""
|
||||||
|
memory, client = _make_mem0_memory()
|
||||||
|
client.add.return_value = []
|
||||||
|
|
||||||
|
payload = MemoryWritePayload(
|
||||||
|
agent_role="coder",
|
||||||
|
inputs_text="test input",
|
||||||
|
input_snapshot=None,
|
||||||
|
output_snapshot=MemoryContentSnapshot(text="test output"),
|
||||||
|
)
|
||||||
|
memory.update(payload)
|
||||||
|
|
||||||
|
call_kwargs = client.add.call_args[1]
|
||||||
|
assert call_kwargs["agent_id"] == "coder"
|
||||||
|
|
||||||
|
def test_update_with_both_ids_prefers_agent_id(self):
|
||||||
|
"""When both user_id and agent_id configured, agent_id takes precedence for writes."""
|
||||||
|
memory, client = _make_mem0_memory(user_id="user-1", agent_id="agent-1")
|
||||||
|
client.add.return_value = []
|
||||||
|
|
||||||
|
payload = MemoryWritePayload(
|
||||||
|
agent_role="writer",
|
||||||
|
inputs_text="input",
|
||||||
|
input_snapshot=None,
|
||||||
|
output_snapshot=MemoryContentSnapshot(text="output"),
|
||||||
|
)
|
||||||
|
memory.update(payload)
|
||||||
|
|
||||||
|
call_kwargs = client.add.call_args[1]
|
||||||
|
assert call_kwargs["agent_id"] == "agent-1"
|
||||||
|
assert "user_id" not in call_kwargs
|
||||||
|
|
||||||
|
def test_update_empty_output_is_noop(self):
|
||||||
|
"""Empty output snapshot skips API call."""
|
||||||
|
memory, client = _make_mem0_memory(agent_id="a1")
|
||||||
|
|
||||||
|
payload = MemoryWritePayload(
|
||||||
|
agent_role="writer",
|
||||||
|
inputs_text="",
|
||||||
|
input_snapshot=None,
|
||||||
|
output_snapshot=MemoryContentSnapshot(text=" "),
|
||||||
|
)
|
||||||
|
memory.update(payload)
|
||||||
|
|
||||||
|
client.add.assert_not_called()
|
||||||
|
|
||||||
|
def test_update_no_snapshot_is_noop(self):
|
||||||
|
"""No snapshot at all skips API call."""
|
||||||
|
memory, client = _make_mem0_memory(agent_id="a1")
|
||||||
|
|
||||||
|
payload = MemoryWritePayload(
|
||||||
|
agent_role="writer",
|
||||||
|
inputs_text="test",
|
||||||
|
input_snapshot=None,
|
||||||
|
output_snapshot=None,
|
||||||
|
)
|
||||||
|
memory.update(payload)
|
||||||
|
|
||||||
|
client.add.assert_not_called()
|
||||||
|
|
||||||
|
def test_update_api_error_does_not_raise(self):
|
||||||
|
"""API errors are logged but do not propagate."""
|
||||||
|
memory, client = _make_mem0_memory(agent_id="a1")
|
||||||
|
client.add.side_effect = Exception("API error")
|
||||||
|
|
||||||
|
payload = MemoryWritePayload(
|
||||||
|
agent_role="writer",
|
||||||
|
inputs_text="test",
|
||||||
|
input_snapshot=None,
|
||||||
|
output_snapshot=MemoryContentSnapshot(text="output"),
|
||||||
|
)
|
||||||
|
# Should not raise
|
||||||
|
memory.update(payload)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMem0MemoryLoadSave:
|
||||||
|
|
||||||
|
def test_load_is_noop(self):
|
||||||
|
"""load() does nothing for cloud-managed store."""
|
||||||
|
memory, _ = _make_mem0_memory(agent_id="a1")
|
||||||
|
memory.load() # Should not raise
|
||||||
|
|
||||||
|
def test_save_is_noop(self):
|
||||||
|
"""save() does nothing for cloud-managed store."""
|
||||||
|
memory, _ = _make_mem0_memory(agent_id="a1")
|
||||||
|
memory.save() # Should not raise
|
||||||
|
|
||||||
|
|
||||||
|
class TestMem0MemoryConfig:
|
||||||
|
|
||||||
|
def test_config_from_dict(self):
|
||||||
|
"""Config parses from dict correctly."""
|
||||||
|
data = {
|
||||||
|
"api_key": "test-key",
|
||||||
|
"user_id": "u1",
|
||||||
|
"org_id": "org-1",
|
||||||
|
}
|
||||||
|
config = Mem0MemoryConfig.from_dict(data, path="test")
|
||||||
|
assert config.api_key == "test-key"
|
||||||
|
assert config.user_id == "u1"
|
||||||
|
assert config.org_id == "org-1"
|
||||||
|
assert config.agent_id is None
|
||||||
|
assert config.project_id is None
|
||||||
|
|
||||||
|
def test_config_field_specs_exist(self):
|
||||||
|
"""FIELD_SPECS are defined for UI generation."""
|
||||||
|
specs = Mem0MemoryConfig.field_specs()
|
||||||
|
assert "api_key" in specs
|
||||||
|
assert "user_id" in specs
|
||||||
|
assert "agent_id" in specs
|
||||||
|
assert specs["api_key"].required is True
|
||||||
|
|
||||||
|
def test_config_requires_api_key(self):
|
||||||
|
"""Config raises ConfigError when api_key is missing."""
|
||||||
|
from entity.configs.base import ConfigError
|
||||||
|
|
||||||
|
data = {"agent_id": "a1"}
|
||||||
|
with pytest.raises(ConfigError):
|
||||||
|
Mem0MemoryConfig.from_dict(data, path="test")
|
||||||
|
|
||||||
|
|
||||||
|
class TestMem0MemoryConstructor:
|
||||||
|
|
||||||
|
def test_raises_on_wrong_config_type(self):
|
||||||
|
"""Mem0Memory raises ValueError when store has wrong config type."""
|
||||||
|
from runtime.node.agent.memory.mem0_memory import Mem0Memory
|
||||||
|
|
||||||
|
store = MagicMock()
|
||||||
|
store.name = "bad_store"
|
||||||
|
store.as_config.return_value = None # Wrong config type
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Mem0 memory store configuration"):
|
||||||
|
Mem0Memory(store)
|
||||||
|
|
||||||
|
def test_import_error_when_mem0ai_missing(self):
|
||||||
|
"""Helpful ImportError when mem0ai is not installed."""
|
||||||
|
from runtime.node.agent.memory.mem0_memory import _get_mem0_client
|
||||||
|
|
||||||
|
mem0_cfg = MagicMock(spec=Mem0MemoryConfig)
|
||||||
|
mem0_cfg.api_key = "test"
|
||||||
|
mem0_cfg.org_id = None
|
||||||
|
mem0_cfg.project_id = None
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", {"mem0": None}):
|
||||||
|
with pytest.raises(ImportError, match="pip install mem0ai"):
|
||||||
|
_get_mem0_client(mem0_cfg)
|
||||||
47
yaml_instance/demo_mem0_memory.yaml
Normal file
47
yaml_instance/demo_mem0_memory.yaml
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
version: 0.4.0
|
||||||
|
vars: {}
|
||||||
|
graph:
|
||||||
|
id: ''
|
||||||
|
description: Memory-backed conversation using Mem0 managed memory service.
|
||||||
|
is_majority_voting: false
|
||||||
|
nodes:
|
||||||
|
- id: writer
|
||||||
|
type: agent
|
||||||
|
config:
|
||||||
|
base_url: ${BASE_URL}
|
||||||
|
api_key: ${API_KEY}
|
||||||
|
provider: openai
|
||||||
|
name: gpt-4o
|
||||||
|
role: |
|
||||||
|
You are a knowledgeable writer. Use your memories to build on past interactions.
|
||||||
|
If memory sections are provided (wrapped by ===== Related Memories =====),
|
||||||
|
incorporate relevant context from those memories into your response.
|
||||||
|
params:
|
||||||
|
temperature: 0.7
|
||||||
|
max_tokens: 2000
|
||||||
|
memories:
|
||||||
|
- name: mem0_store
|
||||||
|
top_k: 5
|
||||||
|
retrieve_stage:
|
||||||
|
- gen
|
||||||
|
read: true
|
||||||
|
write: true
|
||||||
|
edges: []
|
||||||
|
memory:
|
||||||
|
# Agent-scoped memory: uses agent_id for storing and retrieving
|
||||||
|
- name: mem0_store
|
||||||
|
type: mem0
|
||||||
|
config:
|
||||||
|
api_key: ${MEM0_API_KEY}
|
||||||
|
agent_id: writer-agent
|
||||||
|
|
||||||
|
# Alternative: User-scoped memory (uncomment to use instead)
|
||||||
|
# - name: mem0_store
|
||||||
|
# type: mem0
|
||||||
|
# config:
|
||||||
|
# api_key: ${MEM0_API_KEY}
|
||||||
|
# user_id: project-user-123
|
||||||
|
start:
|
||||||
|
- writer
|
||||||
|
end: []
|
||||||
|
initial_instruction: ''
|
||||||
Loading…
x
Reference in New Issue
Block a user