Change sampling method from randint to choice in Replay and robustify policy networks in SAC #111
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This pull request proposes two changes:
Replacement of np.random.randint with np.random.choice in SimpleReplayBuffer
Using randint allows for the possibility of duplicated transitions in each batch which in effect means that the gradients/errors in these transitions have a higher influence on the updates. Using choice is here a better option as it prevents this behavior when replace=False. In order to be compatible with the current implementation replace is only set to False when size > batch_size. Otherwise we allow for duplicates (because the user should have checked for size>batch_size I assume).
This simplified demo code highlights the issue:
array([82, 49, 37, 39, 19, 86, 86, 12, 44, 68, 86, 30, 59, 82, 20, 66, 12,
53, 99, 95, 56, 69, 96, 89, 2, 7, 93, 38, 54, 48, 16, 71, 58, 7,
29, 34, 18, 54, 4, 62, 14, 95, 75, 59, 69, 98, 54, 57, 8, 8, 54,
14, 76, 66, 77, 37, 78, 30, 71, 43, 99, 70, 51, 20])
This change might have impact on all/other algorithms, so another round of tests might be a good idea here ;)
Small fix in the policy network of SAC
Just a small fix which should not have impact on any other method:
Currently, the logprob is summed with .sum(dim=1) which assumes that we have only one batch dimension. This might cause issues if we have two or more batch dimensions, thus we should change that to .sum(dim=-1) to indicate that we want to sum over the data-dimension.
This change makes the class more flexible and should not have any impact on current users.