Description
Currently, optax.scale_by_adam
should be equivalent to torch.optim.Adam
. However, Tensorflow has a different implementation.
In short, if we change https://github.com/deepmind/optax/blob/cebdeff4a1922113a96c520e7a81b5bf79825b77/optax/_src/transform.py#L345-L348 to the following, then the adam optimizer would be the same as tensorflow's imlementation.
updates = jax.tree_util.tree_map(
lambda m, v: (jnp.sqrt(1- b2**count_inc) / (1-b1**count_inc)) * m / (jnp.sqrt(v + eps_root) + eps), mu, nu)
More context
Basically, PyTorch and optax's adam follow Algorithm 1 of the Kingma and Ba’s Adam paper (arxiv/1412.6980), but TensorFlow uses the formulation just before Section 2.1 of the paper and its epsilon referred to here is epsilon hat in the paper.
This was a relevant issue in my recent reproduction of openai's work in https://github.com/openai/lm-human-preferences. Long story short, below is an end-to-end experiment with torch's adam adam_pt
and tensorlfow-style adam adam_tf
. While the final performance (objective/scores
) look the same, the learning curves are different in a non-trivial way. E.g., the torch adam version had a much higher clipfrac
initially, causing a more initial significant update.

The "initial aggressive update" issue gets aggravated in larger models (e.g., gpt2-large). You can see that objective/kl
had a spike with adam_tf
, so this could be a reproducibility issue.


Desired solution
include a
import jax
import jax.numpy as jnp
from optax import ScaleByAdamState, update_moment, update_moment_per_elem_norm
from optax._src.alias import _scale_by_learning_rate
from optax._src import base, utils, combine, numerics
def scale_by_adam_tf_style(
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
mu_dtype = None,
) -> base.GradientTransformation:
"""Rescale updates according to the Adam algorithm.
References:
[Kingma et al, 2014](https://arxiv.org/abs/1412.6980)
WARNING: This is a TensorFlow-style Adam optimizer that uses the
formulation just before Section 2.1 of the Kingma and Ba paper
rather than the formulation in Algorithm 1, the "epsilon" referred
to here is "epsilon hat" in the paper.
Args:
b1: Decay rate for the exponentially weighted average of grads.
b2: Decay rate for the exponentially weighted average of squared grads.
eps: Term added to the denominator to improve numerical stability. (epsilon hat)
eps_root: Term added to the denominator inside the square-root to improve
numerical stability when backpropagating gradients through the rescaling.
mu_dtype: Optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.
Returns:
A `GradientTransformation` object.
"""
mu_dtype = utils.canonicalize_dtype(mu_dtype)
def init_fn(params):
mu = jax.tree_util.tree_map( # First moment
lambda t: jnp.zeros_like(t, dtype=mu_dtype), params)
nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment
return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)
def update_fn(updates, state, params=None):
del params
mu = update_moment(updates, state.mu, b1, 1)
nu = update_moment_per_elem_norm(updates, state.nu, b2, 2)
count_inc = numerics.safe_int32_increment(state.count)
### `optax` default adam implementation
# mu_hat = bias_correction(mu, b1, count_inc)
# nu_hat = bias_correction(nu, b2, count_inc)
# updates = jax.tree_util.tree_map(
# lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat)
### Tensorflow adam implementation
updates = jax.tree_util.tree_map(
lambda m, v: (jnp.sqrt(1- b2**count_inc) / (1-b1**count_inc)) * m / (jnp.sqrt(v + eps_root) + eps), mu, nu) #
mu = utils.cast_tree(mu, mu_dtype)
return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)
return base.GradientTransformation(init_fn, update_fn)
def adam_tf_style(
learning_rate,
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
mu_dtype = None,
):
return combine.chain(
scale_by_adam_tf_style(
b1=b1, b2=b2, eps=eps, eps_root=eps_root, mu_dtype=mu_dtype),
_scale_by_learning_rate(learning_rate),
)
obviously this is bad naming, but I figure you'd have much better ideas :)