mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-28 04:38:25 +00:00
* fix: support local models by making thought field optional in Plan model - Make thought field optional in Plan model to fix Pydantic validation errors with local models - Add Ollama configuration example to conf.yaml.example - Update documentation to include local model support - Improve planner prompt with better JSON format requirements Fixes local model integration issues where models like qwen3:14b would fail due to missing thought field in JSON output. * feat: Add intelligent clarification feature for research queries - Add multi-turn clarification process to refine vague research questions - Implement three-dimension clarification standard (Tech/App, Focus, Scope) - Add clarification state management in coordinator node - Update coordinator prompt with detailed clarification guidelines - Add UI settings to enable/disable clarification feature (disabled by default) - Update workflow to handle clarification rounds recursively - Add comprehensive test coverage for clarification functionality - Update documentation with clarification feature usage guide Key components: - src/graph/nodes.py: Core clarification logic and state management - src/prompts/coordinator.md: Detailed clarification guidelines - src/workflow.py: Recursive clarification handling - web/: UI settings integration - tests/: Comprehensive test coverage - docs/: Updated configuration guide * fix: Improve clarification conversation continuity - Add comprehensive conversation history to clarification context - Include previous exchanges summary in system messages - Add explicit guidelines for continuing rounds in coordinator prompt - Prevent LLM from starting new topics during clarification - Ensure topic continuity across clarification rounds Fixes issue where LLM would restart clarification instead of building upon previous exchanges. * fix: Add conversation history to clarification context * fix: resolve clarification feature message to planer, prompt, test issues - Optimize coordinator.md prompt template for better clarification flow - Simplify final message sent to planner after clarification - Fix API key assertion issues in test_search.py * fix: Add configurable max_clarification_rounds and comprehensive tests - Add max_clarification_rounds parameter for external configuration - Add comprehensive test cases for clarification feature in test_app.py - Fixes issues found during interactive mode testing where: - Recursive call failed due to missing initial_state parameter - Clarification exited prematurely at max rounds - Incorrect logging of max rounds reached * Move clarification tests to test_nodes.py and add max_clarification_rounds to zh.json
267 lines
8.6 KiB
Python
267 lines
8.6 KiB
Python
# src/utils/token_manager.py
|
|
import copy
|
|
import logging
|
|
from typing import List
|
|
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
BaseMessage,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
ToolMessage,
|
|
)
|
|
|
|
from src.config import load_yaml_config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def get_search_config():
|
|
config = load_yaml_config("conf.yaml")
|
|
search_config = config.get("MODEL_TOKEN_LIMITS", {})
|
|
return search_config
|
|
|
|
|
|
class ContextManager:
|
|
"""Context manager and compression class"""
|
|
|
|
def __init__(self, token_limit: int, preserve_prefix_message_count: int = 0):
|
|
"""
|
|
Initialize ContextManager
|
|
|
|
Args:
|
|
token_limit: Maximum token limit
|
|
preserve_prefix_message_count: Number of messages to preserve at the beginning of the context
|
|
"""
|
|
self.token_limit = token_limit
|
|
self.preserve_prefix_message_count = preserve_prefix_message_count
|
|
|
|
def count_tokens(self, messages: List[BaseMessage]) -> int:
|
|
"""
|
|
Count tokens in message list
|
|
|
|
Args:
|
|
messages: List of messages
|
|
|
|
Returns:
|
|
Number of tokens
|
|
"""
|
|
total_tokens = 0
|
|
for message in messages:
|
|
total_tokens += self._count_message_tokens(message)
|
|
return total_tokens
|
|
|
|
def _count_message_tokens(self, message: BaseMessage) -> int:
|
|
"""
|
|
Count tokens in a single message
|
|
|
|
Args:
|
|
message: Message object
|
|
|
|
Returns:
|
|
Number of tokens
|
|
"""
|
|
# Estimate token count based on character length (different calculation for English and non-English)
|
|
token_count = 0
|
|
|
|
# Count tokens in content field
|
|
if hasattr(message, "content") and message.content:
|
|
# Handle different content types
|
|
if isinstance(message.content, str):
|
|
token_count += self._count_text_tokens(message.content)
|
|
|
|
# Count role-related tokens
|
|
if hasattr(message, "type"):
|
|
token_count += self._count_text_tokens(message.type)
|
|
|
|
# Special handling for different message types
|
|
if isinstance(message, SystemMessage):
|
|
# System messages are usually short but important, slightly increase estimate
|
|
token_count = int(token_count * 1.1)
|
|
elif isinstance(message, HumanMessage):
|
|
# Human messages use normal estimation
|
|
pass
|
|
elif isinstance(message, AIMessage):
|
|
# AI messages may contain reasoning content, slightly increase estimate
|
|
token_count = int(token_count * 1.2)
|
|
elif isinstance(message, ToolMessage):
|
|
# Tool messages may contain large amounts of structured data, increase estimate
|
|
token_count = int(token_count * 1.3)
|
|
|
|
# Process additional information in additional_kwargs
|
|
if hasattr(message, "additional_kwargs") and message.additional_kwargs:
|
|
# Simple estimation of extra field tokens
|
|
extra_str = str(message.additional_kwargs)
|
|
token_count += self._count_text_tokens(extra_str)
|
|
|
|
# If there are tool_calls, add estimation
|
|
if "tool_calls" in message.additional_kwargs:
|
|
token_count += 50 # Add estimation for function call information
|
|
|
|
# Ensure at least 1 token
|
|
return max(1, token_count)
|
|
|
|
def _count_text_tokens(self, text: str) -> int:
|
|
"""
|
|
Count tokens in text with different calculations for English and non-English characters.
|
|
English characters: 4 characters ≈ 1 token
|
|
Non-English characters (e.g., Chinese): 1 character ≈ 1 token
|
|
|
|
Args:
|
|
text: Text to count tokens for
|
|
|
|
Returns:
|
|
Number of tokens
|
|
"""
|
|
if not text:
|
|
return 0
|
|
|
|
english_chars = 0
|
|
non_english_chars = 0
|
|
|
|
for char in text:
|
|
# Check if character is ASCII (English letters, digits, punctuation)
|
|
if ord(char) < 128:
|
|
english_chars += 1
|
|
else:
|
|
non_english_chars += 1
|
|
|
|
# Calculate tokens: English at 4 chars/token, others at 1 char/token
|
|
english_tokens = english_chars // 4
|
|
non_english_tokens = non_english_chars
|
|
|
|
return english_tokens + non_english_tokens
|
|
|
|
def is_over_limit(self, messages: List[BaseMessage]) -> bool:
|
|
"""
|
|
Check if messages exceed token limit
|
|
|
|
Args:
|
|
messages: List of messages
|
|
|
|
Returns:
|
|
Whether limit is exceeded
|
|
"""
|
|
return self.count_tokens(messages) > self.token_limit
|
|
|
|
def compress_messages(self, state: dict) -> List[BaseMessage]:
|
|
"""
|
|
Compress messages to fit within token limit
|
|
|
|
Args:
|
|
state: state with original messages
|
|
|
|
Returns:
|
|
Compressed state with compressed messages
|
|
"""
|
|
# If not set token_limit, return original state
|
|
if self.token_limit is None:
|
|
logger.info("No token_limit set, the context management doesn't work.")
|
|
return state
|
|
|
|
if not isinstance(state, dict) or "messages" not in state:
|
|
logger.warning("No messages found in state")
|
|
return state
|
|
|
|
messages = state["messages"]
|
|
|
|
if not self.is_over_limit(messages):
|
|
return state
|
|
|
|
# 2. Compress messages
|
|
compressed_messages = self._compress_messages(messages)
|
|
|
|
logger.info(
|
|
f"Message compression completed: {self.count_tokens(messages)} -> {self.count_tokens(compressed_messages)} tokens"
|
|
)
|
|
|
|
state["messages"] = compressed_messages
|
|
return state
|
|
|
|
def _compress_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:
|
|
"""
|
|
Compress compressible messages
|
|
|
|
Args:
|
|
messages: List of messages to compress
|
|
|
|
Returns:
|
|
Compressed message list
|
|
"""
|
|
|
|
available_token = self.token_limit
|
|
prefix_messages = []
|
|
|
|
# 1. Preserve head messages of specified length to retain system prompts and user input
|
|
for i in range(min(self.preserve_prefix_message_count, len(messages))):
|
|
cur_token_cnt = self._count_message_tokens(messages[i])
|
|
if available_token > 0 and available_token >= cur_token_cnt:
|
|
prefix_messages.append(messages[i])
|
|
available_token -= cur_token_cnt
|
|
elif available_token > 0:
|
|
# Truncate content to fit available tokens
|
|
truncated_message = self._truncate_message_content(
|
|
messages[i], available_token
|
|
)
|
|
prefix_messages.append(truncated_message)
|
|
return prefix_messages
|
|
else:
|
|
break
|
|
|
|
# 2. Compress subsequent messages from the tail, some messages may be discarded
|
|
messages = messages[len(prefix_messages) :]
|
|
suffix_messages = []
|
|
for i in range(len(messages) - 1, -1, -1):
|
|
cur_token_cnt = self._count_message_tokens(messages[i])
|
|
|
|
if cur_token_cnt > 0 and available_token >= cur_token_cnt:
|
|
suffix_messages = [messages[i]] + suffix_messages
|
|
available_token -= cur_token_cnt
|
|
elif available_token > 0:
|
|
# Truncate content to fit available tokens
|
|
truncated_message = self._truncate_message_content(
|
|
messages[i], available_token
|
|
)
|
|
suffix_messages = [truncated_message] + suffix_messages
|
|
return prefix_messages + suffix_messages
|
|
else:
|
|
break
|
|
|
|
return prefix_messages + suffix_messages
|
|
|
|
def _truncate_message_content(
|
|
self, message: BaseMessage, max_tokens: int
|
|
) -> BaseMessage:
|
|
"""
|
|
Truncate message content while preserving all other attributes by copying the original message
|
|
and only modifying its content attribute.
|
|
|
|
Args:
|
|
message: The message to truncate
|
|
max_tokens: Maximum number of tokens to keep
|
|
|
|
Returns:
|
|
New message instance with truncated content
|
|
"""
|
|
|
|
# Create a deep copy of the original message to preserve all attributes
|
|
truncated_message = copy.deepcopy(message)
|
|
|
|
# Truncate only the content attribute
|
|
truncated_message.content = message.content[:max_tokens]
|
|
|
|
return truncated_message
|
|
|
|
def _create_summary_message(self, messages: List[BaseMessage]) -> BaseMessage:
|
|
"""
|
|
Create summary for messages
|
|
|
|
Args:
|
|
messages: Messages to summarize
|
|
|
|
Returns:
|
|
Summary message
|
|
"""
|
|
# TODO: summary implementation
|
|
pass
|