Skip to content

Fix layer decay to work as intended with optimizers #2532

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

Jookare
Copy link

@Jookare Jookare commented Jul 3, 2025

Summary

In the current create_optimizer implementation, the layer_decay parameter does not have any effect. This pull request removes the unused lr_scale and instead directly sets the lr values of the parameter groups, scaled according to their layer depth. This change leads to significantly improved training performance, as the learning rate is now correctly applied per layer.

What was changed

  • Removed the unused lr_scale parameter in param groups.
  • Applied the scaling directly to the lr of each param group.
  • Adjusted optimizer creation logic to ensure lr is resolved correctly (either explicitly or via default).

Results

A simple CIFAR-10 test shows that the previous implementation had minimal effect when setting layer_decay, while the new version properly applies the decay and improves performance.

Experimental Setup

Using AdamW optimizer from create_optimizer_v2, with PyTorch Lightning (deterministic=True)

  • version_0 (gray): layer_decay = 0 (baseline )
  • version_1 (cyan): layer_decay = 0.75 (old implementation)
  • version_2 (pink): layer_decay = 0.75 (new implementation)

layer_decay_test

I chose to explicitly apply the lr values in the param_groups_layer_decay function, as relying on a lr_scale parameter would have required significantly more effort to propagate and support correctly across all relevant components.

@rwightman
Copy link
Collaborator

@Jookare layer decay is working just fine. Because of the coupling between pytorch optimizers and schedulers, the lr_scale is applied by the scheduler, it scales the LR after it's modulated by the schedule, before. If you want to use layer decay with non-timm schedulers you need to find a way to apply it in the scheduler.

If you modify the timm scheduler base class to print scale per param group:

            if 'lr_scale' in param_group:
                print(param_group['lr_scale'], value)
                param_group[self.param_group_field] = value * param_group['lr_scale']
            else:
                param_group[self.param_group_field] = value

You see (layer_decay is 0.7 here):

0.01384128720099999 1.6949152542372884e-06
0.01384128720099999 1.6949152542372884e-06
0.019773267429999988 1.6949152542372884e-06
0.019773267429999988 1.6949152542372884e-06
0.028247524899999984 1.6949152542372884e-06
0.028247524899999984 1.6949152542372884e-06
0.04035360699999998 1.6949152542372884e-06
0.04035360699999998 1.6949152542372884e-06
0.05764800999999997 1.6949152542372884e-06
0.05764800999999997 1.6949152542372884e-06
0.08235429999999996 1.6949152542372884e-06
0.08235429999999996 1.6949152542372884e-06
0.11764899999999996 1.6949152542372884e-06
0.11764899999999996 1.6949152542372884e-06
0.16806999999999994 1.6949152542372884e-06
0.16806999999999994 1.6949152542372884e-06
0.24009999999999995 1.6949152542372884e-06
0.24009999999999995 1.6949152542372884e-06
0.3429999999999999 1.6949152542372884e-06
0.3429999999999999 1.6949152542372884e-06
0.48999999999999994 1.6949152542372884e-06
0.48999999999999994 1.6949152542372884e-06
0.7 1.6949152542372884e-06
1.0 1.6949152542372884e-06
1.0 1.6949152542372884e-06

@Jookare
Copy link
Author

Jookare commented Jul 3, 2025

Just checked and yes, that fixes the issue. I saw that the lr_scale was utilized in the scheduler, but didn't manage to put one and one together. It is good to know that one needs to employ the lr_scheduler to get the benefits of layer_decay.

A follow-up question regarding the second commit: what about if layer_decay is set to zero, then the lr_scale goes to 0 in all but the last two layers. Would it be better to just skip the param_groups_layer_decay function in such a case?

@rwightman
Copy link
Collaborator

@Jookare so if layer_decay is set to zero, it doesn't harm anything to remain as is, however it could be more efficient, not by skipping the param_groups_layer_decay function, but technically optimization could be made much more efficient by having that function disable grads (and skipp adding optimization groups) for any group that has a lr_scale that's 'close enough' to zero ... what close enough is, not sure, below float32 eps?

@Jookare
Copy link
Author

Jookare commented Jul 4, 2025

That is a very interesting idea, but actually not what I was meaning here. Right now setting layer decay to 0 sets the lr_scale of all but the last two layers to zero, as shown below. What I was wondering was that would it be better to have no decay imply that all layers get lr_scale = 1.0?

Here layer_decay is 0.0:
0.0 0.0001
0.0 0.0001
0.0 0.0001
0.0 0.0001
0.0 0.0001
0.0 0.0001
0.0 0.0001
0.0 0.0001
0.0 0.0001
0.0 0.0001
0.0 0.0001
0.0 0.0001
0.0 0.0001
0.0 0.0001
0.0 0.0001
0.0 0.0001
0.0 0.0001
0.0 0.0001
0.0 0.0001
0.0 0.0001
0.0 0.0001
0.0 0.0001
0.0 0.0001
0.0 0.0001
1.0 0.0001
1.0 0.0001

@rwightman
Copy link
Collaborator

@Jookare I'm not seeing how that makes sense, that's the behaviour when layer-decay is not enabled at all. I view layer_decay=None as the disable, and layer_decay=0 as the equivalent of linear-probe, or fine-tuning only the head.

@Jookare
Copy link
Author

Jookare commented Jul 4, 2025

Ahh, now I see. Sorry I was mixing up what setting layer-decay to 0 versus 1 actually does/means. Thanks for the explanation and your time!

@rwightman rwightman closed this Jul 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants