deer-flow/src/utils/context_manager.py
jimmyuconn1982 2510cc61de
feat: Add intelligent clarification feature in coordinate step for research queries (#613)
* fix: support local models by making thought field optional in Plan model

- Make thought field optional in Plan model to fix Pydantic validation errors with local models
- Add Ollama configuration example to conf.yaml.example
- Update documentation to include local model support
- Improve planner prompt with better JSON format requirements

Fixes local model integration issues where models like qwen3:14b would fail
due to missing thought field in JSON output.

* feat: Add intelligent clarification feature for research queries

- Add multi-turn clarification process to refine vague research questions
- Implement three-dimension clarification standard (Tech/App, Focus, Scope)
- Add clarification state management in coordinator node
- Update coordinator prompt with detailed clarification guidelines
- Add UI settings to enable/disable clarification feature (disabled by default)
- Update workflow to handle clarification rounds recursively
- Add comprehensive test coverage for clarification functionality
- Update documentation with clarification feature usage guide

Key components:
- src/graph/nodes.py: Core clarification logic and state management
- src/prompts/coordinator.md: Detailed clarification guidelines
- src/workflow.py: Recursive clarification handling
- web/: UI settings integration
- tests/: Comprehensive test coverage
- docs/: Updated configuration guide

* fix: Improve clarification conversation continuity

- Add comprehensive conversation history to clarification context
- Include previous exchanges summary in system messages
- Add explicit guidelines for continuing rounds in coordinator prompt
- Prevent LLM from starting new topics during clarification
- Ensure topic continuity across clarification rounds

Fixes issue where LLM would restart clarification instead of building upon previous exchanges.

* fix: Add conversation history to clarification context

* fix: resolve clarification feature message to planer, prompt, test issues

- Optimize coordinator.md prompt template for better clarification flow
- Simplify final message sent to planner after clarification
- Fix API key assertion issues in test_search.py

* fix: Add configurable max_clarification_rounds and comprehensive tests

- Add max_clarification_rounds parameter for external configuration
- Add comprehensive test cases for clarification feature in test_app.py
- Fixes issues found during interactive mode testing where:
  - Recursive call failed due to missing initial_state parameter
  - Clarification exited prematurely at max rounds
  - Incorrect logging of max rounds reached

* Move clarification tests to test_nodes.py and add max_clarification_rounds to zh.json
2025-10-14 13:35:57 +08:00

267 lines
8.6 KiB
Python

# src/utils/token_manager.py
import copy
import logging
from typing import List
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from src.config import load_yaml_config
logger = logging.getLogger(__name__)
def get_search_config():
config = load_yaml_config("conf.yaml")
search_config = config.get("MODEL_TOKEN_LIMITS", {})
return search_config
class ContextManager:
"""Context manager and compression class"""
def __init__(self, token_limit: int, preserve_prefix_message_count: int = 0):
"""
Initialize ContextManager
Args:
token_limit: Maximum token limit
preserve_prefix_message_count: Number of messages to preserve at the beginning of the context
"""
self.token_limit = token_limit
self.preserve_prefix_message_count = preserve_prefix_message_count
def count_tokens(self, messages: List[BaseMessage]) -> int:
"""
Count tokens in message list
Args:
messages: List of messages
Returns:
Number of tokens
"""
total_tokens = 0
for message in messages:
total_tokens += self._count_message_tokens(message)
return total_tokens
def _count_message_tokens(self, message: BaseMessage) -> int:
"""
Count tokens in a single message
Args:
message: Message object
Returns:
Number of tokens
"""
# Estimate token count based on character length (different calculation for English and non-English)
token_count = 0
# Count tokens in content field
if hasattr(message, "content") and message.content:
# Handle different content types
if isinstance(message.content, str):
token_count += self._count_text_tokens(message.content)
# Count role-related tokens
if hasattr(message, "type"):
token_count += self._count_text_tokens(message.type)
# Special handling for different message types
if isinstance(message, SystemMessage):
# System messages are usually short but important, slightly increase estimate
token_count = int(token_count * 1.1)
elif isinstance(message, HumanMessage):
# Human messages use normal estimation
pass
elif isinstance(message, AIMessage):
# AI messages may contain reasoning content, slightly increase estimate
token_count = int(token_count * 1.2)
elif isinstance(message, ToolMessage):
# Tool messages may contain large amounts of structured data, increase estimate
token_count = int(token_count * 1.3)
# Process additional information in additional_kwargs
if hasattr(message, "additional_kwargs") and message.additional_kwargs:
# Simple estimation of extra field tokens
extra_str = str(message.additional_kwargs)
token_count += self._count_text_tokens(extra_str)
# If there are tool_calls, add estimation
if "tool_calls" in message.additional_kwargs:
token_count += 50 # Add estimation for function call information
# Ensure at least 1 token
return max(1, token_count)
def _count_text_tokens(self, text: str) -> int:
"""
Count tokens in text with different calculations for English and non-English characters.
English characters: 4 characters ≈ 1 token
Non-English characters (e.g., Chinese): 1 character ≈ 1 token
Args:
text: Text to count tokens for
Returns:
Number of tokens
"""
if not text:
return 0
english_chars = 0
non_english_chars = 0
for char in text:
# Check if character is ASCII (English letters, digits, punctuation)
if ord(char) < 128:
english_chars += 1
else:
non_english_chars += 1
# Calculate tokens: English at 4 chars/token, others at 1 char/token
english_tokens = english_chars // 4
non_english_tokens = non_english_chars
return english_tokens + non_english_tokens
def is_over_limit(self, messages: List[BaseMessage]) -> bool:
"""
Check if messages exceed token limit
Args:
messages: List of messages
Returns:
Whether limit is exceeded
"""
return self.count_tokens(messages) > self.token_limit
def compress_messages(self, state: dict) -> List[BaseMessage]:
"""
Compress messages to fit within token limit
Args:
state: state with original messages
Returns:
Compressed state with compressed messages
"""
# If not set token_limit, return original state
if self.token_limit is None:
logger.info("No token_limit set, the context management doesn't work.")
return state
if not isinstance(state, dict) or "messages" not in state:
logger.warning("No messages found in state")
return state
messages = state["messages"]
if not self.is_over_limit(messages):
return state
# 2. Compress messages
compressed_messages = self._compress_messages(messages)
logger.info(
f"Message compression completed: {self.count_tokens(messages)} -> {self.count_tokens(compressed_messages)} tokens"
)
state["messages"] = compressed_messages
return state
def _compress_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:
"""
Compress compressible messages
Args:
messages: List of messages to compress
Returns:
Compressed message list
"""
available_token = self.token_limit
prefix_messages = []
# 1. Preserve head messages of specified length to retain system prompts and user input
for i in range(min(self.preserve_prefix_message_count, len(messages))):
cur_token_cnt = self._count_message_tokens(messages[i])
if available_token > 0 and available_token >= cur_token_cnt:
prefix_messages.append(messages[i])
available_token -= cur_token_cnt
elif available_token > 0:
# Truncate content to fit available tokens
truncated_message = self._truncate_message_content(
messages[i], available_token
)
prefix_messages.append(truncated_message)
return prefix_messages
else:
break
# 2. Compress subsequent messages from the tail, some messages may be discarded
messages = messages[len(prefix_messages) :]
suffix_messages = []
for i in range(len(messages) - 1, -1, -1):
cur_token_cnt = self._count_message_tokens(messages[i])
if cur_token_cnt > 0 and available_token >= cur_token_cnt:
suffix_messages = [messages[i]] + suffix_messages
available_token -= cur_token_cnt
elif available_token > 0:
# Truncate content to fit available tokens
truncated_message = self._truncate_message_content(
messages[i], available_token
)
suffix_messages = [truncated_message] + suffix_messages
return prefix_messages + suffix_messages
else:
break
return prefix_messages + suffix_messages
def _truncate_message_content(
self, message: BaseMessage, max_tokens: int
) -> BaseMessage:
"""
Truncate message content while preserving all other attributes by copying the original message
and only modifying its content attribute.
Args:
message: The message to truncate
max_tokens: Maximum number of tokens to keep
Returns:
New message instance with truncated content
"""
# Create a deep copy of the original message to preserve all attributes
truncated_message = copy.deepcopy(message)
# Truncate only the content attribute
truncated_message.content = message.content[:max_tokens]
return truncated_message
def _create_summary_message(self, messages: List[BaseMessage]) -> BaseMessage:
"""
Create summary for messages
Args:
messages: Messages to summarize
Returns:
Summary message
"""
# TODO: summary implementation
pass