Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move clipping transforms to optax.transforms. #926

Merged
merged 1 commit into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
move clipping transforms to optax.transforms.
PiperOrigin-RevId: 638186308
  • Loading branch information
mtthss authored and OptaxDev committed May 29, 2024
commit 36ee9f46b681d2bd5006538bbeb9aea7d403d40e
27 changes: 10 additions & 17 deletions optax/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,12 @@ class EmptyState(NamedTuple):
"""An empty state for the simplest stateless transformations."""


def init_empty_state(params: Params) -> EmptyState:
"""Init function for a :class:`GradientTransformation` with empty state."""
del params
return EmptyState()


def identity() -> GradientTransformation:
"""Stateless identity transformation that leaves input gradients untouched.

Expand All @@ -225,14 +231,11 @@ def identity() -> GradientTransformation:
A `GradientTransformation` object.
"""

def init_fn(_):
return EmptyState()

def update_fn(updates, state, params=None):
del params
return updates, state

return GradientTransformation(init_fn, update_fn)
return GradientTransformation(init_empty_state, update_fn)


def set_to_zero() -> GradientTransformation:
Expand All @@ -255,15 +258,11 @@ def set_to_zero() -> GradientTransformation:
A `GradientTransformation` object.
"""

def init_fn(params):
del params
return EmptyState()

def update_fn(updates, state, params=None):
del params # Unused by the zero transform.
return jax.tree_util.tree_map(jnp.zeros_like, updates), state

return GradientTransformation(init_fn, update_fn)
return GradientTransformation(init_empty_state, update_fn)


def stateless(
Expand All @@ -282,14 +281,11 @@ def stateless(
An `optax.GradientTransformation`.
"""

def init_fn(_):
return EmptyState()

def update_fn(updates, state, params=None):
del state
return f(updates, params), EmptyState()

return GradientTransformation(init_fn, update_fn)
return GradientTransformation(init_empty_state, update_fn)


