diff --git a/configs/mixnet/README.MD b/configs/mixnet/README.MD
new file mode 100644
index 00000000..21078092
--- /dev/null
+++ b/configs/mixnet/README.MD
@@ -0,0 +1,89 @@
+# MixNet
+> [MixConv: Mixed Depthwise Convolutional Kernels](https://arxiv.org/abs/1907.09595)
+
+## Introduction
+
+Depthwise convolution is becoming increasingly popular in modern efficient ConvNets, but its kernel size is often
+overlooked. In this paper, the authors systematically study the impact of different kernel sizes, and observe that
+combining the benefits of multiple kernel sizes can lead to better accuracy and efficiency. Based on this observation,
+the authors propose a new mixed depthwise convolution (MixConv), which naturally mixes up multiple kernel sizes in a
+single convolution. As a simple drop-in replacement of vanilla depthwise convolution, our MixConv improves the accuracy
+and efficiency for existing MobileNets on both ImageNet classification and COCO object detection.[[1](#references)]
+
+
+
+
+
+ Figure 1. Architecture of MixNet [1]
+
+
+## Results
+
+Our reproduced model performance on ImageNet-1K is reported as follows.
+
+
+
+| Model | Context | Top-1 (%) | Top-5 (%) | Params (M) | Recipe | Download |
+|----------|----------|-----------|-----------|------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------|
+| MixNet_s | D910x8-G | 75.63 | 92.52 | 4.17 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/mixnet/mixnet_s_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/mixnet/mixnet_s-2a5ef3a3.ckpt) |
+
+
+
+#### Notes
+
+- Context: Training context denoted as {device}x{pieces}-{MS mode}, where mindspore mode can be G - graph mode or F - pynative mode with ms function. For example, D910x8-G is for training on 8 pieces of Ascend 910 NPU using graph mode.
+- Top-1 and Top-5: Accuracy reported on the validation set of ImageNet-1K.
+
+## Quick Start
+
+### Preparation
+
+#### Installation
+Please refer to the [installation instruction](https://github.com/mindspore-ecosystem/mindcv#installation) in MindCV.
+
+#### Dataset Preparation
+Please download the [ImageNet-1K](https://www.image-net.org/challenges/LSVRC/2012/index.php) dataset for model training and validation.
+
+### Training
+
+* Distributed Training
+
+It is easy to reproduce the reported results with the pre-defined training recipe. For distributed training on multiple Ascend 910 devices, please run
+
+```shell
+# distrubted training on multiple GPU/Ascend devices
+mpirun -n 8 python train.py --config configs/mixnet/mixnet_s_ascend.yaml --data_dir /path/to/imagenet
+```
+
+> If the script is executed by the root user, the `--allow-run-as-root` parameter must be added to `mpirun`.
+
+Similarly, you can train the model on multiple GPU devices with the above `mpirun` command.
+
+For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindcv/blob/main/config.py).
+
+**Note:** As the global batch size (batch_size x num_devices) is an important hyper-parameter, it is recommended to keep the global batch size unchanged for reproduction or adjust the learning rate linearly to a new global batch size.
+
+* Standalone Training
+
+If you want to train or finetune the model on a smaller dataset without distributed training, please run:
+
+```shell
+# standalone training on a CPU/GPU/Ascend device
+python train.py --config configs/mixnet/mixnet_s_ascend.yaml --data_dir /path/to/dataset --distribute False
+```
+
+### Validation
+
+To validate the accuracy of the trained model, you can use `validate.py` and parse the checkpoint path with `--ckpt_path`.
+
+```shell
+python validate.py -c configs/mixnet/mixnet_s_ascend.yaml --data_dir /path/to/imagenet --ckpt_path /path/to/ckpt
+```
+
+### Deployment
+
+Please refer to the [deployment tutorial](https://github.com/mindspore-lab/mindcv/blob/main/tutorials/deployment.md) in MindCV.
+
+## References
+
+[1] Tan M, Le Q V. Mixconv: Mixed depthwise convolutional kernels[J]. arXiv preprint arXiv:1907.09595, 2019.
diff --git a/configs/mixnet/mixnet_s_ascend.yaml b/configs/mixnet/mixnet_s_ascend.yaml
new file mode 100644
index 00000000..2eb770fd
--- /dev/null
+++ b/configs/mixnet/mixnet_s_ascend.yaml
@@ -0,0 +1,55 @@
+# system
+mode: 0
+distribute: True
+num_parallel_workers: 8
+val_while_train: True
+
+# dataset
+dataset: "imagenet"
+data_dir: "path/to/imagenet"
+shuffle: True
+dataset_download: False
+batch_size: 128
+drop_remainder: True
+
+# augmentation
+image_resize: 224
+scale: [0.08, 1.0]
+ratio: [0.75, 1.333]
+hflip: 0.5
+interpolation: "bicubic"
+auto_augment: "randaug-m9-mstd0.5"
+re_prob: 0.25
+crop_pct: 0.875
+mixup: 0.2
+cutmix: 1.0
+
+# model
+model: "mixnet_s"
+num_classes: 1000
+pretrained: False
+ckpt_path: ''
+keep_checkpoint_max: 10
+ckpt_save_dir: "./ckpt"
+epoch_size: 600
+dataset_sink_mode: True
+amp_level: "O3"
+
+# loss
+loss: "CE"
+label_smoothing: 0.1
+
+# lr scheduler
+scheduler: "warmup_cosine_decay"
+lr: 0.2
+min_lr: 0.00001
+decay_epochs: 585
+warmup_epochs: 15
+
+# optimizer
+opt: "momentum"
+filter_bias_and_bn: True
+momentum: 0.9
+weight_decay: 0.00002
+loss_scale: 256
+use_nesterov: False
diff --git a/mindcv/models/__init__.py b/mindcv/models/__init__.py
index 8d7dbc33..d0521eff 100644
--- a/mindcv/models/__init__.py
+++ b/mindcv/models/__init__.py
@@ -13,6 +13,7 @@
inception_v3,
inception_v4,
layers,
+ mixnet,
mnasnet,
mobilenet_v1,
mobilenet_v2,
@@ -54,6 +55,7 @@
from .inception_v3 import *
from .inception_v4 import *
from .layers import *
+from .mixnet import *
from .mnasnet import *
from .mobilenet_v1 import *
from .mobilenet_v2 import *
@@ -99,6 +101,7 @@
__all__.extend(["InceptionV3", "inception_v3"])
__all__.extend(["InceptionV4", "inception_v4"])
__all__.extend(layers.__all__)
+__all__.extend(mixnet.__all__)
__all__.extend(mnasnet.__all__)
__all__.extend(mobilenet_v1.__all__)
__all__.extend(mobilenet_v2.__all__)
diff --git a/mindcv/models/mixnet.py b/mindcv/models/mixnet.py
new file mode 100644
index 00000000..423fb809
--- /dev/null
+++ b/mindcv/models/mixnet.py
@@ -0,0 +1,416 @@
+"""
+MindSpore implementation of `MixNet`.
+Refer to MixConv: Mixed Depthwise Convolutional Kernels
+"""
+
+import math
+from typing import Optional
+
+import mindspore.common.initializer as init
+from mindspore import Tensor, nn, ops
+
+from .layers.pooling import GlobalAvgPooling
+from .layers.squeeze_excite import SqueezeExcite
+from .registry import register_model
+from .utils import load_pretrained
+
+__all__ = [
+ "MixNet",
+ "mixnet_s",
+ "mixnet_m",
+ "mixnet_l",
+]
+
+
+def _cfg(url="", **kwargs):
+ return {
+ "url": url,
+ "num_classes": 1000,
+ "first_conv": "stem_conv.0", "classifier": "classifier",
+ **kwargs,
+ }
+
+
+default_cfgs = {
+ "mixnet_s": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/mixnet/mixnet_s-2a5ef3a3.ckpt"),
+ "mixnet_m": _cfg(url=""),
+ "mixnet_l": _cfg(url=""),
+}
+
+
+def _roundchannels(filters: float, divisor: int = 8, min_depth: Optional[int] = None) -> int:
+ if min_depth is None:
+ min_depth = divisor
+ new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
+ if new_filters < 0.9 * filters:
+ new_filters += divisor
+ return new_filters
+
+
+def _splitchannels(channels: int, num_groups: int) -> list:
+ split_channels = [channels // num_groups for _ in range(num_groups)]
+ split_channels[0] += channels - sum(split_channels)
+ return split_channels
+
+
+class Swish(nn.Cell):
+ def __init__(self) -> None:
+ super(Swish, self).__init__()
+ self.sigmoid = ops.Sigmoid()
+
+ def construct(self, x: Tensor) -> Tensor:
+ return x * self.sigmoid(x)
+
+
+class GroupedConv2d(nn.Cell):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: list,
+ stride: int = 1,
+ padding: int = 0,
+ ) -> None:
+ super(GroupedConv2d, self).__init__()
+ self.num_groups = len(kernel_size)
+ if self.num_groups == 1:
+ self.grouped_conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size[0],
+ stride=stride,
+ pad_mode="pad",
+ padding=padding,
+ has_bias=False
+ )
+ else:
+ self.split_in_channels = _splitchannels(in_channels, self.num_groups)
+ self.split_out_channels = _splitchannels(out_channels, self.num_groups)
+
+ self.grouped_conv = nn.CellList()
+ for i in range(self.num_groups):
+ self.grouped_conv.append(nn.Conv2d(
+ self.split_in_channels[i],
+ self.split_out_channels[i],
+ kernel_size[i],
+ stride=stride,
+ pad_mode="pad",
+ padding=padding,
+ has_bias=False
+ ))
+
+ def construct(self, x: Tensor) -> Tensor:
+ if self.num_groups == 1:
+ return self.grouped_conv(x)
+
+ output = []
+ start, end = 0, 0
+ for i in range(self.num_groups):
+ start, end = end, end + self.split_in_channels[i]
+ x_split = x[:, start:end]
+
+ conv = self.grouped_conv[i]
+ output.append(conv(x_split))
+
+ return ops.concat(output, axis=1)
+
+
+class MDConv(nn.Cell):
+ """Mixed Depth-wise Convolution"""
+
+ def __init__(self, channels: int, kernel_size: list, stride: int) -> None:
+ super(MDConv, self).__init__()
+ self.num_groups = len(kernel_size)
+
+ if self.num_groups == 1:
+ self.mixed_depthwise_conv = nn.Conv2d(
+ channels,
+ channels,
+ kernel_size[0],
+ stride=stride,
+ pad_mode="pad",
+ padding=kernel_size[0] // 2,
+ group=channels,
+ has_bias=False
+ )
+ else:
+ self.split_channels = _splitchannels(channels, self.num_groups)
+
+ self.mixed_depthwise_conv = nn.CellList()
+ for i in range(self.num_groups):
+ self.mixed_depthwise_conv.append(nn.Conv2d(
+ self.split_channels[i],
+ self.split_channels[i],
+ kernel_size[i],
+ stride=stride,
+ pad_mode="pad",
+ padding=kernel_size[i] // 2,
+ group=self.split_channels[i],
+ has_bias=False
+ ))
+
+ def construct(self, x: Tensor) -> Tensor:
+ if self.num_groups == 1:
+ return self.mixed_depthwise_conv(x)
+
+ output = []
+ start, end = 0, 0
+ for i in range(self.num_groups):
+ start, end = end, end + self.split_channels[i]
+ x_split = x[:, start:end]
+
+ conv = self.mixed_depthwise_conv[i]
+ output.append(conv(x_split))
+
+ return ops.concat(output, axis=1)
+
+
+class MixNetBlock(nn.Cell):
+ """Basic Block of MixNet"""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: list = [3],
+ expand_ksize: list = [1],
+ project_ksize: list = [1],
+ stride: int = 1,
+ expand_ratio: int = 1,
+ activation: str = "ReLU",
+ se_ratio: float = 0.0,
+ ) -> None:
+ super(MixNetBlock, self).__init__()
+ assert activation in ["ReLU", "Swish"]
+ self.activation = Swish if activation == "Swish" else nn.ReLU
+
+ expand_channels = in_channels * expand_ratio
+ self.residual_connection = (stride == 1 and in_channels == out_channels)
+
+ conv = []
+ if expand_ratio != 1:
+ # expand
+ conv.extend([
+ GroupedConv2d(in_channels, expand_channels, expand_ksize),
+ nn.BatchNorm2d(expand_channels),
+ self.activation()
+ ])
+
+ # depthwise
+ conv.extend([
+ MDConv(expand_channels, kernel_size, stride),
+ nn.BatchNorm2d(expand_channels),
+ self.activation()
+ ])
+
+ if se_ratio > 0:
+ squeeze_channels = int(in_channels * se_ratio)
+ squeeze_excite = SqueezeExcite(expand_channels, rd_channels=squeeze_channels)
+ conv.append(squeeze_excite)
+
+ # projection phase
+ conv.extend([
+ GroupedConv2d(expand_channels, out_channels, project_ksize),
+ nn.BatchNorm2d(out_channels)
+ ])
+
+ self.convs = nn.SequentialCell(conv)
+
+ def construct(self, x: Tensor) -> Tensor:
+ if self.residual_connection:
+ return x + self.convs(x)
+ else:
+ return self.convs(x)
+
+
+class MixNet(nn.Cell):
+ r"""MixNet model class, based on
+ `"MixConv: Mixed Depthwise Convolutional Kernels" `_
+
+ Args:
+ arch: size of the architecture. "small", "medium" or "large". Default: "small".
+ num_classes: number of classification classes. Default: 1000.
+ in_channels: number of the channels of the input. Default: 3.
+ feature_size: numbet of the channels of the output features. Default: 1536.
+ drop_rate: rate of dropout for classifier. Default: 0.2.
+ depth_multiplier: expansion coefficient of channels. Default: 1.0.
+ """
+
+ def __init__(
+ self,
+ arch: str = "small",
+ num_classes: int = 1000,
+ in_channels: int = 3,
+ feature_size: int = 1536,
+ drop_rate: float = 0.2,
+ depth_multiplier: float = 1.0
+ ) -> None:
+ super(MixNet, self).__init__()
+ if arch == "small":
+ block_configs = [
+ [16, 16, [3], [1], [1], 1, 1, "ReLU", 0.0],
+ [16, 24, [3], [1, 1], [1, 1], 2, 6, "ReLU", 0.0],
+ [24, 24, [3], [1, 1], [1, 1], 1, 3, "ReLU", 0.0],
+ [24, 40, [3, 5, 7], [1], [1], 2, 6, "Swish", 0.5],
+ [40, 40, [3, 5], [1, 1], [1, 1], 1, 6, "Swish", 0.5],
+ [40, 40, [3, 5], [1, 1], [1, 1], 1, 6, "Swish", 0.5],
+ [40, 40, [3, 5], [1, 1], [1, 1], 1, 6, "Swish", 0.5],
+ [40, 80, [3, 5, 7], [1], [1, 1], 2, 6, "Swish", 0.25],
+ [80, 80, [3, 5], [1], [1, 1], 1, 6, "Swish", 0.25],
+ [80, 80, [3, 5], [1], [1, 1], 1, 6, "Swish", 0.25],
+ [80, 120, [3, 5, 7], [1, 1], [1, 1], 1, 6, "Swish", 0.5],
+ [120, 120, [3, 5, 7, 9], [1, 1], [1, 1], 1, 3, "Swish", 0.5],
+ [120, 120, [3, 5, 7, 9], [1, 1], [1, 1], 1, 3, "Swish", 0.5],
+ [120, 200, [3, 5, 7, 9, 11], [1], [1], 2, 6, "Swish", 0.5],
+ [200, 200, [3, 5, 7, 9], [1], [1, 1], 1, 6, "Swish", 0.5],
+ [200, 200, [3, 5, 7, 9], [1], [1, 1], 1, 6, "Swish", 0.5]
+ ]
+ stem_channels = 16
+ drop_rate = drop_rate
+ else:
+ block_configs = [
+ [24, 24, [3], [1], [1], 1, 1, "ReLU", 0.0],
+ [24, 32, [3, 5, 7], [1, 1], [1, 1], 2, 6, "ReLU", 0.0],
+ [32, 32, [3], [1, 1], [1, 1], 1, 3, "ReLU", 0.0],
+ [32, 40, [3, 5, 7, 9], [1], [1], 2, 6, "Swish", 0.5],
+ [40, 40, [3, 5], [1, 1], [1, 1], 1, 6, "Swish", 0.5],
+ [40, 40, [3, 5], [1, 1], [1, 1], 1, 6, "Swish", 0.5],
+ [40, 40, [3, 5], [1, 1], [1, 1], 1, 6, "Swish", 0.5],
+ [40, 80, [3, 5, 7], [1], [1], 2, 6, "Swish", 0.25],
+ [80, 80, [3, 5, 7, 9], [1, 1], [1, 1], 1, 6, "Swish", 0.25],
+ [80, 80, [3, 5, 7, 9], [1, 1], [1, 1], 1, 6, "Swish", 0.25],
+ [80, 80, [3, 5, 7, 9], [1, 1], [1, 1], 1, 6, "Swish", 0.25],
+ [80, 120, [3], [1], [1], 1, 6, "Swish", 0.5],
+ [120, 120, [3, 5, 7, 9], [1, 1], [1, 1], 1, 3, "Swish", 0.5],
+ [120, 120, [3, 5, 7, 9], [1, 1], [1, 1], 1, 3, "Swish", 0.5],
+ [120, 120, [3, 5, 7, 9], [1, 1], [1, 1], 1, 3, "Swish", 0.5],
+ [120, 200, [3, 5, 7, 9], [1], [1], 2, 6, "Swish", 0.5],
+ [200, 200, [3, 5, 7, 9], [1], [1, 1], 1, 6, "Swish", 0.5],
+ [200, 200, [3, 5, 7, 9], [1], [1, 1], 1, 6, "Swish", 0.5],
+ [200, 200, [3, 5, 7, 9], [1], [1, 1], 1, 6, "Swish", 0.5]
+ ]
+ if arch == "medium":
+ stem_channels = 24
+ drop_rate = drop_rate
+ elif arch == "large":
+ stem_channels = 24
+ depth_multiplier *= 1.3
+ drop_rate = drop_rate
+ else:
+ raise ValueError(f"Unsupported model type {arch}")
+
+ if depth_multiplier != 1.0:
+ stem_channels = _roundchannels(stem_channels * depth_multiplier)
+
+ for i, conf in enumerate(block_configs):
+ conf_ls = list(conf)
+ conf_ls[0] = _roundchannels(conf_ls[0] * depth_multiplier)
+ conf_ls[1] = _roundchannels(conf_ls[1] * depth_multiplier)
+ block_configs[i] = tuple(conf_ls)
+
+ # stem convolution
+ self.stem_conv = nn.SequentialCell([
+ nn.Conv2d(in_channels, stem_channels, 3, stride=2, pad_mode="pad", padding=1),
+ nn.BatchNorm2d(stem_channels),
+ nn.ReLU()
+ ])
+
+ # building MixNet blocks
+ layers = []
+ for inc, outc, k, ek, pk, s, er, ac, se in block_configs:
+ layers.append(MixNetBlock(
+ inc,
+ outc,
+ kernel_size=k,
+ expand_ksize=ek,
+ project_ksize=pk,
+ stride=s,
+ expand_ratio=er,
+ activation=ac,
+ se_ratio=se
+ ))
+ self.layers = nn.SequentialCell(layers)
+
+ # head
+ self.head_conv = nn.SequentialCell([
+ nn.Conv2d(block_configs[-1][1], feature_size, 1, pad_mode="pad", padding=0),
+ nn.BatchNorm2d(feature_size),
+ nn.ReLU()
+ ])
+
+ self.pool = GlobalAvgPooling()
+ self.dropout = nn.Dropout(keep_prob=1 - drop_rate)
+ self.classifier = nn.Dense(feature_size, num_classes)
+
+ self._initialize_weights()
+
+ def _initialize_weights(self) -> None:
+ """Initialize weights for cells."""
+ for _, cell in self.cells_and_names():
+ if isinstance(cell, nn.Conv2d):
+ fan_out = cell.kernel_size[0] * cell.kernel_size[1] * cell.out_channels
+ cell.weight.set_data(
+ init.initializer(init.Normal(math.sqrt(2.0 / fan_out)),
+ cell.weight.shape, cell.weight.dtype))
+ if cell.bias is not None:
+ cell.bias.set_data(
+ init.initializer("zeros", cell.bias.shape, cell.bias.dtype))
+ elif isinstance(cell, nn.BatchNorm2d):
+ cell.gamma.set_data(init.initializer("ones", cell.gamma.shape, cell.gamma.dtype))
+ cell.beta.set_data(init.initializer("zeros", cell.beta.shape, cell.beta.dtype))
+ elif isinstance(cell, nn.Dense):
+ cell.weight.set_data(
+ init.initializer(init.Uniform(1.0 / math.sqrt(cell.weight.shape[0])),
+ cell.weight.shape, cell.weight.dtype))
+ if cell.bias is not None:
+ cell.bias.set_data(init.initializer("zeros", cell.bias.shape, cell.bias.dtype))
+
+ def forward_features(self, x: Tensor) -> Tensor:
+ x = self.stem_conv(x)
+ x = self.layers(x)
+ x = self.head_conv(x)
+ return x
+
+ def forward_head(self, x: Tensor) -> Tensor:
+ x = self.pool(x)
+ x = self.dropout(x)
+ x = self.classifier(x)
+ return x
+
+ def construct(self, x: Tensor) -> Tensor:
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+@register_model
+def mixnet_s(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
+ default_cfg = default_cfgs["mixnet_s"]
+ model = MixNet(arch="small", in_channels=in_channels, num_classes=num_classes, **kwargs)
+
+ if pretrained:
+ load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
+
+ return model
+
+
+@register_model
+def mixnet_m(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
+ default_cfg = default_cfgs["mixnet_m"]
+ model = MixNet(arch="medium", in_channels=in_channels, num_classes=num_classes, **kwargs)
+
+ if pretrained:
+ load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
+
+ return model
+
+
+@register_model
+def mixnet_l(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
+ default_cfg = default_cfgs["mixnet_l"]
+ model = MixNet(arch="large", in_channels=in_channels, num_classes=num_classes, **kwargs)
+
+ if pretrained:
+ load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
+
+ return model