From e6d82f7d46fe6eaf3ab48e379e7122fd56594480 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 2 Mar 2022 12:36:28 +0000 Subject: [PATCH] Adding EfficientNetV2 architecture (#5450) * Extend the EfficientNet class to support v1 and v2. * Refactor config/builder methods and add prototype builders * Refactoring weight info. * Update dropouts based on TF config ref * Update BN eps on TF base_config * Use Conv2dNormActivation. * Adding pre-trained weights for EfficientNetV2-s * Add Medium and Large weights * Update stats with single batch run. * Add accuracies in the docs. --- docs/source/models.rst | 14 +- hubconf.py | 3 + references/classification/README.md | 22 +- ...elTester.test_efficientnet_v2_l_expect.pkl | Bin 0 -> 939 bytes ...elTester.test_efficientnet_v2_m_expect.pkl | Bin 0 -> 939 bytes ...elTester.test_efficientnet_v2_s_expect.pkl | Bin 0 -> 939 bytes torchvision/models/efficientnet.py | 354 ++++++++++++++---- torchvision/prototype/models/efficientnet.py | 229 ++++++++--- 8 files changed, 507 insertions(+), 115 deletions(-) create mode 100644 test/expect/ModelTester.test_efficientnet_v2_l_expect.pkl create mode 100644 test/expect/ModelTester.test_efficientnet_v2_m_expect.pkl create mode 100644 test/expect/ModelTester.test_efficientnet_v2_s_expect.pkl diff --git a/docs/source/models.rst b/docs/source/models.rst index 58bd0d81cd0..84fee191a8e 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -38,7 +38,7 @@ architectures for image classification: - `ResNeXt`_ - `Wide ResNet`_ - `MNASNet`_ -- `EfficientNet`_ +- `EfficientNet`_ v1 & v2 - `RegNet`_ - `VisionTransformer`_ - `ConvNeXt`_ @@ -70,6 +70,9 @@ You can construct a model with random weights by calling its constructor: efficientnet_b5 = models.efficientnet_b5() efficientnet_b6 = models.efficientnet_b6() efficientnet_b7 = models.efficientnet_b7() + efficientnet_v2_s = models.efficientnet_v2_s() + efficientnet_v2_m = models.efficientnet_v2_m() + efficientnet_v2_l = models.efficientnet_v2_l() regnet_y_400mf = models.regnet_y_400mf() regnet_y_800mf = models.regnet_y_800mf() regnet_y_1_6gf = models.regnet_y_1_6gf() @@ -122,6 +125,9 @@ These can be constructed by passing ``pretrained=True``: efficientnet_b5 = models.efficientnet_b5(pretrained=True) efficientnet_b6 = models.efficientnet_b6(pretrained=True) efficientnet_b7 = models.efficientnet_b7(pretrained=True) + efficientnet_v2_s = models.efficientnet_v2_s(pretrained=True) + efficientnet_v2_m = models.efficientnet_v2_m(pretrained=True) + efficientnet_v2_l = models.efficientnet_v2_l(pretrained=True) regnet_y_400mf = models.regnet_y_400mf(pretrained=True) regnet_y_800mf = models.regnet_y_800mf(pretrained=True) regnet_y_1_6gf = models.regnet_y_1_6gf(pretrained=True) @@ -238,6 +244,9 @@ EfficientNet-B4 83.384 96.594 EfficientNet-B5 83.444 96.628 EfficientNet-B6 84.008 96.916 EfficientNet-B7 84.122 96.908 +EfficientNetV2-s 84.228 96.878 +EfficientNetV2-m 85.112 97.156 +EfficientNetV2-l 85.810 97.792 regnet_x_400mf 72.834 90.950 regnet_x_800mf 75.212 92.348 regnet_x_1_6gf 77.040 93.440 @@ -439,6 +448,9 @@ EfficientNet efficientnet_b5 efficientnet_b6 efficientnet_b7 + efficientnet_v2_s + efficientnet_v2_m + efficientnet_v2_l RegNet ------------ diff --git a/hubconf.py b/hubconf.py index 5c2ad8e9e0d..c3de4f2da9a 100644 --- a/hubconf.py +++ b/hubconf.py @@ -13,6 +13,9 @@ efficientnet_b5, efficientnet_b6, efficientnet_b7, + efficientnet_v2_s, + efficientnet_v2_m, + efficientnet_v2_l, ) from torchvision.models.googlenet import googlenet from torchvision.models.inception import inception_v3 diff --git a/references/classification/README.md b/references/classification/README.md index e75336f23ca..173fb454995 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -88,7 +88,7 @@ Then we averaged the parameters of the last 3 checkpoints that improved the Acc@ and [#3354](https://github.com/pytorch/vision/pull/3354) for details. -### EfficientNet +### EfficientNet-V1 The weights of the B0-B4 variants are ported from Ross Wightman's [timm repo](https://github.com/rwightman/pytorch-image-models/blob/01cb46a9a50e3ba4be167965b5764e9702f09b30/timm/models/efficientnet.py#L95-L108). @@ -114,6 +114,26 @@ torchrun --nproc_per_node=8 train.py --model efficientnet_b7 --interpolation bic --val-resize-size 600 --val-crop-size 600 --train-crop-size 600 --test-only --pretrained ``` + +### EfficientNet-V2 +``` +torchrun --nproc_per_node=8 train.py \ +--model $MODEL --batch-size 128 --lr 0.5 --lr-scheduler cosineannealinglr \ +--lr-warmup-epochs 5 --lr-warmup-method linear --auto-augment ta_wide --epochs 600 --random-erase 0.1 \ +--label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 --weight-decay 0.00002 --norm-weight-decay 0.0 \ +--train-crop-size $TRAIN_SIZE --model-ema --val-crop-size $EVAL_SIZE --val-resize-size $EVAL_SIZE \ +--ra-sampler --ra-reps 4 +``` +Here `$MODEL` is one of `efficientnet_v2_s` and `efficientnet_v2_m`. +Note that the Small variant had a `$TRAIN_SIZE` of `300` and a `$EVAL_SIZE` of `384`, while the Medium `384` and `480` respectively. + +Note that the above command corresponds to training on a single node with 8 GPUs. +For generatring the pre-trained weights, we trained with 4 nodes, each with 8 GPUs (for a total of 32 GPUs), +and `--batch_size 32`. + +The weights of the Large variant are ported from the original paper rather than trained from scratch. See the `EfficientNet_V2_L_Weights` entry for their exact preprocessing transforms. + + ### RegNet #### Small models diff --git a/test/expect/ModelTester.test_efficientnet_v2_l_expect.pkl b/test/expect/ModelTester.test_efficientnet_v2_l_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..f3ca5315337c7f74d8ad7b249ad315b25b2b6e33 GIT binary patch literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5=ZXk1(PMcUyNU9^O`bzjNTYiKgW2-?On!^CN(DFdzP3;u)HwRy`W%Xc1yvO zcSoYB?P5cd57*;O6a01D^0I6UEX5C{d-feC!;CHPm`@OYn`?lZJcbftyFl+tAqtxTK2MU zjWzjT@*@6-$+dfNTb?}-H+kl+Y7+h}am!U{zpYERKQTE{7P-}S=He|$Rg+C_to>`^ zW~*SDlzQA46k0EqeSCfb7*ZgNJGA&2tl^ zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5+}K&-l!|ab@RVFvo^gv@Xx4JH+}OAt;-t|o~<^_GGW~OY_6Ws0f+ZSKD#3} zooQ$_YARl~NzCVx(XV)UquLJL%^j*0o6JKWZM-7$X4CU%=1ng?EHnz}7cmMCn`*2x znROGxoOe1O8hAGJ&U$84#KE|EhO)JBRh`*p-OVb|Jk=09gUl6v|299xMA0N zqaU6Qn`bOsYs3}w+bHGEUZYf2RpXAYO-3TCE*l26${B0z+iG;+^`!Nt2aPr}d<)+E zRjY8*vo(^NL7}Dh-1Y1UU`T;5?$F|Au!e_LS!z)+Fc#dL%!v#xq>zI!jk!QJU!0d7 z$^^6(gaf=8K@>bqBFCWsNCE|*r%-g=$bRBO(fJC4k zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5=S<5r>W+Sf2PNF@|)c-&NaO|!+Tr8ic_0)EEP;W?s}Onp2}ouF~{2MmS^17 zgimbSdg2b5?sH(+c4%Rg=>dl=o0*DtY-71pXLh*Q(R9MkH6{la12ZL36dwuYW5G@COc z!gM0Pqp2hBLX*Q6CYY@(U0{0ckcHWmH}_0^Y;4UooRl)tFj6(W;8J16)pOBQNorA+@ZzKU=0tgvecqtU@W*fnG+dYNFfJd8gqeczBn&E zlnH1n2nTpGf+%>JM2 int: + return _make_divisible(channels * width_mult, 8, min_value) + + +class MBConvConfig(_MBConvConfig): + # Stores information listed at Table 1 of the EfficientNet paper & Table 4 of the EfficientNetV2 paper def __init__( self, expand_ratio: float, @@ -50,38 +75,39 @@ def __init__( input_channels: int, out_channels: int, num_layers: int, - width_mult: float, - depth_mult: float, + width_mult: float = 1.0, + depth_mult: float = 1.0, + block: Optional[Callable[..., nn.Module]] = None, ) -> None: - self.expand_ratio = expand_ratio - self.kernel = kernel - self.stride = stride - self.input_channels = self.adjust_channels(input_channels, width_mult) - self.out_channels = self.adjust_channels(out_channels, width_mult) - self.num_layers = self.adjust_depth(num_layers, depth_mult) - - def __repr__(self) -> str: - s = ( - f"{self.__class__.__name__}(" - f"expand_ratio={self.expand_ratio}" - f", kernel={self.kernel}" - f", stride={self.stride}" - f", input_channels={self.input_channels}" - f", out_channels={self.out_channels}" - f", num_layers={self.num_layers}" - f")" - ) - return s - - @staticmethod - def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int: - return _make_divisible(channels * width_mult, 8, min_value) + input_channels = self.adjust_channels(input_channels, width_mult) + out_channels = self.adjust_channels(out_channels, width_mult) + num_layers = self.adjust_depth(num_layers, depth_mult) + if block is None: + block = MBConv + super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block) @staticmethod def adjust_depth(num_layers: int, depth_mult: float): return int(math.ceil(num_layers * depth_mult)) +class FusedMBConvConfig(_MBConvConfig): + # Stores information listed at Table 4 of the EfficientNetV2 paper + def __init__( + self, + expand_ratio: float, + kernel: int, + stride: int, + input_channels: int, + out_channels: int, + num_layers: int, + block: Optional[Callable[..., nn.Module]] = None, + ) -> None: + if block is None: + block = FusedMBConv + super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block) + + class MBConv(nn.Module): def __init__( self, @@ -149,27 +175,88 @@ def forward(self, input: Tensor) -> Tensor: return result +class FusedMBConv(nn.Module): + def __init__( + self, + cnf: FusedMBConvConfig, + stochastic_depth_prob: float, + norm_layer: Callable[..., nn.Module], + ) -> None: + super().__init__() + + if not (1 <= cnf.stride <= 2): + raise ValueError("illegal stride value") + + self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels + + layers: List[nn.Module] = [] + activation_layer = nn.SiLU + + expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio) + if expanded_channels != cnf.input_channels: + # fused expand + layers.append( + Conv2dNormActivation( + cnf.input_channels, + expanded_channels, + kernel_size=cnf.kernel, + stride=cnf.stride, + norm_layer=norm_layer, + activation_layer=activation_layer, + ) + ) + + # project + layers.append( + Conv2dNormActivation( + expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None + ) + ) + else: + layers.append( + Conv2dNormActivation( + cnf.input_channels, + cnf.out_channels, + kernel_size=cnf.kernel, + stride=cnf.stride, + norm_layer=norm_layer, + activation_layer=activation_layer, + ) + ) + + self.block = nn.Sequential(*layers) + self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") + self.out_channels = cnf.out_channels + + def forward(self, input: Tensor) -> Tensor: + result = self.block(input) + if self.use_res_connect: + result = self.stochastic_depth(result) + result += input + return result + + class EfficientNet(nn.Module): def __init__( self, - inverted_residual_setting: List[MBConvConfig], + inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]], dropout: float, stochastic_depth_prob: float = 0.2, num_classes: int = 1000, - block: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, + last_channel: Optional[int] = None, **kwargs: Any, ) -> None: """ - EfficientNet main class + EfficientNet V1 and V2 main class Args: - inverted_residual_setting (List[MBConvConfig]): Network structure + inverted_residual_setting (Sequence[Union[MBConvConfig, FusedMBConvConfig]]): Network structure dropout (float): The droupout probability stochastic_depth_prob (float): The stochastic depth probability num_classes (int): Number of classes - block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use + last_channel (int): The number of channels on the penultimate layer """ super().__init__() _log_api_usage_once(self) @@ -178,12 +265,19 @@ def __init__( raise ValueError("The inverted_residual_setting should not be empty") elif not ( isinstance(inverted_residual_setting, Sequence) - and all([isinstance(s, MBConvConfig) for s in inverted_residual_setting]) + and all([isinstance(s, _MBConvConfig) for s in inverted_residual_setting]) ): raise TypeError("The inverted_residual_setting should be List[MBConvConfig]") - if block is None: - block = MBConv + if "block" in kwargs: + warnings.warn( + "The parameter 'block' is deprecated since 0.13 and will be removed 0.15. " + "Please pass this information on 'MBConvConfig.block' instead." + ) + if kwargs["block"] is not None: + for s in inverted_residual_setting: + if isinstance(s, MBConvConfig): + s.block = kwargs["block"] if norm_layer is None: norm_layer = nn.BatchNorm2d @@ -215,14 +309,14 @@ def __init__( # adjust stochastic depth probability based on the depth of the stage block sd_prob = stochastic_depth_prob * float(stage_block_id) / total_stage_blocks - stage.append(block(block_cnf, sd_prob, norm_layer)) + stage.append(block_cnf.block(block_cnf, sd_prob, norm_layer)) stage_block_id += 1 layers.append(nn.Sequential(*stage)) # building last several layers lastconv_input_channels = inverted_residual_setting[-1].out_channels - lastconv_output_channels = 4 * lastconv_input_channels + lastconv_output_channels = last_channel if last_channel is not None else 4 * lastconv_input_channels layers.append( Conv2dNormActivation( lastconv_input_channels, @@ -269,24 +363,14 @@ def forward(self, x: Tensor) -> Tensor: def _efficientnet( arch: str, - width_mult: float, - depth_mult: float, + inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]], dropout: float, + last_channel: Optional[int], pretrained: bool, progress: bool, **kwargs: Any, ) -> EfficientNet: - bneck_conf = partial(MBConvConfig, width_mult=width_mult, depth_mult=depth_mult) - inverted_residual_setting = [ - bneck_conf(1, 3, 1, 32, 16, 1), - bneck_conf(6, 3, 2, 16, 24, 2), - bneck_conf(6, 5, 2, 24, 40, 2), - bneck_conf(6, 3, 2, 40, 80, 3), - bneck_conf(6, 5, 1, 80, 112, 3), - bneck_conf(6, 5, 2, 112, 192, 4), - bneck_conf(6, 3, 1, 192, 320, 1), - ] - model = EfficientNet(inverted_residual_setting, dropout, **kwargs) + model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs) if pretrained: if model_urls.get(arch, None) is None: raise ValueError(f"No checkpoint is available for model type {arch}") @@ -295,6 +379,61 @@ def _efficientnet( return model +def _efficientnet_conf( + arch: str, + **kwargs: Any, +) -> Tuple[Sequence[Union[MBConvConfig, FusedMBConvConfig]], Optional[int]]: + inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]] + if arch.startswith("efficientnet_b"): + bneck_conf = partial(MBConvConfig, width_mult=kwargs.pop("width_mult"), depth_mult=kwargs.pop("depth_mult")) + inverted_residual_setting = [ + bneck_conf(1, 3, 1, 32, 16, 1), + bneck_conf(6, 3, 2, 16, 24, 2), + bneck_conf(6, 5, 2, 24, 40, 2), + bneck_conf(6, 3, 2, 40, 80, 3), + bneck_conf(6, 5, 1, 80, 112, 3), + bneck_conf(6, 5, 2, 112, 192, 4), + bneck_conf(6, 3, 1, 192, 320, 1), + ] + last_channel = None + elif arch.startswith("efficientnet_v2_s"): + inverted_residual_setting = [ + FusedMBConvConfig(1, 3, 1, 24, 24, 2), + FusedMBConvConfig(4, 3, 2, 24, 48, 4), + FusedMBConvConfig(4, 3, 2, 48, 64, 4), + MBConvConfig(4, 3, 2, 64, 128, 6), + MBConvConfig(6, 3, 1, 128, 160, 9), + MBConvConfig(6, 3, 2, 160, 256, 15), + ] + last_channel = 1280 + elif arch.startswith("efficientnet_v2_m"): + inverted_residual_setting = [ + FusedMBConvConfig(1, 3, 1, 24, 24, 3), + FusedMBConvConfig(4, 3, 2, 24, 48, 5), + FusedMBConvConfig(4, 3, 2, 48, 80, 5), + MBConvConfig(4, 3, 2, 80, 160, 7), + MBConvConfig(6, 3, 1, 160, 176, 14), + MBConvConfig(6, 3, 2, 176, 304, 18), + MBConvConfig(6, 3, 1, 304, 512, 5), + ] + last_channel = 1280 + elif arch.startswith("efficientnet_v2_l"): + inverted_residual_setting = [ + FusedMBConvConfig(1, 3, 1, 32, 32, 4), + FusedMBConvConfig(4, 3, 2, 32, 64, 7), + FusedMBConvConfig(4, 3, 2, 64, 96, 7), + MBConvConfig(4, 3, 2, 96, 192, 10), + MBConvConfig(6, 3, 1, 192, 224, 19), + MBConvConfig(6, 3, 2, 224, 384, 25), + MBConvConfig(6, 3, 1, 384, 640, 7), + ] + last_channel = 1280 + else: + raise ValueError(f"Unsupported model type {arch}") + + return inverted_residual_setting, last_channel + + def efficientnet_b0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: """ Constructs a EfficientNet B0 architecture from @@ -304,7 +443,9 @@ def efficientnet_b0(pretrained: bool = False, progress: bool = True, **kwargs: A pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _efficientnet("efficientnet_b0", 1.0, 1.0, 0.2, pretrained, progress, **kwargs) + arch = "efficientnet_b0" + inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.0, depth_mult=1.0) + return _efficientnet(arch, inverted_residual_setting, 0.2, last_channel, pretrained, progress, **kwargs) def efficientnet_b1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: @@ -316,7 +457,9 @@ def efficientnet_b1(pretrained: bool = False, progress: bool = True, **kwargs: A pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _efficientnet("efficientnet_b1", 1.0, 1.1, 0.2, pretrained, progress, **kwargs) + arch = "efficientnet_b1" + inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.0, depth_mult=1.1) + return _efficientnet(arch, inverted_residual_setting, 0.2, last_channel, pretrained, progress, **kwargs) def efficientnet_b2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: @@ -328,7 +471,9 @@ def efficientnet_b2(pretrained: bool = False, progress: bool = True, **kwargs: A pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _efficientnet("efficientnet_b2", 1.1, 1.2, 0.3, pretrained, progress, **kwargs) + arch = "efficientnet_b2" + inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.1, depth_mult=1.2) + return _efficientnet(arch, inverted_residual_setting, 0.3, last_channel, pretrained, progress, **kwargs) def efficientnet_b3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: @@ -340,7 +485,9 @@ def efficientnet_b3(pretrained: bool = False, progress: bool = True, **kwargs: A pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _efficientnet("efficientnet_b3", 1.2, 1.4, 0.3, pretrained, progress, **kwargs) + arch = "efficientnet_b3" + inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.2, depth_mult=1.4) + return _efficientnet(arch, inverted_residual_setting, 0.3, last_channel, pretrained, progress, **kwargs) def efficientnet_b4(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: @@ -352,7 +499,9 @@ def efficientnet_b4(pretrained: bool = False, progress: bool = True, **kwargs: A pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _efficientnet("efficientnet_b4", 1.4, 1.8, 0.4, pretrained, progress, **kwargs) + arch = "efficientnet_b4" + inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.4, depth_mult=1.8) + return _efficientnet(arch, inverted_residual_setting, 0.4, last_channel, pretrained, progress, **kwargs) def efficientnet_b5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: @@ -364,11 +513,13 @@ def efficientnet_b5(pretrained: bool = False, progress: bool = True, **kwargs: A pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ + arch = "efficientnet_b5" + inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.6, depth_mult=2.2) return _efficientnet( - "efficientnet_b5", - 1.6, - 2.2, + arch, + inverted_residual_setting, 0.4, + last_channel, pretrained, progress, norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), @@ -385,11 +536,13 @@ def efficientnet_b6(pretrained: bool = False, progress: bool = True, **kwargs: A pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ + arch = "efficientnet_b6" + inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.8, depth_mult=2.6) return _efficientnet( - "efficientnet_b6", - 1.8, - 2.6, + arch, + inverted_residual_setting, 0.5, + last_channel, pretrained, progress, norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), @@ -406,13 +559,84 @@ def efficientnet_b7(pretrained: bool = False, progress: bool = True, **kwargs: A pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ + arch = "efficientnet_b7" + inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=2.0, depth_mult=3.1) return _efficientnet( - "efficientnet_b7", - 2.0, - 3.1, + arch, + inverted_residual_setting, 0.5, + last_channel, pretrained, progress, norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs, ) + + +def efficientnet_v2_s(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: + """ + Constructs an EfficientNetV2-S architecture from + `"EfficientNetV2: Smaller Models and Faster Training" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + arch = "efficientnet_v2_s" + inverted_residual_setting, last_channel = _efficientnet_conf(arch) + return _efficientnet( + arch, + inverted_residual_setting, + 0.2, + last_channel, + pretrained, + progress, + norm_layer=partial(nn.BatchNorm2d, eps=1e-03), + **kwargs, + ) + + +def efficientnet_v2_m(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: + """ + Constructs an EfficientNetV2-M architecture from + `"EfficientNetV2: Smaller Models and Faster Training" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + arch = "efficientnet_v2_m" + inverted_residual_setting, last_channel = _efficientnet_conf(arch) + return _efficientnet( + arch, + inverted_residual_setting, + 0.3, + last_channel, + pretrained, + progress, + norm_layer=partial(nn.BatchNorm2d, eps=1e-03), + **kwargs, + ) + + +def efficientnet_v2_l(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: + """ + Constructs an EfficientNetV2-L architecture from + `"EfficientNetV2: Smaller Models and Faster Training" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + arch = "efficientnet_v2_l" + inverted_residual_setting, last_channel = _efficientnet_conf(arch) + return _efficientnet( + arch, + inverted_residual_setting, + 0.4, + last_channel, + pretrained, + progress, + norm_layer=partial(nn.BatchNorm2d, eps=1e-03), + **kwargs, + ) diff --git a/torchvision/prototype/models/efficientnet.py b/torchvision/prototype/models/efficientnet.py index 1fa2ea4d294..2619709764f 100644 --- a/torchvision/prototype/models/efficientnet.py +++ b/torchvision/prototype/models/efficientnet.py @@ -1,11 +1,11 @@ from functools import partial -from typing import Any, Optional +from typing import Any, Optional, Sequence, Union from torch import nn from torchvision.prototype.transforms import ImageNetEval from torchvision.transforms.functional import InterpolationMode -from ...models.efficientnet import EfficientNet, MBConvConfig +from ...models.efficientnet import EfficientNet, MBConvConfig, FusedMBConvConfig, _efficientnet_conf from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param @@ -21,6 +21,9 @@ "EfficientNet_B5_Weights", "EfficientNet_B6_Weights", "EfficientNet_B7_Weights", + "EfficientNet_V2_S_Weights", + "EfficientNet_V2_M_Weights", + "EfficientNet_V2_L_Weights", "efficientnet_b0", "efficientnet_b1", "efficientnet_b2", @@ -29,13 +32,16 @@ "efficientnet_b5", "efficientnet_b6", "efficientnet_b7", + "efficientnet_v2_s", + "efficientnet_v2_m", + "efficientnet_v2_l", ] def _efficientnet( - width_mult: float, - depth_mult: float, + inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]], dropout: float, + last_channel: Optional[int], weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, @@ -43,18 +49,7 @@ def _efficientnet( if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - bneck_conf = partial(MBConvConfig, width_mult=width_mult, depth_mult=depth_mult) - inverted_residual_setting = [ - bneck_conf(1, 3, 1, 32, 16, 1), - bneck_conf(6, 3, 2, 16, 24, 2), - bneck_conf(6, 5, 2, 24, 40, 2), - bneck_conf(6, 3, 2, 40, 80, 3), - bneck_conf(6, 5, 1, 80, 112, 3), - bneck_conf(6, 5, 2, 112, 192, 4), - bneck_conf(6, 3, 1, 192, 320, 1), - ] - - model = EfficientNet(inverted_residual_setting, dropout, **kwargs) + model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs) if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) @@ -64,12 +59,26 @@ def _efficientnet( _COMMON_META = { "task": "image_classification", + "categories": _IMAGENET_CATEGORIES, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet", +} + + +_COMMON_META_V1 = { + **_COMMON_META, "architecture": "EfficientNet", "publication_year": 2019, - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BICUBIC, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet", + "min_size": (1, 1), +} + + +_COMMON_META_V2 = { + **_COMMON_META, + "architecture": "EfficientNetV2", + "publication_year": 2021, + "interpolation": InterpolationMode.BILINEAR, + "min_size": (33, 33), } @@ -78,7 +87,7 @@ class EfficientNet_B0_Weights(WeightsEnum): url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC), meta={ - **_COMMON_META, + **_COMMON_META_V1, "num_params": 5288548, "size": (224, 224), "acc@1": 77.692, @@ -93,7 +102,7 @@ class EfficientNet_B1_Weights(WeightsEnum): url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", transforms=partial(ImageNetEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC), meta={ - **_COMMON_META, + **_COMMON_META_V1, "num_params": 7794184, "size": (240, 240), "acc@1": 78.642, @@ -104,7 +113,7 @@ class EfficientNet_B1_Weights(WeightsEnum): url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth", transforms=partial(ImageNetEval, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR), meta={ - **_COMMON_META, + **_COMMON_META_V1, "num_params": 7794184, "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-lr-wd-crop-tuning", "interpolation": InterpolationMode.BILINEAR, @@ -121,7 +130,7 @@ class EfficientNet_B2_Weights(WeightsEnum): url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", transforms=partial(ImageNetEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC), meta={ - **_COMMON_META, + **_COMMON_META_V1, "num_params": 9109994, "size": (288, 288), "acc@1": 80.608, @@ -136,7 +145,7 @@ class EfficientNet_B3_Weights(WeightsEnum): url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", transforms=partial(ImageNetEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC), meta={ - **_COMMON_META, + **_COMMON_META_V1, "num_params": 12233232, "size": (300, 300), "acc@1": 82.008, @@ -151,7 +160,7 @@ class EfficientNet_B4_Weights(WeightsEnum): url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", transforms=partial(ImageNetEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC), meta={ - **_COMMON_META, + **_COMMON_META_V1, "num_params": 19341616, "size": (380, 380), "acc@1": 83.384, @@ -166,7 +175,7 @@ class EfficientNet_B5_Weights(WeightsEnum): url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", transforms=partial(ImageNetEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC), meta={ - **_COMMON_META, + **_COMMON_META_V1, "num_params": 30389784, "size": (456, 456), "acc@1": 83.444, @@ -181,7 +190,7 @@ class EfficientNet_B6_Weights(WeightsEnum): url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", transforms=partial(ImageNetEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC), meta={ - **_COMMON_META, + **_COMMON_META_V1, "num_params": 43040704, "size": (528, 528), "acc@1": 84.008, @@ -196,7 +205,7 @@ class EfficientNet_B7_Weights(WeightsEnum): url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", transforms=partial(ImageNetEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC), meta={ - **_COMMON_META, + **_COMMON_META_V1, "num_params": 66347960, "size": (600, 600), "acc@1": 84.122, @@ -206,13 +215,76 @@ class EfficientNet_B7_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 +class EfficientNet_V2_S_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth", + transforms=partial( + ImageNetEval, + crop_size=384, + resize_size=384, + interpolation=InterpolationMode.BILINEAR, + ), + meta={ + **_COMMON_META_V2, + "num_params": 21458488, + "size": (384, 384), + "acc@1": 84.228, + "acc@5": 96.878, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_V2_M_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth", + transforms=partial( + ImageNetEval, + crop_size=480, + resize_size=480, + interpolation=InterpolationMode.BILINEAR, + ), + meta={ + **_COMMON_META_V2, + "num_params": 54139356, + "size": (480, 480), + "acc@1": 85.112, + "acc@5": 97.156, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_V2_L_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth", + transforms=partial( + ImageNetEval, + crop_size=480, + resize_size=480, + interpolation=InterpolationMode.BICUBIC, + mean=(0.5, 0.5, 0.5), + std=(0.5, 0.5, 0.5), + ), + meta={ + **_COMMON_META_V2, + "num_params": 118515272, + "size": (480, 480), + "acc@1": 85.808, + "acc@5": 97.788, + }, + ) + DEFAULT = IMAGENET1K_V1 + + @handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.IMAGENET1K_V1)) def efficientnet_b0( *, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any ) -> EfficientNet: weights = EfficientNet_B0_Weights.verify(weights) - return _efficientnet(width_mult=1.0, depth_mult=1.0, dropout=0.2, weights=weights, progress=progress, **kwargs) + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b0", width_mult=1.0, depth_mult=1.0) + return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs) @handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.IMAGENET1K_V1)) @@ -221,7 +293,8 @@ def efficientnet_b1( ) -> EfficientNet: weights = EfficientNet_B1_Weights.verify(weights) - return _efficientnet(width_mult=1.0, depth_mult=1.1, dropout=0.2, weights=weights, progress=progress, **kwargs) + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b1", width_mult=1.0, depth_mult=1.1) + return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs) @handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.IMAGENET1K_V1)) @@ -230,7 +303,8 @@ def efficientnet_b2( ) -> EfficientNet: weights = EfficientNet_B2_Weights.verify(weights) - return _efficientnet(width_mult=1.1, depth_mult=1.2, dropout=0.3, weights=weights, progress=progress, **kwargs) + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b2", width_mult=1.1, depth_mult=1.2) + return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs) @handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.IMAGENET1K_V1)) @@ -239,7 +313,8 @@ def efficientnet_b3( ) -> EfficientNet: weights = EfficientNet_B3_Weights.verify(weights) - return _efficientnet(width_mult=1.2, depth_mult=1.4, dropout=0.3, weights=weights, progress=progress, **kwargs) + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b3", width_mult=1.2, depth_mult=1.4) + return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs) @handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.IMAGENET1K_V1)) @@ -248,7 +323,8 @@ def efficientnet_b4( ) -> EfficientNet: weights = EfficientNet_B4_Weights.verify(weights) - return _efficientnet(width_mult=1.4, depth_mult=1.8, dropout=0.4, weights=weights, progress=progress, **kwargs) + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b4", width_mult=1.4, depth_mult=1.8) + return _efficientnet(inverted_residual_setting, 0.4, last_channel, weights, progress, **kwargs) @handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.IMAGENET1K_V1)) @@ -257,12 +333,13 @@ def efficientnet_b5( ) -> EfficientNet: weights = EfficientNet_B5_Weights.verify(weights) + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b5", width_mult=1.6, depth_mult=2.2) return _efficientnet( - width_mult=1.6, - depth_mult=2.2, - dropout=0.4, - weights=weights, - progress=progress, + inverted_residual_setting, + 0.4, + last_channel, + weights, + progress, norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs, ) @@ -274,12 +351,13 @@ def efficientnet_b6( ) -> EfficientNet: weights = EfficientNet_B6_Weights.verify(weights) + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b6", width_mult=1.8, depth_mult=2.6) return _efficientnet( - width_mult=1.8, - depth_mult=2.6, - dropout=0.5, - weights=weights, - progress=progress, + inverted_residual_setting, + 0.5, + last_channel, + weights, + progress, norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs, ) @@ -291,12 +369,67 @@ def efficientnet_b7( ) -> EfficientNet: weights = EfficientNet_B7_Weights.verify(weights) + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b7", width_mult=2.0, depth_mult=3.1) return _efficientnet( - width_mult=2.0, - depth_mult=3.1, - dropout=0.5, - weights=weights, - progress=progress, + inverted_residual_setting, + 0.5, + last_channel, + weights, + progress, norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs, ) + + +@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_S_Weights.IMAGENET1K_V1)) +def efficientnet_v2_s( + *, weights: Optional[EfficientNet_V2_S_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: + weights = EfficientNet_V2_S_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_s") + return _efficientnet( + inverted_residual_setting, + 0.2, + last_channel, + weights, + progress, + norm_layer=partial(nn.BatchNorm2d, eps=1e-03), + **kwargs, + ) + + +@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_M_Weights.IMAGENET1K_V1)) +def efficientnet_v2_m( + *, weights: Optional[EfficientNet_V2_M_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: + weights = EfficientNet_V2_M_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_m") + return _efficientnet( + inverted_residual_setting, + 0.3, + last_channel, + weights, + progress, + norm_layer=partial(nn.BatchNorm2d, eps=1e-03), + **kwargs, + ) + + +@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_L_Weights.IMAGENET1K_V1)) +def efficientnet_v2_l( + *, weights: Optional[EfficientNet_V2_L_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: + weights = EfficientNet_V2_L_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_l") + return _efficientnet( + inverted_residual_setting, + 0.4, + last_channel, + weights, + progress, + norm_layer=partial(nn.BatchNorm2d, eps=1e-03), + **kwargs, + )