Skip to content

Add decay_var_list as init option to DecoupledWeightDecayExtension #2018

Closed
@johnbensnyder

Description

@johnbensnyder

Describe the feature and the current behavior/state.

Currently, the decay_var_list can be passed to an optimizer built with DecoupledWeightDecayExtension. However, this option disappears when the optimizer is passed to tf.keras.mixed_precision.experimental.LossScaleOptimizer. This prevents the optimizer from applying weight decay to only a subset of variables.

Instead, could there be an option to pass the decay variables during initialization? Something like

class DecoupledWeightDecayExtension:
    
    @typechecked
    def __init__(self, weight_decay: Union[FloatTensorLike, Callable], 
                 decay_var_list: Optional[List]=None, 
                 **kwargs):
        wd = kwargs.pop("weight_decay", weight_decay)
        decay_var_list = kwargs.pop("decay_var_list", decay_var_list)
        self._decay_var_list = set([v.ref() for v in decay_var_list]) if decay_var_list else False
        super().__init__(**kwargs)
        self._set_hyper("weight_decay", wd)

Relevant information

  • Are you willing to contribute it (yes/no):
    yes
  • Are you willing to maintain it going forward? (yes/no):
    yes
  • Is there a relevant academic paper? (if so, where):
    no
  • Is there already an implementation in another framework? (if so, where):
    no
  • Was it part of tf.contrib? (if so, where):
    no

Which API type would this fall under (layer, metric, optimizer, etc.)
optimizer

Who will benefit with this feature?
anyone using mixed precision

Any other info.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions