Skip to content

exclude_from_weight_decay for AdamW and SGDW #2624

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 13 additions & 26 deletions tensorflow_addons/optimizers/lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
76 minutes](https://arxiv.org/abs/1904.00962).
"""

import re
import warnings

from typing import Optional, Union, Callable, List
from typeguard import typechecked

import tensorflow as tf
from tensorflow_addons.utils.types import FloatTensorLike
from tensorflow_addons.optimizers.utils import is_variable_matched_by_regexes


@tf.keras.utils.register_keras_serializable(package="Addons")
Expand Down Expand Up @@ -163,12 +163,11 @@ def _resource_apply_dense(self, grad, var, apply_state=None):
v_sqrt = tf.sqrt(v_t_hat)
update = m_t_hat / (v_sqrt + coefficients["epsilon"])

var_name = self._get_variable_name(var.name)
if self._do_use_weight_decay(var_name):
if self._do_use_weight_decay(var):
update += coefficients["weight_decay"] * var

ratio = 1.0
if self._do_layer_adaptation(var_name):
if self._do_layer_adaptation(var):
w_norm = tf.norm(var, ord=2)
g_norm = tf.norm(update, ord=2)
ratio = tf.where(
Expand Down Expand Up @@ -206,12 +205,11 @@ def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
v_sqrt = tf.sqrt(v_t_hat)
update = m_t_hat / (v_sqrt + coefficients["epsilon"])

var_name = self._get_variable_name(var.name)
if self._do_use_weight_decay(var_name):
if self._do_use_weight_decay(var):
update += coefficients["weight_decay"] * var

ratio = 1.0
if self._do_layer_adaptation(var_name):
if self._do_layer_adaptation(var):
w_norm = tf.norm(var, ord=2)
g_norm = tf.norm(update, ord=2)
ratio = tf.where(
Expand Down Expand Up @@ -241,26 +239,15 @@ def get_config(self):
)
return config

def _do_use_weight_decay(self, param_name):
def _do_use_weight_decay(self, variable):
"""Whether to use L2 weight decay for `param_name`."""
if self.exclude_from_weight_decay:
for r in self.exclude_from_weight_decay:
if re.search(r, param_name) is not None:
return False
return True
return not is_variable_matched_by_regexes(
variable, self.exclude_from_weight_decay
)

def _do_layer_adaptation(self, param_name):
def _do_layer_adaptation(self, variable):
"""Whether to do layer-wise learning rate adaptation for
`param_name`."""
if self.exclude_from_layer_adaptation:
for r in self.exclude_from_layer_adaptation:
if re.search(r, param_name) is not None:
return False
return True

def _get_variable_name(self, param_name):
"""Get the variable name from the tensor name."""
m = re.match("^(.*):\\d+$", param_name)
if m is not None:
param_name = m.group(1)
return param_name
return not is_variable_matched_by_regexes(
variable, self.exclude_from_layer_adaptation
)
16 changes: 9 additions & 7 deletions tensorflow_addons/optimizers/tests/lamb_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,20 +335,22 @@ def test_get_config():

def test_exclude_weight_decay():
opt = lamb.LAMB(0.01, weight_decay=0.01, exclude_from_weight_decay=["var1"])
assert opt._do_use_weight_decay("var0")
assert not opt._do_use_weight_decay("var1")
assert not opt._do_use_weight_decay("var1_weight")
assert opt._do_use_weight_decay(tf.Variable([], name="var0"))
assert not opt._do_use_weight_decay(tf.Variable([], name="var1"))
assert not opt._do_use_weight_decay(tf.Variable([], name="var1_weight"))


def test_exclude_layer_adaptation():
opt = lamb.LAMB(0.01, exclude_from_layer_adaptation=["var1"])
assert opt._do_layer_adaptation("var0")
assert not opt._do_layer_adaptation("var1")
assert not opt._do_layer_adaptation("var1_weight")
assert opt._do_layer_adaptation(tf.Variable([], name="var0"))
assert not opt._do_layer_adaptation(tf.Variable([], name="var1"))
assert not opt._do_layer_adaptation(tf.Variable([], name="var1_weight"))


def test_serialization():
optimizer = lamb.LAMB(1e-4)
optimizer = lamb.LAMB(
1e-4, weight_decay_rate=0.01, exclude_from_weight_decay=["var1"]
)
config = tf.keras.optimizers.serialize(optimizer)
new_optimizer = tf.keras.optimizers.deserialize(config)
assert new_optimizer.get_config() == optimizer.get_config()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def do_test(
opt = optimizer(**optimizer_kwargs)
# Create the update op.
# Run 3 steps of the optimizer
optimizer_kwargs.pop("exclude_from_weight_decay", None)
for _ in range(3):
if do_decay_var_list:
opt.apply_gradients(
Expand Down Expand Up @@ -241,6 +242,31 @@ def test_basic_decay_var_list_adamw(dtype):
)


def test_exclude_weight_decay_adamw():
optimizer = weight_decay_optimizers.AdamW(
learning_rate=1e-4, weight_decay=1e-4, exclude_from_weight_decay=["var1"]
)
assert optimizer._do_use_weight_decay(tf.Variable([], name="var0"))
assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1"))
assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1_weight"))


@pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)])
def test_var_list_with_exclude_list_adamw(dtype):
do_test(
dtype,
weight_decay_optimizers.AdamW,
adamw_update_numpy,
do_decay_var_list=True,
learning_rate=0.001,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-8,
weight_decay=WEIGHT_DECAY,
exclude_from_weight_decay=["var0_*", "var1_*"],
)


def test_keras_fit():
"""Check if calling model.fit works."""
model = tf.keras.models.Sequential([tf.keras.layers.Dense(2)])
Expand Down Expand Up @@ -341,6 +367,30 @@ def test_basic_decay_var_list_sgdw(dtype):
)


def test_exclude_weight_decay_sgdw():
optimizer = weight_decay_optimizers.SGDW(
learning_rate=0.01, weight_decay=1e-4, exclude_from_weight_decay=["var1"]
)
assert optimizer._do_use_weight_decay(tf.Variable([], name="var0"))
assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1"))
assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1_weight"))


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)])
def test_var_list_with_exclude_list_sgdw(dtype):
do_test(
dtype,
weight_decay_optimizers.SGDW,
sgdw_update_numpy,
do_decay_var_list=True,
learning_rate=0.001,
momentum=0.9,
weight_decay=WEIGHT_DECAY,
exclude_from_weight_decay=["var0_*", "var1_*"],
)


@pytest.mark.parametrize(
"optimizer",
[
Expand Down Expand Up @@ -379,7 +429,9 @@ def test_optimizer_sparse(dtype, optimizer):


def test_serialization():
optimizer = weight_decay_optimizers.AdamW(learning_rate=1e-4, weight_decay=1e-4)
optimizer = weight_decay_optimizers.AdamW(
learning_rate=1e-4, weight_decay=1e-4, exclude_from_weight_decay=["var1"]
)
config = tf.keras.optimizers.serialize(optimizer)
new_optimizer = tf.keras.optimizers.deserialize(config)
assert new_optimizer.get_config() == optimizer.get_config()
Expand Down
22 changes: 22 additions & 0 deletions tensorflow_addons/optimizers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
# ==============================================================================
"""Additional Utilities used for tfa.optimizers."""

import re
import tensorflow as tf
from typing import List


def fit_bn(model, *args, **kwargs):
Expand Down Expand Up @@ -51,3 +53,23 @@ def fit_bn(model, *args, **kwargs):

model.trainable = _trainable
model._metrics = _metrics


def get_variable_name(variable) -> str:
"""Get the variable name from the variable tensor."""
param_name = variable.name
m = re.match("^(.*):\\d+$", param_name)
if m is not None:
param_name = m.group(1)
return param_name


def is_variable_matched_by_regexes(variable, regexes: List[str]) -> bool:
"""Whether variable is matched in regexes list by its name."""
if regexes:
# var_name = get_variable_name(variable)
var_name = variable.name
for r in regexes:
if re.search(r, var_name):
return True
return False
75 changes: 54 additions & 21 deletions tensorflow_addons/optimizers/weight_decay_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

import tensorflow as tf
from tensorflow_addons.utils.types import FloatTensorLike
from tensorflow_addons.optimizers.utils import is_variable_matched_by_regexes

from typeguard import typechecked
from typing import Union, Callable, Type
from typing import Union, Callable, Type, Optional, List


class DecoupledWeightDecayExtension:
Expand Down Expand Up @@ -71,24 +72,40 @@ def __init__(self, weight_decay, *args, **kwargs):
"""

@typechecked
def __init__(self, weight_decay: Union[FloatTensorLike, Callable], **kwargs):
def __init__(
self,
weight_decay: Union[FloatTensorLike, Callable],
exclude_from_weight_decay: Optional[List[str]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add this to the Args section below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Could you also add a sentence that explains that decay_var_list in minimize takes priority over exclude_from_weight_decay if specified (and also add a corresponding sentence to the documentation to minimize)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated DecoupledWeightDecayExtension __init__, minimize and apply_gradients docs, also extend_with_decoupled_weight_decay doc. Added exclude_from_weight_decay in AdamW and SGDW **kwargs doc.

**kwargs,
):
"""Extension class that adds weight decay to an optimizer.

Args:
weight_decay: A `Tensor`, a floating point value, or a schedule
that is a `tf.keras.optimizers.schedules.LearningRateSchedule`
to decay the variable by, in the update step.
exclude_from_weight_decay: List of regex patterns of
variables excluded from weight decay. Variables whose name
contain a substring matching the pattern will be excluded.
Note `decay_var_list` in `minimize` or `apply_gradients` takes
priority over `exclude_from_weight_decay` if specified.
**kwargs: Optional list or tuple or set of `Variable` objects to
decay.
"""
wd = kwargs.pop("weight_decay", weight_decay)
super().__init__(**kwargs)
self._decay_var_list = None # is set in minimize or apply_gradients
self._set_hyper("weight_decay", wd)
self.exclude_from_weight_decay = exclude_from_weight_decay

def get_config(self):
config = super().get_config()
config.update({"weight_decay": self._serialize_hyperparameter("weight_decay")})
config.update(
{
"weight_decay": self._serialize_hyperparameter("weight_decay"),
"exclude_from_weight_decay": self.exclude_from_weight_decay,
}
)
return config

@classmethod
Expand Down Expand Up @@ -130,7 +147,8 @@ def minimize(
grad_loss: Optional. A `Tensor` holding the gradient computed for
`loss`.
decay_var_list: Optional list of variables to be decayed. Defaults
to all variables in var_list.
to all variables in var_list. Note `decay_var_list` takes
priority over `exclude_from_weight_decay` if specified.
name: Optional name for the returned operation.
tape: (Optional) `tf.GradientTape`. If `loss` is provided as a
`Tensor`, the tape that computed the `loss` must be provided.
Expand All @@ -154,10 +172,11 @@ def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None, **kwar

Args:
grads_and_vars: List of (gradient, variable) pairs.
name: Optional name for the returned operation. Default to the
name: Optional name for the returned operation. Default to the
name passed to the `Optimizer` constructor.
decay_var_list: Optional list of variables to be decayed. Defaults
to all variables in var_list.
to all variables in var_list. Note `decay_var_list` takes
priority over `exclude_from_weight_decay` if specified.
**kwargs: Additional arguments to pass to the base optimizer's
apply_gradient method, e.g., TF2.2 added an argument
`experimental_aggregate_gradients`.
Expand All @@ -173,7 +192,7 @@ def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None, **kwar
return super().apply_gradients(grads_and_vars, name=name, **kwargs)

def _decay_weights_op(self, var, apply_state=None):
if not self._decay_var_list or var.ref() in self._decay_var_list:
if self._do_use_weight_decay(var):
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = (apply_state or {}).get(
(var_device, var_dtype)
Expand All @@ -183,7 +202,7 @@ def _decay_weights_op(self, var, apply_state=None):
return tf.no_op()

def _decay_weights_sparse_op(self, var, indices, apply_state=None):
if not self._decay_var_list or var.ref() in self._decay_var_list:
if self._do_use_weight_decay(var):
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = (apply_state or {}).get(
(var_device, var_dtype)
Expand Down Expand Up @@ -226,6 +245,12 @@ def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
grad, var, indices, apply_state=apply_state
)

def _do_use_weight_decay(self, var):
"""Whether to use L2 weight decay for `var`."""
if self._decay_var_list and var.ref() in self._decay_var_list:
return True
return not is_variable_matched_by_regexes(var, self.exclude_from_weight_decay)


@typechecked
def extend_with_decoupled_weight_decay(
Expand All @@ -243,9 +268,13 @@ def extend_with_decoupled_weight_decay(
The API of the new optimizer class slightly differs from the API of the
base optimizer:
- The first argument to the constructor is the weight decay rate.
- Optional keyword argument `exclude_from_weight_decay` accepts list of
regex patterns of variables excluded from weight decay. Variables whose
name contain a substring matching the pattern will be excluded.
- `minimize` and `apply_gradients` accept the optional keyword argument
`decay_var_list`, which specifies the variables that should be decayed.
If `None`, all variables that are optimized are decayed.
Note this takes priority over `exclude_from_weight_decay` if specified.
If both `None`, all variables that are optimized are decayed.

Usage example:
```python
Expand Down Expand Up @@ -376,12 +405,14 @@ def __init__(
nesterov: boolean. Whether to apply Nesterov momentum.
name: Optional name prefix for the operations created when applying
gradients. Defaults to 'SGD'.
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
`clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
norm; `clipvalue` is clip gradients by value, `decay` is
included for backward compatibility to allow time inverse decay
of learning rate. `lr` is included for backward compatibility,
recommended to use `learning_rate` instead.
**kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`,
`lr`, `decay`, `exclude_from_weight_decay`}. `clipnorm` is clip
gradients by norm; `clipvalue` is clip gradients by value.
`decay` is included for backward compatibility to allow time
inverse decay of learning rate. `lr` is included for backward
compatibility, recommended to use `learning_rate` instead.
`exclude_from_weight_decay` accepts list of regex patterns of
variables excluded from weight decay.
"""
super().__init__(
weight_decay,
Expand Down Expand Up @@ -466,12 +497,14 @@ def __init__(
beyond".
name: Optional name for the operations created when applying
gradients. Defaults to "AdamW".
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
`clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
norm; `clipvalue` is clip gradients by value, `decay` is
included for backward compatibility to allow time inverse decay
of learning rate. `lr` is included for backward compatibility,
recommended to use `learning_rate` instead.
**kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`,
`lr`, `decay`, `exclude_from_weight_decay`}. `clipnorm` is clip
gradients by norm; `clipvalue` is clip gradients by value.
`decay` is included for backward compatibility to allow time
inverse decay of learning rate. `lr` is included for backward
compatibility, recommended to use `learning_rate` instead.
`exclude_from_weight_decay` accepts list of regex patterns of
variables excluded from weight decay.
"""
super().__init__(
weight_decay,
Expand Down