Skip to content

Commit e957cfe

Browse files
committed
Fix: R2D2 Prior Error
1 parent 7ce69b1 commit e957cfe

File tree

2 files changed

+10
-16
lines changed

2 files changed

+10
-16
lines changed

POMDP/4-R2D2-Single/memory.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -80,27 +80,22 @@ def __init__(self, capacity):
8080
self.memory_probability = deque(maxlen=capacity)
8181

8282
def td_error_to_prior(self, td_error, lengths):
83-
abs_td_error_sum = td_error.sum(dim=1, keepdim=True).view(-1).abs().detach().numpy()
83+
abs_td_error_sum = td_error.abs().sum(dim=1, keepdim=True).view(-1).detach().numpy()
8484
lengths_burn = [length - burn_in_length for length in lengths]
85-
86-
prior = abs_td_error_sum / lengths_burn
87-
return prior
85+
86+
prior_max = td_error.abs().max(dim=1, keepdim=True)[0].view(-1).detach().numpy()
87+
88+
prior_mean = abs_td_error_sum / lengths_burn
89+
prior = eta * prior_max + (1 - eta) * prior_mean
90+
return prior
8891

8992
def push(self, td_error, batch, lengths):
9093
# batch.state[local_mini_batch, sequence_length, item]
9194
prior = self.td_error_to_prior(td_error, lengths)
92-
95+
9396
for i in range(len(batch)):
94-
if len(self.memory_probability) > 0:
95-
memory_probability = np.array(self.memory_probability)
96-
probability_max = max(memory_probability.max(), prior[i])
97-
probability_mean = (memory_probability.sum() + prior[i]) / (len(self.memory_probability) + 1)
98-
else:
99-
probability_max = prior[i]
100-
probability_mean = prior[i]
10197
self.memory.append([Transition(batch.state[i], batch.next_state[i], batch.action[i], batch.reward[i], batch.mask[i], batch.step[i], batch.rnn_state[i]), lengths[i]])
102-
p = eta * probability_max + (1 - eta) * probability_mean
103-
self.memory_probability.append(p)
98+
self.memory_probability.append(prior[i])
10499

105100
def sample(self, batch_size):
106101
probability = np.array(self.memory_probability)

POMDP/4-R2D2-Single/model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,8 @@ def slice_burn_in(item):
8080

8181
td_error = pred - target.detach()
8282

83-
td_error_slice = []
8483
for idx, length in enumerate(lengths):
85-
td_error_slice.append(td_error[idx][:length-burn_in_length][:])
84+
td_error[idx][length-burn_in_length:][:] = 0
8685

8786
return td_error
8887

0 commit comments

Comments
 (0)