diff --git a/configs/mixnet/README.MD b/configs/mixnet/README.MD new file mode 100644 index 00000000..21078092 --- /dev/null +++ b/configs/mixnet/README.MD @@ -0,0 +1,89 @@ +# MixNet +> [MixConv: Mixed Depthwise Convolutional Kernels](https://arxiv.org/abs/1907.09595) + +## Introduction + +Depthwise convolution is becoming increasingly popular in modern efficient ConvNets, but its kernel size is often +overlooked. In this paper, the authors systematically study the impact of different kernel sizes, and observe that +combining the benefits of multiple kernel sizes can lead to better accuracy and efficiency. Based on this observation, +the authors propose a new mixed depthwise convolution (MixConv), which naturally mixes up multiple kernel sizes in a +single convolution. As a simple drop-in replacement of vanilla depthwise convolution, our MixConv improves the accuracy +and efficiency for existing MobileNets on both ImageNet classification and COCO object detection.[[1](#references)] + +

+ +

+

+ Figure 1. Architecture of MixNet [1] +

+ +## Results + +Our reproduced model performance on ImageNet-1K is reported as follows. + +
+ +| Model | Context | Top-1 (%) | Top-5 (%) | Params (M) | Recipe | Download | +|----------|----------|-----------|-----------|------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------| +| MixNet_s | D910x8-G | 75.63 | 92.52 | 4.17 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/mixnet/mixnet_s_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/mixnet/mixnet_s-2a5ef3a3.ckpt) | + +
+ +#### Notes + +- Context: Training context denoted as {device}x{pieces}-{MS mode}, where mindspore mode can be G - graph mode or F - pynative mode with ms function. For example, D910x8-G is for training on 8 pieces of Ascend 910 NPU using graph mode. +- Top-1 and Top-5: Accuracy reported on the validation set of ImageNet-1K. + +## Quick Start + +### Preparation + +#### Installation +Please refer to the [installation instruction](https://github.com/mindspore-ecosystem/mindcv#installation) in MindCV. + +#### Dataset Preparation +Please download the [ImageNet-1K](https://www.image-net.org/challenges/LSVRC/2012/index.php) dataset for model training and validation. + +### Training + +* Distributed Training + +It is easy to reproduce the reported results with the pre-defined training recipe. For distributed training on multiple Ascend 910 devices, please run + +```shell +# distrubted training on multiple GPU/Ascend devices +mpirun -n 8 python train.py --config configs/mixnet/mixnet_s_ascend.yaml --data_dir /path/to/imagenet +``` + +> If the script is executed by the root user, the `--allow-run-as-root` parameter must be added to `mpirun`. + +Similarly, you can train the model on multiple GPU devices with the above `mpirun` command. + +For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindcv/blob/main/config.py). + +**Note:** As the global batch size (batch_size x num_devices) is an important hyper-parameter, it is recommended to keep the global batch size unchanged for reproduction or adjust the learning rate linearly to a new global batch size. + +* Standalone Training + +If you want to train or finetune the model on a smaller dataset without distributed training, please run: + +```shell +# standalone training on a CPU/GPU/Ascend device +python train.py --config configs/mixnet/mixnet_s_ascend.yaml --data_dir /path/to/dataset --distribute False +``` + +### Validation + +To validate the accuracy of the trained model, you can use `validate.py` and parse the checkpoint path with `--ckpt_path`. + +```shell +python validate.py -c configs/mixnet/mixnet_s_ascend.yaml --data_dir /path/to/imagenet --ckpt_path /path/to/ckpt +``` + +### Deployment + +Please refer to the [deployment tutorial](https://github.com/mindspore-lab/mindcv/blob/main/tutorials/deployment.md) in MindCV. + +## References + +[1] Tan M, Le Q V. Mixconv: Mixed depthwise convolutional kernels[J]. arXiv preprint arXiv:1907.09595, 2019. diff --git a/configs/mixnet/mixnet_s_ascend.yaml b/configs/mixnet/mixnet_s_ascend.yaml new file mode 100644 index 00000000..2eb770fd --- /dev/null +++ b/configs/mixnet/mixnet_s_ascend.yaml @@ -0,0 +1,55 @@ +# system +mode: 0 +distribute: True +num_parallel_workers: 8 +val_while_train: True + +# dataset +dataset: "imagenet" +data_dir: "path/to/imagenet" +shuffle: True +dataset_download: False +batch_size: 128 +drop_remainder: True + +# augmentation +image_resize: 224 +scale: [0.08, 1.0] +ratio: [0.75, 1.333] +hflip: 0.5 +interpolation: "bicubic" +auto_augment: "randaug-m9-mstd0.5" +re_prob: 0.25 +crop_pct: 0.875 +mixup: 0.2 +cutmix: 1.0 + +# model +model: "mixnet_s" +num_classes: 1000 +pretrained: False +ckpt_path: '' +keep_checkpoint_max: 10 +ckpt_save_dir: "./ckpt" +epoch_size: 600 +dataset_sink_mode: True +amp_level: "O3" + +# loss +loss: "CE" +label_smoothing: 0.1 + +# lr scheduler +scheduler: "warmup_cosine_decay" +lr: 0.2 +min_lr: 0.00001 +decay_epochs: 585 +warmup_epochs: 15 + +# optimizer +opt: "momentum" +filter_bias_and_bn: True +momentum: 0.9 +weight_decay: 0.00002 +loss_scale: 256 +use_nesterov: False diff --git a/mindcv/models/__init__.py b/mindcv/models/__init__.py index 8d7dbc33..d0521eff 100644 --- a/mindcv/models/__init__.py +++ b/mindcv/models/__init__.py @@ -13,6 +13,7 @@ inception_v3, inception_v4, layers, + mixnet, mnasnet, mobilenet_v1, mobilenet_v2, @@ -54,6 +55,7 @@ from .inception_v3 import * from .inception_v4 import * from .layers import * +from .mixnet import * from .mnasnet import * from .mobilenet_v1 import * from .mobilenet_v2 import * @@ -99,6 +101,7 @@ __all__.extend(["InceptionV3", "inception_v3"]) __all__.extend(["InceptionV4", "inception_v4"]) __all__.extend(layers.__all__) +__all__.extend(mixnet.__all__) __all__.extend(mnasnet.__all__) __all__.extend(mobilenet_v1.__all__) __all__.extend(mobilenet_v2.__all__) diff --git a/mindcv/models/mixnet.py b/mindcv/models/mixnet.py new file mode 100644 index 00000000..423fb809 --- /dev/null +++ b/mindcv/models/mixnet.py @@ -0,0 +1,416 @@ +""" +MindSpore implementation of `MixNet`. +Refer to MixConv: Mixed Depthwise Convolutional Kernels +""" + +import math +from typing import Optional + +import mindspore.common.initializer as init +from mindspore import Tensor, nn, ops + +from .layers.pooling import GlobalAvgPooling +from .layers.squeeze_excite import SqueezeExcite +from .registry import register_model +from .utils import load_pretrained + +__all__ = [ + "MixNet", + "mixnet_s", + "mixnet_m", + "mixnet_l", +] + + +def _cfg(url="", **kwargs): + return { + "url": url, + "num_classes": 1000, + "first_conv": "stem_conv.0", "classifier": "classifier", + **kwargs, + } + + +default_cfgs = { + "mixnet_s": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/mixnet/mixnet_s-2a5ef3a3.ckpt"), + "mixnet_m": _cfg(url=""), + "mixnet_l": _cfg(url=""), +} + + +def _roundchannels(filters: float, divisor: int = 8, min_depth: Optional[int] = None) -> int: + if min_depth is None: + min_depth = divisor + new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) + if new_filters < 0.9 * filters: + new_filters += divisor + return new_filters + + +def _splitchannels(channels: int, num_groups: int) -> list: + split_channels = [channels // num_groups for _ in range(num_groups)] + split_channels[0] += channels - sum(split_channels) + return split_channels + + +class Swish(nn.Cell): + def __init__(self) -> None: + super(Swish, self).__init__() + self.sigmoid = ops.Sigmoid() + + def construct(self, x: Tensor) -> Tensor: + return x * self.sigmoid(x) + + +class GroupedConv2d(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: list, + stride: int = 1, + padding: int = 0, + ) -> None: + super(GroupedConv2d, self).__init__() + self.num_groups = len(kernel_size) + if self.num_groups == 1: + self.grouped_conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size[0], + stride=stride, + pad_mode="pad", + padding=padding, + has_bias=False + ) + else: + self.split_in_channels = _splitchannels(in_channels, self.num_groups) + self.split_out_channels = _splitchannels(out_channels, self.num_groups) + + self.grouped_conv = nn.CellList() + for i in range(self.num_groups): + self.grouped_conv.append(nn.Conv2d( + self.split_in_channels[i], + self.split_out_channels[i], + kernel_size[i], + stride=stride, + pad_mode="pad", + padding=padding, + has_bias=False + )) + + def construct(self, x: Tensor) -> Tensor: + if self.num_groups == 1: + return self.grouped_conv(x) + + output = [] + start, end = 0, 0 + for i in range(self.num_groups): + start, end = end, end + self.split_in_channels[i] + x_split = x[:, start:end] + + conv = self.grouped_conv[i] + output.append(conv(x_split)) + + return ops.concat(output, axis=1) + + +class MDConv(nn.Cell): + """Mixed Depth-wise Convolution""" + + def __init__(self, channels: int, kernel_size: list, stride: int) -> None: + super(MDConv, self).__init__() + self.num_groups = len(kernel_size) + + if self.num_groups == 1: + self.mixed_depthwise_conv = nn.Conv2d( + channels, + channels, + kernel_size[0], + stride=stride, + pad_mode="pad", + padding=kernel_size[0] // 2, + group=channels, + has_bias=False + ) + else: + self.split_channels = _splitchannels(channels, self.num_groups) + + self.mixed_depthwise_conv = nn.CellList() + for i in range(self.num_groups): + self.mixed_depthwise_conv.append(nn.Conv2d( + self.split_channels[i], + self.split_channels[i], + kernel_size[i], + stride=stride, + pad_mode="pad", + padding=kernel_size[i] // 2, + group=self.split_channels[i], + has_bias=False + )) + + def construct(self, x: Tensor) -> Tensor: + if self.num_groups == 1: + return self.mixed_depthwise_conv(x) + + output = [] + start, end = 0, 0 + for i in range(self.num_groups): + start, end = end, end + self.split_channels[i] + x_split = x[:, start:end] + + conv = self.mixed_depthwise_conv[i] + output.append(conv(x_split)) + + return ops.concat(output, axis=1) + + +class MixNetBlock(nn.Cell): + """Basic Block of MixNet""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: list = [3], + expand_ksize: list = [1], + project_ksize: list = [1], + stride: int = 1, + expand_ratio: int = 1, + activation: str = "ReLU", + se_ratio: float = 0.0, + ) -> None: + super(MixNetBlock, self).__init__() + assert activation in ["ReLU", "Swish"] + self.activation = Swish if activation == "Swish" else nn.ReLU + + expand_channels = in_channels * expand_ratio + self.residual_connection = (stride == 1 and in_channels == out_channels) + + conv = [] + if expand_ratio != 1: + # expand + conv.extend([ + GroupedConv2d(in_channels, expand_channels, expand_ksize), + nn.BatchNorm2d(expand_channels), + self.activation() + ]) + + # depthwise + conv.extend([ + MDConv(expand_channels, kernel_size, stride), + nn.BatchNorm2d(expand_channels), + self.activation() + ]) + + if se_ratio > 0: + squeeze_channels = int(in_channels * se_ratio) + squeeze_excite = SqueezeExcite(expand_channels, rd_channels=squeeze_channels) + conv.append(squeeze_excite) + + # projection phase + conv.extend([ + GroupedConv2d(expand_channels, out_channels, project_ksize), + nn.BatchNorm2d(out_channels) + ]) + + self.convs = nn.SequentialCell(conv) + + def construct(self, x: Tensor) -> Tensor: + if self.residual_connection: + return x + self.convs(x) + else: + return self.convs(x) + + +class MixNet(nn.Cell): + r"""MixNet model class, based on + `"MixConv: Mixed Depthwise Convolutional Kernels" `_ + + Args: + arch: size of the architecture. "small", "medium" or "large". Default: "small". + num_classes: number of classification classes. Default: 1000. + in_channels: number of the channels of the input. Default: 3. + feature_size: numbet of the channels of the output features. Default: 1536. + drop_rate: rate of dropout for classifier. Default: 0.2. + depth_multiplier: expansion coefficient of channels. Default: 1.0. + """ + + def __init__( + self, + arch: str = "small", + num_classes: int = 1000, + in_channels: int = 3, + feature_size: int = 1536, + drop_rate: float = 0.2, + depth_multiplier: float = 1.0 + ) -> None: + super(MixNet, self).__init__() + if arch == "small": + block_configs = [ + [16, 16, [3], [1], [1], 1, 1, "ReLU", 0.0], + [16, 24, [3], [1, 1], [1, 1], 2, 6, "ReLU", 0.0], + [24, 24, [3], [1, 1], [1, 1], 1, 3, "ReLU", 0.0], + [24, 40, [3, 5, 7], [1], [1], 2, 6, "Swish", 0.5], + [40, 40, [3, 5], [1, 1], [1, 1], 1, 6, "Swish", 0.5], + [40, 40, [3, 5], [1, 1], [1, 1], 1, 6, "Swish", 0.5], + [40, 40, [3, 5], [1, 1], [1, 1], 1, 6, "Swish", 0.5], + [40, 80, [3, 5, 7], [1], [1, 1], 2, 6, "Swish", 0.25], + [80, 80, [3, 5], [1], [1, 1], 1, 6, "Swish", 0.25], + [80, 80, [3, 5], [1], [1, 1], 1, 6, "Swish", 0.25], + [80, 120, [3, 5, 7], [1, 1], [1, 1], 1, 6, "Swish", 0.5], + [120, 120, [3, 5, 7, 9], [1, 1], [1, 1], 1, 3, "Swish", 0.5], + [120, 120, [3, 5, 7, 9], [1, 1], [1, 1], 1, 3, "Swish", 0.5], + [120, 200, [3, 5, 7, 9, 11], [1], [1], 2, 6, "Swish", 0.5], + [200, 200, [3, 5, 7, 9], [1], [1, 1], 1, 6, "Swish", 0.5], + [200, 200, [3, 5, 7, 9], [1], [1, 1], 1, 6, "Swish", 0.5] + ] + stem_channels = 16 + drop_rate = drop_rate + else: + block_configs = [ + [24, 24, [3], [1], [1], 1, 1, "ReLU", 0.0], + [24, 32, [3, 5, 7], [1, 1], [1, 1], 2, 6, "ReLU", 0.0], + [32, 32, [3], [1, 1], [1, 1], 1, 3, "ReLU", 0.0], + [32, 40, [3, 5, 7, 9], [1], [1], 2, 6, "Swish", 0.5], + [40, 40, [3, 5], [1, 1], [1, 1], 1, 6, "Swish", 0.5], + [40, 40, [3, 5], [1, 1], [1, 1], 1, 6, "Swish", 0.5], + [40, 40, [3, 5], [1, 1], [1, 1], 1, 6, "Swish", 0.5], + [40, 80, [3, 5, 7], [1], [1], 2, 6, "Swish", 0.25], + [80, 80, [3, 5, 7, 9], [1, 1], [1, 1], 1, 6, "Swish", 0.25], + [80, 80, [3, 5, 7, 9], [1, 1], [1, 1], 1, 6, "Swish", 0.25], + [80, 80, [3, 5, 7, 9], [1, 1], [1, 1], 1, 6, "Swish", 0.25], + [80, 120, [3], [1], [1], 1, 6, "Swish", 0.5], + [120, 120, [3, 5, 7, 9], [1, 1], [1, 1], 1, 3, "Swish", 0.5], + [120, 120, [3, 5, 7, 9], [1, 1], [1, 1], 1, 3, "Swish", 0.5], + [120, 120, [3, 5, 7, 9], [1, 1], [1, 1], 1, 3, "Swish", 0.5], + [120, 200, [3, 5, 7, 9], [1], [1], 2, 6, "Swish", 0.5], + [200, 200, [3, 5, 7, 9], [1], [1, 1], 1, 6, "Swish", 0.5], + [200, 200, [3, 5, 7, 9], [1], [1, 1], 1, 6, "Swish", 0.5], + [200, 200, [3, 5, 7, 9], [1], [1, 1], 1, 6, "Swish", 0.5] + ] + if arch == "medium": + stem_channels = 24 + drop_rate = drop_rate + elif arch == "large": + stem_channels = 24 + depth_multiplier *= 1.3 + drop_rate = drop_rate + else: + raise ValueError(f"Unsupported model type {arch}") + + if depth_multiplier != 1.0: + stem_channels = _roundchannels(stem_channels * depth_multiplier) + + for i, conf in enumerate(block_configs): + conf_ls = list(conf) + conf_ls[0] = _roundchannels(conf_ls[0] * depth_multiplier) + conf_ls[1] = _roundchannels(conf_ls[1] * depth_multiplier) + block_configs[i] = tuple(conf_ls) + + # stem convolution + self.stem_conv = nn.SequentialCell([ + nn.Conv2d(in_channels, stem_channels, 3, stride=2, pad_mode="pad", padding=1), + nn.BatchNorm2d(stem_channels), + nn.ReLU() + ]) + + # building MixNet blocks + layers = [] + for inc, outc, k, ek, pk, s, er, ac, se in block_configs: + layers.append(MixNetBlock( + inc, + outc, + kernel_size=k, + expand_ksize=ek, + project_ksize=pk, + stride=s, + expand_ratio=er, + activation=ac, + se_ratio=se + )) + self.layers = nn.SequentialCell(layers) + + # head + self.head_conv = nn.SequentialCell([ + nn.Conv2d(block_configs[-1][1], feature_size, 1, pad_mode="pad", padding=0), + nn.BatchNorm2d(feature_size), + nn.ReLU() + ]) + + self.pool = GlobalAvgPooling() + self.dropout = nn.Dropout(keep_prob=1 - drop_rate) + self.classifier = nn.Dense(feature_size, num_classes) + + self._initialize_weights() + + def _initialize_weights(self) -> None: + """Initialize weights for cells.""" + for _, cell in self.cells_and_names(): + if isinstance(cell, nn.Conv2d): + fan_out = cell.kernel_size[0] * cell.kernel_size[1] * cell.out_channels + cell.weight.set_data( + init.initializer(init.Normal(math.sqrt(2.0 / fan_out)), + cell.weight.shape, cell.weight.dtype)) + if cell.bias is not None: + cell.bias.set_data( + init.initializer("zeros", cell.bias.shape, cell.bias.dtype)) + elif isinstance(cell, nn.BatchNorm2d): + cell.gamma.set_data(init.initializer("ones", cell.gamma.shape, cell.gamma.dtype)) + cell.beta.set_data(init.initializer("zeros", cell.beta.shape, cell.beta.dtype)) + elif isinstance(cell, nn.Dense): + cell.weight.set_data( + init.initializer(init.Uniform(1.0 / math.sqrt(cell.weight.shape[0])), + cell.weight.shape, cell.weight.dtype)) + if cell.bias is not None: + cell.bias.set_data(init.initializer("zeros", cell.bias.shape, cell.bias.dtype)) + + def forward_features(self, x: Tensor) -> Tensor: + x = self.stem_conv(x) + x = self.layers(x) + x = self.head_conv(x) + return x + + def forward_head(self, x: Tensor) -> Tensor: + x = self.pool(x) + x = self.dropout(x) + x = self.classifier(x) + return x + + def construct(self, x: Tensor) -> Tensor: + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +@register_model +def mixnet_s(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs): + default_cfg = default_cfgs["mixnet_s"] + model = MixNet(arch="small", in_channels=in_channels, num_classes=num_classes, **kwargs) + + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) + + return model + + +@register_model +def mixnet_m(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs): + default_cfg = default_cfgs["mixnet_m"] + model = MixNet(arch="medium", in_channels=in_channels, num_classes=num_classes, **kwargs) + + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) + + return model + + +@register_model +def mixnet_l(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs): + default_cfg = default_cfgs["mixnet_l"] + model = MixNet(arch="large", in_channels=in_channels, num_classes=num_classes, **kwargs) + + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) + + return model