mirror of
https://github.com/OpenBMB/ChatDev.git
synced 2026-04-25 11:18:06 +00:00
153 lines
5.8 KiB
Python
Executable File
153 lines
5.8 KiB
Python
Executable File
"""Token usage tracking module for DevAll project."""
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from typing import Dict, List, Optional, Any
|
|
from collections import defaultdict
|
|
|
|
|
|
@dataclass
|
|
class TokenUsage:
|
|
"""Stores token usage metrics for individual API calls."""
|
|
input_tokens: int = 0
|
|
output_tokens: int = 0
|
|
total_tokens: int = 0
|
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
timestamp: datetime = field(default_factory=datetime.now)
|
|
node_id: Optional[str] = None
|
|
model_name: Optional[str] = None
|
|
workflow_id: Optional[str] = None
|
|
provider: Optional[str] = None # Add provider field
|
|
|
|
def to_dict(self):
|
|
"""Convert to dictionary format."""
|
|
return {
|
|
"input_tokens": self.input_tokens,
|
|
"output_tokens": self.output_tokens,
|
|
"total_tokens": self.total_tokens,
|
|
"metadata": dict(self.metadata),
|
|
"timestamp": self.timestamp.isoformat(),
|
|
"node_id": self.node_id,
|
|
"model_name": self.model_name,
|
|
"workflow_id": self.workflow_id,
|
|
"provider": self.provider # Include provider in output
|
|
}
|
|
|
|
|
|
class TokenTracker:
|
|
"""Singleton class to track token usage across a workflow."""
|
|
|
|
def __init__(self, workflow_id: str):
|
|
self.workflow_id = workflow_id
|
|
self.total_usage = TokenUsage()
|
|
self.node_usages = defaultdict(TokenUsage)
|
|
self.model_usages = defaultdict(TokenUsage)
|
|
self.call_history = []
|
|
self.node_call_counts = defaultdict(int) # Track how many times each node is called
|
|
|
|
def record_usage(self, node_id: str, model_name: str, usage: TokenUsage, provider: str = None):
|
|
"""Records token usage for a specific call, handling multiple node executions."""
|
|
# Update the usage with provider if it wasn't set already
|
|
if provider and not usage.provider:
|
|
usage.provider = provider
|
|
|
|
# Add to total usage
|
|
self.total_usage.input_tokens += usage.input_tokens
|
|
self.total_usage.output_tokens += usage.output_tokens
|
|
self.total_usage.total_tokens += usage.total_tokens
|
|
|
|
# Add to node-specific usage
|
|
node_usage = self.node_usages[node_id]
|
|
node_usage.input_tokens += usage.input_tokens
|
|
node_usage.output_tokens += usage.output_tokens
|
|
node_usage.total_tokens += usage.total_tokens
|
|
if provider:
|
|
node_usage.provider = provider # Store provider info
|
|
|
|
# Add to model-specific usage
|
|
model_usage = self.model_usages[model_name]
|
|
model_usage.input_tokens += usage.input_tokens
|
|
model_usage.output_tokens += usage.output_tokens
|
|
model_usage.total_tokens += usage.total_tokens
|
|
if provider:
|
|
model_usage.provider = provider # Store provider info
|
|
|
|
# Increment call count for this node
|
|
self.node_call_counts[node_id] += 1
|
|
|
|
# Add to call history
|
|
history_entry = {
|
|
"node_id": node_id,
|
|
"model_name": model_name,
|
|
"input_tokens": usage.input_tokens,
|
|
"output_tokens": usage.output_tokens,
|
|
"total_tokens": usage.total_tokens,
|
|
"metadata": dict(usage.metadata),
|
|
"timestamp": usage.timestamp.isoformat(),
|
|
"execution_number": self.node_call_counts[node_id] # Track which execution this is
|
|
}
|
|
|
|
# Add provider to history entry if available
|
|
if provider:
|
|
history_entry["provider"] = provider
|
|
|
|
self.call_history.append(history_entry)
|
|
|
|
def get_total_usage(self) -> TokenUsage:
|
|
"""Get total token usage for the workflow."""
|
|
return self.total_usage
|
|
|
|
def get_node_usage(self, node_id: str) -> TokenUsage:
|
|
"""Get token usage for a specific node (across all its executions)."""
|
|
return self.node_usages[node_id]
|
|
|
|
def get_model_usage(self, model_name: str) -> TokenUsage:
|
|
"""Get token usage for a specific model."""
|
|
return self.model_usages[model_name]
|
|
|
|
def get_node_execution_count(self, node_id: str) -> int:
|
|
"""Get how many times a node was executed."""
|
|
return self.node_call_counts[node_id]
|
|
|
|
def get_token_usage(self) -> Dict[str, Any]:
|
|
data = {
|
|
"workflow_id": self.workflow_id,
|
|
"total_usage": {
|
|
"input_tokens": self.total_usage.input_tokens,
|
|
"output_tokens": self.total_usage.output_tokens,
|
|
"total_tokens": self.total_usage.total_tokens,
|
|
},
|
|
"node_usages": {
|
|
node_id: {
|
|
"input_tokens": usage.input_tokens,
|
|
"output_tokens": usage.output_tokens,
|
|
"total_tokens": usage.total_tokens,
|
|
}
|
|
for node_id, usage in self.node_usages.items()
|
|
},
|
|
"model_usages": {
|
|
model_name: {
|
|
"input_tokens": usage.input_tokens,
|
|
"output_tokens": usage.output_tokens,
|
|
"total_tokens": usage.total_tokens,
|
|
}
|
|
for model_name, usage in self.model_usages.items()
|
|
},
|
|
"node_execution_counts": dict(self.node_call_counts),
|
|
"call_history": self.call_history,
|
|
}
|
|
return data
|
|
|
|
def export_to_file(self, filepath: str):
|
|
"""Export token usage data to a JSON file."""
|
|
import json
|
|
from pathlib import Path
|
|
|
|
# Create directory if it doesn't exist
|
|
path = Path(filepath)
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
data = self.get_token_usage()
|
|
|
|
with open(filepath, 'w', encoding='utf-8') as f:
|
|
json.dump(data, f, ensure_ascii=False, indent=2)
|