From cbebea23519d62feb61396f7e79269cc97b6734b Mon Sep 17 00:00:00 2001 From: Matteo Hessel Date: Wed, 10 Apr 2024 06:22:46 -0700 Subject: [PATCH] move clipping transforms to optax.transforms. PiperOrigin-RevId: 623473687 --- optax/_src/base.py | 27 +- optax/_src/clipping.py | 302 +----------------- optax/_src/combine.py | 236 +------------- optax/_src/transform.py | 227 ++----------- optax/_src/transform_test.py | 102 ------ optax/transforms/_accumulation_test.py | 42 +++ optax/transforms/_adding.py | 105 ++++++ optax/transforms/_adding_test.py | 96 ++++++ optax/transforms/_clipping.py | 282 ++++++++++++++++ .../_clipping_test.py} | 22 +- optax/transforms/_combining.py | 255 +++++++++++++++ optax/transforms/_combining_test.py | 284 ++++++++++++++++ optax/tree_utils/_random_test.py | 18 +- optax/tree_utils/_tree_math.py | 11 +- 14 files changed, 1147 insertions(+), 862 deletions(-) create mode 100644 optax/transforms/_adding.py create mode 100644 optax/transforms/_adding_test.py create mode 100644 optax/transforms/_clipping.py rename optax/{_src/clipping_test.py => transforms/_clipping_test.py} (91%) create mode 100644 optax/transforms/_combining.py create mode 100644 optax/transforms/_combining_test.py diff --git a/optax/_src/base.py b/optax/_src/base.py index 38fc609aa..dfda98350 100644 --- a/optax/_src/base.py +++ b/optax/_src/base.py @@ -212,6 +212,12 @@ class EmptyState(NamedTuple): """An empty state for the simplest stateless transformations.""" +def init_empty_state(params: Params) -> EmptyState: + """Init function for a :class:`GradientTransformation` with empty state.""" + del params + return EmptyState() + + def identity() -> GradientTransformation: """Stateless identity transformation that leaves input gradients untouched. @@ -225,14 +231,11 @@ def identity() -> GradientTransformation: A `GradientTransformation` object. """ - def init_fn(_): - return EmptyState() - def update_fn(updates, state, params=None): del params return updates, state - return GradientTransformation(init_fn, update_fn) + return GradientTransformation(init_empty_state, update_fn) def set_to_zero() -> GradientTransformation: @@ -255,15 +258,11 @@ def set_to_zero() -> GradientTransformation: A `GradientTransformation` object. """ - def init_fn(params): - del params - return EmptyState() - def update_fn(updates, state, params=None): del params # Unused by the zero transform. return jax.tree_util.tree_map(jnp.zeros_like, updates), state - return GradientTransformation(init_fn, update_fn) + return GradientTransformation(init_empty_state, update_fn) def stateless( @@ -282,14 +281,11 @@ def stateless( An `optax.GradientTransformation`. """ - def init_fn(_): - return EmptyState() - def update_fn(updates, state, params=None): del state return f(updates, params), EmptyState() - return GradientTransformation(init_fn, update_fn) + return GradientTransformation(init_empty_state, update_fn) def stateless_with_tree_map( @@ -310,9 +306,6 @@ def stateless_with_tree_map( An `optax.GradientTransformation`. """ - def init_fn(_): - return EmptyState() - def update_fn(updates, state, params=None): del state if params is not None: @@ -321,7 +314,7 @@ def update_fn(updates, state, params=None): f_ = lambda u: f(u, None) return jax.tree_util.tree_map(f_, updates), EmptyState() - return GradientTransformation(init_fn, update_fn) + return GradientTransformation(init_empty_state, update_fn) def with_extra_args_support( diff --git a/optax/_src/clipping.py b/optax/_src/clipping.py index 14ee5e9ad..794c427c2 100644 --- a/optax/_src/clipping.py +++ b/optax/_src/clipping.py @@ -12,298 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Gradient clipping transformations. - -Note that complex numbers are also supported, see -https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29 -""" - -import chex -import jax -import jax.numpy as jnp +"""Gradient clipping transformations.""" from optax._src import base -from optax._src import linear_algebra -from optax._src import numerics - -ClipState = base.EmptyState - - -def clip(max_delta: chex.Numeric) -> base.GradientTransformation: - """Clips updates element-wise, to be in ``[-max_delta, +max_delta]``. - - Args: - max_delta: The maximum absolute value for each element in the update. - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - del params - return ClipState() - - def update_fn(updates, state, params=None): - del params - updates = jax.tree_util.tree_map( - lambda g: jnp.clip(g, -max_delta, max_delta), updates) - return updates, state - - return base.GradientTransformation(init_fn, update_fn) - - -def clip_by_block_rms(threshold: float) -> base.GradientTransformation: - """Clips updates to a max rms for the gradient of each param vector or matrix. - - A `block` is here a weight vector (e.g. in a Linear layer) or a weight matrix - (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree. - - Args: - threshold: The maximum rms for the gradient of each param vector or matrix. - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - del params - return base.EmptyState() - - def update_fn(updates, state, params=None): - del params - - def _clip_fn(u): - clip_denom = jnp.maximum( - 1.0, - jnp.sqrt(jnp.mean(numerics.abs_sq(u))) / threshold) - return u / clip_denom - - updates = jax.tree_util.tree_map(_clip_fn, updates) - return updates, state - - return base.GradientTransformation(init_fn, update_fn) - - -ClipByGlobalNormState = base.EmptyState - - -def clip_by_global_norm(max_norm: float) -> base.GradientTransformation: - """Clips updates using their global norm. - - References: - [Pascanu et al, 2012](https://arxiv.org/abs/1211.5063) - - Args: - max_norm: The maximum global norm for an update. - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - del params - return ClipByGlobalNormState() - - def update_fn(updates, state, params=None): - del params - g_norm = linear_algebra.global_norm(updates) - # TODO(b/163995078): revert back to the following (faster) implementation - # once analysed how it affects backprop through update (e.g. meta-gradients) - # g_norm = jnp.maximum(max_norm, g_norm) - # updates = jax.tree_util.tree_map( - # lambda t: (t / g_norm) * max_norm, updates) - trigger = jnp.squeeze(g_norm < max_norm) - chex.assert_shape(trigger, ()) # A scalar. - - def clip_fn(t): - return jax.lax.select(trigger, t, (t / g_norm.astype(t.dtype)) * max_norm) - - updates = jax.tree_util.tree_map(clip_fn, updates) - return updates, state - - return base.GradientTransformation(init_fn, update_fn) - - -def per_example_global_norm_clip( - grads: list[chex.Array], l2_norm_clip: float -) -> tuple[list[chex.Array], jax.Array]: - """Applies gradient clipping per-example using their global norm. - - References: - [Abadi et al, 2016](https://arxiv.org/abs/1607.00133) - - Args: - grads: flattened update; the function expects these to have a batch - dimension on the 0th axis. - l2_norm_clip: maximum L2 norm of the per-example gradients. - - Returns: - A tuple containing sum of the clipped per-example grads, and the number of - per-example grads that were clipped. - """ - bsize = grads[0].shape[0] - - if any(g.ndim == 0 or bsize != g.shape[0] for g in grads): - raise ValueError( - 'Unlike other transforms, `per_example_global_norm_clip` expects' - ' `grads` to have a batch dimension in the 0th axis.') - - global_grad_norms = jax.vmap(linear_algebra.global_norm)(grads) - divisors = jnp.maximum(global_grad_norms / l2_norm_clip, 1.0) - num_clipped = jnp.greater(divisors, 1.0).sum() - clipped_sum = [(jnp.moveaxis(g, 0, -1) / divisors).sum(-1) for g in grads] - return clipped_sum, num_clipped - - -def per_example_layer_norm_clip( - grads: list[chex.Array], - global_l2_norm_clip: float, - uniform: bool = True, - eps: float = 1e-8, -) -> tuple[list[chex.Array], list[chex.Array]]: - """Applies gradient clipping per-example using per-layer norms. - - References: - [McMahan et al, 2012](https://arxiv.org/abs/1710.06963)] - - Args: - grads: flattened update; i.e. a list of gradients in which each item is - the gradient for one layer; the function expects these to have a batch - dimension on the 0th axis. - global_l2_norm_clip: overall L2 clip norm to use. - uniform: If `True`, per-layer clip norm is global_l2_norm_clip/sqrt(L), - where L is the number of layers. Otherwise, per-layer clip norm is - global_l2_norm_clip * sqrt(f), where f is the fraction of total model - parameters that are in this layer. - eps: Small positive value to add to norms to avoid possible division by - zero. - - Let C = `global_l2_norm_clip value`. Then per-layer clipping is done as - follows: - - * If `uniform` is `True`, each of the K layers has an individual clip - norm of C / sqrt(K). - - * If `uniform` is `False`, each of the K layers has an individual clip - norm of C * sqrt(D_i / D) where D_i is the number of parameters in - layer i, and D is the total number of parameters in the model. - - Returns: - A tuple containing sum of the clipped per-example grads and the number of - per-example grads that were clipped for each layer. - """ - bsize = grads[0].shape[0] - - if any(g.ndim == 0 or bsize != g.shape[0] for g in grads): - raise ValueError( - 'Unlike other transforms, `per_example_layer_norm_clip` expects' - ' `grads` to have a batch dimension in the 0th axis; got shapes:' - f' {(g.shape for g in grads)}.' - ) - - num_layers = len(grads) - - # Compute per-layer clip norms, based on whether we are using uniform - # variant or not. - if uniform: - # Create list of length `num_layers` of per-layer clip norm. - layer_clip_norms = ( - global_l2_norm_clip * (1.0 / num_layers) ** 0.5, - ) * num_layers - else: - total_params = sum(g[0].size for g in grads) - layer_clip_norms = tuple( - global_l2_norm_clip * (g[0].size / float(total_params)) ** 0.5 - for g in grads - ) - - # Compute per-layer grad norms. - def map_layer_norm(grads_list): - return [jnp.linalg.norm(g, ord=None, axis=None) for g in grads_list] - - layer_grad_norms_per_example = jax.vmap(map_layer_norm)(grads) - - # Perform clipping. - divisors = ( - tuple( - jnp.maximum( - layer_grad_norm / (layer_clip_norm + eps), 1.0 - ) - for layer_grad_norm, layer_clip_norm in zip( - layer_grad_norms_per_example, layer_clip_norms - ) - ) - ) - num_clipped = [jnp.greater(divisor, 1.0).sum() for divisor in divisors] - clipped_sum = [ - (g / jnp.expand_dims(d, axis=[i for i in range(1, g.ndim)])).sum(0) - for g, d in zip(grads, divisors) - ] - return clipped_sum, num_clipped - - -def unitwise_norm(x: chex.Array) -> chex.Array: - """Computes norms of each output unit separately.""" - if jnp.squeeze(x).ndim <= 1: # Scalars and vectors - squared_norm = jnp.sum(numerics.abs_sq(x), keepdims=True) - # Note that this assumes parameters with a shape of length 3 are multihead - # linear parameters--if you wish to apply AGC to 1D convs, you may need - # to modify this line. - elif x.ndim in (2, 3): # Linear layers of shape IO or multihead linear - squared_norm = jnp.sum(numerics.abs_sq(x), axis=0, keepdims=True) - elif x.ndim == 4: # Conv kernels of shape HWIO - squared_norm = jnp.sum(numerics.abs_sq(x), axis=(0, 1, 2), keepdims=True) - else: - raise ValueError( - f'Expected parameter with shape in {1, 2, 3, 4}, got {x.shape}.') - chex.assert_is_broadcastable(squared_norm.shape, x.shape) - return jnp.broadcast_to(jnp.sqrt(squared_norm), x.shape) - - -def unitwise_clip(g_norm: chex.Array, - max_norm: chex.Array, - grad: chex.Array, - div_eps: float = 1e-6) -> chex.Array: - """Applies gradient clipping unit-wise.""" - # This little max(., div_eps) is distinct from the normal eps and just - # prevents division by zero. It technically should be impossible to engage. - clipped_grad = grad * (max_norm / jnp.maximum(g_norm, div_eps)) - chex.assert_equal_shape((g_norm, max_norm, grad, clipped_grad)) - return jnp.where(g_norm < max_norm, grad, clipped_grad) +from optax.transforms import _clipping +adaptive_grad_clip = _clipping.adaptive_grad_clip AdaptiveGradClipState = base.EmptyState - - -def adaptive_grad_clip(clipping: float, - eps: float = 1e-3) -> base.GradientTransformation: - """Clips updates to be at most ``clipping * parameter_norm``, unit-wise. - - References: - [Brock, Smith, De, Simonyan 2021] High-Performance Large-Scale Image - Recognition Without Normalization. (https://arxiv.org/abs/2102.06171) - - Args: - clipping: The maximum allowed ratio of update norm to parameter norm. - eps: An epsilon term to prevent clipping of zero-initialized params. - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - del params - return AdaptiveGradClipState() - - def update_fn(updates, state, params): - if params is None: - raise ValueError(base.NO_PARAMS_MSG) - g_norm, p_norm = jax.tree_util.tree_map(unitwise_norm, (updates, params)) - # Maximum allowable norm. - max_norm = jax.tree_util.tree_map( - lambda x: clipping * jnp.maximum(x, eps), p_norm) - # If grad norm > clipping * param_norm, rescale. - updates = jax.tree_util.tree_map(unitwise_clip, g_norm, max_norm, updates) - return updates, state - - return base.GradientTransformation(init_fn, update_fn) +ClipState = base.EmptyState +clip = _clipping.clip +clip_by_block_rms = _clipping.clip_by_block_rms +clip_by_global_norm = _clipping.clip_by_global_norm +ClipByGlobalNormState = base.EmptyState +per_example_global_norm_clip = _clipping.per_example_global_norm_clip +per_example_layer_norm_clip = _clipping.per_example_layer_norm_clip +unitwise_norm = _clipping.unitwise_norm +unitwise_clip = _clipping.unitwise_clip diff --git a/optax/_src/combine.py b/optax/_src/combine.py index c75532d5b..a7709b46d 100644 --- a/optax/_src/combine.py +++ b/optax/_src/combine.py @@ -14,235 +14,9 @@ # ============================================================================== """Flexibly compose gradient transformations.""" -from typing import Callable, NamedTuple, Union, Mapping, Hashable +from optax.transforms import _combining -import jax - -from optax._src import base -from optax._src import wrappers - - -def chain( - *args: base.GradientTransformation, -) -> base.GradientTransformationExtraArgs: - """Applies a list of chainable update transformations. - - This function creates a new :func:`optax.GradientTransformation` that applies - a sequence of gradient transformations in order. The ``init`` function of the - new transformation constructs the optimizer state by concatenating the states - of the individual transforms, while the ``update`` function applies the - updates in the given order. - - Examples: - - A transform that scales by -0.1 the adam update: - - >>> import optax - >>> transform1 = optax.scale_by_adam() - >>> transform2 = optax.scale(-0.1) - >>> chained_transform = optax.chain(transform1, transform2) - >>> params = {'a': 1.0} - >>> state = chained_transform.init(params) - >>> updates = {'a': -0.5} - >>> updates, new_state = chained_transform.update(updates, state, params) - - Args: - *args: a sequence of chainable (init_fn, update_fn) tuples. - - Returns: - A :func:`GradientTransformationExtraArgs`, created by chaining the input - transformations. Note that independent of the argument types, the resulting - transformation always supports extra args. Any extra arguments passed to the - returned transformation will be passed only to those transformations in the - chain that support extra args. - """ - - transforms = [base.with_extra_args_support(t) for t in args] - init_fns, update_fns = zip(*transforms) - - def init_fn(params): - return tuple(fn(params) for fn in init_fns) - - def update_fn(updates, state, params=None, **extra_args): - if len(update_fns) != len(state): - raise ValueError('The number of updates and states has to be the same in ' - 'chain! Make sure you have called init first!') - - new_state = [] - for s, fn in zip(state, update_fns): - updates, new_s = fn(updates, s, params, **extra_args) - new_state.append(new_s) - return updates, tuple(new_state) - - # We opt to always return the GradientTransformationExtraArgs type here, - # instead of selecting the return type based on the arguments, since it works - # much better with the currently available type checkers. It also means that - # users will not get unexpected signature errors if they remove all of the - # transformations in a chain accepting extra args. - return base.GradientTransformationExtraArgs(init_fn, update_fn) - - -def named_chain( - *transforms: tuple[str, base.GradientTransformation] -) -> base.GradientTransformationExtraArgs: - """Chains optax gradient transformations. - - The `transforms` are `(name, transformation)` pairs, constituted of a string - `name` and an associated gradient transformation `transformation`. The - gradient transformation must be an instance of either `GradientTransformation` - or `GradientTransformationExtraArgs`. - - Each `name` is used as key for the state of the corresponding transformation - within the `named_chain` state. Thus the state of the gradient transformation - with a given `name` can be retrieved as `opt_state[name]`. - - Example: - - # tx1 is a GradientTransformation with no extra_args. - # tx2 is a GradientTransformationExtraArgs that requires `loss`. - # tx3 is a GradientTransformationExtraArgs that requires `temperature`. - - tx = named_chain(('one', tx1), ('two', tx2), ('three', tx3)) - extra_args={'loss': 0.3, 'temperature': 0.01} - tx.init(params) - tx.update(grads, state, params, **extra_args) - - Args: - *transforms: an arbitrary number of `(name, tx)` pairs, constituted of a - string `name` and an associated gradient transformation `tx`. The latter - is a `GradientTransformation` or `GradientTransformationExtraArgs`. - - Returns: - A single (init_fn, update_fn) tuple. - """ - - names = [name for name, _ in transforms] - - if len(names) != len(set(names)): - raise ValueError( - f'Named transformations must have unique names, but got {names}') - - transforms = [ - (name, base.with_extra_args_support(t)) - for name, t in transforms] - - def init_fn(params): - states = {} - for (name, tx) in transforms: - states[name] = tx.init(params) - return states - def update_fn(updates, state, params=None, **extra_args): - new_state = {} - for (name, tx) in transforms: - updates, new_state[name] = tx.update( - updates, state[name], params, **extra_args) - return updates, new_state - - return base.GradientTransformationExtraArgs(init_fn, update_fn) - - -class MultiTransformState(NamedTuple): - inner_states: Mapping[Hashable, base.OptState] - - -def multi_transform( - transforms: Mapping[Hashable, base.GradientTransformation], - param_labels: Union[base.PyTree, Callable[[base.PyTree], base.PyTree]], - *, - mask_compatible_extra_args: bool = False, -) -> base.GradientTransformationExtraArgs: - """Partitions params and applies a different transformation to each subset. - - Below is an example where we apply Adam to the weights and SGD to the biases - of a 2-layer neural network:: - - import optax - import jax - import jax.numpy as jnp - - def map_nested_fn(fn): - '''Recursively apply `fn` to the key-value pairs of a nested dict.''' - def map_fn(nested_dict): - return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v)) - for k, v in nested_dict.items()} - return map_fn - - params = {'linear_1': {'w': jnp.zeros((5, 6)), 'b': jnp.zeros(5)}, - 'linear_2': {'w': jnp.zeros((6, 1)), 'b': jnp.zeros(1)}} - gradients = jax.tree_util.tree_map(jnp.ones_like, params) # dummy gradients - - label_fn = map_nested_fn(lambda k, _: k) - tx = optax.multi_transform({'w': optax.adam(1.0), 'b': optax.sgd(1.0)}, - label_fn) - state = tx.init(params) - updates, new_state = tx.update(gradients, state, params) - new_params = optax.apply_updates(params, updates) - - Instead of providing a ``label_fn``, you may provide a PyTree of labels - directly. Also, this PyTree may be a prefix of the parameters PyTree. This - is demonstrated in the GAN pseudocode below:: - - generator_params = ... - discriminator_params = ... - all_params = (generator_params, discriminator_params) - param_labels = ('generator', 'discriminator') - - tx = optax.multi_transform( - {'generator': optax.adam(0.1), 'discriminator': optax.adam(0.5)}, - param_labels) - - If you would like to not optimize some parameters, you may wrap - :func:`optax.multi_transform` with :func:`optax.masked`. - - Args: - transforms: A mapping from labels to transformations. Each transformation - will be only be applied to parameters with the same label. - param_labels: A PyTree that is the same shape or a prefix of the - parameters/updates (or a function that returns one given the parameters as - input). The leaves of this PyTree correspond to the keys of the transforms - (therefore the values at the leaves must be a subset of the keys). - mask_compatible_extra_args: Whether to also apply the same masking to - extra_arg fields with the same tree structure as params/updates. - - Returns: - A :func:`optax.GradientTransformationExtraArgs` that implements an ``init`` - and ``update`` function. - """ - - transforms = { - k: base.with_extra_args_support(v) - for k, v in transforms.items() - } - - def make_mask(labels, group): - return jax.tree_util.tree_map(lambda label: label == group, labels) - - def init_fn(params): - labels = param_labels(params) if callable(param_labels) else param_labels - - label_set = set(jax.tree_util.tree_leaves(labels)) - if not label_set.issubset(transforms.keys()): - raise ValueError('Some parameters have no corresponding transformation.\n' - f'Parameter labels: {list(sorted(label_set))} \n' - f'Transforms keys: {list(sorted(transforms.keys()))} \n') - - inner_states = { - group: wrappers.masked( - tx, make_mask(labels, group), - mask_compatible_extra_args=mask_compatible_extra_args).init(params) - for group, tx in transforms.items() - } - return MultiTransformState(inner_states) - - def update_fn(updates, state, params=None, **extra_args): - labels = param_labels(updates) if callable(param_labels) else param_labels - new_inner_state = {} - for group, tx in transforms.items(): - masked_tx = wrappers.masked( - tx, make_mask(labels, group), - mask_compatible_extra_args=mask_compatible_extra_args) - updates, new_inner_state[group] = masked_tx.update( - updates, state.inner_states[group], params, **extra_args) - return updates, MultiTransformState(new_inner_state) - - return base.GradientTransformationExtraArgs(init_fn, update_fn) +chain = _combining.chain +named_chain = _combining.named_chain +multi_transform = _combining.partition +MultiTransformState = _combining.PartitionState diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 1a5511ccd..0ba835c1c 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -15,7 +15,7 @@ """Gradient transformations.""" import functools -from typing import Any, Callable, NamedTuple, Optional, Union +from typing import NamedTuple, Optional import chex import jax @@ -26,58 +26,11 @@ from optax._src import base from optax._src import numerics from optax._src import utils -from optax._src import wrappers +from optax.transforms import _accumulation +from optax.transforms import _adding -abs_sq = numerics.abs_sq - - -def _init_empty_state(params: base.Params) -> base.EmptyState: - """Init function for an empty state.""" - del params - return base.EmptyState() - - -class TraceState(NamedTuple): - """Holds an aggregation of past updates.""" - trace: base.Params - - -def trace( - decay: float, - nesterov: bool = False, - accumulator_dtype: Optional[Any] = None, -) -> base.GradientTransformation: - """Compute a trace of past updates. - - Note: `trace` and `ema` have very similar but distinct updates; - `trace = decay * trace + t`, while `ema = decay * ema + (1-decay) * t`. - Both are frequently found in the optimization literature. - - Args: - decay: Decay rate for the trace of past updates. - nesterov: Whether to use Nesterov momentum. - accumulator_dtype: Optional `dtype` to be used for the accumulator; if - `None` then the `dtype` is inferred from `params` and `updates`. - Returns: - A `GradientTransformation` object. - """ - - accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype) - - def init_fn(params): - return TraceState( - trace=otu.tree_zeros_like(params, dtype=accumulator_dtype)) - - def update_fn(updates, state, params=None): - del params - f = lambda g, t: g + decay * t - new_trace = jtu.tree_map(f, updates, state.trace) - updates = jtu.tree_map(f, updates, new_trace) if nesterov else new_trace - new_trace = otu.tree_cast(new_trace, accumulator_dtype) - return updates, TraceState(trace=new_trace) - - return base.GradientTransformation(init_fn, update_fn) +abs_sq = numerics.abs_sq def _reject_complex(params): @@ -85,53 +38,6 @@ def _reject_complex(params): raise ValueError('This transformation does not support complex parameters.') -class EmaState(NamedTuple): - """Holds an exponential moving average of past updates.""" - count: chex.Array # shape=(), dtype=jnp.int32. - ema: base.Params - - -def ema( - decay: float, - debias: bool = True, - accumulator_dtype: Optional[Any] = None -) -> base.GradientTransformation: - """Compute an exponential moving average of past updates. - - Note: `trace` and `ema` have very similar but distinct updates; - `ema = decay * ema + (1-decay) * t`, while `trace = decay * trace + t`. - Both are frequently found in the optimization literature. - - Args: - decay: Decay rate for the exponential moving average. - debias: Whether to debias the transformed gradient. - accumulator_dtype: Optional `dtype` to used for the accumulator; if `None` - then the `dtype` is inferred from `params` and `updates`. - - Returns: - A `GradientTransformation` object. - """ - - accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype) - - def init_fn(params): - return EmaState( - count=jnp.zeros([], jnp.int32), - ema=otu.tree_zeros_like(params, dtype=accumulator_dtype)) - - def update_fn(updates, state, params=None): - del params - updates = new_ema = otu.tree_update_moment( - updates, state.ema, decay, order=1) - count_inc = utils.safe_int32_increment(state.count) - if debias: - updates = otu.tree_bias_correction(new_ema, decay, count_inc) - state_ema = otu.tree_cast(new_ema, accumulator_dtype) - return updates, EmaState(count=count_inc, ema=state_ema) - - return base.GradientTransformation(init_fn, update_fn) - - class ScaleByRssState(NamedTuple): """State holding the sum of gradient squares to date.""" sum_of_squares: base.Updates @@ -476,9 +382,6 @@ def update_fn(updates, state, params=None): return base.GradientTransformation(init_fn, update_fn) -ScaleState = base.EmptyState - - def scale( step_size: float ) -> base.GradientTransformation: @@ -491,16 +394,12 @@ def scale( A `GradientTransformation` object. """ - def init_fn(params): - del params - return ScaleState() - def update_fn(updates, state, params=None): del params updates = jtu.tree_map(lambda g: step_size * g, updates) return updates, state - return base.GradientTransformation(init_fn, update_fn) + return base.GradientTransformation(base.init_empty_state, update_fn) def scale_by_param_block_norm( @@ -526,7 +425,7 @@ def update_fn(updates, state, params): updates, params) return updates, state - return base.GradientTransformation(_init_empty_state, update_fn) + return base.GradientTransformation(base.init_empty_state, update_fn) def scale_by_param_block_rms( @@ -552,7 +451,7 @@ def update_fn(updates, state, params): updates, params) return updates, state - return base.GradientTransformation(_init_empty_state, update_fn) + return base.GradientTransformation(base.init_empty_state, update_fn) class ScaleByAdaDeltaState(NamedTuple): @@ -845,45 +744,6 @@ def update_fn(updates, state, params=None): return base.GradientTransformation(init_fn, update_fn) -AddDecayedWeightsState = base.EmptyState - - -def add_decayed_weights( - weight_decay: Union[float, jax.Array] = 0.0, - mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None -) -> base.GradientTransformation: - """Add parameter scaled by `weight_decay`. - - Args: - weight_decay: A scalar weight decay rate. - mask: A tree with same structure as (or a prefix of) the params PyTree, - or a Callable that returns such a pytree given the params/updates. - The leaves should be booleans, `True` for leaves/subtrees you want to - apply the transformation to, and `False` for those you want to skip. - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - del params - return AddDecayedWeightsState() - - def update_fn(updates, state, params): - if params is None: - raise ValueError(base.NO_PARAMS_MSG) - updates = jtu.tree_map( - lambda g, p: g + weight_decay * p, updates, params) - return updates, state - - # If mask is not `None`, apply mask to the gradient transformation. - # E.g. it is common to skip weight decay on bias units and batch stats. - if mask is not None: - return wrappers.masked( - base.GradientTransformation(init_fn, update_fn), mask) - return base.GradientTransformation(init_fn, update_fn) - - class ScaleByScheduleState(NamedTuple): """Maintains count for scale scheduling.""" count: chex.Array # shape=(), dtype=jnp.int32 @@ -941,10 +801,6 @@ def update_fn(updates, state, params=None): return base.GradientTransformation(init_fn, update_fn) -class ScaleByTrustRatioState(NamedTuple): - """The scale and decay trust ratio transformation is stateless.""" - - def scale_by_trust_ratio( min_norm: float = 0.0, trust_coefficient: float = 1., @@ -964,10 +820,6 @@ def scale_by_trust_ratio( A `GradientTransformation` object. """ - def init_fn(params): - del params - return ScaleByTrustRatioState() - def update_fn(updates, state, params): if params is None: raise ValueError(base.NO_PARAMS_MSG) @@ -990,57 +842,7 @@ def _scale_update(update, param): updates = jtu.tree_map(_scale_update, updates, params) return updates, state - return base.GradientTransformation(init_fn, update_fn) - - -class AddNoiseState(NamedTuple): - """State for adding gradient noise. Contains a count for annealing.""" - count: chex.Array - rng_key: chex.PRNGKey - - -def add_noise( - eta: float, - gamma: float, - seed: int -) -> base.GradientTransformation: - """Add gradient noise. - - References: - [Neelakantan et al, 2014](https://arxiv.org/abs/1511.06807) - - Args: - eta: Base variance of the gaussian noise added to the gradient. - gamma: Decay exponent for annealing of the variance. - seed: Seed for random number generation. - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - del params - return AddNoiseState( - count=jnp.zeros([], jnp.int32), - rng_key=jax.random.PRNGKey(seed)) - - def update_fn(updates, state, params=None): # pylint: disable=missing-docstring - del params - num_vars = len(jtu.tree_leaves(updates)) - treedef = jtu.tree_structure(updates) - count_inc = numerics.safe_int32_increment(state.count) - variance = eta / count_inc**gamma - standard_deviation = jnp.sqrt(variance) - all_keys = jax.random.split(state.rng_key, num=num_vars + 1) - noise = jtu.tree_map( - lambda g, k: jax.random.normal(k, shape=g.shape, dtype=g.dtype), - updates, jtu.tree_unflatten(treedef, all_keys[1:])) - updates = jtu.tree_map( - lambda g, n: g + standard_deviation.astype(g.dtype) * n, - updates, noise) - return updates, AddNoiseState(count=count_inc, rng_key=all_keys[0]) - - return base.GradientTransformation(init_fn, update_fn) + return base.GradientTransformation(base.init_empty_state, update_fn) class ApplyEvery(NamedTuple): @@ -1419,7 +1221,7 @@ def update_fn( updates = otu.tree_scalar_mul(step, updates) return updates, state - return base.GradientTransformationExtraArgs(_init_empty_state, update_fn) + return base.GradientTransformationExtraArgs(base.init_empty_state, update_fn) ### Legacy symbols to be removed. ### @@ -1433,3 +1235,14 @@ def cast_tree( dtype: Optional[chex.ArrayDType] ) -> chex.ArrayTree: return otu.tree_cast(tree, dtype) + +trace = _accumulation.trace +TraceState = _accumulation.TraceState +ema = _accumulation.ema +EmaState = _accumulation.EmaState +add_noise = _adding.add_noise +AddNoiseState = _adding.AddNoiseState +add_decayed_weights = _adding.add_decayed_weights +AddDecayedWeightsState = base.EmptyState +ScaleState = base.EmptyState +ScaleByTrustRatioState = base.EmptyState diff --git a/optax/_src/transform_test.py b/optax/_src/transform_test.py index 39a914499..a6e5d6fcd 100644 --- a/optax/_src/transform_test.py +++ b/optax/_src/transform_test.py @@ -22,7 +22,6 @@ import chex import jax import jax.numpy as jnp -import numpy as np from optax._src import alias from optax._src import combine @@ -75,80 +74,6 @@ def test_scalers(self, scaler_constr): jax.tree_util.tree_map( lambda *args: chex.assert_equal_shape(args), params, updates) - @chex.all_variants - def test_add_decayed_weights(self): - # Define a transform that add decayed weights. - # We can define a mask either as a pytree, or as a function that - # returns the pytree. Below we define the pytree directly. - mask = (True, dict(a=True, b=False)) - tx = transform.add_decayed_weights(0.1, mask=mask) - # Define input updates and weights. - updates = ( - jnp.zeros((2,), dtype=jnp.float32), - dict( - a=jnp.zeros((2,), dtype=jnp.float32), - b=jnp.zeros((2,), dtype=jnp.float32),)) - weights = ( - jnp.ones((2,), dtype=jnp.float32), - dict( - a=jnp.ones((2,), dtype=jnp.float32), - b=jnp.ones((2,), dtype=jnp.float32),)) - # This mask means that we will add decayed weights to the first two - # terms in the input updates, but not to the last element. - expected_tx_updates = ( - 0.1*jnp.ones((2,), dtype=jnp.float32), - dict( - a=0.1*jnp.ones((2,), dtype=jnp.float32), - b=jnp.zeros((2,), dtype=jnp.float32),)) - # Apply transform - state = tx.init(weights) - transform_fn = self.variant(tx.update) - new_updates, _ = transform_fn(updates, state, weights) - # Assert output as expected. - chex.assert_trees_all_close(new_updates, expected_tx_updates) - - @chex.all_variants - def test_ema(self): - values = jnp.array([5.0, 7.0]) - decay = 0.9 - d = decay - - ema = transform.ema(decay=decay, debias=False) - state = ema.init(values[0]) # init to zeroes - - transform_fn = self.variant(ema.update) - mean, state = transform_fn(values[0], state) - np.testing.assert_allclose(mean, (1-d) * values[0], atol=1e-4) - - mean, state = transform_fn(values[1], state) - np.testing.assert_allclose( - mean, - (1 - d) * (values[1] + d * values[0]), atol=1e-2) - - @chex.all_variants - def test_ema_debias(self): - values = jnp.array([5.0, 7.0]) - decay = 0.9 - d = decay - - ema = transform.ema(decay=decay) - state = ema.init(values[0]) - - transform_fn = self.variant(ema.update) - mean, state = transform_fn(values[0], state) - np.testing.assert_allclose(mean, values[0], atol=1e-4) - - mean, state = transform_fn(values[1], state) - np.testing.assert_allclose( - mean, - ((1 - d) * values[1] + d * (1 - d) * values[0]) / (1 - d**2), - atol=1e-2) - # The state must not be debiased. - np.testing.assert_allclose( - state.ema, - (1 - d) * values[1] + d * (1 - d) * values[0], - atol=1e-2) - @chex.all_variants def test_apply_every(self): # The frequency of the application of sgd @@ -218,33 +143,6 @@ def test_centralize(self, inputs, outputs): centralized_inputs, _ = centralizer.update(inputs, {}) chex.assert_trees_all_close(centralized_inputs, outputs) - @chex.all_variants - def test_add_noise_has_correct_variance_scaling(self): - # Prepare to compare noise with a rescaled unit-variance substitute. - eta = 0.3 - gamma = 0.55 - seed = 314 - noise = transform.add_noise(eta, gamma, seed) - noise_unit = transform.add_noise(1.0, 0.0, seed) - - params = self.init_params - state = noise.init(params) - state_unit = noise_unit.init(params) - - # Check the noise itself by adding it to zeros. - updates = jax.tree_util.tree_map(jnp.zeros_like, params) - - for i in range(1, STEPS + 1): - updates_i, state = self.variant(noise.update)(updates, state) - updates_i_unit, state_unit = noise_unit.update(updates, state_unit) - - scale = jnp.sqrt(eta / i**gamma) - - updates_i_rescaled = jax.tree_util.tree_map( - lambda g, s=scale: g * s, updates_i_unit) - - chex.assert_trees_all_close(updates_i, updates_i_rescaled, rtol=1e-4) - def test_scale_by_optimistic_gradient(self): def f(params: jnp.ndarray) -> jnp.ndarray: diff --git a/optax/transforms/_accumulation_test.py b/optax/transforms/_accumulation_test.py index ba54e3b30..a32e95e5a 100644 --- a/optax/transforms/_accumulation_test.py +++ b/optax/transforms/_accumulation_test.py @@ -31,6 +31,48 @@ class AccumulationTest(chex.TestCase): + @chex.all_variants + def test_ema(self): + values = jnp.array([5.0, 7.0]) + decay = 0.9 + d = decay + + ema = _accumulation.ema(decay=decay, debias=False) + state = ema.init(values[0]) # init to zeroes + + transform_fn = self.variant(ema.update) + mean, state = transform_fn(values[0], state) + np.testing.assert_allclose(mean, (1-d) * values[0], atol=1e-4) + + mean, _ = transform_fn(values[1], state) + np.testing.assert_allclose( + mean, + (1 - d) * (values[1] + d * values[0]), atol=1e-2) + + @chex.all_variants + def test_ema_debias(self): + values = jnp.array([5.0, 7.0]) + decay = 0.9 + d = decay + + ema = _accumulation.ema(decay=decay) + state = ema.init(values[0]) + + transform_fn = self.variant(ema.update) + mean, state = transform_fn(values[0], state) + np.testing.assert_allclose(mean, values[0], atol=1e-4) + + mean, state = transform_fn(values[1], state) + np.testing.assert_allclose( + mean, + ((1 - d) * values[1] + d * (1 - d) * values[0]) / (1 - d**2), + atol=1e-2) + # The state must not be debiased. + np.testing.assert_allclose( + state.ema, + (1 - d) * values[1] + d * (1 - d) * values[0], + atol=1e-2) + def test_skip_not_finite(self): step = jnp.zeros([], dtype=jnp.int32) diff --git a/optax/transforms/_adding.py b/optax/transforms/_adding.py new file mode 100644 index 000000000..a52892e3b --- /dev/null +++ b/optax/transforms/_adding.py @@ -0,0 +1,105 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Additive components in gradient transformations.""" + +from typing import Any, Callable, NamedTuple, Optional, Union + +import chex +import jax +from jax import tree_util as jtu +import jax.numpy as jnp + +from optax import tree_utils as otu +from optax._src import base +from optax._src import numerics +from optax._src import wrappers + + +def add_decayed_weights( + weight_decay: Union[float, jax.Array] = 0.0, + mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None +) -> base.GradientTransformation: + """Add parameter scaled by `weight_decay`. + + Args: + weight_decay: A scalar weight decay rate. + mask: A tree with same structure as (or a prefix of) the params PyTree, + or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the transformation to, and `False` for those you want to skip. + + Returns: + A `GradientTransformation` object. + """ + + def update_fn(updates, state, params): + if params is None: + raise ValueError(base.NO_PARAMS_MSG) + updates = jtu.tree_map( + lambda g, p: g + weight_decay * p, updates, params) + return updates, state + + # If mask is not `None`, apply mask to the gradient transformation. + # E.g. it is common to skip weight decay on bias units and batch stats. + if mask is not None: + return wrappers.masked( + base.GradientTransformation(base.init_empty_state, update_fn), mask) + return base.GradientTransformation(base.init_empty_state, update_fn) + + +class AddNoiseState(NamedTuple): + """State for adding gradient noise. Contains a count for annealing.""" + count: chex.Array + rng_key: chex.PRNGKey + + +def add_noise( + eta: float, + gamma: float, + seed: int +) -> base.GradientTransformation: + """Add gradient noise. + + References: + [Neelakantan et al, 2014](https://arxiv.org/abs/1511.06807) + + Args: + eta: Base variance of the gaussian noise added to the gradient. + gamma: Decay exponent for annealing of the variance. + seed: Seed for random number generation. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return AddNoiseState( + count=jnp.zeros([], jnp.int32), + rng_key=jax.random.PRNGKey(seed)) + + def update_fn(updates, state, params=None): # pylint: disable=missing-docstring + del params + count_inc = numerics.safe_int32_increment(state.count) + standard_deviation = jnp.sqrt(eta / count_inc**gamma) + + rng_key, sample_key = jax.random.split(state.rng_key) + noise = otu.tree_random_like( + sample_key, target_tree=updates, sampler=jax.random.normal) + updates = otu.tree_add_scalar_mul( + tree_x=updates, scalar=standard_deviation, tree_y=noise) + return updates, AddNoiseState(count=count_inc, rng_key=rng_key) + + return base.GradientTransformation(init_fn, update_fn) diff --git a/optax/transforms/_adding_test.py b/optax/transforms/_adding_test.py new file mode 100644 index 000000000..34a35f171 --- /dev/null +++ b/optax/transforms/_adding_test.py @@ -0,0 +1,96 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for optax.transforms._adding.""" + +from absl.testing import absltest + +import chex +from jax import tree_util as jtu +import jax.numpy as jnp + +from optax.transforms import _adding + +STEPS = 50 + + +class AddingTest(chex.TestCase): + + def setUp(self): + super().setUp() + self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) + self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) + + @chex.all_variants + def test_add_decayed_weights(self): + # Define a transform that add decayed weights. + # We can define a mask either as a pytree, or as a function that + # returns the pytree. Below we define the pytree directly. + mask = (True, dict(a=True, b=False)) + tx = _adding.add_decayed_weights(0.1, mask=mask) + # Define input updates and weights. + updates = ( + jnp.zeros((2,), dtype=jnp.float32), + dict( + a=jnp.zeros((2,), dtype=jnp.float32), + b=jnp.zeros((2,), dtype=jnp.float32),)) + weights = ( + jnp.ones((2,), dtype=jnp.float32), + dict( + a=jnp.ones((2,), dtype=jnp.float32), + b=jnp.ones((2,), dtype=jnp.float32),)) + # This mask means that we will add decayed weights to the first two + # terms in the input updates, but not to the last element. + expected_tx_updates = ( + 0.1*jnp.ones((2,), dtype=jnp.float32), + dict( + a=0.1*jnp.ones((2,), dtype=jnp.float32), + b=jnp.zeros((2,), dtype=jnp.float32),)) + # Apply transform + state = tx.init(weights) + transform_fn = self.variant(tx.update) + new_updates, _ = transform_fn(updates, state, weights) + # Assert output as expected. + chex.assert_trees_all_close(new_updates, expected_tx_updates) + + @chex.all_variants + def test_add_noise_has_correct_variance_scaling(self): + # Prepare to compare noise with a rescaled unit-variance substitute. + eta = 0.3 + gamma = 0.55 + seed = 314 + noise = _adding.add_noise(eta, gamma, seed) + noise_unit = _adding.add_noise(1.0, 0.0, seed) + + params = self.init_params + state = noise.init(params) + state_unit = noise_unit.init(params) + + # Check the noise itself by adding it to zeros. + updates = jtu.tree_map(jnp.zeros_like, params) + + for i in range(1, STEPS + 1): + updates_i, state = self.variant(noise.update)(updates, state) + updates_i_unit, state_unit = noise_unit.update(updates, state_unit) + + scale = jnp.sqrt(eta / i**gamma) + + updates_i_rescaled = jtu.tree_map( + lambda g, s=scale: g * s, updates_i_unit) + + chex.assert_trees_all_close(updates_i, updates_i_rescaled, rtol=1e-4) + + +if __name__ == "__main__": + absltest.main() diff --git a/optax/transforms/_clipping.py b/optax/transforms/_clipping.py new file mode 100644 index 000000000..cd111fafc --- /dev/null +++ b/optax/transforms/_clipping.py @@ -0,0 +1,282 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Gradient clipping transformations. + +Note that complex numbers are also supported, see +https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29 +""" + +import chex +import jax +from jax import tree_util as jtu +import jax.numpy as jnp + +from optax import tree_utils as otu +from optax._src import base +from optax._src import linear_algebra +from optax._src import numerics + + +def clip(max_delta: chex.Numeric) -> base.GradientTransformation: + """Clips updates element-wise, to be in ``[-max_delta, +max_delta]``. + + Args: + max_delta: The maximum absolute value for each element in the update. + + Returns: + A `GradientTransformation` object. + """ + + def update_fn(updates, state, params=None): + del params + return otu.tree_clip(updates, -max_delta, max_delta), state + + return base.GradientTransformation(base.init_empty_state, update_fn) + + +def clip_by_block_rms(threshold: float) -> base.GradientTransformation: + """Clips updates to a max rms for the gradient of each param vector or matrix. + + A `block` is here a weight vector (e.g. in a Linear layer) or a weight matrix + (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree. + + Args: + threshold: The maximum rms for the gradient of each param vector or matrix. + + Returns: + A `GradientTransformation` object. + """ + + def update_fn(updates, state, params=None): + del params + + def _clip_fn(u): + clip_denom = jnp.maximum( + 1.0, + jnp.sqrt(jnp.mean(numerics.abs_sq(u))) / threshold) + return u / clip_denom + + updates = jtu.tree_map(_clip_fn, updates) + return updates, state + + return base.GradientTransformation(base.init_empty_state, update_fn) + + +def clip_by_global_norm(max_norm: float) -> base.GradientTransformation: + """Clips updates using their global norm. + + References: + [Pascanu et al, 2012](https://arxiv.org/abs/1211.5063) + + Args: + max_norm: The maximum global norm for an update. + + Returns: + A `GradientTransformation` object. + """ + + def update_fn(updates, state, params=None): + del params + g_norm = linear_algebra.global_norm(updates) + # TODO(b/163995078): revert back to the following (faster) implementation + # once analysed how it affects backprop through update (e.g. meta-gradients) + # g_norm = jnp.maximum(max_norm, g_norm) + # updates = jtu.tree_map(lambda t: (t / g_norm) * max_norm, updates) + trigger = jnp.squeeze(g_norm < max_norm) + chex.assert_shape(trigger, ()) # A scalar. + + def clip_fn(t): + return jax.lax.select(trigger, t, (t / g_norm.astype(t.dtype)) * max_norm) + + updates = jtu.tree_map(clip_fn, updates) + return updates, state + + return base.GradientTransformation(base.init_empty_state, update_fn) + + +def per_example_global_norm_clip( + grads: list[chex.Array], l2_norm_clip: float +) -> tuple[list[chex.Array], jax.Array]: + """Applies gradient clipping per-example using their global norm. + + References: + [Abadi et al, 2016](https://arxiv.org/abs/1607.00133) + + Args: + grads: flattened update; the function expects these to have a batch + dimension on the 0th axis. + l2_norm_clip: maximum L2 norm of the per-example gradients. + + Returns: + A tuple containing sum of the clipped per-example grads, and the number of + per-example grads that were clipped. + """ + bsize = grads[0].shape[0] + + if any(g.ndim == 0 or bsize != g.shape[0] for g in grads): + raise ValueError( + 'Unlike other transforms, `per_example_global_norm_clip` expects' + ' `grads` to have a batch dimension in the 0th axis.') + + global_grad_norms = jax.vmap(linear_algebra.global_norm)(grads) + divisors = jnp.maximum(global_grad_norms / l2_norm_clip, 1.0) + num_clipped = jnp.greater(divisors, 1.0).sum() + clipped_sum = [(jnp.moveaxis(g, 0, -1) / divisors).sum(-1) for g in grads] + return clipped_sum, num_clipped + + +def per_example_layer_norm_clip( + grads: list[chex.Array], + global_l2_norm_clip: float, + uniform: bool = True, + eps: float = 1e-8, +) -> tuple[list[chex.Array], list[chex.Array]]: + """Applies gradient clipping per-example using per-layer norms. + + References: + [McMahan et al, 2012](https://arxiv.org/abs/1710.06963)] + + Args: + grads: flattened update; i.e. a list of gradients in which each item is + the gradient for one layer; the function expects these to have a batch + dimension on the 0th axis. + global_l2_norm_clip: overall L2 clip norm to use. + uniform: If `True`, per-layer clip norm is global_l2_norm_clip/sqrt(L), + where L is the number of layers. Otherwise, per-layer clip norm is + global_l2_norm_clip * sqrt(f), where f is the fraction of total model + parameters that are in this layer. + eps: Small positive value to add to norms to avoid possible division by + zero. + + Let C = `global_l2_norm_clip value`. Then per-layer clipping is done as + follows: + (1) If `uniform` is `True`, each of the K layers has an individual clip + norm of C / sqrt(K). + (2) If `uniform` is `False`, each of the K layers has an individual clip + norm of C * sqrt(D_i / D) where D_i is the number of parameters in + layer i, and D is the total number of parameters in the model. + + Returns: + A tuple containing sum of the clipped per-example grads and the number of + per-example grads that were clipped for each layer. + """ + bsize = grads[0].shape[0] + + if any(g.ndim == 0 or bsize != g.shape[0] for g in grads): + raise ValueError( + 'Unlike other transforms, `per_example_layer_norm_clip` expects' + ' `grads` to have a batch dimension in the 0th axis; got shapes:' + f' {(g.shape for g in grads)}.' + ) + + num_layers = len(grads) + + # Compute per-layer clip norms, based on whether we are using uniform + # variant or not. + if uniform: + # Create list of length `num_layers` of per-layer clip norm. + layer_clip_norms = ( + global_l2_norm_clip * (1.0 / num_layers) ** 0.5, + ) * num_layers + else: + total_params = sum(g[0].size for g in grads) + layer_clip_norms = tuple( + global_l2_norm_clip * (g[0].size / float(total_params)) ** 0.5 + for g in grads + ) + + # Compute per-layer grad norms. + def map_layer_norm(grads_list): + return [jnp.linalg.norm(g, ord=None, axis=None) for g in grads_list] + + layer_grad_norms_per_example = jax.vmap(map_layer_norm)(grads) + + # Perform clipping. + divisors = ( + tuple( + jnp.maximum( + layer_grad_norm / (layer_clip_norm + eps), 1.0 + ) + for layer_grad_norm, layer_clip_norm in zip( + layer_grad_norms_per_example, layer_clip_norms + ) + ) + ) + num_clipped = [jnp.greater(divisor, 1.0).sum() for divisor in divisors] + clipped_sum = [ + (g / jnp.expand_dims(d, axis=[i for i in range(1, g.ndim)])).sum(0) + for g, d in zip(grads, divisors) + ] + return clipped_sum, num_clipped + + +def unitwise_norm(x: chex.Array) -> chex.Array: + """Computes norms of each output unit separately.""" + if jnp.squeeze(x).ndim <= 1: # Scalars and vectors + squared_norm = jnp.sum(numerics.abs_sq(x), keepdims=True) + # Note that this assumes parameters with a shape of length 3 are multihead + # linear parameters--if you wish to apply AGC to 1D convs, you may need + # to modify this line. + elif x.ndim in (2, 3): # Linear layers of shape IO or multihead linear + squared_norm = jnp.sum(numerics.abs_sq(x), axis=0, keepdims=True) + elif x.ndim == 4: # Conv kernels of shape HWIO + squared_norm = jnp.sum(numerics.abs_sq(x), axis=(0, 1, 2), keepdims=True) + else: + raise ValueError( + f'Expected parameter with shape in {1, 2, 3, 4}, got {x.shape}.') + chex.assert_is_broadcastable(squared_norm.shape, x.shape) + return jnp.broadcast_to(jnp.sqrt(squared_norm), x.shape) + + +def unitwise_clip(g_norm: chex.Array, + max_norm: chex.Array, + grad: chex.Array, + div_eps: float = 1e-6) -> chex.Array: + """Applies gradient clipping unit-wise.""" + # This little max(., div_eps) is distinct from the normal eps and just + # prevents division by zero. It technically should be impossible to engage. + clipped_grad = grad * (max_norm / jnp.maximum(g_norm, div_eps)) + chex.assert_equal_shape((g_norm, max_norm, grad, clipped_grad)) + return jnp.where(g_norm < max_norm, grad, clipped_grad) + + +def adaptive_grad_clip(clipping: float, + eps: float = 1e-3) -> base.GradientTransformation: + """Clips updates to be at most ``clipping * parameter_norm``, unit-wise. + + References: + [Brock, Smith, De, Simonyan 2021] High-Performance Large-Scale Image + Recognition Without Normalization. (https://arxiv.org/abs/2102.06171) + + Args: + clipping: The maximum allowed ratio of update norm to parameter norm. + eps: An epsilon term to prevent clipping of zero-initialized params. + + Returns: + A `GradientTransformation` object. + """ + + def update_fn(updates, state, params): + if params is None: + raise ValueError(base.NO_PARAMS_MSG) + g_norm, p_norm = jtu.tree_map(unitwise_norm, (updates, params)) + # Maximum allowable norm. + max_norm = jtu.tree_map( + lambda x: clipping * jnp.maximum(x, eps), p_norm) + # If grad norm > clipping * param_norm, rescale. + updates = jtu.tree_map(unitwise_clip, g_norm, max_norm, updates) + return updates, state + + return base.GradientTransformation(base.init_empty_state, update_fn) diff --git a/optax/_src/clipping_test.py b/optax/transforms/_clipping_test.py similarity index 91% rename from optax/_src/clipping_test.py rename to optax/transforms/_clipping_test.py index 185e17306..54317b73e 100644 --- a/optax/_src/clipping_test.py +++ b/optax/transforms/_clipping_test.py @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for `clipping.py`.""" +"""Tests for optax.transforms._clipping.""" from absl.testing import absltest - import chex import jax import jax.numpy as jnp -from optax._src import clipping from optax._src import linear_algebra +from optax.transforms import _clipping + STEPS = 50 LR = 1e-2 @@ -37,11 +37,11 @@ def setUp(self): def test_clip(self): updates = self.per_step_updates # For a sufficiently high delta the update should not be changed. - clipper = clipping.clip(1e6) + clipper = _clipping.clip(1e6) clipped_updates, _ = clipper.update(updates, None) chex.assert_trees_all_close(clipped_updates, clipped_updates) # Clipping at delta=1 should make all updates exactly 1. - clipper = clipping.clip(1.) + clipper = _clipping.clip(1.) clipped_updates, _ = clipper.update(updates, None) chex.assert_trees_all_close( clipped_updates, jax.tree_util.tree_map(jnp.ones_like, updates)) @@ -50,7 +50,7 @@ def test_clip_by_block_rms(self): rmf_fn = lambda t: jnp.sqrt(jnp.mean(t**2)) updates = self.per_step_updates for i in range(1, STEPS + 1): - clipper = clipping.clip_by_block_rms(1. / i) + clipper = _clipping.clip_by_block_rms(1. / i) # Check that the clipper actually works and block rms is <= threshold updates, _ = clipper.update(updates, None) self.assertAlmostEqual(rmf_fn(updates[0]), 1. / i) @@ -62,7 +62,7 @@ def test_clip_by_block_rms(self): def test_clip_by_global_norm(self): updates = self.per_step_updates for i in range(1, STEPS + 1): - clipper = clipping.clip_by_global_norm(1. / i) + clipper = _clipping.clip_by_global_norm(1. / i) # Check that the clipper actually works and global norm is <= max_norm updates, _ = clipper.update(updates, None) self.assertAlmostEqual( @@ -76,12 +76,12 @@ def test_adaptive_grad_clip(self): params = self.init_params for i in range(1, STEPS + 1): clip_r = 1. / i - clipper = clipping.adaptive_grad_clip(clip_r) + clipper = _clipping.adaptive_grad_clip(clip_r) # Check that the clipper actually works and upd_norm is < c * param_norm. updates, _ = clipper.update(updates, None, params) u_norm, p_norm = jax.tree_util.tree_map( - clipping.unitwise_norm, (updates, params)) + _clipping.unitwise_norm, (updates, params)) cmp = jax.tree_util.tree_map( lambda u, p, c=clip_r: u - c * p < 1e-6, u_norm, p_norm) for leaf in jax.tree_util.tree_leaves(cmp): @@ -101,7 +101,7 @@ def test_per_example_layer_norm_clip(self): ] with self.subTest(name='Uniform Variant'): - sum_clipped_grads, num_clipped = clipping.per_example_layer_norm_clip( + sum_clipped_grads, num_clipped = _clipping.per_example_layer_norm_clip( grads_flat, global_l2_norm_clip=jnp.sqrt(2), uniform=True ) @@ -119,7 +119,7 @@ def test_per_example_layer_norm_clip(self): self.assertEqual(num_clipped[1], 4) with self.subTest(name='Scaled Variant'): - sum_clipped_grads, num_clipped = clipping.per_example_layer_norm_clip( + sum_clipped_grads, num_clipped = _clipping.per_example_layer_norm_clip( grads_flat, global_l2_norm_clip=jnp.sqrt(19), uniform=False ) diff --git a/optax/transforms/_combining.py b/optax/transforms/_combining.py new file mode 100644 index 000000000..857c3dacf --- /dev/null +++ b/optax/transforms/_combining.py @@ -0,0 +1,255 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flexibly compose gradient transformations.""" + +from typing import Callable, NamedTuple, Union, Mapping, Hashable + +import jax + +from optax._src import base +from optax._src import wrappers + + +def chain( + *args: base.GradientTransformation, +) -> base.GradientTransformationExtraArgs: + """Applies a list of chainable update transformations. + + This function creates a new :func:`optax.GradientTransformation` that applies + a sequence of gradient transformations in order. The ``init`` function of the + new transformation constructs the optimizer state by concatenating the states + of the individual transforms, while the ``update`` function applies the + updates in the given order. + + Examples: + + A transform that scales by -0.1 the adam update: + + >>> import optax + >>> transform1 = optax.scale_by_adam() + >>> transform2 = optax.scale(-0.1) + >>> chained_transform = optax.chain(transform1, transform2) + >>> params = {'a': 1.0} + >>> state = chained_transform.init(params) + >>> updates = {'a': -0.5} + >>> updates, new_state = chained_transform.update(updates, state, params) + + Args: + *args: a sequence of chainable (init_fn, update_fn) tuples. + + Returns: + A :func:`GradientTransformationExtraArgs`, created by chaining the input + transformations. Note that independent of the argument types, the resulting + transformation always supports extra args. Any extra arguments passed to the + returned transformation will be passed only to those transformations in the + chain that support extra args. + """ + + transforms = [base.with_extra_args_support(t) for t in args] + init_fns, update_fns = zip(*transforms) + + def init_fn(params): + return tuple(fn(params) for fn in init_fns) + + def update_fn(updates, state, params=None, **extra_args): + if len(update_fns) != len(state): + raise ValueError('The number of updates and states has to be the same in ' + 'chain! Make sure you have called init first!') + + new_state = [] + for s, fn in zip(state, update_fns): + updates, new_s = fn(updates, s, params, **extra_args) + new_state.append(new_s) + return updates, tuple(new_state) + + # We opt to always return the GradientTransformationExtraArgs type here, + # instead of selecting the return type based on the arguments, since it works + # much better with the currently available type checkers. It also means that + # users will not get unexpected signature errors if they remove all of the + # transformations in a chain accepting extra args. + return base.GradientTransformationExtraArgs(init_fn, update_fn) + + +def named_chain( + *transforms: tuple[str, base.GradientTransformation] +) -> base.GradientTransformationExtraArgs: + """Chains optax gradient transformations. + + A variant of :func:`optax.chain` that allows to name each transformation. + + Here the ``transforms`` are ``(name, transformation)`` pairs, constituted of a + string ``name`` and an associated transformation ``transformation``. The + gradient transformation must be an instance of :func:`GradientTransformation` + or :func:`GradientTransformationExtraArgs`. + + Each ``name`` is used as key for the state of the corresponding transformation + within the ``named_chain`` state. Thus the state of the transformation + with a given ``name`` can be easily retrieved as ``opt_state[name]``. + + Examples: + + >>> # tx1 is a GradientTransformation with no extra_args. + >>> # tx2 is a GradientTransformationExtraArgs that requires `loss`. + >>> # tx3 is a GradientTransformationExtraArgs that requires `temperature`. + >>> tx = named_chain(('one', tx1), ('two', tx2), ('three', tx3)) + >>> extra_args={'loss': 0.3, 'temperature': 0.01} + >>> tx.init(params) + >>> tx.update(grads, state, params, **extra_args) + + Args: + *transforms: an arbitrary number of ``(name, tx)`` pairs, constituted of a + string ``name`` and an associated transformation ``tx``. The latter is a + :func:`GradientTransformation` or :func:`GradientTransformationExtraArgs`. + + Returns: + A single (init_fn, update_fn) tuple. + """ + + names = [name for name, _ in transforms] + + if len(names) != len(set(names)): + raise ValueError( + f'Named transformations must have unique names, but got {names}') + + transforms = [ + (name, base.with_extra_args_support(t)) + for name, t in transforms] + + def init_fn(params): + states = {} + for (name, tx) in transforms: + states[name] = tx.init(params) + return states + def update_fn(updates, state, params=None, **extra_args): + new_state = {} + for (name, tx) in transforms: + updates, new_state[name] = tx.update( + updates, state[name], params, **extra_args) + return updates, new_state + + return base.GradientTransformationExtraArgs(init_fn, update_fn) + + +class PartitionState(NamedTuple): + inner_states: Mapping[Hashable, base.OptState] + + +def partition( + transforms: Mapping[Hashable, base.GradientTransformation], + param_labels: Union[base.PyTree, Callable[[base.PyTree], base.PyTree]], + *, + mask_compatible_extra_args: bool = False, +) -> base.GradientTransformationExtraArgs: + """Partitions params and applies a different transformation to each subset. + + Sometimes you may want to apply different transformations to different + parameters. For example, you may want to apply Adam to the weights of a + neural network, but SGD to the biases. This function allows you to do that. + + Examples: + + Below is an example where we apply Adam to the weights and SGD to the biases + of a 2-layer neural network:: + + >>> import optax + >>> import jax + >>> import jax.numpy as jnp + + >>> def map_nested_fn(fn): + ... '''Recursively apply `fn` to key-value pairs of a nested dict.''' + ... def map_fn(nested_dict): + ... return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v)) + ... for k, v in nested_dict.items()} + ... return map_fn + + >>> params = {'linear_1': {'w': jnp.zeros((5, 6)), 'b': jnp.zeros(5)}, + ... 'linear_2': {'w': jnp.zeros((6, 1)), 'b': jnp.zeros(1)}} + >>> gradients = jtu.tree_map(jnp.ones_like, params) # dummy gradients + + >>> label_fn = map_nested_fn(lambda k, _: k) + >>> tx = optax.multi_transform( + ... {'w': optax.adam(1.0), 'b': optax.sgd(1.0)}, label_fn) + >>> state = tx.init(params) + >>> updates, new_state = tx.update(gradients, state, params) + >>> new_params = optax.apply_updates(params, updates) + + Instead of providing a ``label_fn``, you may provide a PyTree of labels + directly. Also, this PyTree may be a prefix of the parameters PyTree. This + is demonstrated in the GAN pseudocode below:: + + >>> generator_params = ... + >>> discriminator_params = ... + >>> all_params = (generator_params, discriminator_params) + >>> param_labels = ('generator', 'discriminator') + + >>> tx = optax.multi_transform( + >>> {'generator': optax.adam(0.1), 'discriminator': optax.adam(0.5)}, + >>> param_labels) + + If you would like to not optimize some parameters, you may wrap + :func:`optax.multi_transform` with :func:`optax.masked`. + + Args: + transforms: A mapping from labels to transformations. Each transformation + will be only be applied to parameters with the same label. + param_labels: A PyTree that is the same shape or a prefix of the + parameters/updates (or a function that returns one given the parameters as + input). The leaves of this PyTree correspond to the keys of the transforms + (therefore the values at the leaves must be a subset of the keys). + mask_compatible_extra_args: Whether to also apply the same masking to + extra_arg fields with the same tree structure as params/updates. + + Returns: + A :func:`optax.GradientTransformationExtraArgs` that implements an ``init`` + and ``update`` function. + """ + + transforms = { + k: base.with_extra_args_support(v) + for k, v in transforms.items() + } + + def make_mask(labels, group): + return jax.tree_util.tree_map(lambda label: label == group, labels) + + def init_fn(params): + labels = param_labels(params) if callable(param_labels) else param_labels + + label_set = set(jax.tree_util.tree_leaves(labels)) + if not label_set.issubset(transforms.keys()): + raise ValueError('Some parameters have no corresponding transformation.\n' + f'Parameter labels: {list(sorted(label_set))} \n' + f'Transforms keys: {list(sorted(transforms.keys()))} \n') + + inner_states = { + group: wrappers.masked( + tx, make_mask(labels, group), + mask_compatible_extra_args=mask_compatible_extra_args).init(params) + for group, tx in transforms.items() + } + return PartitionState(inner_states) + + def update_fn(updates, state, params=None, **extra_args): + labels = param_labels(updates) if callable(param_labels) else param_labels + new_inner_state = {} + for group, tx in transforms.items(): + masked_tx = wrappers.masked( + tx, make_mask(labels, group), + mask_compatible_extra_args=mask_compatible_extra_args) + updates, new_inner_state[group] = masked_tx.update( + updates, state.inner_states[group], params, **extra_args) + return updates, PartitionState(new_inner_state) + + return base.GradientTransformationExtraArgs(init_fn, update_fn) diff --git a/optax/transforms/_combining_test.py b/optax/transforms/_combining_test.py new file mode 100644 index 000000000..4ae908d90 --- /dev/null +++ b/optax/transforms/_combining_test.py @@ -0,0 +1,284 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `optax.transforms._combining.py`.""" + +from absl.testing import absltest +from absl.testing import parameterized + +import chex +import jax +import jax.numpy as jnp + +from optax._src import alias +from optax._src import base +from optax._src import transform +from optax._src import update +from optax.transforms import _accumulation +from optax.transforms import _combining + +STEPS = 50 +LR = 1e-2 + + +class CombiningTest(chex.TestCase): + + def setUp(self): + super().setUp() + self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) + self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) + + @chex.all_variants + def test_chain(self): + transformations = [ + transform.scale_by_adam(), + _accumulation.trace(decay=0, nesterov=False), + transform.scale(-LR)] + + # Apply updates with chain. + chain_params = self.init_params + chained_transforms = _combining.chain(*transformations) + state = chained_transforms.init(chain_params) + self.assertIsInstance(state, tuple) + + @self.variant + def update_fn(updates, state): + return chained_transforms.update(updates, state) + + for _ in range(STEPS): + updates, state = update_fn(self.per_step_updates, state) + self.assertIsInstance(state, tuple) + chain_params = update.apply_updates(chain_params, updates) + + # Manually apply sequence of transformations. + manual_params = self.init_params + states = [t.init(manual_params) for t in transformations] + for _ in range(STEPS): + updates = self.per_step_updates + new_states = [] + for t, s in zip(transformations, states): + updates, state = t.update(updates, s) + new_states.append(state) + manual_params = update.apply_updates(manual_params, updates) + states = new_states + + # Check equivalence. + chex.assert_trees_all_close(manual_params, chain_params, rtol=1e-4) + + +def _map_keys_fn(fn): + def map_fn(nested_dict): + return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v)) + for k, v in nested_dict.items()} + return map_fn + + +class ExtraArgsTest(chex.TestCase): + + def test_extra_args(self): + def init_fn(params): + del params + return tuple() + + # Arguments required by a transformation should be keyword-only. + # For example, the loss argument in this transformation. + def update_fn(updates, state, params=None, *, loss, **extra_args): + # Extra args should always be accepted. + del extra_args, params + assert loss == 1 + return updates, state + + t = base.GradientTransformationExtraArgs(init_fn, update_fn) + result = _combining.chain(alias.adam(1e-3), t) + self.assertIsInstance(result, base.GradientTransformationExtraArgs) + + params = {'a': 1, 'b': 2} + state = result.init(params) + result.update(params, state, loss=1, ignored_kwarg='hi') + + def test_extra_args_chaining(self): + def init_fn(params): + del params + return {} + def update_fn(updates, state, params=None): + del params + return updates, state + + # Possible gotcha: Chaining regular gradient transformations results in + # a transformation that supports extra args. + t1 = base.GradientTransformation(init_fn, update_fn) + t2 = _combining.chain(t1, t1) + self.assertIsInstance(t2, base.GradientTransformation) + self.assertIsInstance(t2, base.GradientTransformationExtraArgs) + + t3 = base.with_extra_args_support(t2) + self.assertIsInstance(t3, base.GradientTransformationExtraArgs) + + def test_extra_args_positional_params(self): + def init_fn(params): + del params + return tuple() + + def update_fn(updates, state, params=None): + assert params is not None + return updates, state + + def update_fn_kwargs(updates, state, params=None, **extra_args): + del extra_args + assert params is not None + return updates, state + + t1 = base.GradientTransformation(init_fn, update_fn) + t2 = base.GradientTransformationExtraArgs(init_fn, update_fn_kwargs) + opt = _combining.chain(t1, t2) + params = {'a': 1, 'b': 2} + state = opt.init(params) + opt.update(params, state, params, ignored_kwarg='hi') + opt.update(params, state, params=params, ignored_kwarg='hi') + + +class PartitionTest(chex.TestCase): + """Tests for the partition wrapper.""" + + @chex.all_variants + @parameterized.parameters(True, False) + def test_partition(self, use_fn): + params = {'a1': 1., 'b1': 2., 'z1': {'a2': 3., 'z2': {'c1': 4.}}} + params = jax.tree_util.tree_map(jnp.asarray, params) + input_updates = jax.tree_util.tree_map(lambda x: x / 10.0, params) + tx_dict = {'a': transform.scale(-1.0), + 'b': transform.ema(0.0), # stateful + 'c': transform.scale(2.0)} + param_labels = _map_keys_fn(lambda k, _: k[0]) + if not use_fn: + param_labels = param_labels(params) + tx = _combining.partition(tx_dict, param_labels) + update_fn = self.variant(tx.update) + state = self.variant(tx.init)(params) + + correct_update_fn = _map_keys_fn( + lambda k, v: {'a': -v, 'b': v, 'c': 2.0*v}[k[0]]) + + updates, state = update_fn(input_updates, state, params) + correct_updates = correct_update_fn(input_updates) + chex.assert_trees_all_close(updates, correct_updates) + + # Check repeated application, this time with no params. + correct_updates = correct_update_fn(correct_updates) + updates, _ = update_fn(updates, state) + chex.assert_trees_all_close(updates, correct_updates) + + def test_extra_args(self): + + class ArgNotEqual1Error(ValueError): + """Raised when argument not set as expected.""" + + def init(params): + return {'mu': params} + + def update_with_arg(updates, state, params=None, *, arg, **extra_args): + del params, extra_args + if arg != 1: + raise ArgNotEqual1Error() + return updates, state + + def update_without_arg(updates, state, params=None): + del params + return updates, state + + opt_no_arg = base.GradientTransformation(init, update_without_arg) + opt_extra_arg = base.GradientTransformationExtraArgs(init, update_with_arg) + + opt = _combining.partition( + { + 'a': opt_no_arg, + 'b': opt_extra_arg, + }, + ('a', 'b'), + ) + + fake_params = ({'u': jnp.array([1])}, {'v': jnp.array([1])}) + state = opt.init(fake_params) + + with self.assertRaises(TypeError): + opt.update(fake_params, state) + with self.assertRaises(ArgNotEqual1Error): + opt.update(fake_params, state, arg=2, ignored_kwarg='hi') + opt.update(fake_params, state, arg=1, ignored_kwarg='hi') + + @parameterized.parameters(list, tuple, dict) + def test_empty(self, container): + init_fn, update_fn = _combining.partition( + {0: alias.sgd(1.)}, lambda _: 0) + updates, _ = update_fn(container(), init_fn(container())) + self.assertEqual(updates, container()) + + @chex.all_variants + @parameterized.parameters( + (False, False), (False, True), (True, False), (True, True)) + def test_labels_mismatch(self, use_extra_label, use_fn): + # The labels from label_fn must be a subet of the keys for the tx. + params = {'a': 1., 'b': [2., 3.], 'c': {'d': 4., 'e': (5., 6.)}} + params = jax.tree_util.tree_map(jnp.asarray, params) + label_tree = {'a': 0, 'b': [1, 0], 'c': 1} # prefix of params + + if use_extra_label: + label_tree['a'] = 3 + + transforms = {0: alias.sgd(1.), + 1: alias.adam(1., b1=0., b2=0.), + 2: _accumulation.trace(1.0)} + init_fn, update_fn = _combining.partition( + transforms, (lambda _: label_tree) if use_fn else label_tree) + + if use_extra_label: + with self.assertRaises(ValueError): + self.variant(init_fn)(params) + else: + state = self.variant(init_fn)(params) + updates = jax.tree_util.tree_map(lambda x: x / 10.0, params) + self.variant(update_fn)(updates, state) + + +def scale_by_loss(): + """Scale the gradient by the absolute value of the loss.""" + + def update_fn(updates, state, params, *, loss, **extra_args): + del params, extra_args + updates = jax.tree_util.tree_map( + lambda u: u / loss, updates) + return updates, state + + return base.GradientTransformationExtraArgs(base.init_empty_state, update_fn) + + +class NamedChainTest(absltest.TestCase): + + def test_named_chain(self): + tx = _combining.named_chain( + ('scale', transform.scale(0.1)), + ('scale_loss', scale_by_loss()), + ) + + params = {'a': jnp.ones((4,))} + grads = params + + opt_state = tx.init(params) + updates, _ = tx.update(grads, opt_state, params, loss=0.1) + + chex.assert_trees_all_close(updates, {'a': jnp.ones((4,))}) + + +if __name__ == '__main__': + absltest.main() diff --git a/optax/tree_utils/_random_test.py b/optax/tree_utils/_random_test.py index a37e50433..0d3204926 100644 --- a/optax/tree_utils/_random_test.py +++ b/optax/tree_utils/_random_test.py @@ -35,8 +35,22 @@ def setUp(self): self.tree_a = (rng.randn(20, 10) + 1j * rng.randn(20, 10), rng.randn(20)) self.tree_b = (rng.randn(20, 10), rng.randn(20)) - self.tree_a_dict = (1.0, {'k1': 1.0, 'k2': (1.0, 1.0)}, 1.0) - self.tree_b_dict = (1.0, {'k1': 2.0, 'k2': (3.0, 4.0)}, 5.0) + self.tree_a_dict = jtu.tree_map( + jnp.asarray, + ( + 1.0, + {'k1': 1.0, 'k2': (1.0, 1.0)}, + 1.0 + ) + ) + self.tree_b_dict = jtu.tree_map( + jnp.asarray, + ( + 1.0, + {'k1': 2.0, 'k2': (3.0, 4.0)}, + 5.0 + ) + ) self.array_a = rng.randn(20) + 1j * rng.randn(20) self.array_b = rng.randn(20) diff --git a/optax/tree_utils/_tree_math.py b/optax/tree_utils/_tree_math.py index 32d53f39d..d017a1084 100644 --- a/optax/tree_utils/_tree_math.py +++ b/optax/tree_utils/_tree_math.py @@ -83,7 +83,10 @@ def tree_div(tree_x: Any, tree_y: Any) -> Any: return jtu.tree_map(operator.truediv, tree_x, tree_y) -def tree_scalar_mul(scalar: Union[float, jax.Array], tree: Any) -> Any: +def tree_scalar_mul( + scalar: Union[float, jax.Array], + tree: Any, +) -> Any: r"""Multiply a tree by a scalar. In infix notation, the function performs ``out = scalar * tree``. @@ -113,7 +116,11 @@ def tree_add_scalar_mul( Returns: a pytree with the same structure as ``tree_x`` and ``tree_y``. """ - return jtu.tree_map(lambda x, y: x + scalar * y, tree_x, tree_y) + scalar = jnp.asarray(scalar) + return jtu.tree_map( + lambda x, y: x + scalar.astype(x.dtype) * y, + tree_x, + tree_y) _vdot = functools.partial(jnp.vdot, precision=jax.lax.Precision.HIGHEST)