def stateless_with_tree_map(
Expand All @@ -310,9 +306,6 @@ def stateless_with_tree_map(
An `optax.GradientTransformation`.
"""

def init_fn(_):
return EmptyState()

def update_fn(updates, state, params=None):
del state
if params is not None:
Expand All @@ -321,7 +314,7 @@ def update_fn(updates, state, params=None):
f_ = lambda u: f(u, None)
return jax.tree_util.tree_map(f_, updates), EmptyState()

return GradientTransformation(init_fn, update_fn)
return GradientTransformation(init_empty_state, update_fn)


def with_extra_args_support(
Expand Down
302 changes: 12 additions & 290 deletions optax/_src/clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,298 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gradient clipping transformations.

Note that complex numbers are also supported, see
https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29
"""

import chex
import jax
import jax.numpy as jnp
"""Gradient clipping transformations."""

from optax._src import base
from optax._src import linear_algebra
from optax._src import numerics

ClipState = base.EmptyState


def clip(max_delta: chex.Numeric) -> base.GradientTransformation:
"""Clips updates element-wise, to be in ``[-max_delta, +max_delta]``.

Args:
max_delta: The maximum absolute value for each element in the update.

Returns:
A `GradientTransformation` object.
"""

def init_fn(params):
del params
return ClipState()

def update_fn(updates, state, params=None):
del params
updates = jax.tree_util.tree_map(
lambda g: jnp.clip(g, -max_delta, max_delta), updates)
return updates, state

return base.GradientTransformation(init_fn, update_fn)


def clip_by_block_rms(threshold: float) -> base.GradientTransformation:
"""Clips updates to a max rms for the gradient of each param vector or matrix.

A `block` is here a weight vector (e.g. in a Linear layer) or a weight matrix
(e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree.

Args:
threshold: The maximum rms for the gradient of each param vector or matrix.

Returns:
A `GradientTransformation` object.
"""

def init_fn(params):
del params
return base.EmptyState()

def update_fn(updates, state, params=None):
del params

def _clip_fn(u):
clip_denom = jnp.maximum(
1.0,
jnp.sqrt(jnp.mean(numerics.abs_sq(u))) / threshold)
return u / clip_denom

updates = jax.tree_util.tree_map(_clip_fn, updates)
return updates, state

return base.GradientTransformation(init_fn, update_fn)


ClipByGlobalNormState = base.EmptyState


def clip_by_global_norm(max_norm: float) -> base.GradientTransformation:
"""Clips updates using their global norm.

References:
[Pascanu et al, 2012](https://arxiv.org/abs/1211.5063)

Args:
max_norm: The maximum global norm for an update.

Returns:
A `GradientTransformation` object.
"""

def init_fn(params):
del params
return ClipByGlobalNormState()

def update_fn(updates, state, params=None):
del params
g_norm = linear_algebra.global_norm(updates)
# TODO(b/163995078): revert back to the following (faster) implementation
# once analysed how it affects backprop through update (e.g. meta-gradients)
# g_norm = jnp.maximum(max_norm, g_norm)
# updates = jax.tree_util.tree_map(
# lambda t: (t / g_norm) * max_norm, updates)
trigger = jnp.squeeze(g_norm < max_norm)
chex.assert_shape(trigger, ()) # A scalar.

def clip_fn(t):
return jax.lax.select(trigger, t, (t / g_norm.astype(t.dtype)) * max_norm)

updates = jax.tree_util.tree_map(clip_fn, updates)
return updates, state

return base.GradientTransformation(init_fn, update_fn)


def per_example_global_norm_clip(
grads: list[chex.Array], l2_norm_clip: float
) -> tuple[list[chex.Array], jax.Array]:
"""Applies gradient clipping per-example using their global norm.

References:
[Abadi et al, 2016](https://arxiv.org/abs/1607.00133)

Args:
grads: flattened update; the function expects these to have a batch
dimension on the 0th axis.
l2_norm_clip: maximum L2 norm of the per-example gradients.

Returns:
A tuple containing sum of the clipped per-example grads, and the number of
per-example grads that were clipped.
"""
bsize = grads[0].shape[0]

if any(g.ndim == 0 or bsize != g.shape[0] for g in grads):
raise ValueError(
'Unlike other transforms, `per_example_global_norm_clip` expects'
' `grads` to have a batch dimension in the 0th axis.')

global_grad_norms = jax.vmap(linear_algebra.global_norm)(grads)
divisors = jnp.maximum(global_grad_norms / l2_norm_clip, 1.0)
num_clipped = jnp.greater(divisors, 1.0).sum()
clipped_sum = [(jnp.moveaxis(g, 0, -1) / divisors).sum(-1) for g in grads]
return clipped_sum, num_clipped


def per_example_layer_norm_clip(
grads: list[chex.Array],
global_l2_norm_clip: float,
uniform: bool = True,
eps: float = 1e-8,
) -> tuple[list[chex.Array], list[chex.Array]]:
"""Applies gradient clipping per-example using per-layer norms.

References:
[McMahan et al, 2012](https://arxiv.org/abs/1710.06963)]

Args:
grads: flattened update; i.e. a list of gradients in which each item is
the gradient for one layer; the function expects these to have a batch
dimension on the 0th axis.
global_l2_norm_clip: overall L2 clip norm to use.
uniform: If `True`, per-layer clip norm is global_l2_norm_clip/sqrt(L),
where L is the number of layers. Otherwise, per-layer clip norm is
global_l2_norm_clip * sqrt(f), where f is the fraction of total model
parameters that are in this layer.
eps: Small positive value to add to norms to avoid possible division by
zero.

Let C = `global_l2_norm_clip value`. Then per-layer clipping is done as
follows:

* If `uniform` is `True`, each of the K layers has an individual clip
norm of C / sqrt(K).

* If `uniform` is `False`, each of the K layers has an individual clip
norm of C * sqrt(D_i / D) where D_i is the number of parameters in
layer i, and D is the total number of parameters in the model.

Returns:
A tuple containing sum of the clipped per-example grads and the number of
per-example grads that were clipped for each layer.
"""
bsize = grads[0].shape[0]

if any(g.ndim == 0 or bsize != g.shape[0] for g in grads):
raise ValueError(
'Unlike other transforms, `per_example_layer_norm_clip` expects'
' `grads` to have a batch dimension in the 0th axis; got shapes:'
f' {(g.shape for g in grads)}.'
)

num_layers = len(grads)

# Compute per-layer clip norms, based on whether we are using uniform
# variant or not.
if uniform:
# Create list of length `num_layers` of per-layer clip norm.
layer_clip_norms = (
global_l2_norm_clip * (1.0 / num_layers) ** 0.5,
) * num_layers
else:
total_params = sum(g[0].size for g in grads)
layer_clip_norms = tuple(
global_l2_norm_clip * (g[0].size / float(total_params)) ** 0.5
for g in grads
)

# Compute per-layer grad norms.
def map_layer_norm(grads_list):
return [jnp.linalg.norm(g, ord=None, axis=None) for g in grads_list]

layer_grad_norms_per_example = jax.vmap(map_layer_norm)(grads)

# Perform clipping.
divisors = (
tuple(
jnp.maximum(
layer_grad_norm / (layer_clip_norm + eps), 1.0
)
for layer_grad_norm, layer_clip_norm in zip(
layer_grad_norms_per_example, layer_clip_norms
)
)
)
num_clipped = [jnp.greater(divisor, 1.0).sum() for divisor in divisors]
clipped_sum = [
(g / jnp.expand_dims(d, axis=[i for i in range(1, g.ndim)])).sum(0)
for g, d in zip(grads, divisors)
]
return clipped_sum, num_clipped


def unitwise_norm(x: chex.Array) -> chex.Array:
"""Computes norms of each output unit separately."""
if jnp.squeeze(x).ndim <= 1: # Scalars and vectors
squared_norm = jnp.sum(numerics.abs_sq(x), keepdims=True)
# Note that this assumes parameters with a shape of length 3 are multihead
# linear parameters--if you wish to apply AGC to 1D convs, you may need
# to modify this line.
elif x.ndim in (2, 3): # Linear layers of shape IO or multihead linear
squared_norm = jnp.sum(numerics.abs_sq(x), axis=0, keepdims=True)
elif x.ndim == 4: # Conv kernels of shape HWIO
squared_norm = jnp.sum(numerics.abs_sq(x), axis=(0, 1, 2), keepdims=True)
else:
raise ValueError(
f'Expected parameter with shape in {1, 2, 3, 4}, got {x.shape}.')
chex.assert_is_broadcastable(squared_norm.shape, x.shape)
return jnp.broadcast_to(jnp.sqrt(squared_norm), x.shape)


def unitwise_clip(g_norm: chex.Array,
max_norm: chex.Array,
grad: chex.Array,
div_eps: float = 1e-6) -> chex.Array:
"""Applies gradient clipping unit-wise."""
# This little max(., div_eps) is distinct from the normal eps and just
# prevents division by zero. It technically should be impossible to engage.
clipped_grad = grad * (max_norm / jnp.maximum(g_norm, div_eps))
chex.assert_equal_shape((g_norm, max_norm, grad, clipped_grad))
return jnp.where(g_norm < max_norm, grad, clipped_grad)
from optax.transforms import _clipping


adaptive_grad_clip = _clipping.adaptive_grad_clip
AdaptiveGradClipState = base.EmptyState


def adaptive_grad_clip(clipping: float,
eps: float = 1e-3) -> base.GradientTransformation:
"""Clips updates to be at most ``clipping * parameter_norm``, unit-wise.

References:
[Brock, Smith, De, Simonyan 2021] High-Performance Large-Scale Image
Recognition Without Normalization. (https://arxiv.org/abs/2102.06171)

Args:
clipping: The maximum allowed ratio of update norm to parameter norm.
eps: An epsilon term to prevent clipping of zero-initialized params.

Returns:
A `GradientTransformation` object.
"""

def init_fn(params):
del params
return AdaptiveGradClipState()

def update_fn(updates, state, params):
if params is None:
raise ValueError(base.NO_PARAMS_MSG)
g_norm, p_norm = jax.tree_util.tree_map(unitwise_norm, (updates, params))
# Maximum allowable norm.
max_norm = jax.tree_util.tree_map(
lambda x: clipping * jnp.maximum(x, eps), p_norm)
# If grad norm > clipping * param_norm, rescale.
updates = jax.tree_util.tree_map(unitwise_clip, g_norm, max_norm, updates)
return updates, state

return base.GradientTransformation(init_fn, update_fn)
ClipState = base.EmptyState
clip = _clipping.clip
clip_by_block_rms = _clipping.clip_by_block_rms
clip_by_global_norm = _clipping.clip_by_global_norm
ClipByGlobalNormState = base.EmptyState
per_example_global_norm_clip = _clipping.per_example_global_norm_clip
per_example_layer_norm_clip = _clipping.per_example_layer_norm_clip
unitwise_norm = _clipping.unitwise_norm
unitwise_clip = _clipping.unitwise_clip
Loading
Loading