Skip to content

Commit

Permalink
move clipping transforms to optax.transforms.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623473687
  • Loading branch information
mtthss authored and OptaxDev committed Apr 30, 2024
1 parent 8a3ee74 commit 820e372
Show file tree
Hide file tree
Showing 13 changed files with 1,129 additions and 858 deletions.
27 changes: 10 additions & 17 deletions optax/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,12 @@ class EmptyState(NamedTuple):
"""An empty state for the simplest stateless transformations."""


def init_empty_state(params: Params) -> EmptyState:
"""Init function for a :class:`GradientTransformation` with empty state."""
del params
return EmptyState()


def identity() -> GradientTransformation:
"""Stateless identity transformation that leaves input gradients untouched.
Expand All @@ -225,14 +231,11 @@ def identity() -> GradientTransformation:
A `GradientTransformation` object.
"""

def init_fn(_):
return EmptyState()

def update_fn(updates, state, params=None):
del params
return updates, state

return GradientTransformation(init_fn, update_fn)
return GradientTransformation(init_empty_state, update_fn)


def set_to_zero() -> GradientTransformation:
Expand All @@ -255,15 +258,11 @@ def set_to_zero() -> GradientTransformation:
A `GradientTransformation` object.
"""

def init_fn(params):
del params
return EmptyState()

def update_fn(updates, state, params=None):
del params # Unused by the zero transform.
return jax.tree_util.tree_map(jnp.zeros_like, updates), state

return GradientTransformation(init_fn, update_fn)
return GradientTransformation(init_empty_state, update_fn)


def stateless(
Expand All @@ -282,14 +281,11 @@ def stateless(
An `optax.GradientTransformation`.
"""

def init_fn(_):
return EmptyState()

def update_fn(updates, state, params=None):
del state
return f(updates, params), EmptyState()

return GradientTransformation(init_fn, update_fn)
return GradientTransformation(init_empty_state, update_fn)


