Skip to content

Commit

Permalink
[fbsync] Adding resnext101 64x4d model (#5935)
Browse files Browse the repository at this point in the history
Summary:
* Add resnext101_64x4d model definition

* Add test for resnext101_64x4d

* Add resnext101_64x4d weight

* Update checkpoint to use EMA weigth

* Add quantization model signature for resnext101_64x4d

* Fix class name and update accuracy using 1 gpu and batch_size=1

* Apply ufmt

* Update the quantized weight and accuracy that we still keep the training log

* Add quantized expect file

* Update docs and fix acc1

* Add recipe for quantized to PR

* Update models.rst

Reviewed By: YosuaMichael

Differential Revision: D36281598

fbshipit-source-id: 300bd36343b8ad8b185a246b794e078bdf67f5c8
  • Loading branch information
datumbox authored and facebook-github-bot committed May 11, 2022
1 parent 680a15b commit 309483c
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 0 deletions.
8 changes: 8 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ You can construct a model with random weights by calling its constructor:
mobilenet_v3_large = models.mobilenet_v3_large()
mobilenet_v3_small = models.mobilenet_v3_small()
resnext50_32x4d = models.resnext50_32x4d()
resnext101_32x8d = models.resnext101_32x8d()
resnext101_64x4d = models.resnext101_64x4d()
wide_resnet50_2 = models.wide_resnet50_2()
mnasnet = models.mnasnet1_0()
efficientnet_b0 = models.efficientnet_b0()
Expand Down Expand Up @@ -185,6 +187,7 @@ MobileNet V3 Large 74.042 91.340
MobileNet V3 Small 67.668 87.402
ResNeXt-50-32x4d 77.618 93.698
ResNeXt-101-32x8d 79.312 94.526
ResNeXt-101-64x4d 83.246 96.454
Wide ResNet-50-2 78.468 94.086
Wide ResNet-101-2 78.848 94.284
MNASNet 1.0 73.456 91.510
Expand Down Expand Up @@ -366,6 +369,7 @@ ResNext

resnext50_32x4d
resnext101_32x8d
resnext101_64x4d

Wide ResNet
-----------
Expand Down Expand Up @@ -481,8 +485,11 @@ a model with random weights by calling its constructor:
resnet18 = models.quantization.resnet18()
resnet50 = models.quantization.resnet50()
resnext101_32x8d = models.quantization.resnext101_32x8d()
resnext101_64x4d = models.quantization.resnext101_64x4d()
shufflenet_v2_x0_5 = models.quantization.shufflenet_v2_x0_5()
shufflenet_v2_x1_0 = models.quantization.shufflenet_v2_x1_0()
shufflenet_v2_x1_5 = models.quantization.shufflenet_v2_x1_5()
shufflenet_v2_x2_0 = models.quantization.shufflenet_v2_x2_0()
Obtaining a pre-trained quantized model can be done with a few lines of code:

Expand All @@ -508,6 +515,7 @@ ShuffleNet V2 x2.0 75.354 92.488
ResNet 18 69.494 88.882
ResNet 50 75.920 92.814
ResNext 101 32x8d 78.986 94.480
ResNext 101 64x4d 82.898 96.326
Inception V3 77.176 93.354
GoogleNet 69.826 89.404
================================ ============= =============
Expand Down
Binary file not shown.
Binary file not shown.
1 change: 1 addition & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def _check_input_backprop(model, inputs):
"convnext_base",
"convnext_large",
"resnext101_32x8d",
"resnext101_64x4d",
"wide_resnet101_2",
"efficientnet_b6",
"efficientnet_b7",
Expand Down
44 changes: 44 additions & 0 deletions torchvision/models/quantization/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ResNet18_Weights,
ResNet50_Weights,
ResNeXt101_32X8D_Weights,
ResNeXt101_64X4D_Weights,
)

from ...transforms._presets import ImageClassification
Expand All @@ -25,9 +26,11 @@
"ResNet18_QuantizedWeights",
"ResNet50_QuantizedWeights",
"ResNeXt101_32X8D_QuantizedWeights",
"ResNeXt101_64X4D_QuantizedWeights",
"resnet18",
"resnet50",
"resnext101_32x8d",
"resnext101_64x4d",
]


Expand Down Expand Up @@ -231,6 +234,24 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
DEFAULT = IMAGENET1K_FBGEMM_V2


class ResNeXt101_64X4D_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/resnext101_64x4d_fbgemm-605a1cb3.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 83455272,
"recipe": "https://github.com/pytorch/vision/pull/5935",
"unquantized": ResNeXt101_64X4D_Weights.IMAGENET1K_V1,
"metrics": {
"acc@1": 82.898,
"acc@5": 96.326,
},
},
)
DEFAULT = IMAGENET1K_FBGEMM_V1


@handle_legacy_interface(
weights=(
"pretrained",
Expand Down Expand Up @@ -318,3 +339,26 @@ def resnext101_32x8d(
_ovewrite_named_param(kwargs, "groups", 32)
_ovewrite_named_param(kwargs, "width_per_group", 8)
return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs)


def resnext101_64x4d(
*,
weights: Optional[Union[ResNeXt101_64X4D_QuantizedWeights, ResNeXt101_64X4D_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableResNet:
r"""ResNeXt-101 64x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
weights (ResNeXt101_64X4D_QuantizedWeights or ResNeXt101_64X4D_Weights, optional): The pretrained
weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model
"""
weights = (ResNeXt101_64X4D_QuantizedWeights if quantize else ResNeXt101_64X4D_Weights).verify(weights)

_ovewrite_named_param(kwargs, "groups", 64)
_ovewrite_named_param(kwargs, "width_per_group", 4)
return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs)
48 changes: 48 additions & 0 deletions torchvision/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"ResNet152_Weights",
"ResNeXt50_32X4D_Weights",
"ResNeXt101_32X8D_Weights",
"ResNeXt101_64X4D_Weights",
"Wide_ResNet50_2_Weights",
"Wide_ResNet101_2_Weights",
"resnet18",
Expand All @@ -30,6 +31,7 @@
"resnet152",
"resnext50_32x4d",
"resnext101_32x8d",
"resnext101_64x4d",
"wide_resnet50_2",
"wide_resnet101_2",
]
Expand Down Expand Up @@ -491,6 +493,24 @@ class ResNeXt101_32X8D_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V2


class ResNeXt101_64X4D_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnext101_64x4d-173b62eb.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 83455272,
"recipe": "https://github.com/pytorch/vision/pull/5935",
"metrics": {
# Mock
"acc@1": 83.246,
"acc@5": 96.454,
},
},
)
DEFAULT = IMAGENET1K_V1


class Wide_ResNet50_2_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
Expand Down Expand Up @@ -734,6 +754,34 @@ def resnext101_32x8d(
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)


def resnext101_64x4d(
*, weights: Optional[ResNeXt101_64X4D_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
"""ResNeXt-101 64x4d model from
`Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_.
Args:
weights (:class:`~torchvision.models.ResNeXt101_64X4D_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.ResNeXt101_64X4D_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
for more details about this class.
.. autoclass:: torchvision.models.ResNeXt101_64X4D_Weights
:members:
"""
weights = ResNeXt101_64X4D_Weights.verify(weights)

_ovewrite_named_param(kwargs, "groups", 64)
_ovewrite_named_param(kwargs, "width_per_group", 4)
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)


@handle_legacy_interface(weights=("pretrained", Wide_ResNet50_2_Weights.IMAGENET1K_V1))
def wide_resnet50_2(
*, weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any
Expand Down

0 comments on commit 309483c

Please sign in to comment.