From effea886e333fa1a48b42b234f82e712c6825e4f Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Sun, 4 Jul 2021 13:43:07 +0800 Subject: [PATCH] add ghostnet backbone unit test --- nanodet/model/backbone/ghostnet.py | 48 +++++++++++++------ nanodet/model/backbone/mobilenetv2.py | 38 +++++++++++---- .../test_backbone/test_ghostnet.py | 44 +++++++++++++++++ 3 files changed, 106 insertions(+), 24 deletions(-) create mode 100644 tests/test_models/test_backbone/test_ghostnet.py diff --git a/nanodet/model/backbone/ghostnet.py b/nanodet/model/backbone/ghostnet.py index a670d7ab1..06c7119e3 100644 --- a/nanodet/model/backbone/ghostnet.py +++ b/nanodet/model/backbone/ghostnet.py @@ -10,6 +10,7 @@ """ import logging import math +import warnings import torch import torch.nn as nn @@ -20,7 +21,7 @@ def get_url(width_mult=1.0): if width_mult == 1.0: - return "https://github.com/huawei-noah/ghostnet/raw/master/pytorch/models/state_dict_93.98.pth" # noqa E501 + return "https://raw.githubusercontent.com/huawei-noah/CV-Backbones/master/ghostnet_pytorch/models/state_dict_73.98.pth" # noqa E501 else: logging.info("GhostNet only has 1.0 pretrain model. ") return None @@ -55,7 +56,7 @@ def __init__( in_chs, se_ratio=0.25, reduced_base_chs=None, - act="ReLU", + activation="ReLU", gate_fn=hard_sigmoid, divisor=4, **_ @@ -65,7 +66,7 @@ def __init__( reduced_chs = _make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor) self.avg_pool = nn.AdaptiveAvgPool2d(1) self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) - self.act1 = act_layers(act) + self.act1 = act_layers(activation) self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) def forward(self, x): @@ -78,13 +79,13 @@ def forward(self, x): class ConvBnAct(nn.Module): - def __init__(self, in_chs, out_chs, kernel_size, stride=1, act="ReLU"): + def __init__(self, in_chs, out_chs, kernel_size, stride=1, activation="ReLU"): super(ConvBnAct, self).__init__() self.conv = nn.Conv2d( in_chs, out_chs, kernel_size, stride, kernel_size // 2, bias=False ) self.bn1 = nn.BatchNorm2d(out_chs) - self.act1 = act_layers(act) + self.act1 = act_layers(activation) def forward(self, x): x = self.conv(x) @@ -95,7 +96,7 @@ def forward(self, x): class GhostModule(nn.Module): def __init__( - self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, act="ReLU" + self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, activation="ReLU" ): super(GhostModule, self).__init__() self.oup = oup @@ -107,7 +108,7 @@ def __init__( inp, init_channels, kernel_size, stride, kernel_size // 2, bias=False ), nn.BatchNorm2d(init_channels), - act_layers(act) if act else nn.Sequential(), + act_layers(activation) if activation else nn.Sequential(), ) self.cheap_operation = nn.Sequential( @@ -121,7 +122,7 @@ def __init__( bias=False, ), nn.BatchNorm2d(new_channels), - act_layers(act) if act else nn.Sequential(), + act_layers(activation) if activation else nn.Sequential(), ) def forward(self, x): @@ -141,7 +142,7 @@ def __init__( out_chs, dw_kernel_size=3, stride=1, - act="ReLU", + activation="ReLU", se_ratio=0.0, ): super(GhostBottleneck, self).__init__() @@ -149,7 +150,7 @@ def __init__( self.stride = stride # Point-wise expansion - self.ghost1 = GhostModule(in_chs, mid_chs, act=act) + self.ghost1 = GhostModule(in_chs, mid_chs, activation=activation) # Depth-wise convolution if self.stride > 1: @@ -171,7 +172,7 @@ def __init__( self.se = None # Point-wise linear projection - self.ghost2 = GhostModule(mid_chs, out_chs, act=None) + self.ghost2 = GhostModule(mid_chs, out_chs, activation=None) # shortcut if in_chs == out_chs and self.stride == 1: @@ -215,8 +216,16 @@ def forward(self, x): class GhostNet(nn.Module): - def __init__(self, width_mult=1.0, out_stages=(4, 6, 9), act="ReLU", pretrain=True): + def __init__( + self, + width_mult=1.0, + out_stages=(4, 6, 9), + activation="ReLU", + pretrain=True, + act=None, + ): super(GhostNet, self).__init__() + assert set(out_stages).issubset(i for i in range(10)) self.width_mult = width_mult self.out_stages = out_stages # setting of inverted residual blocks @@ -250,11 +259,18 @@ def __init__(self, width_mult=1.0, out_stages=(4, 6, 9), act="ReLU", pretrain=Tr ] # ------conv+bn+act----------# 9 1/32 + self.activation = activation + if act is not None: + warnings.warn( + "Warning! act argument has been deprecated, " "use activation instead!" + ) + self.activation = act + # building first layer output_channel = _make_divisible(16 * width_mult, 4) self.conv_stem = nn.Conv2d(3, output_channel, 3, 2, 1, bias=False) self.bn1 = nn.BatchNorm2d(output_channel) - self.act1 = act_layers(act) + self.act1 = act_layers(self.activation) input_channel = output_channel # building inverted residual blocks @@ -272,7 +288,7 @@ def __init__(self, width_mult=1.0, out_stages=(4, 6, 9), act="ReLU", pretrain=Tr output_channel, k, s, - act=act, + activation=self.activation, se_ratio=se_ratio, ) ) @@ -281,7 +297,9 @@ def __init__(self, width_mult=1.0, out_stages=(4, 6, 9), act="ReLU", pretrain=Tr output_channel = _make_divisible(exp_size * width_mult, 4) stages.append( - nn.Sequential(ConvBnAct(input_channel, output_channel, 1, act=act)) + nn.Sequential( + ConvBnAct(input_channel, output_channel, 1, activation=self.activation) + ) ) # 9 self.blocks = nn.Sequential(*stages) diff --git a/nanodet/model/backbone/mobilenetv2.py b/nanodet/model/backbone/mobilenetv2.py index d376d9fe0..11d7978a8 100644 --- a/nanodet/model/backbone/mobilenetv2.py +++ b/nanodet/model/backbone/mobilenetv2.py @@ -1,5 +1,7 @@ from __future__ import absolute_import, division, print_function +import warnings + import torch.nn as nn from ..module.activation import act_layers @@ -7,7 +9,13 @@ class ConvBNReLU(nn.Sequential): def __init__( - self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, act="ReLU" + self, + in_planes, + out_planes, + kernel_size=3, + stride=1, + groups=1, + activation="ReLU", ): padding = (kernel_size - 1) // 2 super(ConvBNReLU, self).__init__( @@ -21,12 +29,12 @@ def __init__( bias=False, ), nn.BatchNorm2d(out_planes), - act_layers(act), + act_layers(activation), ) class InvertedResidual(nn.Module): - def __init__(self, inp, oup, stride, expand_ratio, act="ReLU"): + def __init__(self, inp, oup, stride, expand_ratio, activation="ReLU"): super(InvertedResidual, self).__init__() self.stride = stride assert stride in [1, 2] @@ -37,12 +45,18 @@ def __init__(self, inp, oup, stride, expand_ratio, act="ReLU"): layers = [] if expand_ratio != 1: # pw - layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, act=act)) + layers.append( + ConvBNReLU(inp, hidden_dim, kernel_size=1, activation=activation) + ) layers.extend( [ # dw ConvBNReLU( - hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, act=act + hidden_dim, + hidden_dim, + stride=stride, + groups=hidden_dim, + activation=activation, ), # pw-linear nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), @@ -65,6 +79,7 @@ def __init__( out_stages=(1, 2, 4, 6), last_channel=1280, activation="ReLU", + act=None, ): super(MobileNetV2, self).__init__() # TODO: support load torchvison pretrained weight @@ -74,6 +89,11 @@ def __init__( input_channel = 32 self.last_channel = last_channel self.activation = activation + if act is not None: + warnings.warn( + "Warning! act argument has been deprecated, " "use activation instead!" + ) + self.activation = act self.interverted_residual_setting = [ # t, c, n, s [1, 16, 1, 1], @@ -88,7 +108,7 @@ def __init__( # building first layer self.input_channel = int(input_channel * width_mult) self.first_layer = ConvBNReLU( - 3, self.input_channel, stride=2, act=self.activation + 3, self.input_channel, stride=2, activation=self.activation ) # building inverted residual blocks for i in range(7): @@ -109,7 +129,7 @@ def build_mobilenet_stage(self, stage_num): output_channel, s, expand_ratio=t, - act=self.activation, + activation=self.activation, ) ) else: @@ -119,7 +139,7 @@ def build_mobilenet_stage(self, stage_num): output_channel, 1, expand_ratio=t, - act=self.activation, + activation=self.activation, ) ) self.input_channel = output_channel @@ -128,7 +148,7 @@ def build_mobilenet_stage(self, stage_num): self.input_channel, self.last_channel, kernel_size=1, - act=self.activation, + activation=self.activation, ) stage.append(last_layer) stage = nn.Sequential(*stage) diff --git a/tests/test_models/test_backbone/test_ghostnet.py b/tests/test_models/test_backbone/test_ghostnet.py new file mode 100644 index 000000000..7e10a0009 --- /dev/null +++ b/tests/test_models/test_backbone/test_ghostnet.py @@ -0,0 +1,44 @@ +import pytest +import torch + +from nanodet.model.backbone import GhostNet, build_backbone + + +def test_ghostnet(): + with pytest.raises(AssertionError): + cfg = dict(name="GhostNet", width_mult=1.0, out_stages=(11, 12), pretrain=False) + build_backbone(cfg) + + input = torch.rand(1, 3, 64, 64) + out_stages = [i for i in range(10)] + model = GhostNet( + width_mult=1.0, out_stages=out_stages, activation="ReLU6", pretrain=True + ) + output = model(input) + + assert output[0].shape == torch.Size([1, 16, 32, 32]) + assert output[1].shape == torch.Size([1, 24, 16, 16]) + assert output[2].shape == torch.Size([1, 24, 16, 16]) + assert output[3].shape == torch.Size([1, 40, 8, 8]) + assert output[4].shape == torch.Size([1, 40, 8, 8]) + assert output[5].shape == torch.Size([1, 80, 4, 4]) + assert output[6].shape == torch.Size([1, 112, 4, 4]) + assert output[7].shape == torch.Size([1, 160, 2, 2]) + assert output[8].shape == torch.Size([1, 160, 2, 2]) + assert output[9].shape == torch.Size([1, 960, 2, 2]) + + model = GhostNet( + width_mult=0.75, out_stages=out_stages, activation="LeakyReLU", pretrain=False + ) + output = model(input) + + assert output[0].shape == torch.Size([1, 12, 32, 32]) + assert output[1].shape == torch.Size([1, 20, 16, 16]) + assert output[2].shape == torch.Size([1, 20, 16, 16]) + assert output[3].shape == torch.Size([1, 32, 8, 8]) + assert output[4].shape == torch.Size([1, 32, 8, 8]) + assert output[5].shape == torch.Size([1, 60, 4, 4]) + assert output[6].shape == torch.Size([1, 84, 4, 4]) + assert output[7].shape == torch.Size([1, 120, 2, 2]) + assert output[8].shape == torch.Size([1, 120, 2, 2]) + assert output[9].shape == torch.Size([1, 720, 2, 2])