def stateless_with_tree_map(
Expand All @@ -310,9 +306,6 @@ def stateless_with_tree_map(
An `optax.GradientTransformation`.
"""

def init_fn(_):
return EmptyState()

def update_fn(updates, state, params=None):
del state
if params is not None:
Expand All @@ -321,7 +314,7 @@ def update_fn(updates, state, params=None):
f_ = lambda u: f(u, None)
return jax.tree_util.tree_map(f_, updates), EmptyState()

return GradientTransformation(init_fn, update_fn)
return GradientTransformation(init_empty_state, update_fn)


def with_extra_args_support(
Expand Down
300 changes: 12 additions & 288 deletions optax/_src/clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,296 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gradient clipping transformations.
Note that complex numbers are also supported, see
https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29
"""

import chex
import jax
import jax.numpy as jnp
"""Gradient clipping transformations."""

from optax._src import base
from optax._src import linear_algebra
from optax._src import numerics

ClipState = base.EmptyState


def clip(max_delta: chex.Numeric) -> base.GradientTransformation:
"""Clips updates element-wise, to be in ``[-max_delta, +max_delta]``.
Args:
max_delta: The maximum absolute value for each element in the update.
Returns:
A `GradientTransformation` object.
"""

def init_fn(params):
del params
return ClipState()

def update_fn(updates, state, params=None):
del params
updates = jax.tree_util.tree_map(
lambda g: jnp.clip(g, -max_delta, max_delta), updates)
return updates, state

return base.GradientTransformation(init_fn, update_fn)


def clip_by_block_rms(threshold: float) -> base.GradientTransformation:
"""Clips updates to a max rms for the gradient of each param vector or matrix.
A `block` is here a weight vector (e.g. in a Linear layer) or a weight matrix
(e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree.
Args:
threshold: The maximum rms for the gradient of each param vector or matrix.
Returns:
A `GradientTransformation` object.
"""

def init_fn(params):
del params
return base.EmptyState()

def update_fn(updates, state, params=None):
del params

def _clip_fn(u):
clip_denom = jnp.maximum(
1.0,
jnp.sqrt(jnp.mean(numerics.abs_sq(u))) / threshold)
return u / clip_denom

updates = jax.tree_util.tree_map(_clip_fn, updates)
return updates, state

return base.GradientTransformation(init_fn, update_fn)


ClipByGlobalNormState = base.EmptyState


def clip_by_global_norm(max_norm: float) -> base.GradientTransformation:
"""Clips updates using their global norm.
References:
[Pascanu et al, 2012](https://arxiv.org/abs/1211.5063)
Args:
max_norm: The maximum global norm for an update.
Returns:
A `GradientTransformation` object.
"""

def init_fn(params):
del params
return ClipByGlobalNormState()

def update_fn(updates, state, params=None):
del params
g_norm = linear_algebra.global_norm(updates)
# TODO(b/163995078): revert back to the following (faster) implementation
# once analysed how it affects backprop through update (e.g. meta-gradients)
# g_norm = jnp.maximum(max_norm, g_norm)
# updates = jax.tree_util.tree_map(
# lambda t: (t / g_norm) * max_norm, updates)
trigger = jnp.squeeze(g_norm < max_norm)
chex.assert_shape(trigger, ()) # A scalar.

def clip_fn(t):
return jax.lax.select(trigger, t, (t / g_norm.astype(t.dtype)) * max_norm)

updates = jax.tree_util.tree_map(clip_fn, updates)
return updates, state

return base.GradientTransformation(init_fn, update_fn)


def per_example_global_norm_clip(
grads: list[chex.Array], l2_norm_clip: float
) -> tuple[list[chex.Array], jax.Array]:
"""Applies gradient clipping per-example using their global norm.
References:
[Abadi et al, 2016](https://arxiv.org/abs/1607.00133)
Args:
grads: flattened update; the function expects these to have a batch
dimension on the 0th axis.
l2_norm_clip: maximum L2 norm of the per-example gradients.
Returns:
A tuple containing sum of the clipped per-example grads, and the number of
per-example grads that were clipped.
"""
bsize = grads[0].shape[0]

if any(g.ndim == 0 or bsize != g.shape[0] for g in grads):
raise ValueError(
'Unlike other transforms, `per_example_global_norm_clip` expects'
' `grads` to have a batch dimension in the 0th axis.')

global_grad_norms = jax.vmap(linear_algebra.global_norm)(grads)
divisors = jnp.maximum(global_grad_norms / l2_norm_clip, 1.0)
num_clipped = jnp.greater(divisors, 1.0).sum()
clipped_sum = [(jnp.moveaxis(g, 0, -1) / divisors).sum(-1) for g in grads]
return clipped_sum, num_clipped


def per_example_layer_norm_clip(
grads: list[chex.Array],
global_l2_norm_clip: float,
uniform: bool = True,
eps: float = 1e-8,
) -> tuple[list[chex.Array], list[chex.Array]]:
"""Applies gradient clipping per-example using per-layer norms.
References:
[McMahan et al, 2012](https://arxiv.org/abs/1710.06963)]
Args:
grads: flattened update; i.e. a list of gradients in which each item is
the gradient for one layer; the function expects these to have a batch
dimension on the 0th axis.
global_l2_norm_clip: overall L2 clip norm to use.
uniform: If `True`, per-layer clip norm is global_l2_norm_clip/sqrt(L),
where L is the number of layers. Otherwise, per-layer clip norm is
global_l2_norm_clip * sqrt(f), where f is the fraction of total model
parameters that are in this layer.
eps: Small positive value to add to norms to avoid possible division by
zero.
Let C = `global_l2_norm_clip value`. Then per-layer clipping is done as
follows:
(1) If `uniform` is `True`, each of the K layers has an individual clip
norm of C / sqrt(K).
(2) If `uniform` is `False`, each of the K layers has an individual clip
norm of C * sqrt(D_i / D) where D_i is the number of parameters in
layer i, and D is the total number of parameters in the model.
Returns:
A tuple containing sum of the clipped per-example grads and the number of
per-example grads that were clipped for each layer.
"""
bsize = grads[0].shape[0]

if any(g.ndim == 0 or bsize != g.shape[0] for g in grads):
raise ValueError(
'Unlike other transforms, `per_example_layer_norm_clip` expects'
' `grads` to have a batch dimension in the 0th axis; got shapes:'
f' {(g.shape for g in grads)}.'
)

num_layers = len(grads)

# Compute per-layer clip norms, based on whether we are using uniform
# variant or not.
if uniform:
# Create list of length `num_layers` of per-layer clip norm.
layer_clip_norms = (
global_l2_norm_clip * (1.0 / num_layers) ** 0.5,
) * num_layers
else:
total_params = sum(g[0].size for g in grads)
layer_clip_norms = tuple(
global_l2_norm_clip * (g[0].size / float(total_params)) ** 0.5
for g in grads
)

# Compute per-layer grad norms.
def map_layer_norm(grads_list):
return [jnp.linalg.norm(g, ord=None, axis=None) for g in grads_list]

layer_grad_norms_per_example = jax.vmap(map_layer_norm)(grads)

# Perform clipping.
divisors = (
tuple(
jnp.maximum(
layer_grad_norm / (layer_clip_norm + eps), 1.0
)
for layer_grad_norm, layer_clip_norm in zip(
layer_grad_norms_per_example, layer_clip_norms
)
)
)
num_clipped = [jnp.greater(divisor, 1.0).sum() for divisor in divisors]
clipped_sum = [
(g / jnp.expand_dims(d, axis=[i for i in range(1, g.ndim)])).sum(0)
for g, d in zip(grads, divisors)
]
return clipped_sum, num_clipped


def unitwise_norm(x: chex.Array) -> chex.Array:
"""Computes norms of each output unit separately."""
if jnp.squeeze(x).ndim <= 1: # Scalars and vectors
squared_norm = jnp.sum(numerics.abs_sq(x), keepdims=True)
# Note that this assumes parameters with a shape of length 3 are multihead
# linear parameters--if you wish to apply AGC to 1D convs, you may need
# to modify this line.
elif x.ndim in (2, 3): # Linear layers of shape IO or multihead linear
squared_norm = jnp.sum(numerics.abs_sq(x), axis=0, keepdims=True)
elif x.ndim == 4: # Conv kernels of shape HWIO
squared_norm = jnp.sum(numerics.abs_sq(x), axis=(0, 1, 2), keepdims=True)
else:
raise ValueError(
f'Expected parameter with shape in {1, 2, 3, 4}, got {x.shape}.')
chex.assert_is_broadcastable(squared_norm.shape, x.shape)
return jnp.broadcast_to(jnp.sqrt(squared_norm), x.shape)


def unitwise_clip(g_norm: chex.Array,
max_norm: chex.Array,
grad: chex.Array,
div_eps: float = 1e-6) -> chex.Array:
"""Applies gradient clipping unit-wise."""
# This little max(., div_eps) is distinct from the normal eps and just
# prevents division by zero. It technically should be impossible to engage.
clipped_grad = grad * (max_norm / jnp.maximum(g_norm, div_eps))
chex.assert_equal_shape((g_norm, max_norm, grad, clipped_grad))
return jnp.where(g_norm < max_norm, grad, clipped_grad)
from optax.transforms import _clipping


adaptive_grad_clip = _clipping.adaptive_grad_clip
AdaptiveGradClipState = base.EmptyState


def adaptive_grad_clip(clipping: float,
eps: float = 1e-3) -> base.GradientTransformation:
"""Clips updates to be at most ``clipping * parameter_norm``, unit-wise.
References:
[Brock, Smith, De, Simonyan 2021] High-Performance Large-Scale Image
Recognition Without Normalization. (https://arxiv.org/abs/2102.06171)
Args:
clipping: The maximum allowed ratio of update norm to parameter norm.
eps: An epsilon term to prevent clipping of zero-initialized params.
Returns:
A `GradientTransformation` object.
"""

def init_fn(params):
del params
return AdaptiveGradClipState()

def update_fn(updates, state, params):
if params is None:
raise ValueError(base.NO_PARAMS_MSG)
g_norm, p_norm = jax.tree_util.tree_map(unitwise_norm, (updates, params))
# Maximum allowable norm.
max_norm = jax.tree_util.tree_map(
lambda x: clipping * jnp.maximum(x, eps), p_norm)
# If grad norm > clipping * param_norm, rescale.
updates = jax.tree_util.tree_map(unitwise_clip, g_norm, max_norm, updates)
return updates, state

return base.GradientTransformation(init_fn, update_fn)
ClipState = base.EmptyState
clip = _clipping.clip
clip_by_block_rms = _clipping.clip_by_block_rms
clip_by_global_norm = _clipping.clip_by_global_norm
ClipByGlobalNormState = base.EmptyState
per_example_global_norm_clip = _clipping.per_example_global_norm_clip
per_example_layer_norm_clip = _clipping.per_example_layer_norm_clip
unitwise_norm = _clipping.unitwise_norm
unitwise_clip = _clipping.unitwise_clip
Loading

0 comments on commit 820e372

Please sign in to comment.