Skip to content

Improve the accuracy of Classification models by using SOTA recipes and primitives #3995

Closed
@datumbox

Description

🚀 Feature

Update the weights of all pre-trained models to improve their accuracy.

Motivation

New Recipe + FixRes mitigations

torchrun --nproc_per_node=8 train.py --model $MODEL_NAME --batch-size 128 --lr 0.5 \
--lr-scheduler cosineannealinglr --lr-warmup-epochs 5 --lr-warmup-method linear \
--auto-augment ta_wide --epochs 600 --random-erase 0.1 --weight-decay 0.00002 \
--norm-weight-decay 0.0 --label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 \
--train-crop-size 176 --model-ema --val-resize-size 232

Using a recipe which includes Warmup, Cosine Annealing, Label Smoothing, Mixup, Cutmix, Random Erasing, TrivialAugment, No BN weight decay, EMA and long training cycles and optional FixRes mitigations we are able to improve the resnet50 accuracy by over 4.5 points. For more information on the training recipe, check here:

Old ResNet50:
Acc@1 76.130 Acc@5 92.862

New ResNet50:
Acc@1 80.674 Acc@5 95.166

Running other models through the same recipe, achieves the following improved accuracies:

ResNet101:
Acc@1 81.728 Acc@5 95.670

ResNet152:
Acc@1 82.042 Acc@5 95.926

ResNeXt50_32x4d:
Acc@1 81.116 Acc@5 95.478

ResNeXt101_32x8d:
Acc@1 82.834 Acc@5 96.228

MobileNetV3 Large:
Acc@1 74.938 Acc@5 92.496

Wide ResNet50 2:
Acc@1 81.602 Acc@5 95.758 (@prabhat00155)

Wide ResNet101 2:
Acc@1 82.492 Acc@5 96.110 (@prabhat00155)

regnet_x_400mf:
Acc@1 74.864 Acc@5 92.322 (@kazhang)

regnet_x_800mf:
Acc@1 77.522 Acc@5 93.826 (@kazhang)

regnet_x_1_6gf:
Acc@1 79.668 Acc@5 94.922 (@kazhang)

New Recipe (without FixRes mitigations)

torchrun --nproc_per_node=8 train.py --model $MODEL_NAME --batch-size 128 --lr 0.5 \
--lr-scheduler cosineannealinglr --lr-warmup-epochs 5 --lr-warmup-method linear \
--auto-augment ta_wide --epochs 600 --random-erase 0.1 --weight-decay 0.00002 \
--norm-weight-decay 0.0 --label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 \
--model-ema --val-resize-size 232

Removing the optional FixRes mitigations seems to yield better results for some deeper architectures and variants with larger receptive fields:

ResNet101:
Acc@1 81.886 Acc@5 95.780

ResNet152:
Acc@1 82.284 Acc@5 96.002

ResNeXt50_32x4d:
Acc@1 81.198 Acc@5 95.340

ResNeXt101_32x8d:
Acc@1 82.812 Acc@5 96.226

MobileNetV3 Large:
Acc@1 75.152 Acc@5 92.634

Wide ResNet50_2:
Acc@1 81.452 Acc@5 95.544 (@prabhat00155)

Wide ResNet101_2:
Acc@1 82.510 Acc@5 96.020 (@prabhat00155)

regnet_x_3_2gf:
Acc@1 81.196 Acc@5 95.430

regnet_x_8gf:
Acc@1 81.682 Acc@5 95.678

regnet_x_16g:
Acc@1 82.716 Acc@5 96.196

regnet_x_32gf:
Acc@1 83.014 Acc@5 96.288

regnet_y_400mf:
Acc@1 75.804 Acc@5 92.742

regnet_y_800mf:
Acc@1 78.828 Acc@5 94.502

regnet_y_1_6gf:
Acc@1 80.876 Acc@5 95.444

regnet_y_3_2gf:
Acc@1 81.982 Acc@5 95.972

regnet_y_8gf:
Acc@1 82.828 Acc@5 96.330

regnet_y_16gf:
Acc@1 82.886 Acc@5 96.328

regnet_y_32gf:
Acc@1 83.368 Acc@5 96.498

New Recipe + Regularization tuning

torchrun --nproc_per_node=8 train.py --model $MODEL_NAME --batch-size 128 --lr 0.5 \
--lr-scheduler cosineannealinglr --lr-warmup-epochs 5 --lr-warmup-method linear \
--auto-augment ta_wide --epochs 600 --random-erase 0.1 --weight-decay 0.00001 \
--norm-weight-decay 0.0 --label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 \
--model-ema --val-resize-size 232

Adjusting slightly the regularization can help us improve the following:

MobileNetV3 Large:
Acc@1 75.274 Acc@5 92.566

In addition to regularization adjustment we can also apply the Repeated Augmentation trick --ra-sampler --ra-reps 4:

MobileNetV2:
Acc@1 72.154 Acc@5 90.822

Post-Training Quantized models

ResNet50:
Acc@1 80.282 Acc@5 94.976

ResNeXt101_32x8d:
Acc@1 82.574 Acc@5 96.132

New Recipe (LR+weight_decay+train_crop_size tuning)

torchrun --ngpus 8 --nodes 1 --model $MODEL_NAME --batch-size 128 --lr 1 \
--lr-scheduler cosineannealinglr --lr-warmup-epochs 5 --lr-warmup-method linear \
--auto-augment ta_wide --epochs 600 --random-erase 0.1 --weight-decay 0.000002 \
--norm-weight-decay 0.0 --label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 \
--train-crop-size 208 --model-ema --val-crop-size 240 --val-resize-size 255
EfficientNet-B1:
Acc@1 79.838 Acc@5 94.934

Pitch

To be able to improve the pre-trained model accuracy, we need to complete the "Batteries Included" work as #3911. Moreover we will need to extend our existing model builders to support multiple weights as described at #4611. Then we will be able to:

  • Update our reference scripts for classification to support the new primitives added by the "Batteries Included" initiative.
  • Find a good training recipe for the most important pre-trained models and re-train them. Note that different training configuration might be required for different types of models (for example mobile models are less likely to overfit comparing to bigger models and thus make use of different recipes/primitives)
  • Update the weights of the models in the library.

cc @datumbox @vfdev-5

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions