Skip to content

Commit

Permalink
add ghostnet backbone unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
RangiLyu committed Jul 4, 2021
1 parent 0c3906b commit effea88
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 24 deletions.
48 changes: 33 additions & 15 deletions nanodet/model/backbone/ghostnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"""
import logging
import math
import warnings

import torch
import torch.nn as nn
Expand All @@ -20,7 +21,7 @@

def get_url(width_mult=1.0):
if width_mult == 1.0:
return "https://github.com/huawei-noah/ghostnet/raw/master/pytorch/models/state_dict_93.98.pth" # noqa E501
return "https://raw.githubusercontent.com/huawei-noah/CV-Backbones/master/ghostnet_pytorch/models/state_dict_73.98.pth" # noqa E501
else:
logging.info("GhostNet only has 1.0 pretrain model. ")
return None
Expand Down Expand Up @@ -55,7 +56,7 @@ def __init__(
in_chs,
se_ratio=0.25,
reduced_base_chs=None,
act="ReLU",
activation="ReLU",
gate_fn=hard_sigmoid,
divisor=4,
**_
Expand All @@ -65,7 +66,7 @@ def __init__(
reduced_chs = _make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
self.act1 = act_layers(act)
self.act1 = act_layers(activation)
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)

def forward(self, x):
Expand All @@ -78,13 +79,13 @@ def forward(self, x):


class ConvBnAct(nn.Module):
def __init__(self, in_chs, out_chs, kernel_size, stride=1, act="ReLU"):
def __init__(self, in_chs, out_chs, kernel_size, stride=1, activation="ReLU"):
super(ConvBnAct, self).__init__()
self.conv = nn.Conv2d(
in_chs, out_chs, kernel_size, stride, kernel_size // 2, bias=False
)
self.bn1 = nn.BatchNorm2d(out_chs)
self.act1 = act_layers(act)
self.act1 = act_layers(activation)

def forward(self, x):
x = self.conv(x)
Expand All @@ -95,7 +96,7 @@ def forward(self, x):

class GhostModule(nn.Module):
def __init__(
self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, act="ReLU"
self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, activation="ReLU"
):
super(GhostModule, self).__init__()
self.oup = oup
Expand All @@ -107,7 +108,7 @@ def __init__(
inp, init_channels, kernel_size, stride, kernel_size // 2, bias=False
),
nn.BatchNorm2d(init_channels),
act_layers(act) if act else nn.Sequential(),
act_layers(activation) if activation else nn.Sequential(),
)

self.cheap_operation = nn.Sequential(
Expand All @@ -121,7 +122,7 @@ def __init__(
bias=False,
),
nn.BatchNorm2d(new_channels),
act_layers(act) if act else nn.Sequential(),
act_layers(activation) if activation else nn.Sequential(),
)

def forward(self, x):
Expand All @@ -141,15 +142,15 @@ def __init__(
out_chs,
dw_kernel_size=3,
stride=1,
act="ReLU",
activation="ReLU",
se_ratio=0.0,
):
super(GhostBottleneck, self).__init__()
has_se = se_ratio is not None and se_ratio > 0.0
self.stride = stride

# Point-wise expansion
self.ghost1 = GhostModule(in_chs, mid_chs, act=act)
self.ghost1 = GhostModule(in_chs, mid_chs, activation=activation)

# Depth-wise convolution
if self.stride > 1:
Expand All @@ -171,7 +172,7 @@ def __init__(
self.se = None

# Point-wise linear projection
self.ghost2 = GhostModule(mid_chs, out_chs, act=None)
self.ghost2 = GhostModule(mid_chs, out_chs, activation=None)

# shortcut
if in_chs == out_chs and self.stride == 1:
Expand Down Expand Up @@ -215,8 +216,16 @@ def forward(self, x):


class GhostNet(nn.Module):
def __init__(self, width_mult=1.0, out_stages=(4, 6, 9), act="ReLU", pretrain=True):
def __init__(
self,
width_mult=1.0,
out_stages=(4, 6, 9),
activation="ReLU",
pretrain=True,
act=None,
):
super(GhostNet, self).__init__()
assert set(out_stages).issubset(i for i in range(10))
self.width_mult = width_mult
self.out_stages = out_stages
# setting of inverted residual blocks
Expand Down Expand Up @@ -250,11 +259,18 @@ def __init__(self, width_mult=1.0, out_stages=(4, 6, 9), act="ReLU", pretrain=Tr
]
# ------conv+bn+act----------# 9 1/32

self.activation = activation
if act is not None:
warnings.warn(
"Warning! act argument has been deprecated, " "use activation instead!"
)
self.activation = act

# building first layer
output_channel = _make_divisible(16 * width_mult, 4)
self.conv_stem = nn.Conv2d(3, output_channel, 3, 2, 1, bias=False)
self.bn1 = nn.BatchNorm2d(output_channel)
self.act1 = act_layers(act)
self.act1 = act_layers(self.activation)
input_channel = output_channel

# building inverted residual blocks
Expand All @@ -272,7 +288,7 @@ def __init__(self, width_mult=1.0, out_stages=(4, 6, 9), act="ReLU", pretrain=Tr
output_channel,
k,
s,
act=act,
activation=self.activation,
se_ratio=se_ratio,
)
)
Expand All @@ -281,7 +297,9 @@ def __init__(self, width_mult=1.0, out_stages=(4, 6, 9), act="ReLU", pretrain=Tr

output_channel = _make_divisible(exp_size * width_mult, 4)
stages.append(
nn.Sequential(ConvBnAct(input_channel, output_channel, 1, act=act))
nn.Sequential(
ConvBnAct(input_channel, output_channel, 1, activation=self.activation)
)
) # 9

self.blocks = nn.Sequential(*stages)
Expand Down
38 changes: 29 additions & 9 deletions nanodet/model/backbone/mobilenetv2.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
from __future__ import absolute_import, division, print_function

import warnings

import torch.nn as nn

from ..module.activation import act_layers


class ConvBNReLU(nn.Sequential):
def __init__(
self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, act="ReLU"
self,
in_planes,
out_planes,
kernel_size=3,
stride=1,
groups=1,
activation="ReLU",
):
padding = (kernel_size - 1) // 2
super(ConvBNReLU, self).__init__(
Expand All @@ -21,12 +29,12 @@ def __init__(
bias=False,
),
nn.BatchNorm2d(out_planes),
act_layers(act),
act_layers(activation),
)


class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio, act="ReLU"):
def __init__(self, inp, oup, stride, expand_ratio, activation="ReLU"):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
Expand All @@ -37,12 +45,18 @@ def __init__(self, inp, oup, stride, expand_ratio, act="ReLU"):
layers = []
if expand_ratio != 1:
# pw
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, act=act))
layers.append(
ConvBNReLU(inp, hidden_dim, kernel_size=1, activation=activation)
)
layers.extend(
[
# dw
ConvBNReLU(
hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, act=act
hidden_dim,
hidden_dim,
stride=stride,
groups=hidden_dim,
activation=activation,
),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
Expand All @@ -65,6 +79,7 @@ def __init__(
out_stages=(1, 2, 4, 6),
last_channel=1280,
activation="ReLU",
act=None,
):
super(MobileNetV2, self).__init__()
# TODO: support load torchvison pretrained weight
Expand All @@ -74,6 +89,11 @@ def __init__(
input_channel = 32
self.last_channel = last_channel
self.activation = activation
if act is not None:
warnings.warn(
"Warning! act argument has been deprecated, " "use activation instead!"
)
self.activation = act
self.interverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
Expand All @@ -88,7 +108,7 @@ def __init__(
# building first layer
self.input_channel = int(input_channel * width_mult)
self.first_layer = ConvBNReLU(
3, self.input_channel, stride=2, act=self.activation
3, self.input_channel, stride=2, activation=self.activation
)
# building inverted residual blocks
for i in range(7):
Expand All @@ -109,7 +129,7 @@ def build_mobilenet_stage(self, stage_num):
output_channel,
s,
expand_ratio=t,
act=self.activation,
activation=self.activation,
)
)
else:
Expand All @@ -119,7 +139,7 @@ def build_mobilenet_stage(self, stage_num):
output_channel,
1,
expand_ratio=t,
act=self.activation,
activation=self.activation,
)
)
self.input_channel = output_channel
Expand All @@ -128,7 +148,7 @@ def build_mobilenet_stage(self, stage_num):
self.input_channel,
self.last_channel,
kernel_size=1,
act=self.activation,
activation=self.activation,
)
stage.append(last_layer)
stage = nn.Sequential(*stage)
Expand Down
44 changes: 44 additions & 0 deletions tests/test_models/test_backbone/test_ghostnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest
import torch

from nanodet.model.backbone import GhostNet, build_backbone


def test_ghostnet():
with pytest.raises(AssertionError):
cfg = dict(name="GhostNet", width_mult=1.0, out_stages=(11, 12), pretrain=False)
build_backbone(cfg)

input = torch.rand(1, 3, 64, 64)
out_stages = [i for i in range(10)]
model = GhostNet(
width_mult=1.0, out_stages=out_stages, activation="ReLU6", pretrain=True
)
output = model(input)

assert output[0].shape == torch.Size([1, 16, 32, 32])
assert output[1].shape == torch.Size([1, 24, 16, 16])
assert output[2].shape == torch.Size([1, 24, 16, 16])
assert output[3].shape == torch.Size([1, 40, 8, 8])
assert output[4].shape == torch.Size([1, 40, 8, 8])
assert output[5].shape == torch.Size([1, 80, 4, 4])
assert output[6].shape == torch.Size([1, 112, 4, 4])
assert output[7].shape == torch.Size([1, 160, 2, 2])
assert output[8].shape == torch.Size([1, 160, 2, 2])
assert output[9].shape == torch.Size([1, 960, 2, 2])

model = GhostNet(
width_mult=0.75, out_stages=out_stages, activation="LeakyReLU", pretrain=False
)
output = model(input)

assert output[0].shape == torch.Size([1, 12, 32, 32])
assert output[1].shape == torch.Size([1, 20, 16, 16])
assert output[2].shape == torch.Size([1, 20, 16, 16])
assert output[3].shape == torch.Size([1, 32, 8, 8])
assert output[4].shape == torch.Size([1, 32, 8, 8])
assert output[5].shape == torch.Size([1, 60, 4, 4])
assert output[6].shape == torch.Size([1, 84, 4, 4])
assert output[7].shape == torch.Size([1, 120, 2, 2])
assert output[8].shape == torch.Size([1, 120, 2, 2])
assert output[9].shape == torch.Size([1, 720, 2, 2])

0 comments on commit effea88

Please sign in to comment.