Add the Adam optimizer from [Kingma et al., 2014](http://arxiv.org/abs/1412.6980). #264
+1,115
−1
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Add the Adam optimizer from Kingma et al., 2014.
Some specific design decisions were made that differ from Keras/Optax.
optax
andtensorflow
's Adam optimizer's setting. google-deepmind/optax#571),which differs from the original paper. We do correct for the bias,
consistent with optax/pytorch.
amsgrad: bool
as an option, which changes how the variable isupdated, keeping track of the maximum velocity encountered. However, this
would lead to an additional state parameter (
v_max
), and conditionallychanges the number of slot variables. Slot variables are particularly
expensive in large embedding lookups (each is the size of the entire
sharded table), and would require a different underlying primitive anyways.
If we need the option, we can create a new optimizer. This is consistent with optax,
which has a separate
optax.amsgrad
optimizer.nesterov: bool
option. Similar toamsgrad
, this modifiesthe update rule. Technically the Nesterov modification also adds a step-dependent
beta_1
parameter, and requires an additional state variable to keep trackof the accumulated product - something Optax currently ignores. Keras handles
this with a different optimizer,
keras.optimizer.Nadam
, which does addthe additional state variable. PyTorch also has a separate
torch.optim.NAdam
specifically for this.