Charm_10 / reinforcement_learning.py
GeminiFan207's picture
Upload 12 files
18fa92b verified
raw
history blame contribute delete
7.28 kB
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import heapq
# Neural Network for Deep Q-Learning
class QNetwork(nn.Module):
def __init__(self, state_size, action_size, hidden_sizes=[256, 128, 64], dropout_rate=0.1):
super(QNetwork, self).__init__()
self.state_size = state_size
self.action_size = action_size
# Build a deeper network with configurable hidden layers
layers = []
prev_size = state_size
for hidden_size in hidden_sizes:
layers.append(nn.Linear(prev_size, hidden_size))
layers.append(nn.BatchNorm1d(hidden_size)) # Add batch normalization
layers.append(nn.ReLU())
layers.append(nn.Dropout(dropout_rate)) # Add dropout for regularization
prev_size = hidden_size
layers.append(nn.Linear(prev_size, action_size))
self.network = nn.Sequential(*layers)
def forward(self, x):
return self.network(x)
# Prioritized Experience Replay Memory (simplified)
class PriorityReplayMemory:
def __init__(self, capacity, alpha=0.6, beta=0.4, beta_increment=0.001):
self.capacity = capacity
self.alpha = alpha # Priority exponent
self.beta = beta # Importance sampling weight
self.beta_increment = beta_increment
self.memory = [] # Heap for priorities
self.experiences = deque(maxlen=capacity) # Store experiences
self.max_priority = 1.0
def add(self, experience, error=None):
priority = error if error is not None else self.max_priority
priority = (abs(priority) + 1e-5) ** self.alpha # Small constant to avoid zero priority
heapq.heappush(self.memory, (-priority, len(self.experiences))) # Negative for max heap
self.experiences.append(experience)
def sample(self, batch_size):
if len(self.experiences) < batch_size:
return None, None, None
# Calculate sampling probabilities
priorities = np.array([-p for p, _ in self.memory[:len(self.experiences)]])
probs = priorities / priorities.sum()
# Sample indices
indices = np.random.choice(len(self.experiences), batch_size, p=probs, replace=False)
samples = [self.experiences[idx] for idx in indices]
# Importance sampling weights
weights = (len(self.experiences) * probs[indices]) ** (-self.beta)
weights /= weights.max() # Normalize
self.beta = min(1.0, self.beta + self.beta_increment) # Anneal beta
return samples, indices, torch.FloatTensor(weights)
def update_priorities(self, indices, errors):
for idx, error in zip(indices, errors):
priority = (abs(error) + 1e-5) ** self.alpha
self.memory[idx] = (-priority, self.memory[idx][1])
self.max_priority = max(self.max_priority, priority)
heapq.heapify(self.memory) # Re-heapify after updates
def __len__(self):
return len(self.experiences)
# Enhanced Reinforcement Learning Agent
class RLAgent:
def __init__(self, state_size, action_size,
lr=0.0005, gamma=0.99, epsilon=1.0, epsilon_decay=0.995, min_epsilon=0.01,
memory_capacity=10000, batch_size=64, target_update_freq=1000,
use_double_dqn=True, clip_grad_norm=1.0):
self.state_size = state_size
self.action_size = action_size
self.gamma = gamma
self.epsilon = epsilon
self.epsilon_decay = epsilon_decay
self.min_epsilon = min_epsilon
self.batch_size = batch_size
self.use_double_dqn = use_double_dqn
self.clip_grad_norm = clip_grad_norm
# Networks
self.policy_net = QNetwork(state_size, action_size)
self.target_net = QNetwork(state_size, action_size)
self.target_net.load_state_dict(self.policy_net.state_dict()) # Copy weights
self.target_net.eval() # Target network doesn't train
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
self.criterion = nn.SmoothL1Loss() # Huber loss for stability
# Memory
self.memory = PriorityReplayMemory(memory_capacity)
self.steps = 0
self.target_update_freq = target_update_freq
def remember(self, state, action, reward, next_state, done):
# Initial error estimate (could be refined with TD error later)
state_tensor = torch.FloatTensor(state).unsqueeze(0)
next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0)
with torch.no_grad():
q_value = self.policy_net(state_tensor)[action]
next_q = self.target_net(next_state_tensor).max().item()
target = reward + (1 - done) * self.gamma * next_q
error = abs(q_value.item() - target)
self.memory.add((state, action, reward, next_state, done), error)
def act(self, state):
self.steps += 1
if random.random() < self.epsilon:
return random.randint(0, self.action_size - 1)
state = torch.FloatTensor(state).unsqueeze(0)
with torch.no_grad():
return torch.argmax(self.policy_net(state)).item()
def train(self):
if len(self.memory) < self.batch_size:
return
# Sample from memory
batch, indices, weights = self.memory.sample(self.batch_size)
if batch is None:
return
states, actions, rewards, next_states, dones = zip(*batch)
states = torch.FloatTensor(states)
actions = torch.LongTensor(actions)
rewards = torch.FloatTensor(rewards)
next_states = torch.FloatTensor(next_states)
dones = torch.FloatTensor(dones)
weights = weights.unsqueeze(1)
# Compute Q-values
q_values = self.policy_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)
# Double DQN or standard DQN
if self.use_double_dqn:
next_actions = self.policy_net(next_states).argmax(1)
next_q_values = self.target_net(next_states).gather(1, next_actions.unsqueeze(1)).squeeze(1)
else:
next_q_values = self.target_net(next_states).max(1)[0]
# Compute targets
targets = rewards + self.gamma * next_q_values * (1 - dones)
# Compute TD errors for priority update
td_errors = (q_values - targets).detach().cpu().numpy()
# Loss with importance sampling weights
loss = (self.criterion(q_values, targets) * weights).mean()
# Optimize
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), self.clip_grad_norm)
self.optimizer.step()
# Update priorities
self.memory.update_priorities(indices, td_errors)
# Update target network
if self.steps % self.target_update_freq == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
# Decay epsilon
self.epsilon = max(self.min_epsilon, self.epsilon * self.epsilon_decay)