ChatDev/puppeteer/inference/policy/REINFORCE_continuous.py
2025-09-23 20:31:19 +08:00

593 lines
28 KiB
Python

import torch
import atexit
import os
import datetime
import json
import numpy as np
import torch.nn as nn
import yaml
import logging
from utils.other_utils import Singleton
from inference.policy.base_policy import LLMPolicy, LearningPolicy
from model.embedding import RewardModelTokenRepresentation
global_config = yaml.safe_load(open("./config/global.yaml", "r"))
logger = logging.getLogger("train")
@Singleton
class MLP_PolicyNetwork(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.fc1 = torch.nn.Linear(input_dim, 512)
self.fc2 = torch.nn.Linear(512, 128)
self.fc3 = torch.nn.Linear(128, 32)
self.fc4 = torch.nn.Linear(32, output_dim)
self.relu = torch.nn.ReLU()
self.softmax = torch.nn.Softmax(dim=1)
self.input_dim = input_dim
self.output_dim = output_dim
def forward(self, x):
x = x.to(torch.float32)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
x = self.relu(x)
x = self.fc4(x)
x = self.softmax(x)
return x
@Singleton
class ContinuousREINFORCE(LearningPolicy):
def __init__(self, agent_graph, action_graph, config_path="config/policy.json"):
super().__init__(agent_graph, action_graph)
with open(config_path, 'r') as f:
self.config = json.load(f)
# Set parameters from config
self.device = self.config["device"]["type"]
# Training parameters
self.model_path = self.config["paths"]["model_path"]
self.training = self.config["training"]["training"]
self.loading = self.config["training"]["loading"]
self.learning_rate = self.config["training"]["learning_rate"]
self.gamma = self.config["training"]["gamma"]
self.sample_size = self.config["training"]["sample_size"]
self.lambda_kl_loss = self.config["training"]["lambda_kl_loss"]
# Agent parameters
self.max_num_agents = self.config["agent"]["max_num_agents"]
self.next_num_agents = self.config["agent"]["next_num_agents"]
self.max_path = self.config["agent"]["max_path"]
self.threshold = self.config["agent"]["threshold"]
# LLM parameters
self.llm_prior = self.config["llm"]["prior"]
self.llm_prior_redistribution = self.config["llm"]["prior_redistribution"]
self.redistribution_weight = self.config["llm"]["redistribution_weight"]
# Initialize state representation and policy network
self.state_representation = RewardModelTokenRepresentation()
self.policy_network = MLP_PolicyNetwork(self.state_representation.dim, self.actions_dim)
self.policy_network = self.policy_network.to(self.device)
if not self.training:
self.load_model(self.get_latest_model_path())
if self.loading:
self.load_model(self.model_path)
# Agent setup
self.agent_hash_list = agent_graph.hash_nodes
self.agent_role_list = agent_graph.role_nodes
# Initialize tracking variables
self.executed_trajectories = []
self.execution_count = 0
self.current_trajectories = []
self.current_trajectory_idx = 0
self.policy_losses = []
self.rewards_history = []
self.action_probs_history = []
self.llm_action_probs_history = []
self.reward_from_rm = []
self.accumulated_acc = []
self.entropy_history = []
# Setup actions and rewards
self.end_action = torch.tensor(self.agent_graph.terminator_agent_index, device=self.device)
self.web_actions = torch.tensor(self.agent_graph.search_agent_indices, device=self.device)
# Initialize reward factors from config
reward_factors = self.config["agent"]["reward_factors"]
self.agent_reward_factor = [reward_factors["default"]] * self.actions_dim
self.agent_reward_factor[self.end_action.item()] = reward_factors["terminator"]
for web_idx in self.web_actions:
self.agent_reward_factor[web_idx.item()] = reward_factors["web_search"]
self.current_task = None
self.previous_task = None
self.global_step = 0
self.prob_step=0
# Initialize optimizer
self.optimizer = torch.optim.Adam(self.policy_network.parameters(), lr=self.learning_rate)
self.max_step_num = global_config.get("graph").get("max_step_num")
self.llm_policy = LLMPolicy(self.agent_graph, self.action_graph)
atexit.register(self.save_model)
def logarithmic_cost(self, step):
"""Calculate logarithmic cost using config parameters"""
scale = self.config["cost"]["scale"]
growth_rate = self.config["cost"]["growth_rate"]
# Normalize step to [0,1] range
normalized_step = (step + 1) / (self.max_step_num + 1)
if self.config["cost"]["inverse"]:
step_cost = scale * (1 - torch.log(torch.tensor(1 + growth_rate * normalized_step, device=self.device))
/ torch.log(torch.tensor(1 + growth_rate, device=self.device)))
else:
step_cost = scale * (torch.log(torch.tensor(1 + growth_rate * normalized_step, device=self.device))
/ torch.log(torch.tensor(1 + growth_rate, device=self.device)))
print("\033[1;33mstep cost: {}\033[0m".format(step_cost))
return step_cost
def save_model(self, path=None, tag=None):
"""Save model with config"""
path = self.config["paths"]["checkpoint_path"]
os.makedirs(path, exist_ok=True)
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f'policy_net_{timestamp}' + (f'_{tag}' if tag else '') + '.pt'
save_path = os.path.join(path, filename)
checkpoint = {
'model_state_dict': self.policy_network.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict() if hasattr(self, 'optimizer') else None,
'input_dim': self.policy_network.input_dim,
'output_dim': self.policy_network.output_dim,
'timestamp': timestamp,
'config': self.config, # Save the config with the model
'metadata': {
'tag': tag,
'version': '1.0',
}
}
try:
torch.save(checkpoint, save_path)
print(f"Model saved successfully to {save_path}")
return save_path
except Exception as e:
print(f"Error saving model: {str(e)}")
return None
def update_executed_trajectories(self):
if self.current_task != self.previous_task:
self.previous_task = self.current_task
self.execution_count += 1
num_to_add = self.execution_count - len(self.executed_trajectories)
if num_to_add > 0:
self.executed_trajectories.extend([[] for _ in range(num_to_add)])
self.current_trajectories = self.executed_trajectories[self.execution_count-1]
def init_forward(self, global_info):
print("\033[1;33mInit Policy Forward\033[0m")
logger.info("[Init Policy Forward]")
self.current_task = global_info.task
self.update_executed_trajectories()
state, rew = self.get_state_representation(global_info)
action_probs = self.policy_network(state) #shape (1,action_dim)
self.action_probs_history.append(action_probs.T.squeeze(1))
self.reward_from_rm.append(rew)
logger.info("Action probs: {}".format(action_probs))
entropy = -(action_probs * torch.log(action_probs + 1e-10)).sum()
self.entropy_history.append(entropy)
m = torch.distributions.Categorical(action_probs)
agent_indices = self.select_agents_by_threshold(action_probs).T.squeeze(1)
self.current_trajectory_idx = 0
length = len(self.current_trajectories) + agent_indices.shape[0]
while len(self.current_trajectories) < length:
self.current_trajectories.append([])
for i, agent_idx in enumerate(agent_indices):
prob_value = action_probs[0, agent_idx.item()]
if i == 0:
self.append_to_trajectory(self.current_trajectory_idx, agent_idx, prob_value, global_info, None, m, rew)
else:
trajectory_idx = len(self.current_trajectories) - len(agent_indices) + i
self.append_to_trajectory(trajectory_idx, agent_idx, prob_value, global_info, None, m, rew)
return agent_indices
def iter_forward(self, global_info):
print("\033[1;33mFollowing Policy Forward\033[0m")
logger.info("Following Policy Forward")
self.current_task = global_info.task
if self.llm_prior:
prior_action_probs = self.llm_policy.forward_prior(global_info)
else:
prior_action_probs = None
print("Prior action probs: {}".format(prior_action_probs))
logger.info("Prior action probs: {}".format(prior_action_probs))
state, rew = self.get_state_representation(global_info)
action_probs = self.policy_network(state) #shape (1,action_dim)
if self.llm_prior_redistribution:
action_probs = (1 - self.redistribution_weight) * action_probs + \
self.redistribution_weight * prior_action_probs
action_probs /= action_probs.sum()
self.action_probs_history.append(action_probs.T.squeeze(1))
self.reward_from_rm.append(rew)
logger.info("Action probs: {}".format(action_probs))
entropy = -(action_probs * torch.log(action_probs + 1e-10)).sum()
self.entropy_history.append(entropy)
m = torch.distributions.Categorical(action_probs)
agent_indices = self.select_agents_by_threshold(action_probs).T.squeeze(1)
self.current_trajectory_idx = global_info.path_id
length = len(self.current_trajectories) + len(agent_indices) - 1
original_length = len(self.current_trajectories)
while len(self.current_trajectories) < length:
self.current_trajectories.append([])
for i, agent_idx in enumerate(agent_indices):
prob_value = action_probs[0, agent_idx.item()]
if i == 0:
self.append_to_trajectory(self.current_trajectory_idx, agent_idx, prob_value, global_info, prior_action_probs, m, rew)
else:
trajectory_idx = original_length + i - 1
self.current_trajectories[trajectory_idx] = self.clone_trajectory(self.current_trajectory_idx)
self.append_to_trajectory(trajectory_idx, agent_idx, prob_value, global_info, prior_action_probs, m, rew)
return agent_indices
def append_to_trajectory(self, trajectory_idx, agent_idx, prob_value, global_info, prior_action_probs, m, rew=0):
cost = self.logarithmic_cost(len(self.current_trajectories[trajectory_idx])) * self.agent_reward_factor[agent_idx.item()]
self.current_trajectories[trajectory_idx].append({
'prob': prob_value,
'log_prob': m.log_prob(agent_idx),
'state_identifier': global_info.workflow.state,
'action': self.agent_role_list[agent_idx.item()],
'reward': cost,
'reward_model': rew,
'prior_prob': prior_action_probs[agent_idx.item()] if prior_action_probs is not None else None
})
print(trajectory_idx, self.current_trajectories[trajectory_idx])
def clone_trajectory(self, source_idx):
return [{
'prob': t['prob'].clone(),
'log_prob': t['log_prob'].clone(),
'state_identifier': t['state_identifier'],
'action': t['action'],
'reward': t['reward'],
'reward_model': t['reward_model'],
'prior_prob': t['prior_prob'].clone() if t['prior_prob'] is not None else None
} for t in self.current_trajectories[source_idx][:-1]]
def forward(self, global_info):
if global_info.path_id == -1:
agent_indices = self.init_forward(global_info)
else:
agent_indices = self.iter_forward(global_info)
print("Agent Indices: {}".format(agent_indices))
selected_agents = [self.agent_hash_list[i] for i in agent_indices]
return selected_agents
def calculate_returns(self, trajectory):
returns = []
R = 0
for t in reversed(trajectory):
R = t.get('reward', 0) + self.gamma * R
returns.insert(0, R)
return torch.tensor(returns, device=self.device)
def get_state_representation(self, global_info):
role_list = global_info.agent_role_list()
print(role_list)
state_context = self.agent_graph.get_agent_dialog_history(role_list, question=global_info.task.get("Question"))
print(state_context)
print(type(state_context))
state, reward = self.state_representation(state_context)
print(state, reward)
return state, reward
def update(self):
logger.info("Update")
logger.info("Executed trajectories: {}".format(self.executed_trajectories))
if not self.training:
metrics = {
'reasoning/action_probs': torch.sum(torch.stack(self.action_probs_history), dim=0),
"training/entropy": np.mean([e.detach().cpu().item() for e in self.entropy_history])
}
logger.info("metrics: {}".format(metrics))
self.current_trajectories = []
self.executed_trajectories = []
self.entropy_history = []
self.execution_count = 0
return {}
if len(self.executed_trajectories) >= self.sample_size:
episode_returns = []
episode_lengths = []
episode_last_rewards = []
episode_acc = []
episode_tokens = []
episode_cost = []
episode_metrics = {}
kl_losses = []
logger.info("Update with sample size {}".format(self.sample_size))
policy_loss = []
episode_loss = []
for trajectories in self.executed_trajectories[:self.sample_size]:
task_avg_length = []
task_avg_reward = []
task_last_reward = []
task_acc = []
task_avg_tokens = []
task_avg_cost = []
task_avg_metrics = []
for trajectory in trajectories:
if trajectory[-1].get('finalized', False):
logger.info("Trajectory: {}".format(trajectory))
returns = self.calculate_returns(trajectory)
# episode_returns.append(sum(returns))
task_avg_reward.append(sum(returns))
task_avg_length.append(len(trajectory))
task_last_reward.append(trajectory[-1].get('reward', 0))
task_avg_tokens.append(trajectory[-1].get('total_tokens', 0))
task_avg_cost.append(trajectory[-1].get('total_cost', 0))
task_avg_metrics.append(trajectory[-1].get('metrics', {}))
if task_last_reward[-1] > 0:
task_acc.append(1)
else:
task_acc.append(0)
# task_acc.append(task_last_reward[-1].cpu().item())
# episode_lengths.append(len(trajectory))
print("returns: {}".format(returns))
logger.info("Trajectory returns: {}".format(returns))
for t, R in zip(trajectory, returns):
if t.get('prob', None) is not None and t.get('prior_prob', None) is not None:
kl_loss = t.get('prior_prob', 0) * torch.log(t['prior_prob'] / (t['prob']+1e-10))
logger.info("Add KL loss: {}".format(kl_loss))
else:
kl_loss = 0
logger.info("No KL loss: {}".format(kl_loss))
kl_loss = torch.tensor(kl_loss).to(self.device)
kl_losses.append(kl_loss)
loss = (-t['log_prob'] * R + self.lambda_kl_loss * kl_loss).to(self.device)
if loss.dim() == 0: # scalar loss, convert to shape [1]
loss = loss.view(1)
elif loss.dim() == 1: # already [1], keep it
pass
policy_loss.append(loss)
logger.info("loss for one sample: {}".format(policy_loss))
if len(task_avg_length) == 0:
continue
else:
episode_lengths.append(sum(task_avg_length)/len(task_avg_length))
if len(task_avg_reward) == 0:
continue
else:
episode_returns.append(sum(task_avg_reward)/len(task_avg_reward))
if len(task_last_reward) == 0:
continue
else:
episode_last_rewards.append(sum(task_last_reward)/len(task_last_reward))
if len(task_avg_tokens) == 0:
continue
else:
episode_tokens.append(sum(task_avg_tokens)/len(task_avg_tokens))
if len(task_avg_cost) == 0:
continue
else:
episode_cost.append(sum(task_avg_cost)/len(task_avg_cost))
if len(task_acc) == 0:
continue
else:
episode_acc.append(sum(task_acc)/len(task_acc))
if len(task_avg_metrics) == 0:
continue
elif task_avg_metrics[0] == {}:
continue
else:
for key in task_avg_metrics[0].keys():
if key not in episode_metrics:
episode_metrics[key] = []
episode_metrics[key].append(sum([m[key] for m in task_avg_metrics])/len(task_avg_metrics))
if policy_loss:
logger.info("Policy loss: {}".format(policy_loss))
policy_loss = torch.stack(policy_loss).sum()/(self.sample_size)
logger.info("Policy loss stack: {}".format(policy_loss))
policy_loss -= sum(self.entropy_history)
logger.info("Policy loss with entropy: {}".format(policy_loss))
self.optimizer.zero_grad()
policy_loss.backward()
self.optimizer.step()
metrics = {
'reasoning/action_probs': torch.sum(torch.stack(self.action_probs_history), dim=0),
'reasoning/reward_from_rm': sum(self.reward_from_rm),
'reasoning/acc': np.mean([a for a in episode_acc]),
'reasoning/tokens': np.mean([t for t in episode_tokens]),
'reasoning/cost': np.mean([c for c in episode_cost]),
'training/policy_loss': policy_loss.item(),
'reasoning/mean_return': np.mean([r.detach().cpu().item() for r in episode_returns]),
'reasoning/mean_episode_length': np.mean(episode_lengths),
'reasoning/mean_last_reward': np.mean([r.detach().cpu().item() for r in episode_last_rewards]),
'training/mean_kl_loss': np.mean([kl.detach().cpu().item() for kl in kl_losses]),
"training/entropy": np.mean([e.detach().cpu().item() for e in self.entropy_history]),
}
metrics.update({f'reasoning/{key}': np.mean([r.cpu().item() for r in episode_metrics[key]]) for key in episode_metrics})
logger.info("metrics: {}".format(metrics))
self.global_step += 1
self.policy_losses.append(policy_loss.item())
self.current_trajectories = []
self.executed_trajectories = []
self.entropy_history = []
self.execution_count = 0
self.reward_from_rm = []
self.action_probs_history = []
self.llm_action_probs_history = []
return {
'policy_loss': policy_loss.item(),
'mean_reward': torch.tensor(returns, device=self.device).mean().item()
}
return {}
def finalize_task(self, transition, global_info):
print("\033[1;33mtransition reward: {}\033[0m".format(transition.get('reward', 0)))
self.current_trajectories = self.executed_trajectories[self.execution_count-1]
idx = transition.get('path_id', 0)
if self.current_trajectories and idx < len(self.current_trajectories):
state, rew = self.get_state_representation(global_info)
action_probs = self.policy_network(state)
prob_value = action_probs[0, self.end_action.item()]
m = torch.distributions.Categorical(action_probs)
current_trajectory = self.current_trajectories[idx]
for index, action in enumerate(global_info.workflow.workflow):
cost = action.cost
print("\033[1;33mtoken cost: {}\033[0m".format(cost))
print("\033[1;33mcost factor: {}\033[0m".format(cost/100000))
current_trajectory[index]["reward"] *= cost/100000
print("\033[1;33mReward: {}\033[0m".format(current_trajectory[index]['reward']))
if current_trajectory:
step_reward = self.logarithmic_cost(len(current_trajectory))
total_tokens = global_info.total_tokens
total_cost = global_info.total_cost
if transition.get('reward', 0) > 0:
reward = transition.get('reward', 0) + self.agent_reward_factor[self.end_action.item()] * step_reward
else:
reward = transition.get('reward', 0) - self.agent_reward_factor[self.end_action.item()] * step_reward
if current_trajectory[-1].get("action") == self.agent_role_list[self.end_action.item()]:
current_trajectory[-1]["reward"] = reward
current_trajectory[-1]['total_tokens'] = total_tokens
current_trajectory[-1]['total_cost'] = total_cost
current_trajectory[-1]['finalized'] = True
current_trajectory[-1]['reward_model'] = rew
current_trajectory[-1]['metrics'] = transition.get('metrics', {})
print("\033[1;33mLast Reward: {}\033[0m".format(current_trajectory[-1]['reward']))
else:
current_trajectory.append({
'prob': prob_value,
'log_prob': m.log_prob(self.end_action),
'state_identifier': transition.get('state', global_info.workflow.state),
'action': self.agent_role_list[self.end_action.item()],
'reward': reward,
'reward_model': rew,
'finalized': True,
'total_tokens': total_tokens,
'total_cost': total_cost,
'metrics': transition.get('metrics', {})
})
print("\033[1;33mLast Reward: {}\033[0m".format(current_trajectory[-1]['reward']))
self.rewards_history.append(transition.get('reward', 0))
def select_agents_by_probability(self, action_probs):
num_agents_to_select = torch.randint(1, self.max_num_agents+1, (1,)).item()
selected_indices = torch.multinomial(action_probs, num_agents_to_select, replacement=False)
return selected_indices
def select_agents_by_threshold(self, action_probs, threshold=0.1):
threshold = 2/self.agent_graph.num
selected_indices = torch.nonzero(action_probs[0] > threshold).squeeze(1)
if len(selected_indices) == 0:
num_to_select = min(self.max_path, self.max_num_agents)
selected_indices = torch.multinomial(action_probs, num_to_select, replacement=False)
return selected_indices
else:
probs = action_probs[0][selected_indices]
sorted_idx = torch.argsort(probs, descending=True)
selected_indices = selected_indices[sorted_idx]
num_agents_to_select = min(len(selected_indices), self.max_path, self.max_num_agents)
selected_indices = selected_indices[:num_agents_to_select]
return selected_indices.unsqueeze(0)
def load_model(self, path, strict=True):
try:
if not os.path.exists(path):
logger.error(f"Model file not found: {path}")
return False
checkpoint = torch.load(path, map_location=self.device)
# Validate model architecture
if (checkpoint['input_dim'] != self.policy_network.input_dim or
checkpoint['output_dim'] != self.policy_network.output_dim):
if strict:
raise ValueError(f"Model architecture mismatch. Expected input_dim={self.policy_network.input_dim}, "
f"output_dim={self.policy_network.output_dim} but got input_dim={checkpoint['input_dim']}, "
f"output_dim={checkpoint['output_dim']}")
logger.warning("Model architecture mismatch, but continuing due to non-strict mode")
# Load model state
self.policy_network.load_state_dict(checkpoint['model_state_dict'], strict=strict)
self.policy_network = self.policy_network.to(self.device)
# Load optimizer state if available
if checkpoint['optimizer_state_dict'] and hasattr(self, 'optimizer'):
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# Move optimizer state to correct device
for state in self.optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(self.device)
# Load config if available
if 'config' in checkpoint:
# Merge loaded config with current config, prioritizing current config
self.config.update({k: v for k, v in checkpoint['config'].items()
if k not in self.config})
logger.info(f"Model loaded successfully from {path}")
logger.info(f"Model timestamp: {checkpoint['timestamp']}")
if checkpoint['metadata'].get('tag'):
logger.info(f"Model tag: {checkpoint['metadata']['tag']}")
return True
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
return False
def get_latest_model_path(self):
"""Get the path of the latest model checkpoint"""
try:
path = self.model_path
if os.path.exists(path) and os.path.isfile(path):
return path
path = self.config["paths"]["checkpoint_path"]
if not os.path.exists(path):
return None
model_files = [f for f in os.listdir(path) if f.endswith('.pt')]
if not model_files:
return None
latest_model = max(model_files, key=lambda x: os.path.getctime(os.path.join(path, x)))
return os.path.join(path, latest_model)
except Exception as e:
print(f"Error finding latest model: {str(e)}")
return None