A minimum working example for incorporating WASAM in an image classification pipeline implemented in PyTorch.
from wasam import WASAM
...
model = YourModel()
base_optimizer = torch.optim.SGD(lr=0.1, momentum=0.9) # define an optimizer for the "sharpness-aware" update
optimizer = WASAM(model.parameters(), base_optimizer, rho=0.05, lr=0.1, momentum=0.9)
max_epochs, swa_start_coeff = 200, 0.75
swa_start_epoch = int(max_epochs * swa_start_coeff)
...
for epoch in range(1, max_epochs+1):
# train one epoch
for input, output in loader:
def closure():
loss = loss_function(output, model(input))
loss.backward()
return loss
loss = loss_function(output, model(input))
loss.backward()
optimizer.step(closure) # performs model update and zeros gradients internally
# during end of training, average weights
if epoch >= swa_start_epoch:
optimizer.update_swa()
# before model evaluation, swap weights with averaged weights
optimizer.swap_swa_sgd()
evaluate_model(model)
# after model evaluation, swap them back (if training continues)
optimizer.swap_swa_sgd()
...
This option is slightly more complicated but enables higher flexibility.
There are two differences:
- We perform both forward and backward passes directly in the training loop
- We store and update multiple averaged models starting at different times
from wasam import WASAM
from swa_utils import MultipleSWAModels
...
device = torch.device("cuda:0")
model = YourModel()
base_optimizer = torch.optim.SGD(lr=0.1, momentum=0.9) # define an optimizer for the "sharpness-aware" update
optimizer = WASAM(model.parameters(), base_optimizer, rho=0.05)
max_epochs, swa_start_coeff = 200, 0.75
swa_starts = [0.5, 0.6, 0.75, 0.9]
swa_models = MultipleSWAModels(model, device, max_epochs, swa_starts)
...
for epoch in range(1, max_epochs+1):
# train one epoch
for input, output in loader:
# first forward-backward pass
loss = loss_function(output, model(input)) # use this loss for any training statistics
loss.backward()
optimizer.first_step(zero_grad=True)
# second forward-backward pass
loss_function(output, model(input)).backward() # make sure to do a full forward pass
optimizer.second_step(zero_grad=True)
# average weights
swa_models.update_parameters(model, epoch) # checks if epoch >= swa_start internally
# for model evaluation, you can loop over all averaged models
for model_dict in swa_models.models:
swa_model, swa_start = model_dict["model"], model_dict["start"]
if epoch >= swa_start:
evaluate_model(swa_model)
...
If your model possesses BatchNorm layers, you have to update the activation statistics of the averaged model, before you can use it. Here is a modified version of the simple option example.
from wasam import WASAM
...
model = YourModel()
base_optimizer = torch.optim.SGD(lr=0.1, momentum=0.9) # define an optimizer for the "sharpness-aware" update
optimizer = WASAM(model.parameters(), base_optimizer, rho=0.05, lr=0.1, momentum=0.9)
max_epochs, swa_start_coeff = 200, 0.75
swa_start_epoch = int(max_epochs * swa_start_coeff)
...
for epoch in range(1, max_epochs + 1):
train_model(loader, model, optimizer)
# during end of training, average weights
if epoch >= swa_start_epoch:
optimizer.update_swa()
# before model evaluation, swap weights with averaged weights
optimizer.swap_swa_sgd()
optimizer.bn_update(loader, model) # <-------------- Update batchnorm statistics
evaluate_model(model)
# after model evaluation, swap them back (if training continues)
optimizer.swap_swa_sgd()
...
Similarly, in the advanced option, one can update them as follows:
# for model evaluation, you can loop over all averaged models
for model_dict in swa_models.models:
swa_model, swa_start = model_dict["model"], model_dict["start"]
if epoch >= swa_start:
optimizer.bn_update(loader, swa_model) # <-------------- Update batchnorm statistics
evaluate_model(swa_model)
Install packages by
pip install -r requirements.txt
Then, you can run
cd example
python main.py
If you find this repository useful, please consider citing the paper.
@inproceedings{
kaddour2022when,
title={When Do Flat Minima Optimizers Work?},
author={Jean Kaddour and Linqing Liu and Ricardo Silva and Matt Kusner},
booktitle={Advances in Neural Information Processing Systems},
editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
year={2022},
url={https://openreview.net/forum?id=vDeh2yxTvuh}
}
This codebase builds on other repositories:
- (Adaptive) SAM Optimizer (PyTorch).
- A PyTorch implementation for PyramidNets
- Label Smoothing in PyTorch
- torch-contrib
Thanks a lot to the authors of these!