ChatDev/utils/token_tracker.py
2026-01-07 16:24:01 +08:00

153 lines
5.8 KiB
Python
Executable File

"""Token usage tracking module for DevAll project."""
from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, List, Optional, Any
from collections import defaultdict
@dataclass
class TokenUsage:
"""Stores token usage metrics for individual API calls."""
input_tokens: int = 0
output_tokens: int = 0
total_tokens: int = 0
metadata: Dict[str, Any] = field(default_factory=dict)
timestamp: datetime = field(default_factory=datetime.now)
node_id: Optional[str] = None
model_name: Optional[str] = None
workflow_id: Optional[str] = None
provider: Optional[str] = None # Add provider field
def to_dict(self):
"""Convert to dictionary format."""
return {
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"total_tokens": self.total_tokens,
"metadata": dict(self.metadata),
"timestamp": self.timestamp.isoformat(),
"node_id": self.node_id,
"model_name": self.model_name,
"workflow_id": self.workflow_id,
"provider": self.provider # Include provider in output
}
class TokenTracker:
"""Singleton class to track token usage across a workflow."""
def __init__(self, workflow_id: str):
self.workflow_id = workflow_id
self.total_usage = TokenUsage()
self.node_usages = defaultdict(TokenUsage)
self.model_usages = defaultdict(TokenUsage)
self.call_history = []
self.node_call_counts = defaultdict(int) # Track how many times each node is called
def record_usage(self, node_id: str, model_name: str, usage: TokenUsage, provider: str = None):
"""Records token usage for a specific call, handling multiple node executions."""
# Update the usage with provider if it wasn't set already
if provider and not usage.provider:
usage.provider = provider
# Add to total usage
self.total_usage.input_tokens += usage.input_tokens
self.total_usage.output_tokens += usage.output_tokens
self.total_usage.total_tokens += usage.total_tokens
# Add to node-specific usage
node_usage = self.node_usages[node_id]
node_usage.input_tokens += usage.input_tokens
node_usage.output_tokens += usage.output_tokens
node_usage.total_tokens += usage.total_tokens
if provider:
node_usage.provider = provider # Store provider info
# Add to model-specific usage
model_usage = self.model_usages[model_name]
model_usage.input_tokens += usage.input_tokens
model_usage.output_tokens += usage.output_tokens
model_usage.total_tokens += usage.total_tokens
if provider:
model_usage.provider = provider # Store provider info
# Increment call count for this node
self.node_call_counts[node_id] += 1
# Add to call history
history_entry = {
"node_id": node_id,
"model_name": model_name,
"input_tokens": usage.input_tokens,
"output_tokens": usage.output_tokens,
"total_tokens": usage.total_tokens,
"metadata": dict(usage.metadata),
"timestamp": usage.timestamp.isoformat(),
"execution_number": self.node_call_counts[node_id] # Track which execution this is
}
# Add provider to history entry if available
if provider:
history_entry["provider"] = provider
self.call_history.append(history_entry)
def get_total_usage(self) -> TokenUsage:
"""Get total token usage for the workflow."""
return self.total_usage
def get_node_usage(self, node_id: str) -> TokenUsage:
"""Get token usage for a specific node (across all its executions)."""
return self.node_usages[node_id]
def get_model_usage(self, model_name: str) -> TokenUsage:
"""Get token usage for a specific model."""
return self.model_usages[model_name]
def get_node_execution_count(self, node_id: str) -> int:
"""Get how many times a node was executed."""
return self.node_call_counts[node_id]
def get_token_usage(self) -> Dict[str, Any]:
data = {
"workflow_id": self.workflow_id,
"total_usage": {
"input_tokens": self.total_usage.input_tokens,
"output_tokens": self.total_usage.output_tokens,
"total_tokens": self.total_usage.total_tokens,
},
"node_usages": {
node_id: {
"input_tokens": usage.input_tokens,
"output_tokens": usage.output_tokens,
"total_tokens": usage.total_tokens,
}
for node_id, usage in self.node_usages.items()
},
"model_usages": {
model_name: {
"input_tokens": usage.input_tokens,
"output_tokens": usage.output_tokens,
"total_tokens": usage.total_tokens,
}
for model_name, usage in self.model_usages.items()
},
"node_execution_counts": dict(self.node_call_counts),
"call_history": self.call_history,
}
return data
def export_to_file(self, filepath: str):
"""Export token usage data to a JSON file."""
import json
from pathlib import Path
# Create directory if it doesn't exist
path = Path(filepath)
path.parent.mkdir(parents=True, exist_ok=True)
data = self.get_token_usage()
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)