mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
fix(memory): inject stored facts into system prompt memory context (#1083)
* fix(memory): inject stored facts into system prompt memory context - add Facts section rendering in format_memory_for_injection - rank facts by confidence and coerce confidence values safely - enforce max token budget while appending fact lines - add regression tests for fact inclusion, ordering, and budget behavior Fixes #1059 * Update the document with the latest status * fix(memory): harden fact injection — NaN/inf confidence, None content, incremental token budget (#1090) * Initial plan * fix(memory): address review feedback on confidence coercion, None content, and token budget Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com> --------- Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
This commit is contained in:
parent
3521cc2668
commit
b5fcb1334a
@ -1,281 +1,65 @@
|
|||||||
# Memory System Improvements
|
# Memory System Improvements
|
||||||
|
|
||||||
This document describes recent improvements to the memory system's fact injection mechanism.
|
This document tracks memory injection behavior and roadmap status.
|
||||||
|
|
||||||
## Overview
|
## Status (As Of 2026-03-10)
|
||||||
|
|
||||||
Two major improvements have been made to the `format_memory_for_injection` function:
|
Implemented in `main`:
|
||||||
|
- Accurate token counting via `tiktoken` in `format_memory_for_injection`.
|
||||||
|
- Facts are injected into prompt memory context.
|
||||||
|
- Facts are ranked by confidence (descending).
|
||||||
|
- Injection respects `max_injection_tokens` budget.
|
||||||
|
|
||||||
1. **Similarity-Based Fact Retrieval**: Uses TF-IDF to select facts most relevant to current conversation context
|
Planned / not yet merged:
|
||||||
2. **Accurate Token Counting**: Uses tiktoken for precise token estimation instead of rough character-based approximation
|
- TF-IDF similarity-based fact retrieval.
|
||||||
|
- `current_context` input for context-aware scoring.
|
||||||
|
- Configurable similarity/confidence weights (`similarity_weight`, `confidence_weight`).
|
||||||
|
- Middleware/runtime wiring for context-aware retrieval before each model call.
|
||||||
|
|
||||||
## 1. Similarity-Based Fact Retrieval
|
## Current Behavior
|
||||||
|
|
||||||
### Problem
|
Function today:
|
||||||
The original implementation selected facts based solely on confidence scores, taking the top 15 highest-confidence facts regardless of their relevance to the current conversation. This could result in injecting irrelevant facts while omitting contextually important ones.
|
|
||||||
|
|
||||||
### Solution
|
|
||||||
The new implementation uses **TF-IDF (Term Frequency-Inverse Document Frequency)** vectorization with cosine similarity to measure how relevant each fact is to the current conversation context.
|
|
||||||
|
|
||||||
**Scoring Formula**:
|
|
||||||
```
|
|
||||||
final_score = (similarity × 0.6) + (confidence × 0.4)
|
|
||||||
```
|
|
||||||
|
|
||||||
- **Similarity (60% weight)**: Cosine similarity between fact content and current context
|
|
||||||
- **Confidence (40% weight)**: LLM-assigned confidence score (0-1)
|
|
||||||
|
|
||||||
### Benefits
|
|
||||||
- **Context-Aware**: Prioritizes facts relevant to what the user is currently discussing
|
|
||||||
- **Dynamic**: Different facts surface based on conversation topic
|
|
||||||
- **Balanced**: Considers both relevance and reliability
|
|
||||||
- **Fallback**: Gracefully degrades to confidence-only ranking if context is unavailable
|
|
||||||
|
|
||||||
### Example
|
|
||||||
Given facts about Python, React, and Docker:
|
|
||||||
- User asks: *"How should I write Python tests?"*
|
|
||||||
- Prioritizes: Python testing, type hints, pytest
|
|
||||||
- User asks: *"How to optimize my Next.js app?"*
|
|
||||||
- Prioritizes: React/Next.js experience, performance optimization
|
|
||||||
|
|
||||||
### Configuration
|
|
||||||
Customize weights in `config.yaml` (optional):
|
|
||||||
```yaml
|
|
||||||
memory:
|
|
||||||
similarity_weight: 0.6 # Weight for TF-IDF similarity (0-1)
|
|
||||||
confidence_weight: 0.4 # Weight for confidence score (0-1)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Note**: Weights should sum to 1.0 for best results.
|
|
||||||
|
|
||||||
## 2. Accurate Token Counting
|
|
||||||
|
|
||||||
### Problem
|
|
||||||
The original implementation estimated tokens using a simple formula:
|
|
||||||
```python
|
|
||||||
max_chars = max_tokens * 4
|
|
||||||
```
|
|
||||||
|
|
||||||
This assumes ~4 characters per token, which is:
|
|
||||||
- Inaccurate for many languages and content types
|
|
||||||
- Can lead to over-injection (exceeding token limits)
|
|
||||||
- Can lead to under-injection (wasting available budget)
|
|
||||||
|
|
||||||
### Solution
|
|
||||||
The new implementation uses **tiktoken**, OpenAI's official tokenizer library, to count tokens accurately:
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import tiktoken
|
def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000) -> str:
|
||||||
|
|
||||||
def _count_tokens(text: str, encoding_name: str = "cl100k_base") -> int:
|
|
||||||
encoding = tiktoken.get_encoding(encoding_name)
|
|
||||||
return len(encoding.encode(text))
|
|
||||||
```
|
```
|
||||||
|
|
||||||
- Uses `cl100k_base` encoding (GPT-4, GPT-3.5, text-embedding-ada-002)
|
Current injection format:
|
||||||
- Provides exact token counts for budget management
|
- `User Context` section from `user.*.summary`
|
||||||
- Falls back to character-based estimation if tiktoken fails
|
- `History` section from `history.*.summary`
|
||||||
|
- `Facts` section from `facts[]`, sorted by confidence, appended until token budget is reached
|
||||||
|
|
||||||
### Benefits
|
Token counting:
|
||||||
- **Precision**: Exact token counts match what the model sees
|
- Uses `tiktoken` (`cl100k_base`) when available
|
||||||
- **Budget Optimization**: Maximizes use of available token budget
|
- Falls back to `len(text) // 4` if tokenizer import fails
|
||||||
- **No Overflows**: Prevents exceeding `max_injection_tokens` limit
|
|
||||||
- **Better Planning**: Each section's token cost is known precisely
|
|
||||||
|
|
||||||
### Example
|
## Known Gap
|
||||||
```python
|
|
||||||
text = "This is a test string to count tokens accurately using tiktoken."
|
|
||||||
|
|
||||||
# Old method
|
Previous versions of this document described TF-IDF/context-aware retrieval as if it were already shipped.
|
||||||
char_count = len(text) # 64 characters
|
That was not accurate for `main` and caused confusion.
|
||||||
old_estimate = char_count // 4 # 16 tokens (overestimate)
|
|
||||||
|
|
||||||
# New method
|
Issue reference: `#1059`
|
||||||
accurate_count = _count_tokens(text) # 13 tokens (exact)
|
|
||||||
|
## Roadmap (Planned)
|
||||||
|
|
||||||
|
Planned scoring strategy:
|
||||||
|
|
||||||
|
```text
|
||||||
|
final_score = (similarity * 0.6) + (confidence * 0.4)
|
||||||
```
|
```
|
||||||
|
|
||||||
**Result**: 3-token difference (18.75% error rate)
|
Planned integration shape:
|
||||||
|
1. Extract recent conversational context from filtered user/final-assistant turns.
|
||||||
|
2. Compute TF-IDF cosine similarity between each fact and current context.
|
||||||
|
3. Rank by weighted score and inject under token budget.
|
||||||
|
4. Fall back to confidence-only ranking if context is unavailable.
|
||||||
|
|
||||||
In production, errors can be much larger for:
|
## Validation
|
||||||
- Code snippets (more tokens per character)
|
|
||||||
- Non-English text (variable token ratios)
|
|
||||||
- Technical jargon (often multi-token words)
|
|
||||||
|
|
||||||
## Implementation Details
|
Current regression coverage includes:
|
||||||
|
- facts inclusion in memory injection output
|
||||||
|
- confidence ordering
|
||||||
|
- token-budget-limited fact inclusion
|
||||||
|
|
||||||
### Function Signature
|
Tests:
|
||||||
```python
|
- `backend/tests/test_memory_prompt_injection.py`
|
||||||
def format_memory_for_injection(
|
|
||||||
memory_data: dict[str, Any],
|
|
||||||
max_tokens: int = 2000,
|
|
||||||
current_context: str | None = None,
|
|
||||||
) -> str:
|
|
||||||
```
|
|
||||||
|
|
||||||
**New Parameter**:
|
|
||||||
- `current_context`: Optional string containing recent conversation messages for similarity calculation
|
|
||||||
|
|
||||||
### Backward Compatibility
|
|
||||||
The function remains **100% backward compatible**:
|
|
||||||
- If `current_context` is `None` or empty, falls back to confidence-only ranking
|
|
||||||
- Existing callers without the parameter work exactly as before
|
|
||||||
- Token counting is always accurate (transparent improvement)
|
|
||||||
|
|
||||||
### Integration Point
|
|
||||||
Memory is **dynamically injected** via `MemoryMiddleware.before_model()`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# src/agents/middlewares/memory_middleware.py
|
|
||||||
|
|
||||||
def _extract_conversation_context(messages: list, max_turns: int = 3) -> str:
|
|
||||||
"""Extract recent conversation (user input + final responses only)."""
|
|
||||||
context_parts = []
|
|
||||||
turn_count = 0
|
|
||||||
|
|
||||||
for msg in reversed(messages):
|
|
||||||
if msg.type == "human":
|
|
||||||
# Always include user messages
|
|
||||||
context_parts.append(extract_text(msg))
|
|
||||||
turn_count += 1
|
|
||||||
if turn_count >= max_turns:
|
|
||||||
break
|
|
||||||
|
|
||||||
elif msg.type == "ai" and not msg.tool_calls:
|
|
||||||
# Only include final AI responses (no tool_calls)
|
|
||||||
context_parts.append(extract_text(msg))
|
|
||||||
|
|
||||||
# Skip tool messages and AI messages with tool_calls
|
|
||||||
|
|
||||||
return " ".join(reversed(context_parts))
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryMiddleware:
|
|
||||||
def before_model(self, state, runtime):
|
|
||||||
"""Inject memory before EACH LLM call (not just before_agent)."""
|
|
||||||
|
|
||||||
# Get recent conversation context (filtered)
|
|
||||||
conversation_context = _extract_conversation_context(
|
|
||||||
state["messages"],
|
|
||||||
max_turns=3
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load memory with context-aware fact selection
|
|
||||||
memory_data = get_memory_data()
|
|
||||||
memory_content = format_memory_for_injection(
|
|
||||||
memory_data,
|
|
||||||
max_tokens=config.max_injection_tokens,
|
|
||||||
current_context=conversation_context, # ✅ Clean conversation only
|
|
||||||
)
|
|
||||||
|
|
||||||
# Inject as system message
|
|
||||||
memory_message = SystemMessage(
|
|
||||||
content=f"<memory>\n{memory_content}\n</memory>",
|
|
||||||
name="memory_context",
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"messages": [memory_message] + state["messages"]}
|
|
||||||
```
|
|
||||||
|
|
||||||
### How It Works
|
|
||||||
|
|
||||||
1. **User continues conversation**:
|
|
||||||
```
|
|
||||||
Turn 1: "I'm working on a Python project"
|
|
||||||
Turn 2: "It uses FastAPI and SQLAlchemy"
|
|
||||||
Turn 3: "How do I write tests?" ← Current query
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **Extract recent context**: Last 3 turns combined:
|
|
||||||
```
|
|
||||||
"I'm working on a Python project. It uses FastAPI and SQLAlchemy. How do I write tests?"
|
|
||||||
```
|
|
||||||
|
|
||||||
3. **TF-IDF scoring**: Ranks facts by relevance to this context
|
|
||||||
- High score: "Prefers pytest for testing" (testing + Python)
|
|
||||||
- High score: "Likes type hints in Python" (Python related)
|
|
||||||
- High score: "Expert in Python and FastAPI" (Python + FastAPI)
|
|
||||||
- Low score: "Uses Docker for containerization" (less relevant)
|
|
||||||
|
|
||||||
4. **Injection**: Top-ranked facts injected into system prompt's `<memory>` section
|
|
||||||
|
|
||||||
5. **Agent sees**: Full system prompt with relevant memory context
|
|
||||||
|
|
||||||
### Benefits of Dynamic System Prompt
|
|
||||||
|
|
||||||
- **Multi-Turn Context**: Uses last 3 turns, not just current question
|
|
||||||
- Captures ongoing conversation flow
|
|
||||||
- Better understanding of user's current focus
|
|
||||||
- **Query-Specific Facts**: Different facts surface based on conversation topic
|
|
||||||
- **Clean Architecture**: No middleware message manipulation
|
|
||||||
- **LangChain Native**: Uses built-in dynamic system prompt support
|
|
||||||
- **Runtime Flexibility**: Memory regenerated for each agent invocation
|
|
||||||
|
|
||||||
## Dependencies
|
|
||||||
|
|
||||||
New dependencies added to `pyproject.toml`:
|
|
||||||
```toml
|
|
||||||
dependencies = [
|
|
||||||
# ... existing dependencies ...
|
|
||||||
"tiktoken>=0.8.0", # Accurate token counting
|
|
||||||
"scikit-learn>=1.6.1", # TF-IDF vectorization
|
|
||||||
]
|
|
||||||
```
|
|
||||||
|
|
||||||
Install with:
|
|
||||||
```bash
|
|
||||||
cd backend
|
|
||||||
uv sync
|
|
||||||
```
|
|
||||||
|
|
||||||
## Testing
|
|
||||||
|
|
||||||
Run the test script to verify improvements:
|
|
||||||
```bash
|
|
||||||
cd backend
|
|
||||||
python test_memory_improvement.py
|
|
||||||
```
|
|
||||||
|
|
||||||
Expected output shows:
|
|
||||||
- Different fact ordering based on context
|
|
||||||
- Accurate token counts vs old estimates
|
|
||||||
- Budget-respecting fact selection
|
|
||||||
|
|
||||||
## Performance Impact
|
|
||||||
|
|
||||||
### Computational Cost
|
|
||||||
- **TF-IDF Calculation**: O(n × m) where n=facts, m=vocabulary
|
|
||||||
- Negligible for typical fact counts (10-100 facts)
|
|
||||||
- Caching opportunities if context doesn't change
|
|
||||||
- **Token Counting**: ~10-100µs per call
|
|
||||||
- Faster than the old character-counting approach
|
|
||||||
- Minimal overhead compared to LLM inference
|
|
||||||
|
|
||||||
### Memory Usage
|
|
||||||
- **TF-IDF Vectorizer**: ~1-5MB for typical vocabulary
|
|
||||||
- Instantiated once per injection call
|
|
||||||
- Garbage collected after use
|
|
||||||
- **Tiktoken Encoding**: ~1MB (cached singleton)
|
|
||||||
- Loaded once per process lifetime
|
|
||||||
|
|
||||||
### Recommendations
|
|
||||||
- Current implementation is optimized for accuracy over caching
|
|
||||||
- For high-throughput scenarios, consider:
|
|
||||||
- Pre-computing fact embeddings (store in memory.json)
|
|
||||||
- Caching TF-IDF vectorizer between calls
|
|
||||||
- Using approximate nearest neighbor search for >1000 facts
|
|
||||||
|
|
||||||
## Summary
|
|
||||||
|
|
||||||
| Aspect | Before | After |
|
|
||||||
|--------|--------|-------|
|
|
||||||
| Fact Selection | Top 15 by confidence only | Relevance-based (similarity + confidence) |
|
|
||||||
| Token Counting | `len(text) // 4` | `tiktoken.encode(text)` |
|
|
||||||
| Context Awareness | None | TF-IDF cosine similarity |
|
|
||||||
| Accuracy | ±25% token estimate | Exact token count |
|
|
||||||
| Configuration | Fixed weights | Customizable similarity/confidence weights |
|
|
||||||
|
|
||||||
These improvements result in:
|
|
||||||
- **More relevant** facts injected into context
|
|
||||||
- **Better utilization** of available token budget
|
|
||||||
- **Fewer hallucinations** due to focused context
|
|
||||||
- **Higher quality** agent responses
|
|
||||||
|
|||||||
@ -1,260 +1,38 @@
|
|||||||
# Memory System Improvements - Summary
|
# Memory System Improvements - Summary
|
||||||
|
|
||||||
## 改进概述
|
## Sync Note (2026-03-10)
|
||||||
|
|
||||||
针对你提出的两个问题进行了优化:
|
This summary is synchronized with the `main` branch implementation.
|
||||||
1. ✅ **粗糙的 token 计算**(`字符数 * 4`)→ 使用 tiktoken 精确计算
|
TF-IDF/context-aware retrieval is **planned**, not merged yet.
|
||||||
2. ✅ **缺乏相似度召回** → 使用 TF-IDF + 最近对话上下文
|
|
||||||
|
|
||||||
## 核心改进
|
## Implemented
|
||||||
|
|
||||||
### 1. 基于对话上下文的智能 Facts 召回
|
- Accurate token counting with `tiktoken` in memory injection.
|
||||||
|
- Facts are injected into `<memory>` prompt content.
|
||||||
|
- Facts are ordered by confidence and bounded by `max_injection_tokens`.
|
||||||
|
|
||||||
**之前**:
|
## Planned (Not Yet Merged)
|
||||||
- 只按 confidence 排序取前 15 个
|
|
||||||
- 无论用户在讨论什么都注入相同的 facts
|
|
||||||
|
|
||||||
**现在**:
|
- TF-IDF cosine similarity recall based on recent conversation context.
|
||||||
- 提取最近 **3 轮对话**(human + AI 消息)作为上下文
|
- `current_context` parameter for `format_memory_for_injection`.
|
||||||
- 使用 **TF-IDF 余弦相似度**计算每个 fact 与对话的相关性
|
- Weighted ranking (`similarity` + `confidence`).
|
||||||
- 综合评分:`相似度(60%) + 置信度(40%)`
|
- Runtime extraction/injection flow for context-aware fact selection.
|
||||||
- 动态选择最相关的 facts
|
|
||||||
|
|
||||||
**示例**:
|
## Why This Sync Was Needed
|
||||||
```
|
|
||||||
对话历史:
|
|
||||||
Turn 1: "我在做一个 Python 项目"
|
|
||||||
Turn 2: "使用 FastAPI 和 SQLAlchemy"
|
|
||||||
Turn 3: "怎么写测试?"
|
|
||||||
|
|
||||||
上下文: "我在做一个 Python 项目 使用 FastAPI 和 SQLAlchemy 怎么写测试?"
|
Earlier docs described TF-IDF behavior as already implemented, which did not match code in `main`.
|
||||||
|
This mismatch is tracked in issue `#1059`.
|
||||||
|
|
||||||
相关度高的 facts:
|
## Current API Shape
|
||||||
✓ "Prefers pytest for testing" (Python + 测试)
|
|
||||||
✓ "Expert in Python and FastAPI" (Python + FastAPI)
|
|
||||||
✓ "Likes type hints in Python" (Python)
|
|
||||||
|
|
||||||
相关度低的 facts:
|
|
||||||
✗ "Uses Docker for containerization" (不相关)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. 精确的 Token 计算
|
|
||||||
|
|
||||||
**之前**:
|
|
||||||
```python
|
|
||||||
max_chars = max_tokens * 4 # 粗糙估算
|
|
||||||
```
|
|
||||||
|
|
||||||
**现在**:
|
|
||||||
```python
|
|
||||||
import tiktoken
|
|
||||||
|
|
||||||
def _count_tokens(text: str) -> int:
|
|
||||||
encoding = tiktoken.get_encoding("cl100k_base") # GPT-4/3.5
|
|
||||||
return len(encoding.encode(text))
|
|
||||||
```
|
|
||||||
|
|
||||||
**效果对比**:
|
|
||||||
```python
|
|
||||||
text = "This is a test string to count tokens accurately."
|
|
||||||
旧方法: len(text) // 4 = 12 tokens (估算)
|
|
||||||
新方法: tiktoken.encode = 10 tokens (精确)
|
|
||||||
误差: 20%
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. 多轮对话上下文
|
|
||||||
|
|
||||||
**之前的担心**:
|
|
||||||
> "只传最近一条 human message 会不会上下文不太够?"
|
|
||||||
|
|
||||||
**现在的解决方案**:
|
|
||||||
- 提取最近 **3 轮对话**(可配置)
|
|
||||||
- 包括 human 和 AI 消息
|
|
||||||
- 更完整的对话上下文
|
|
||||||
|
|
||||||
**示例**:
|
|
||||||
```
|
|
||||||
单条消息: "怎么写测试?"
|
|
||||||
→ 缺少上下文,不知道是什么项目
|
|
||||||
|
|
||||||
3轮对话: "Python 项目 + FastAPI + 怎么写测试?"
|
|
||||||
→ 完整上下文,能选择更相关的 facts
|
|
||||||
```
|
|
||||||
|
|
||||||
## 实现方式
|
|
||||||
|
|
||||||
### Middleware 动态注入
|
|
||||||
|
|
||||||
使用 `before_model` 钩子在**每次 LLM 调用前**注入 memory:
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# src/agents/middlewares/memory_middleware.py
|
def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000) -> str:
|
||||||
|
|
||||||
def _extract_conversation_context(messages: list, max_turns: int = 3) -> str:
|
|
||||||
"""提取最近 3 轮对话(只包含用户输入和最终回复)"""
|
|
||||||
context_parts = []
|
|
||||||
turn_count = 0
|
|
||||||
|
|
||||||
for msg in reversed(messages):
|
|
||||||
msg_type = getattr(msg, "type", None)
|
|
||||||
|
|
||||||
if msg_type == "human":
|
|
||||||
# ✅ 总是包含用户消息
|
|
||||||
content = extract_text(msg)
|
|
||||||
if content:
|
|
||||||
context_parts.append(content)
|
|
||||||
turn_count += 1
|
|
||||||
if turn_count >= max_turns:
|
|
||||||
break
|
|
||||||
|
|
||||||
elif msg_type == "ai":
|
|
||||||
# ✅ 只包含没有 tool_calls 的 AI 消息(最终回复)
|
|
||||||
tool_calls = getattr(msg, "tool_calls", None)
|
|
||||||
if not tool_calls:
|
|
||||||
content = extract_text(msg)
|
|
||||||
if content:
|
|
||||||
context_parts.append(content)
|
|
||||||
|
|
||||||
# ✅ 跳过 tool messages 和带 tool_calls 的 AI 消息
|
|
||||||
|
|
||||||
return " ".join(reversed(context_parts))
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryMiddleware:
|
|
||||||
def before_model(self, state, runtime):
|
|
||||||
"""在每次 LLM 调用前注入 memory(不是 before_agent)"""
|
|
||||||
|
|
||||||
# 1. 提取最近 3 轮对话(过滤掉 tool calls)
|
|
||||||
messages = state["messages"]
|
|
||||||
conversation_context = _extract_conversation_context(messages, max_turns=3)
|
|
||||||
|
|
||||||
# 2. 使用干净的对话上下文选择相关 facts
|
|
||||||
memory_data = get_memory_data()
|
|
||||||
memory_content = format_memory_for_injection(
|
|
||||||
memory_data,
|
|
||||||
max_tokens=config.max_injection_tokens,
|
|
||||||
current_context=conversation_context, # ✅ 只包含真实对话内容
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. 作为 system message 注入到消息列表开头
|
|
||||||
memory_message = SystemMessage(
|
|
||||||
content=f"<memory>\n{memory_content}\n</memory>",
|
|
||||||
name="memory_context", # 用于去重检测
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. 插入到消息列表开头
|
|
||||||
updated_messages = [memory_message] + messages
|
|
||||||
return {"messages": updated_messages}
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### 为什么这样设计?
|
No `current_context` argument is currently available in `main`.
|
||||||
|
|
||||||
基于你的三个重要观察:
|
## Verification Pointers
|
||||||
|
|
||||||
1. **应该用 `before_model` 而不是 `before_agent`**
|
- Implementation: `backend/src/agents/memory/prompt.py`
|
||||||
- ✅ `before_agent`: 只在整个 agent 开始时调用一次
|
- Prompt assembly: `backend/src/agents/lead_agent/prompt.py`
|
||||||
- ✅ `before_model`: 在**每次 LLM 调用前**都会调用
|
- Regression tests: `backend/tests/test_memory_prompt_injection.py`
|
||||||
- ✅ 这样每次 LLM 推理都能看到最新的相关 memory
|
|
||||||
|
|
||||||
2. **messages 数组里只有 human/ai/tool,没有 system**
|
|
||||||
- ✅ 虽然不常见,但 LangChain 允许在对话中插入 system message
|
|
||||||
- ✅ Middleware 可以修改 messages 数组
|
|
||||||
- ✅ 使用 `name="memory_context"` 防止重复注入
|
|
||||||
|
|
||||||
3. **应该剔除 tool call 的 AI messages,只传用户输入和最终输出**
|
|
||||||
- ✅ 过滤掉带 `tool_calls` 的 AI 消息(中间步骤)
|
|
||||||
- ✅ 只保留: - Human 消息(用户输入)
|
|
||||||
- AI 消息但无 tool_calls(最终回复)
|
|
||||||
- ✅ 上下文更干净,TF-IDF 相似度计算更准确
|
|
||||||
|
|
||||||
## 配置选项
|
|
||||||
|
|
||||||
在 `config.yaml` 中可以调整:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
memory:
|
|
||||||
enabled: true
|
|
||||||
max_injection_tokens: 2000 # ✅ 使用精确 token 计数
|
|
||||||
|
|
||||||
# 高级设置(可选)
|
|
||||||
# max_context_turns: 3 # 对话轮数(默认 3)
|
|
||||||
# similarity_weight: 0.6 # 相似度权重
|
|
||||||
# confidence_weight: 0.4 # 置信度权重
|
|
||||||
```
|
|
||||||
|
|
||||||
## 依赖变更
|
|
||||||
|
|
||||||
新增依赖:
|
|
||||||
```toml
|
|
||||||
dependencies = [
|
|
||||||
"tiktoken>=0.8.0", # 精确 token 计数
|
|
||||||
"scikit-learn>=1.6.1", # TF-IDF 向量化
|
|
||||||
]
|
|
||||||
```
|
|
||||||
|
|
||||||
安装:
|
|
||||||
```bash
|
|
||||||
cd backend
|
|
||||||
uv sync
|
|
||||||
```
|
|
||||||
|
|
||||||
## 性能影响
|
|
||||||
|
|
||||||
- **TF-IDF 计算**:O(n × m),n=facts 数量,m=词汇表大小
|
|
||||||
- 典型场景(10-100 facts):< 10ms
|
|
||||||
- **Token 计数**:~100µs per call
|
|
||||||
- 比字符计数还快
|
|
||||||
- **总开销**:可忽略(相比 LLM 推理)
|
|
||||||
|
|
||||||
## 向后兼容性
|
|
||||||
|
|
||||||
✅ 完全向后兼容:
|
|
||||||
- 如果没有 `current_context`,退化为按 confidence 排序
|
|
||||||
- 所有现有配置继续工作
|
|
||||||
- 不影响其他功能
|
|
||||||
|
|
||||||
## 文件变更清单
|
|
||||||
|
|
||||||
1. **核心功能**
|
|
||||||
- `src/agents/memory/prompt.py` - 添加 TF-IDF 召回和精确 token 计数
|
|
||||||
- `src/agents/lead_agent/prompt.py` - 动态系统提示
|
|
||||||
- `src/agents/lead_agent/agent.py` - 传入函数而非字符串
|
|
||||||
|
|
||||||
2. **依赖**
|
|
||||||
- `pyproject.toml` - 添加 tiktoken 和 scikit-learn
|
|
||||||
|
|
||||||
3. **文档**
|
|
||||||
- `docs/MEMORY_IMPROVEMENTS.md` - 详细技术文档
|
|
||||||
- `docs/MEMORY_IMPROVEMENTS_SUMMARY.md` - 改进总结(本文件)
|
|
||||||
- `CLAUDE.md` - 更新架构说明
|
|
||||||
- `config.example.yaml` - 添加配置说明
|
|
||||||
|
|
||||||
## 测试验证
|
|
||||||
|
|
||||||
运行项目验证:
|
|
||||||
```bash
|
|
||||||
cd backend
|
|
||||||
make dev
|
|
||||||
```
|
|
||||||
|
|
||||||
在对话中测试:
|
|
||||||
1. 讨论不同主题(Python、React、Docker 等)
|
|
||||||
2. 观察不同对话注入的 facts 是否不同
|
|
||||||
3. 检查 token 预算是否被准确控制
|
|
||||||
|
|
||||||
## 总结
|
|
||||||
|
|
||||||
| 问题 | 之前 | 现在 |
|
|
||||||
|------|------|------|
|
|
||||||
| Token 计算 | `len(text) // 4` (±25% 误差) | `tiktoken.encode()` (精确) |
|
|
||||||
| Facts 选择 | 按 confidence 固定排序 | TF-IDF 相似度 + confidence |
|
|
||||||
| 上下文 | 无 | 最近 3 轮对话 |
|
|
||||||
| 实现方式 | 静态系统提示 | 动态系统提示函数 |
|
|
||||||
| 配置灵活性 | 有限 | 可调轮数和权重 |
|
|
||||||
|
|
||||||
所有改进都实现了,并且:
|
|
||||||
- ✅ 不修改 messages 数组
|
|
||||||
- ✅ 使用多轮对话上下文
|
|
||||||
- ✅ 精确 token 计数
|
|
||||||
- ✅ 智能相似度召回
|
|
||||||
- ✅ 完全向后兼容
|
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
"""Prompt templates for memory update and injection."""
|
"""Prompt templates for memory update and injection."""
|
||||||
|
|
||||||
|
import math
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -166,6 +167,22 @@ def _count_tokens(text: str, encoding_name: str = "cl100k_base") -> int:
|
|||||||
return len(text) // 4
|
return len(text) // 4
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_confidence(value: Any, default: float = 0.0) -> float:
|
||||||
|
"""Coerce a confidence-like value to a bounded float in [0, 1].
|
||||||
|
|
||||||
|
Non-finite values (NaN, inf, -inf) are treated as invalid and fall back
|
||||||
|
to the default before clamping, preventing them from dominating ranking.
|
||||||
|
The ``default`` parameter is assumed to be a finite value.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
confidence = float(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return max(0.0, min(1.0, default))
|
||||||
|
if not math.isfinite(confidence):
|
||||||
|
return max(0.0, min(1.0, default))
|
||||||
|
return max(0.0, min(1.0, confidence))
|
||||||
|
|
||||||
|
|
||||||
def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000) -> str:
|
def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000) -> str:
|
||||||
"""Format memory data for injection into system prompt.
|
"""Format memory data for injection into system prompt.
|
||||||
|
|
||||||
@ -217,6 +234,55 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
|||||||
if history_sections:
|
if history_sections:
|
||||||
sections.append("History:\n" + "\n".join(f"- {s}" for s in history_sections))
|
sections.append("History:\n" + "\n".join(f"- {s}" for s in history_sections))
|
||||||
|
|
||||||
|
# Format facts (sorted by confidence; include as many as token budget allows)
|
||||||
|
facts_data = memory_data.get("facts", [])
|
||||||
|
if isinstance(facts_data, list) and facts_data:
|
||||||
|
ranked_facts = sorted(
|
||||||
|
(
|
||||||
|
f
|
||||||
|
for f in facts_data
|
||||||
|
if isinstance(f, dict)
|
||||||
|
and isinstance(f.get("content"), str)
|
||||||
|
and f.get("content").strip()
|
||||||
|
),
|
||||||
|
key=lambda fact: _coerce_confidence(fact.get("confidence"), default=0.0),
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute token count for existing sections once, then account
|
||||||
|
# incrementally for each fact line to avoid full-string re-tokenization.
|
||||||
|
base_text = "\n\n".join(sections)
|
||||||
|
base_tokens = _count_tokens(base_text) if base_text else 0
|
||||||
|
# Account for the separator between existing sections and the facts section.
|
||||||
|
facts_header = "Facts:\n"
|
||||||
|
separator_tokens = _count_tokens("\n\n" + facts_header) if base_text else _count_tokens(facts_header)
|
||||||
|
running_tokens = base_tokens + separator_tokens
|
||||||
|
|
||||||
|
fact_lines: list[str] = []
|
||||||
|
for fact in ranked_facts:
|
||||||
|
content_value = fact.get("content")
|
||||||
|
if not isinstance(content_value, str):
|
||||||
|
continue
|
||||||
|
content = content_value.strip()
|
||||||
|
if not content:
|
||||||
|
continue
|
||||||
|
category = str(fact.get("category", "context")).strip() or "context"
|
||||||
|
confidence = _coerce_confidence(fact.get("confidence"), default=0.0)
|
||||||
|
line = f"- [{category} | {confidence:.2f}] {content}"
|
||||||
|
|
||||||
|
# Each additional line is preceded by a newline (except the first).
|
||||||
|
line_text = ("\n" + line) if fact_lines else line
|
||||||
|
line_tokens = _count_tokens(line_text)
|
||||||
|
|
||||||
|
if running_tokens + line_tokens <= max_tokens:
|
||||||
|
fact_lines.append(line)
|
||||||
|
running_tokens += line_tokens
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
if fact_lines:
|
||||||
|
sections.append("Facts:\n" + "\n".join(fact_lines))
|
||||||
|
|
||||||
if not sections:
|
if not sections:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|||||||
122
backend/tests/test_memory_prompt_injection.py
Normal file
122
backend/tests/test_memory_prompt_injection.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
"""Tests for memory prompt injection formatting."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
from src.agents.memory.prompt import _coerce_confidence, format_memory_for_injection
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_memory_includes_facts_section() -> None:
|
||||||
|
memory_data = {
|
||||||
|
"user": {},
|
||||||
|
"history": {},
|
||||||
|
"facts": [
|
||||||
|
{"content": "User uses PostgreSQL", "category": "knowledge", "confidence": 0.9},
|
||||||
|
{"content": "User prefers SQLAlchemy", "category": "preference", "confidence": 0.8},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = format_memory_for_injection(memory_data, max_tokens=2000)
|
||||||
|
|
||||||
|
assert "Facts:" in result
|
||||||
|
assert "User uses PostgreSQL" in result
|
||||||
|
assert "User prefers SQLAlchemy" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_memory_sorts_facts_by_confidence_desc() -> None:
|
||||||
|
memory_data = {
|
||||||
|
"user": {},
|
||||||
|
"history": {},
|
||||||
|
"facts": [
|
||||||
|
{"content": "Low confidence fact", "category": "context", "confidence": 0.4},
|
||||||
|
{"content": "High confidence fact", "category": "knowledge", "confidence": 0.95},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = format_memory_for_injection(memory_data, max_tokens=2000)
|
||||||
|
|
||||||
|
assert result.index("High confidence fact") < result.index("Low confidence fact")
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_memory_respects_budget_when_adding_facts(monkeypatch) -> None:
|
||||||
|
# Make token counting deterministic for this test by counting characters.
|
||||||
|
monkeypatch.setattr("src.agents.memory.prompt._count_tokens", lambda text, encoding_name="cl100k_base": len(text))
|
||||||
|
|
||||||
|
memory_data = {
|
||||||
|
"user": {},
|
||||||
|
"history": {},
|
||||||
|
"facts": [
|
||||||
|
{"content": "First fact should fit", "category": "knowledge", "confidence": 0.95},
|
||||||
|
{"content": "Second fact should not fit in tiny budget", "category": "knowledge", "confidence": 0.90},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
first_fact_only_memory_data = {
|
||||||
|
"user": {},
|
||||||
|
"history": {},
|
||||||
|
"facts": [
|
||||||
|
{"content": "First fact should fit", "category": "knowledge", "confidence": 0.95},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
one_fact_result = format_memory_for_injection(first_fact_only_memory_data, max_tokens=2000)
|
||||||
|
two_facts_result = format_memory_for_injection(memory_data, max_tokens=2000)
|
||||||
|
# Choose a budget that can include exactly one fact section line.
|
||||||
|
max_tokens = (len(one_fact_result) + len(two_facts_result)) // 2
|
||||||
|
|
||||||
|
first_only_result = format_memory_for_injection(memory_data, max_tokens=max_tokens)
|
||||||
|
|
||||||
|
assert "First fact should fit" in first_only_result
|
||||||
|
assert "Second fact should not fit in tiny budget" not in first_only_result
|
||||||
|
|
||||||
|
|
||||||
|
def test_coerce_confidence_nan_falls_back_to_default() -> None:
|
||||||
|
"""NaN should not be treated as a valid confidence value."""
|
||||||
|
result = _coerce_confidence(math.nan, default=0.5)
|
||||||
|
assert result == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def test_coerce_confidence_inf_falls_back_to_default() -> None:
|
||||||
|
"""Infinite values should fall back to default rather than clamping to 1.0."""
|
||||||
|
assert _coerce_confidence(math.inf, default=0.3) == 0.3
|
||||||
|
assert _coerce_confidence(-math.inf, default=0.3) == 0.3
|
||||||
|
|
||||||
|
|
||||||
|
def test_coerce_confidence_valid_values_are_clamped() -> None:
|
||||||
|
"""Valid floats outside [0, 1] are clamped; values inside are preserved."""
|
||||||
|
assert _coerce_confidence(1.5) == 1.0
|
||||||
|
assert _coerce_confidence(-0.5) == 0.0
|
||||||
|
assert abs(_coerce_confidence(0.75) - 0.75) < 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_memory_skips_none_content_facts() -> None:
|
||||||
|
"""Facts with content=None must not produce a 'None' line in the output."""
|
||||||
|
memory_data = {
|
||||||
|
"facts": [
|
||||||
|
{"content": None, "category": "knowledge", "confidence": 0.9},
|
||||||
|
{"content": "Real fact", "category": "knowledge", "confidence": 0.8},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = format_memory_for_injection(memory_data, max_tokens=2000)
|
||||||
|
|
||||||
|
assert "None" not in result
|
||||||
|
assert "Real fact" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_memory_skips_non_string_content_facts() -> None:
|
||||||
|
"""Facts with non-string content (e.g. int/list) must be ignored."""
|
||||||
|
memory_data = {
|
||||||
|
"facts": [
|
||||||
|
{"content": 42, "category": "knowledge", "confidence": 0.9},
|
||||||
|
{"content": ["list"], "category": "knowledge", "confidence": 0.85},
|
||||||
|
{"content": "Valid fact", "category": "knowledge", "confidence": 0.7},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = format_memory_for_injection(memory_data, max_tokens=2000)
|
||||||
|
|
||||||
|
# The formatted line for an integer content would be "- [knowledge | 0.90] 42".
|
||||||
|
assert "| 0.90] 42" not in result
|
||||||
|
# The formatted line for a list content would be "- [knowledge | 0.85] ['list']".
|
||||||
|
assert "| 0.85]" not in result
|
||||||
|
assert "Valid fact" in result
|
||||||
|
|
||||||
Loading…
x
Reference in New Issue
Block a user