Closed
Description
Describe the feature and the current behavior/state.
Currently, the decay_var_list can be passed to an optimizer built with DecoupledWeightDecayExtension. However, this option disappears when the optimizer is passed to tf.keras.mixed_precision.experimental.LossScaleOptimizer. This prevents the optimizer from applying weight decay to only a subset of variables.
Instead, could there be an option to pass the decay variables during initialization? Something like
class DecoupledWeightDecayExtension:
@typechecked
def __init__(self, weight_decay: Union[FloatTensorLike, Callable],
decay_var_list: Optional[List]=None,
**kwargs):
wd = kwargs.pop("weight_decay", weight_decay)
decay_var_list = kwargs.pop("decay_var_list", decay_var_list)
self._decay_var_list = set([v.ref() for v in decay_var_list]) if decay_var_list else False
super().__init__(**kwargs)
self._set_hyper("weight_decay", wd)
Relevant information
- Are you willing to contribute it (yes/no):
yes - Are you willing to maintain it going forward? (yes/no):
yes - Is there a relevant academic paper? (if so, where):
no - Is there already an implementation in another framework? (if so, where):
no - Was it part of tf.contrib? (if so, where):
no
Which API type would this fall under (layer, metric, optimizer, etc.)
optimizer
Who will benefit with this feature?
anyone using mixed precision
Any other info.