Skip to content

[Feature] Add Discriminative Layer Training that works with tf.keras.model.fit in eager and distributed #958

Closed
@hyang0129

Description

@hyang0129

Currently, the only way to apply different lr to different layers is to use a custom training loop. This issue was raised in this link tensorflow/tensorflow#33036, but closed.

I am willing to contribute and maintain this feature. Fastai has implemented this functionality, see https://docs.fast.ai/basic_train.html#Discriminative-layer-training.

This should fall under the optimizer api.

This should benefit anyone who is trying to do transfer learning. I actually used this feature in pytorch for training bert models for a kaggle competition.

Proposed implementation method:
For any optimizer, wrap the optimizer.apply_gradient method and run a gradient scaler for each gradient based on the variable's learning rate multiplier.

This means that the optimizer.minimize() would do something like this:
opt._compute_gradients -> scale_gradients_for_each_variable -> opt.apply_gradients

This interception of the apply_gradients can be done on the instance level of the optimizer, without modifying the optimizer class. https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance
Example:
foo = Foo()
def sample_method(self): print('abc')
foo.sample_method = sample_method.get(foo)

To make things easier for users, variables are assigned a multiplier based on the layer's multipler attribute. If the layer doesn't have a multiplier attribute, assign the variable a multiplier of 1.

Layers should recursively apply their multiplier to nested layers. This will make it easier to assign a lower lr to the pretrained part of a model, eg. resnet + head. This can be done by this cope snippet.

def get_lowest_layers(model):
layers = get_layers(model)

mult = get_mult(model)

if len(layers) > 0: 
    for layer in layers: 
        assign_mult(layer, mult)
        for sublayer in get_lowest_layers(layer):
            yield sublayer
else:
    yield model

Users could interface with the new discriminative lr wrapper in this way.

model = model_fn()
model.layers[0].lr_mult = 0.05 #5% of total LR
model.compile(loss ='MSE', optimizer = 'adam')
model = wrap_fn(model)
model.fit(x,y)

Testing needs to be done to ensure that this methodology would work in distributed mode. I've already tested the concept in eager mode, but ideally it should be built to work in any mode.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions