Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saurav maheshkar saurav/scale by grad norm #1000

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: add normalize_by_update_norm
  • Loading branch information
SauravMaheshkar committed Jun 25, 2024
commit 1528c68e1e60fe137ea8dc07e42ab970aae9ea32
2 changes: 2 additions & 0 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@
from optax._src.transform import centralize
from optax._src.transform import ema
from optax._src.transform import EmaState
from optax._src.transform import normalize_by_update_norm
from optax._src.transform import NormalizeByUpdateNormState
from optax._src.transform import scale
from optax._src.transform import scale_by_adadelta
from optax._src.transform import scale_by_adam
Expand Down
37 changes: 37 additions & 0 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1419,6 +1419,43 @@ def update_fn(
return base.GradientTransformationExtraArgs(_init_empty_state, update_fn)


class NormalizeByUpdateNormState(NamedTuple):
"""State for normalize_by_update_norm."""
scale_factor: float
eps: float


def normalize_by_update_norm(
scale_factor: float = 1.0, eps: float = 1e-6
) -> base.GradientTransformation:
"""
Scale by the inverse of the gradient norm.

Args:
scale_factor: (float) scaling factor
eps: (float) jitter term to avoid dividing by 0

Returns:
A `GradientTransformation` object.
"""

def init_fn(params):
del params
return NormalizeByUpdateNormState(scale_factor, eps)

def update_fn(
updates: base.Updates,
state: base.EmptyState,
params: Optional[base.Params] = None,
) -> tuple[base.Updates, base.EmptyState]:
del params
g_norm = (utils.global_norm(updates) + eps) / scale_factor
updates = jtu.tree_map(lambda g: g / g_norm, updates)
return updates, state

return base.GradientTransformation(init_fn, update_fn)


### Legacy symbols to be removed. ###


Expand Down
1 change: 1 addition & 0 deletions optax/_src/transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def setUp(self):
('param_block_norm', transform.scale_by_param_block_norm),
('param_block_rms', transform.scale_by_param_block_rms),
('distance_over_gradients', transform.scale_by_distance_over_gradients),
('normalize_by_update_norm', transform.normalize_by_update_norm),
])
def test_scalers(self, scaler_constr):
params = self.init_params
Expand Down
Loading