Description
There are several optimizer classes which wrap an existing optimizer and add functionality, but they have slightly different APIs The Lookahead, MovingAverage, and SWA each take an optimizer in its constructor and adds additional functionality. Similarly, LossScaleOptimizer from Keras, an optimizer which can prevent numeric underflow by scaling loss and gradients, takes an optimizer in its constructor as well. On the other hand, the extend_with_decoupled_weight_decay function instead creates a new dynamically-created class which subclasses from the original optimizer's class but adds functionality. I think we should make the API for all these classes consistent, so that all optimizer wrappers have the same API.
LossScaleOptimizer is being made non-experimental in TF 2.4, so I need to decide on the final API soon. Ideally, it would be consistent with the optimizer APIs in TF-Addons, which is why I am bringing up this issue. There is a mixed precision RFC describing the proposed changes (no need to read it, I will summarize everything important here).
I summarize the two approaches an optimizer extension can take below:
1. Continue wrapping optimizers with another optimizer class:
This approach continues having optimizers like Lookahead and LossScaleOptimizer wrap other optimizers. An OptimizerWrapper helper class can be created, which automatically deals with wrapping logic such as implementing get_config
properly so each optimizer wrapper can focus on just the added functionality.
One question is how to whether __getattribute__
and __setattr__
should be delegate hyperparameters (and potentially other attributes) to the inner optimizer, which could be done in OptimizerWrapper. I proposed doing this in the mixed precision RFC, but can change it. Doing this delegation would allow attributes of the inner optimizer to be accessed Lookahead
, LossScaleOptimizer
, and other wrappers:
opt = tf.keras.optimizers.SGD(0.1, momentum=0.1)
opt = tfa.optimizers.Lookahead(opt)
... # Do some training
print(opt.momemtum)
opt.momentum = 0.05
Delegating all attributes in __getattribute__
is tempting, but it does have a significant flaw: it can cause the custom logic in optimizer wrappers not to be called:
class MyCustomOptimizer(tf.keras.optimizers.SGD):
def apply_gradients_zero_min(self, grads_and_vars):
grads_and_vars = [(tf.nn.relu(g), v) for g, v in grads_and_vars]
self.apply_gradients(grads_and_vars)
opt = MyCustomOptimizer(1.)
opt = tf.keras.mixed_precision.LossScaleOptimizer(opt)
opt.apply_gradients_zero_min(...) # Bug! Does not call LossScaleOptimizer version of apply_gradients.
Custom logic in the function LossScaleOptimizer.apply_gradients
is not called, as only the apply_gradients
of the inner optimizer is called. Similar issues occur with any optimizer wrapper.
2. Dynamically create a subclass
This approach is used by extend_with_decoupled_weight_decay
. It consists of dynamically creating a subclass with two superclasses: The original optimizer (e.g. Adam) and a special extension class which adds the extra optimizer functionality. For example, see the implementation of extend_with_decoupled_weight_decay
. Other optimizers could use the same approach.
We want to allow users to, e.g., create an Adam optimizer then pass it to a function which returns an optimizer of the new type. To allow this, we could have a function which takes in the old optimizer, creates the new class and returns an optimizer of the new class, e.g.:
opt = tf.keras.optimizers.SGD(0.1, momentum=0.1)
opt = tfa.optimizers.copy_with_decoupled_weight_decay(opt)
In this case, copy_with_decoupled_weight_decay
would create the new optimizer class then use opt.get_config
to serialize the old optimizer and from_config
to deserialize it into the new optimizer. One flaw is this requires the optimzier to be serializable with get_config
and from_config
Another flaw is that users might accidentally modify the old optimizer, not realizing that this does not modify the new optimizer.
Alternatively, we could have a function monkey-patch opt.__class__
to the new optimizer class, so the optimizer does not need to be serializable with get_config
and from_config
:
opt = tf.keras.optimizers.SGD(0.1, momentum=0.1)
# Monkey-patches opt.__class__ to an optimizer class which subclasses
# both DecoupledWeightDecayExtension and SGD
tfa.optimizers.modify_class_to_use_decoupled_weight_decay(opt)
Which to choose?
Does anyone have any opinions on which approach is superior, and whether all the optimizers should be unified to use the same approach? I think I prefer dynamically creating a subclass because it completely emulates the API of the original optimizer's class.
/CC @CyberZHG, @bhack, @Squadrick, @shreyashpatodia, @PhilJd, who have worked on optimizers in TF-Addons
/CC @omalleyt12, @tomerk, who have worked on Keras optimizers