Skip to content

Commit

Permalink
[fbsync] Add shufflenetv2 1.5 and 2.0 weights (#5906)
Browse files Browse the repository at this point in the history
Summary:
* Add shufflenetv2 1.5 and 2.0 weights

* Update recipe

* Add to docs

* Use resize_size=232 for eval and update the result

* Add quantized shufflenetv2 large

* Update docs and readme

* Format with ufmt

* Add to hubconf.py

* Update readme for classification reference

* Fix reference classification readme

* Fix typo on readme

* Update reference/classification/readme

Reviewed By: jdsgomes, NicolasHug

Differential Revision: D36095677

fbshipit-source-id: 74a575c6272df397852dba325f9c1b1e5a1c0231
  • Loading branch information
YosuaMichael authored and facebook-github-bot committed May 6, 2022
1 parent 3156a3e commit ccc0a92
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 6 deletions.
8 changes: 6 additions & 2 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,10 @@ Densenet-201 76.896 93.370
Densenet-161 77.138 93.560
Inception v3 77.294 93.450
GoogleNet 69.778 89.530
ShuffleNet V2 x1.0 69.362 88.316
ShuffleNet V2 x0.5 60.552 81.746
ShuffleNet V2 x1.0 69.362 88.316
ShuffleNet V2 x1.5 72.996 91.086
ShuffleNet V2 x2.0 76.230 93.006
MobileNet V2 71.878 90.286
MobileNet V3 Large 74.042 91.340
MobileNet V3 Small 67.668 87.402
Expand Down Expand Up @@ -499,8 +501,10 @@ Model Acc@1 Acc@5
================================ ============= =============
MobileNet V2 71.658 90.150
MobileNet V3 Large 73.004 90.858
ShuffleNet V2 x1.0 68.360 87.582
ShuffleNet V2 x0.5 57.972 79.780
ShuffleNet V2 x1.0 68.360 87.582
ShuffleNet V2 x1.5 72.052 90.700
ShuffleNet V2 x2.0 75.354 92.488
ResNet 18 69.494 88.882
ResNet 50 75.920 92.814
ResNext 101 32x8d 78.986 94.480
Expand Down
7 changes: 6 additions & 1 deletion hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@
deeplabv3_mobilenet_v3_large,
lraspp_mobilenet_v3_large,
)
from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
from torchvision.models.shufflenetv2 import (
shufflenet_v2_x0_5,
shufflenet_v2_x1_0,
shufflenet_v2_x1_5,
shufflenet_v2_x2_0,
)
from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1
from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
from torchvision.models.vision_transformer import (
Expand Down
29 changes: 29 additions & 0 deletions references/classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,20 @@ torchrun --nproc_per_node=8 train.py\
Note that `--val-resize-size` was optimized in a post-training step, see their `Weights` entry for the exact value.


### ShuffleNet V2
```
torchrun --nproc_per_node=8 train.py \
--batch-size=128 \
--lr=0.5 --lr-scheduler=cosineannealinglr --lr-warmup-epochs=5 --lr-warmup-method=linear \
--auto-augment=ta_wide --epochs=600 --random-erase=0.1 --weight-decay=0.00002 \
--norm-weight-decay=0.0 --label-smoothing=0.1 --mixup-alpha=0.2 --cutmix-alpha=1.0 \
--train-crop-size=176 --model-ema --val-resize-size=232 --ra-sampler --ra-reps=4
```
Here `$MODEL` is either `shufflenet_v2_x1_5` or `shufflenet_v2_x2_0`.

The models `shufflenet_v2_x0_5` and `shufflenet_v2_x1_0` were contributed by the community. See [PR-849](https://github.com/pytorch/vision/pull/849#issuecomment-483391686) for details.


## Mixed precision training
Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [torch.cuda.amp](https://pytorch.org/docs/stable/amp.html?highlight=amp#module-torch.cuda.amp).

Expand Down Expand Up @@ -263,6 +277,21 @@ python train_quantization.py --device='cpu' --post-training-quantize --backend='
```
Here `$MODEL` is one of `googlenet`, `inception_v3`, `resnet18`, `resnet50`, `resnext101_32x8d`, `shufflenet_v2_x0_5` and `shufflenet_v2_x1_0`.

### Quantized ShuffleNet V2

Here are commands that we use to quantized the `shufflenet_v2_x1_5` and `shufflenet_v2_x2_0` models.
```
# For shufflenet_v2_x1_5
python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' \
--model=shufflenet_v2_x1_5 --weights="ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1" \
--train-crop-size 176 --val-resize-size 232 --data-path /datasets01_ontap/imagenet_full_size/061417/
# For shufflenet_v2_x2_0
python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' \
--model=shufflenet_v2_x2_0 --weights="ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1" \
--train-crop-size 176 --val-resize-size 232 --data-path /datasets01_ontap/imagenet_full_size/061417/
```

### QAT MobileNetV2

For Mobilenet-v2, the model was trained with quantization aware training, the settings used are:
Expand Down
95 changes: 94 additions & 1 deletion torchvision/models/quantization/shufflenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,25 @@
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights
from ..shufflenetv2 import (
ShuffleNet_V2_X0_5_Weights,
ShuffleNet_V2_X1_0_Weights,
ShuffleNet_V2_X1_5_Weights,
ShuffleNet_V2_X2_0_Weights,
)
from .utils import _fuse_modules, _replace_relu, quantize_model


__all__ = [
"QuantizableShuffleNetV2",
"ShuffleNet_V2_X0_5_QuantizedWeights",
"ShuffleNet_V2_X1_0_QuantizedWeights",
"ShuffleNet_V2_X1_5_QuantizedWeights",
"ShuffleNet_V2_X2_0_QuantizedWeights",
"shufflenet_v2_x0_5",
"shufflenet_v2_x1_0",
"shufflenet_v2_x1_5",
"shufflenet_v2_x2_0",
]


Expand Down Expand Up @@ -143,6 +152,42 @@ class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum):
DEFAULT = IMAGENET1K_FBGEMM_V1


class ShuffleNet_V2_X1_5_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_5_fbgemm-d7401f05.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/pull/5906",
"num_params": 3503624,
"unquantized": ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1,
"metrics": {
"acc@1": 72.052,
"acc@5": 90.700,
},
},
)
DEFAULT = IMAGENET1K_FBGEMM_V1


class ShuffleNet_V2_X2_0_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/shufflenetv2_x2_0_fbgemm-5cac526c.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/pull/5906",
"num_params": 7393996,
"unquantized": ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1,
"metrics": {
"acc@1": 75.354,
"acc@5": 92.488,
},
},
)
DEFAULT = IMAGENET1K_FBGEMM_V1


@handle_legacy_interface(
weights=(
"pretrained",
Expand Down Expand Up @@ -205,3 +250,51 @@ def shufflenet_v2_x1_0(
return _shufflenetv2(
[4, 8, 4], [24, 116, 232, 464, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs
)


def shufflenet_v2_x1_5(
*,
weights: Optional[Union[ShuffleNet_V2_X1_5_QuantizedWeights, ShuffleNet_V2_X1_5_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableShuffleNetV2:
"""
Constructs a ShuffleNetV2 with 1.5x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`_.
Args:
weights (ShuffleNet_V2_X1_5_QuantizedWeights or ShuffleNet_V2_X1_5_Weights, optional): The pretrained
weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model
"""
weights = (ShuffleNet_V2_X1_5_QuantizedWeights if quantize else ShuffleNet_V2_X1_5_Weights).verify(weights)
return _shufflenetv2(
[4, 8, 4], [24, 176, 352, 704, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs
)


def shufflenet_v2_x2_0(
*,
weights: Optional[Union[ShuffleNet_V2_X2_0_QuantizedWeights, ShuffleNet_V2_X2_0_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableShuffleNetV2:
"""
Constructs a ShuffleNetV2 with 2.0x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`_.
Args:
weights (ShuffleNet_V2_X2_0_QuantizedWeights or ShuffleNet_V2_X2_0_Weights, optional): The pretrained
weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model
"""
weights = (ShuffleNet_V2_X2_0_QuantizedWeights if quantize else ShuffleNet_V2_X2_0_Weights).verify(weights)
return _shufflenetv2(
[4, 8, 4], [24, 244, 488, 976, 2048], weights=weights, progress=progress, quantize=quantize, **kwargs
)
30 changes: 28 additions & 2 deletions torchvision/models/shufflenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,37 @@ class ShuffleNet_V2_X1_0_Weights(WeightsEnum):


class ShuffleNet_V2_X1_5_Weights(WeightsEnum):
pass
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/shufflenetv2_x1_5-3c479a10.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/pull/5906",
"num_params": 3503624,
"metrics": {
"acc@1": 72.996,
"acc@5": 91.086,
},
},
)
DEFAULT = IMAGENET1K_V1


class ShuffleNet_V2_X2_0_Weights(WeightsEnum):
pass
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/shufflenetv2_x2_0-8be3c8ee.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/pull/5906",
"num_params": 7393996,
"metrics": {
"acc@1": 76.230,
"acc@5": 93.006,
},
},
)
DEFAULT = IMAGENET1K_V1


@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1))
Expand Down

0 comments on commit ccc0a92

Please sign in to comment.