Skip to content

Commit

Permalink
Fix a huge bug with recurrence: thank you Dima Bahdanau
Browse files Browse the repository at this point in the history
  • Loading branch information
lcswillems committed Aug 10, 2018
1 parent 76b2a6d commit 588f9a9
Showing 1 changed file with 20 additions and 9 deletions.
29 changes: 20 additions & 9 deletions torch_rl/torch_rl/algos/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,19 +189,30 @@ def collect_experiences(self):
delta = self.rewards[i] + self.discount * next_value * next_mask - self.values[i]
self.advantages[i] = delta + self.discount * self.gae_lambda * next_advantage * next_mask

# Defines experiences
# Define experiences:
# the whole experience is the concatenation of the experience
# of each process.
# In comments below:
# - T is self.num_frames_per_proc,
# - P is self.num_procs,
# - D is the dimensionality.

exps = DictList()
exps.obs = [obs for obss in self.obss for obs in obss]
exps.obs = [self.obss[i][j]
for j in range(self.num_procs)
for i in range(self.num_frames_per_proc)]
if self.acmodel.recurrent:
exps.memory = self.memories.view(-1, *self.memories.shape[2:])
exps.mask = self.masks.view(-1, *self.masks.shape[2:]).unsqueeze(1)
exps.action = self.actions.view(-1, *self.actions.shape[2:])
exps.value = self.values.view(-1, *self.values.shape[2:])
exps.reward = self.rewards.view(-1, *self.rewards.shape[2:])
exps.advantage = self.advantages.view(-1, *self.advantages.shape[2:])
# T x P x D -> P x T x D -> (P * T) x D
exps.memory = self.memories.transpose(0, 1).reshape(-1, *self.memories.shape[2:])
# T x P -> P x T -> (P * T) x 1
exps.mask = self.masks.transpose(0, 1).reshape(-1).unsqueeze(1)
# for all tensors below, T x P -> P x T -> P * T
exps.action = self.actions.transpose(0, 1).reshape(-1)
exps.value = self.values.transpose(0, 1).reshape(-1)
exps.reward = self.rewards.transpose(0, 1).reshape(-1)
exps.advantage = self.advantages.transpose(0, 1).reshape(-1)
exps.returnn = exps.value + exps.advantage
exps.log_prob = self.log_probs.view(-1, *self.log_probs.shape[2:])
exps.log_prob = self.log_probs.transpose(0, 1).reshape(-1)

# Preprocess experiences

Expand Down

0 comments on commit 588f9a9

Please sign in to comment.