Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change sampling method from randint to choice in Replay and robustify policy networks in SAC #111

Merged
merged 5 commits into from
Aug 10, 2020

Conversation

ksluck
Copy link
Contributor

@ksluck ksluck commented Jul 3, 2020

This pull request proposes two changes:

Replacement of np.random.randint with np.random.choice in SimpleReplayBuffer

Using randint allows for the possibility of duplicated transitions in each batch which in effect means that the gradients/errors in these transitions have a higher influence on the updates. Using choice is here a better option as it prevents this behavior when replace=False. In order to be compatible with the current implementation replace is only set to False when size > batch_size. Otherwise we allow for duplicates (because the user should have checked for size>batch_size I assume).
This simplified demo code highlights the issue:

>>> size = 100
>>> batch_size = 64
>>> indices = np.random.randint(0, size, batch_size)
>>> indices

array([82, 49, 37, 39, 19, 86, 86, 12, 44, 68, 86, 30, 59, 82, 20, 66, 12,
53, 99, 95, 56, 69, 96, 89, 2, 7, 93, 38, 54, 48, 16, 71, 58, 7,
29, 34, 18, 54, 4, 62, 14, 95, 75, 59, 69, 98, 54, 57, 8, 8, 54,
14, 76, 66, 77, 37, 78, 30, 71, 43, 99, 70, 51, 20])

>>> indices_2 = np.random.choice(size, size=batch_size, replace=size<batch_size)
>>> indices_2
array([36,  6, 87, 79, 93,  0, 62, 98, 95, 71, 18, 73, 92, 37, 55, 80, 19,
       43, 49, 74, 56, 39,  1, 45, 29,  5, 32, 78, 28,  9,  2, 41, 26, 64,
       44, 38,  3, 33, 85,  8, 60, 51, 22, 16, 89, 63, 52, 83, 75, 81, 17,
       82, 15, 88, 53,  7,  4, 77, 40, 25, 30, 84, 13, 50])

This change might have impact on all/other algorithms, so another round of tests might be a good idea here ;)

Small fix in the policy network of SAC

Just a small fix which should not have impact on any other method:
Currently, the logprob is summed with .sum(dim=1) which assumes that we have only one batch dimension. This might cause issues if we have two or more batch dimensions, thus we should change that to .sum(dim=-1) to indicate that we want to sum over the data-dimension.
This change makes the class more flexible and should not have any impact on current users.

@vitchyr
Copy link
Collaborator

vitchyr commented Jul 6, 2020

Thanks for this! One reason I used random.randint is that it's quite a bit faster than choice, presumably because it doesn't try to prevent duplicates. Have you found that it actually makes a large difference? I imagine it wouldn't matter in many use cases, but it'd be nice to have this option. Instead of replacing the behavior, could you add a flag that allows people to choose between the two options (with the default being to use random.randint)? That way, it won't surprisingly change the training time for users, but gives people the option the sample without replacement.

Also, thanks for the policy network change!

@ksluck
Copy link
Contributor Author

ksluck commented Jul 8, 2020

I would assume in most cases where data generation is much faster than the training process (like in simulations or the Atari Games) and we have a large number of steps per episode this should indeed not be much of an issue. The probability to select multiple times the same data for the batch will decrease quickly.

The impact will be larger when either the episodes have low numbers of steps or the training process outpaces the collection of training data if the simulations are complex or we collect data in the real world, and we have to / can train longer on the first few hundreds of steps collected. Well, the joy of doing robotics 😄

Good point about the processing time - actually, if I remember it right choice uses randint when the replace flag is set to True so we could introduce a parameter to the class named replace which has the standard value of True.
In that way we can still only use choice and get the expected behavior you describe. I am wondering what the best handling of the case of batch_size > size should be? Silently using replace=True when the flag for the class is set to False seems a bit unsatisfying, maybe throwing a Warning might be a good idea?

size = 10000000
batch_size=512
a = time.time(); indices = np.random.randint(0, size, batch_size); b=time.time()
print(b-a)
0.00011873245239257812

a = time.time(); indices_2 = np.random.choice(size, size=batch_size, replace=False); b=time.time()
print(b-a)
0.4596829414367676
a = time.time(); indices_2 = np.random.choice(size, size=batch_size, replace=True); b=time.time()
print(b-a)
0.00021719932556152344

@ksluck
Copy link
Contributor Author

ksluck commented Aug 9, 2020

Added the proposed changes + warning if desired behaviour is not met

Copy link
Collaborator

@vitchyr vitchyr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your solution is a good one! Thanks for the PR.

@vitchyr vitchyr merged commit 55ace41 into rail-berkeley:master Aug 10, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants