Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] add model script, training configs and training weights of resnetv2 #515

Merged
merged 1 commit into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions configs/resnetv2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# ResNetV2

> [Identity Mappings in Deep Residual Networks](https://arxiv.org/abs/1603.05027)

## Introduction

Author analyzes the propagation formulations behind the residual building blocks, which suggest that the forward and backward signals can be directly propagated from one block
to any other block, when using identity mappings as the skip connections and after-addition activation.

<p align="center">
<img src="https://user-images.githubusercontent.com/52945530/224595993-ba8617da-e55d-4d19-a487-3340026393c9.png" width=300 height=400 />
</p>
<p align="center">
<em>Figure 1. Architecture of ResNetV2 [<a href="#references">1</a>] </em>
</p>

## Results

Our reproduced model performance on ImageNet-1K is reported as follows.

<div align="center">

| Model | Context | Top-1 (%) | Top-5 (%) | Params (M) | Recipe | Download |
|-----------------|-----------|-----------|-----------|-------|-------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------|
| ResNetv2_50 | D910x8-G | 76.90 | 93.37 | 25.60 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/resnetv2/resnetv2_50_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/resnetv2/resnetv2_50-8da5c0f4.ckpt) |
| ResNetv2_101 | D910x8-G | 78.48 | 94.23 | 44.55 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/resnetv2/resnetv2_101_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/resnetv2/resnetv2_101-c14199e9.ckpt) |

</div>

#### Notes
- Context: Training context denoted as {device}x{pieces}-{MS mode}, where mindspore mode can be G - graph mode or F - pynative mode with ms function. For example, D910x8-G is for training on 8 pieces of Ascend 910 NPU using graph mode.
- Top-1 and Top-5: Accuracy reported on the validation set of ImageNet-1K.


## Quick Start

### Preparation

#### Installation
Please refer to the [installation instruction](https://github.com/mindspore-lab/mindcv#installation) in MindCV.

#### Dataset Preparation
Please download the [ImageNet-1K](https://www.image-net.org/challenges/LSVRC/2012/index.php) dataset for model training and validation.

### Training

* Distributed Training

It is easy to reproduce the reported results with the pre-defined training recipe. For distributed training on multiple Ascend 910 devices, please run

```shell
# distributed training on multiple GPU/Ascend devices
mpirun -n 8 python train.py --config configs/resnetv2/resnetv2_50_ascend.yaml --data_dir /path/to/imagenet
```

> If the script is executed by the root user, the `--allow-run-as-root` parameter must be added to `mpirun`.

Similarly, you can train the model on multiple GPU devices with the above `mpirun` command.

For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindcv/blob/main/config.py).

**Note:** As the global batch size (batch_size x num_devices) is an important hyper-parameter, it is recommended to keep the global batch size unchanged for reproduction or adjust the learning rate linearly to a new global batch size.

* Standalone Training

If you want to train or finetune the model on a smaller dataset without distributed training, please run:

```shell
# standalone training on a CPU/GPU/Ascend device
python train.py --config configs/resnetv2/resnetv2_50_ascend.yaml --data_dir /path/to/dataset --distribute False
```

### Validation

To validate the accuracy of the trained model, you can use `validate.py` and parse the checkpoint path with `--ckpt_path`.

```shell
python validate.py -c configs/resnetv2/resnetv2_50_ascend.yaml --data_dir /path/to/imagenet --ckpt_path /path/to/ckpt
```

### Deployment

To deploy online inference services with the trained model efficiently, please refer to the [deployment tutorial](https://github.com/mindspore-lab/mindcv/blob/main/tutorials/deployment.md).

## References

[1] He K, Zhang X, Ren S, et al. Identity mappings in deep residual networks[C]//Computer Vision–ECCV 2016: 14th European Conference, Amsterdam, The Netherlands, October 11–14, 2016, Proceedings, Part IV 14. Springer International Publishing, 2016: 630-645.
51 changes: 51 additions & 0 deletions configs/resnetv2/resnetv2_101_ascend.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# system
mode: 0
distribute: True
num_parallel_workers: 8
val_while_train: True

# dataset
dataset: "imagenet"
data_dir: "/path/to/imagenet"
shuffle: True
dataset_download: False
batch_size: 32
drop_remainder: True

# augmentation
image_resize: 224
scale: [0.08, 1.0]
ratio: [0.75, 1.333]
hflip: 0.5
interpolation: "bilinear"
crop_pct: 0.875

# model
model: "resnetv2_101"
num_classes: 1000
pretrained: False
ckpt_path: ""
keep_checkpoint_max: 10
ckpt_save_dir: "./ckpt"
epoch_size: 120
dataset_sink_mode: True
amp_level: "O2"

# loss
loss: "CE"
label_smoothing: 0.1

# lr scheduler
scheduler: "cosine_decay"
min_lr: 0.0
lr: 0.1
warmup_epochs: 0
decay_epochs: 120

# optimizer
opt: "momentum"
filter_bias_and_bn: True
momentum: 0.9
weight_decay: 0.0001
loss_scale: 1024
use_nesterov: False
51 changes: 51 additions & 0 deletions configs/resnetv2/resnetv2_50_ascend.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# system
mode: 0
distribute: True
num_parallel_workers: 8
val_while_train: True

# dataset
dataset: "imagenet"
data_dir: "/path/to/imagenet"
shuffle: True
dataset_download: False
batch_size: 32
drop_remainder: True

# augmentation
image_resize: 224
scale: [0.08, 1.0]
ratio: [0.75, 1.333]
hflip: 0.5
interpolation: "bilinear"
crop_pct: 0.875

# model
model: "resnetv2_50"
num_classes: 1000
pretrained: False
ckpt_path: ""
keep_checkpoint_max: 30
ckpt_save_dir: "./ckpt"
epoch_size: 120
dataset_sink_mode: True
amp_level: "O2"

# loss
loss: "CE"
label_smoothing: 0.1

# lr scheduler
scheduler: "cosine_decay"
min_lr: 0.0
lr: 0.1
warmup_epochs: 0
decay_epochs: 120

# optimizer
opt: "momentum"
filter_bias_and_bn: True
momentum: 0.9
weight_decay: 0.0001
loss_scale: 1024
use_nesterov: False
3 changes: 3 additions & 0 deletions mindcv/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
repvgg,
res2net,
resnet,
resnetv2,
rexnet,
senet,
shufflenetv1,
Expand Down Expand Up @@ -72,6 +73,7 @@
from .repvgg import *
from .res2net import *
from .resnet import *
from .resnetv2 import *
from .rexnet import *
from .senet import *
from .shufflenetv1 import *
Expand Down Expand Up @@ -118,6 +120,7 @@
__all__.extend(["RepVGG", "repvgg"])
__all__.extend(res2net.__all__)
__all__.extend(resnet.__all__)
__all__.extend(resnetv2.__all__)
__all__.extend(rexnet.__all__)
__all__.extend(senet.__all__)
__all__.extend(shufflenetv1.__all__)
Expand Down
119 changes: 119 additions & 0 deletions mindcv/models/resnetv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""
MindSpore implementation of `ResNetV2`.
Refer to Identity Mappings in Deep Residual Networks.
"""

from typing import Optional

from mindspore import Tensor, nn

from .registry import register_model
from .resnet import ResNet
from .utils import load_pretrained

__all__ = [
"resnetv2_50",
"resnetv2_101",
]


def _cfg(url='', **kwargs):
return {
"url": url,
"num_classes": 1000,
"first_conv": "conv1",
"classifier": "classifier",
**kwargs
}


default_cfgs = {
"resnetv2_50": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnetv2/resnetv2_50-8da5c0f4.ckpt"),
"resnetv2_101": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnetv2/resnetv2_101-c14199e9.ckpt"),
}


class PreActBottleneck(nn.Cell):
expansion: int = 4

def __init__(self,
in_channels: int,
channels: int,
stride: int = 1,
groups: int = 1,
base_width: int = 64,
norm: Optional[nn.Cell] = None,
down_sample: Optional[nn.Cell] = None
) -> None:
super().__init__()
if norm is None:
norm = nn.BatchNorm2d

width = int(channels * (base_width / 64.0)) * groups

self.bn1 = norm(in_channels)
self.conv1 = nn.Conv2d(in_channels, width, kernel_size=1, stride=1)

self.bn2 = norm(width)
self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
padding=1, pad_mode='pad', group=groups)

self.bn3 = norm(width)
self.conv3 = nn.Conv2d(width, channels * self.expansion,
kernel_size=1, stride=1)

self.relu = nn.ReLU()
self.down_sample = down_sample

def construct(self, x: Tensor) -> Tensor:
identity = x

out = self.bn1(x)
out = self.relu(out)

residual = out

out = self.conv1(out)

out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)

out = self.bn3(out)
out = self.relu(out)
out = self.conv3(out)

if self.down_sample is not None:
identity = self.down_sample(residual)

out += identity

return out


@register_model
def resnetv2_50(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
"""Get 50 layers ResNetV2 model.
Refer to the base class `models.ResNet` for more details.
"""
default_cfg = default_cfgs['resnetv2_50']
model = ResNet(PreActBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_channels=in_channels, **kwargs)

if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)

return model


@register_model
def resnetv2_101(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
"""Get 101 layers ResNetV2 model.
Refer to the base class `models.ResNet` for more details.
"""
default_cfg = default_cfgs["resnetv2_101"]
model = ResNet(PreActBottleneck, [3, 4, 23, 3], num_classes=num_classes, in_channels=in_channels, **kwargs)

if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)

return model