|
| 1 | +import random |
| 2 | +from collections import namedtuple, deque |
| 3 | +from config import sequence_length, burn_in_length, eta |
| 4 | +import torch |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +Transition = namedtuple('Transition', ('state', 'next_state', 'action', 'reward', 'mask', 'rnn_state')) |
| 8 | + |
| 9 | + |
| 10 | +class LocalBuffer(object): |
| 11 | + def __init__(self): |
| 12 | + self.local_memory = [] |
| 13 | + self.memory = [] |
| 14 | + |
| 15 | + def push(self, state, next_state, action, reward, mask, rnn_state): |
| 16 | + self.local_memory.append(Transition(state, next_state, action, reward, mask, torch.stack(rnn_state).view(2, -1))) |
| 17 | + if len(self.local_memory) == sequence_length or mask == 0: |
| 18 | + if len(self.local_memory) > burn_in_length: |
| 19 | + length = len(self.local_memory) |
| 20 | + while len(self.local_memory) < sequence_length: |
| 21 | + self.local_memory.append(Transition( |
| 22 | + torch.Tensor([0, 0]), |
| 23 | + torch.Tensor([0, 0]), |
| 24 | + 0, |
| 25 | + 0, |
| 26 | + 0, |
| 27 | + torch.zeros([2, 1, 16]).view(2, -1) |
| 28 | + )) |
| 29 | + self.memory.append([self.local_memory, length]) |
| 30 | + self.local_memory = [] |
| 31 | + |
| 32 | + def sample(self): |
| 33 | + episodes = self.memory |
| 34 | + batch_state, batch_next_state, batch_action, batch_reward, batch_mask, batch_rnn_state = [], [], [], [], [], [] |
| 35 | + lengths = [] |
| 36 | + for episode, length in episodes: |
| 37 | + batch = Transition(*zip(*episode)) |
| 38 | + |
| 39 | + batch_state.append(torch.stack(list(batch.state))) |
| 40 | + batch_next_state.append(torch.stack(list(batch.next_state))) |
| 41 | + batch_action.append(torch.Tensor(list(batch.action))) |
| 42 | + batch_reward.append(torch.Tensor(list(batch.reward))) |
| 43 | + batch_mask.append(torch.Tensor(list(batch.mask))) |
| 44 | + batch_rnn_state.append(torch.stack(list(batch.rnn_state))) |
| 45 | + |
| 46 | + lengths.append(length) |
| 47 | + self.memory = [] |
| 48 | + return Transition(batch_state, batch_next_state, batch_action, batch_reward, batch_mask, batch_rnn_state), lengths |
| 49 | + |
| 50 | + |
| 51 | + |
| 52 | +class Memory(object): |
| 53 | + def __init__(self, capacity): |
| 54 | + self.capacity = capacity |
| 55 | + self.memory = deque(maxlen=capacity) |
| 56 | + self.memory_probability = deque(maxlen=capacity) |
| 57 | + |
| 58 | + def td_error_to_prior(self, td_error, lengths): |
| 59 | + abs_td_error_sum = td_error.sum(dim=1, keepdim=True).view(-1).abs().detach().numpy() |
| 60 | + lengths_burn = [length - burn_in_length for length in lengths] |
| 61 | + |
| 62 | + prior = abs_td_error_sum / lengths_burn |
| 63 | + return prior |
| 64 | + |
| 65 | + def push(self, td_error, batch, lengths): |
| 66 | + # batch.state[local_mini_batch, sequence_length, item] |
| 67 | + prior = self.td_error_to_prior(td_error, lengths) |
| 68 | + |
| 69 | + for i in range(len(batch)): |
| 70 | + if len(self.memory_probability) > 0: |
| 71 | + memory_probability = np.array(self.memory_probability) |
| 72 | + probability_max = max(memory_probability.max(), prior[i]) |
| 73 | + probability_mean = (memory_probability.sum() + prior[i]) / (len(self.memory_probability) + 1) |
| 74 | + else: |
| 75 | + probability_max = prior[i] |
| 76 | + probability_mean = prior[i] |
| 77 | + self.memory.append([Transition(batch.state[i], batch.next_state[i], batch.action[i], batch.reward[i], batch.mask[i], batch.rnn_state[i]), lengths[i]]) |
| 78 | + p = eta * probability_max + (1 - eta) * probability_mean |
| 79 | + self.memory_probability.append(p) |
| 80 | + |
| 81 | + def sample(self, batch_size): |
| 82 | + probability = np.array(self.memory_probability) |
| 83 | + probability = probability / probability.sum() |
| 84 | + |
| 85 | + indexes = np.random.choice(range(len(self.memory_probability)), batch_size, p=probability) |
| 86 | + # indexes = np.random.choice(range(len(self.memory_probability)), batch_size) |
| 87 | + episodes = [self.memory[idx][0] for idx in indexes] |
| 88 | + lengths = [self.memory[idx][1] for idx in indexes] |
| 89 | + |
| 90 | + batch_state, batch_next_state, batch_action, batch_reward, batch_mask, batch_rnn_state = [], [], [], [], [], [] |
| 91 | + for episode in episodes: |
| 92 | + batch_state.append(episode.state) |
| 93 | + batch_next_state.append(episode.next_state) |
| 94 | + batch_action.append(episode.action) |
| 95 | + batch_reward.append(episode.reward) |
| 96 | + batch_mask.append(episode.mask) |
| 97 | + batch_rnn_state.append(episode.rnn_state) |
| 98 | + |
| 99 | + return Transition(batch_state, batch_next_state, batch_action, batch_reward, batch_mask, batch_rnn_state), indexes, lengths |
| 100 | + |
| 101 | + def update_prior(self, indexes, td_error, lengths): |
| 102 | + prior = self.td_error_to_prior(td_error, lengths) |
| 103 | + priors_idx = 0 |
| 104 | + for idx in indexes: |
| 105 | + self.memory_probability[idx] = prior[priors_idx] |
| 106 | + priors_idx += 1 |
| 107 | + |
| 108 | + def __len__(self): |
| 109 | + return len(self.memory) |
0 commit comments