Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saurav maheshkar saurav/scale by grad norm #1000

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/api/transformations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ Transformations
identity
keep_params_nonnegative
NonNegativeParamsState
normalize_by_update_norm
NormalizeByUpdateNormState
OptState
Params
per_example_global_norm_clip
Expand Down Expand Up @@ -172,6 +174,10 @@ Transformations and states
.. autoclass:: NonNegativeParamsState
:members:

.. autofunction:: normalize_by_update_norm
.. autoclass:: NormalizeByUpdateNormState
:members:

.. autofunction:: per_example_global_norm_clip
.. autofunction:: per_example_layer_norm_clip

Expand Down
2 changes: 2 additions & 0 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@
from optax._src.transform import centralize
from optax._src.transform import ema
from optax._src.transform import EmaState
from optax._src.transform import normalize_by_update_norm
from optax._src.transform import NormalizeByUpdateNormState
from optax._src.transform import scale
from optax._src.transform import scale_by_adadelta
from optax._src.transform import scale_by_adam
Expand Down
58 changes: 58 additions & 0 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1419,6 +1419,64 @@ def update_fn(
return base.GradientTransformationExtraArgs(_init_empty_state, update_fn)


class NormalizeByUpdateNormState(NamedTuple):
"""State for normalize_by_update_norm."""
scale_factor: float
eps: float


def normalize_by_update_norm(
scale_factor: float = 1.0, eps: float = 1e-6
) -> base.GradientTransformation:
"""
Scale by the inverse of the update norm.

Examples:
>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function
>>> solver = optax.normalize_by_update_norm(scale_factor=1.0)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
... grad = jax.grad(f)(params)
... updates, opt_state = solver.update(grad, opt_state, params)
... params = optax.apply_updates(params, updates)
... print('Objective function: {:.2E}'.format(f(params)))
Objective function: 2.25E+01
Objective function: 3.30E+01
Objective function: 4.54E+01
Objective function: 5.99E+01
Objective function: 7.64E+01

Args:
scale_factor: factor by which the update will be multiplied (defaults to 1).
eps: (float) jitter term to avoid dividing by 0

Returns:
A `GradientTransformation` object.
"""

def init_fn(params):
del params
return NormalizeByUpdateNormState(scale_factor, eps)

def update_fn(
updates: base.Updates,
state: base.EmptyState,
params: Optional[base.Params] = None,
) -> tuple[base.Updates, base.EmptyState]:
del params
g_norm = (otu.tree_l2_norm(updates) + eps) / scale_factor
updates = jtu.tree_map(lambda g: g / g_norm, updates)
return updates, state

return base.GradientTransformation(init_fn, update_fn)


### Legacy symbols to be removed. ###


Expand Down
1 change: 1 addition & 0 deletions optax/_src/transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def setUp(self):
('param_block_norm', transform.scale_by_param_block_norm),
('param_block_rms', transform.scale_by_param_block_rms),
('distance_over_gradients', transform.scale_by_distance_over_gradients),
('normalize_by_update_norm', transform.normalize_by_update_norm),
])
def test_scalers(self, scaler_constr):
params = self.init_params
Expand Down
Loading