Skip to content

optax and tensorflow's Adam optimizer's setting. #571

Closed
@vwxyzjn

Description

@vwxyzjn

Currently, optax.scale_by_adam should be equivalent to torch.optim.Adam. However, Tensorflow has a different implementation.

In short, if we change https://github.com/deepmind/optax/blob/cebdeff4a1922113a96c520e7a81b5bf79825b77/optax/_src/transform.py#L345-L348 to the following, then the adam optimizer would be the same as tensorflow's imlementation.

updates = jax.tree_util.tree_map(
    lambda m, v: (jnp.sqrt(1- b2**count_inc) / (1-b1**count_inc)) *  m / (jnp.sqrt(v + eps_root) + eps), mu, nu)

More context

Basically, PyTorch and optax's adam follow Algorithm 1 of the Kingma and Ba’s Adam paper (arxiv/1412.6980), but TensorFlow uses the formulation just before Section 2.1 of the paper and its epsilon referred to here is epsilon hat in the paper.

This was a relevant issue in my recent reproduction of openai's work in https://github.com/openai/lm-human-preferences. Long story short, below is an end-to-end experiment with torch's adam adam_pt and tensorlfow-style adam adam_tf. While the final performance (objective/scores) look the same, the learning curves are different in a non-trivial way. E.g., the torch adam version had a much higher clipfrac initially, causing a more initial significant update.

image

The "initial aggressive update" issue gets aggravated in larger models (e.g., gpt2-large). You can see that objective/kl had a spike with adam_tf, so this could be a reproducibility issue.

image image

Desired solution

include a

import jax
import jax.numpy as jnp
from optax import ScaleByAdamState, update_moment, update_moment_per_elem_norm
from optax._src.alias import _scale_by_learning_rate
from optax._src import base, utils, combine, numerics


def scale_by_adam_tf_style(
    b1: float = 0.9,
    b2: float = 0.999,
    eps: float = 1e-8,
    eps_root: float = 0.0,
    mu_dtype = None,
) -> base.GradientTransformation:
  """Rescale updates according to the Adam algorithm.
  References:
    [Kingma et al, 2014](https://arxiv.org/abs/1412.6980)
  WARNING: This is a TensorFlow-style Adam optimizer that uses the
    formulation just before Section 2.1 of the Kingma and Ba paper
    rather than the formulation in Algorithm 1, the "epsilon" referred 
    to here is "epsilon hat" in the paper.
  Args:
    b1: Decay rate for the exponentially weighted average of grads.
    b2: Decay rate for the exponentially weighted average of squared grads.
    eps: Term added to the denominator to improve numerical stability. (epsilon hat)
    eps_root: Term added to the denominator inside the square-root to improve
      numerical stability when backpropagating gradients through the rescaling.
    mu_dtype: Optional `dtype` to be used for the first order accumulator; if
      `None` then the `dtype` is inferred from `params` and `updates`.
  Returns:
    A `GradientTransformation` object.
  """

  mu_dtype = utils.canonicalize_dtype(mu_dtype)

  def init_fn(params):
    mu = jax.tree_util.tree_map(  # First moment
        lambda t: jnp.zeros_like(t, dtype=mu_dtype), params)
    nu = jax.tree_util.tree_map(jnp.zeros_like, params)  # Second moment
    return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)

  def update_fn(updates, state, params=None):
    del params
    mu = update_moment(updates, state.mu, b1, 1)
    nu = update_moment_per_elem_norm(updates, state.nu, b2, 2)
    count_inc = numerics.safe_int32_increment(state.count)

    ### `optax` default adam implementation
    # mu_hat = bias_correction(mu, b1, count_inc)
    # nu_hat = bias_correction(nu, b2, count_inc)
    # updates = jax.tree_util.tree_map(
    #     lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat)
    ### Tensorflow adam implementation
    updates = jax.tree_util.tree_map(
        lambda m, v: (jnp.sqrt(1- b2**count_inc) / (1-b1**count_inc)) *  m / (jnp.sqrt(v + eps_root) + eps), mu, nu) # 
    mu = utils.cast_tree(mu, mu_dtype)
    return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)

  return base.GradientTransformation(init_fn, update_fn)


def adam_tf_style(
    learning_rate,
    b1: float = 0.9,
    b2: float = 0.999,
    eps: float = 1e-8,
    eps_root: float = 0.0,
    mu_dtype = None,
):
  return combine.chain(
      scale_by_adam_tf_style(
          b1=b1, b2=b2, eps=eps, eps_root=eps_root, mu_dtype=mu_dtype),
      _scale_by_learning_rate(learning_rate),
  )

obviously this is bad naming, but I figure you'd have much better ideas :)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions