# src/utils/token_manager.py import copy import logging from typing import List from langchain_core.messages import ( AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage, ) from src.config import load_yaml_config logger = logging.getLogger(__name__) def get_search_config(): config = load_yaml_config("conf.yaml") search_config = config.get("MODEL_TOKEN_LIMITS", {}) return search_config class ContextManager: """Context manager and compression class""" def __init__(self, token_limit: int, preserve_prefix_message_count: int = 0): """ Initialize ContextManager Args: token_limit: Maximum token limit preserve_prefix_message_count: Number of messages to preserve at the beginning of the context """ self.token_limit = token_limit self.preserve_prefix_message_count = preserve_prefix_message_count def count_tokens(self, messages: List[BaseMessage]) -> int: """ Count tokens in message list Args: messages: List of messages Returns: Number of tokens """ total_tokens = 0 for message in messages: total_tokens += self._count_message_tokens(message) return total_tokens def _count_message_tokens(self, message: BaseMessage) -> int: """ Count tokens in a single message Args: message: Message object Returns: Number of tokens """ # Estimate token count based on character length (different calculation for English and non-English) token_count = 0 # Count tokens in content field if hasattr(message, "content") and message.content: # Handle different content types if isinstance(message.content, str): token_count += self._count_text_tokens(message.content) # Count role-related tokens if hasattr(message, "type"): token_count += self._count_text_tokens(message.type) # Special handling for different message types if isinstance(message, SystemMessage): # System messages are usually short but important, slightly increase estimate token_count = int(token_count * 1.1) elif isinstance(message, HumanMessage): # Human messages use normal estimation pass elif isinstance(message, AIMessage): # AI messages may contain reasoning content, slightly increase estimate token_count = int(token_count * 1.2) elif isinstance(message, ToolMessage): # Tool messages may contain large amounts of structured data, increase estimate token_count = int(token_count * 1.3) # Process additional information in additional_kwargs if hasattr(message, "additional_kwargs") and message.additional_kwargs: # Simple estimation of extra field tokens extra_str = str(message.additional_kwargs) token_count += self._count_text_tokens(extra_str) # If there are tool_calls, add estimation if "tool_calls" in message.additional_kwargs: token_count += 50 # Add estimation for function call information # Ensure at least 1 token return max(1, token_count) def _count_text_tokens(self, text: str) -> int: """ Count tokens in text with different calculations for English and non-English characters. English characters: 4 characters ≈ 1 token Non-English characters (e.g., Chinese): 1 character ≈ 1 token Args: text: Text to count tokens for Returns: Number of tokens """ if not text: return 0 english_chars = 0 non_english_chars = 0 for char in text: # Check if character is ASCII (English letters, digits, punctuation) if ord(char) < 128: english_chars += 1 else: non_english_chars += 1 # Calculate tokens: English at 4 chars/token, others at 1 char/token english_tokens = english_chars // 4 non_english_tokens = non_english_chars return english_tokens + non_english_tokens def is_over_limit(self, messages: List[BaseMessage]) -> bool: """ Check if messages exceed token limit Args: messages: List of messages Returns: Whether limit is exceeded """ return self.count_tokens(messages) > self.token_limit def compress_messages(self, state: dict) -> List[BaseMessage]: """ Compress messages to fit within token limit Args: state: state with original messages Returns: Compressed state with compressed messages """ # If not set token_limit, return original state if self.token_limit is None: logger.info("No token_limit set, the context management doesn't work.") return state if not isinstance(state, dict) or "messages" not in state: logger.warning("No messages found in state") return state messages = state["messages"] if not self.is_over_limit(messages): return state # 2. Compress messages compressed_messages = self._compress_messages(messages) logger.info( f"Message compression completed: {self.count_tokens(messages)} -> {self.count_tokens(compressed_messages)} tokens" ) state["messages"] = compressed_messages return state def _compress_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]: """ Compress compressible messages Args: messages: List of messages to compress Returns: Compressed message list """ available_token = self.token_limit prefix_messages = [] # 1. Preserve head messages of specified length to retain system prompts and user input for i in range(min(self.preserve_prefix_message_count, len(messages))): cur_token_cnt = self._count_message_tokens(messages[i]) if available_token > 0 and available_token >= cur_token_cnt: prefix_messages.append(messages[i]) available_token -= cur_token_cnt elif available_token > 0: # Truncate content to fit available tokens truncated_message = self._truncate_message_content( messages[i], available_token ) prefix_messages.append(truncated_message) return prefix_messages else: break # 2. Compress subsequent messages from the tail, some messages may be discarded messages = messages[len(prefix_messages) :] suffix_messages = [] for i in range(len(messages) - 1, -1, -1): cur_token_cnt = self._count_message_tokens(messages[i]) if cur_token_cnt > 0 and available_token >= cur_token_cnt: suffix_messages = [messages[i]] + suffix_messages available_token -= cur_token_cnt elif available_token > 0: # Truncate content to fit available tokens truncated_message = self._truncate_message_content( messages[i], available_token ) suffix_messages = [truncated_message] + suffix_messages return prefix_messages + suffix_messages else: break return prefix_messages + suffix_messages def _truncate_message_content( self, message: BaseMessage, max_tokens: int ) -> BaseMessage: """ Truncate message content while preserving all other attributes by copying the original message and only modifying its content attribute. Args: message: The message to truncate max_tokens: Maximum number of tokens to keep Returns: New message instance with truncated content """ # Create a deep copy of the original message to preserve all attributes truncated_message = copy.deepcopy(message) # Truncate only the content attribute truncated_message.content = message.content[:max_tokens] return truncated_message def _create_summary_message(self, messages: List[BaseMessage]) -> BaseMessage: """ Create summary for messages Args: messages: Messages to summarize Returns: Summary message """ # TODO: summary implementation pass