Skip to content

Commit

Permalink
naive binary multi-class classification of mobilenetv2 on 100K images
Browse files Browse the repository at this point in the history
Former-commit-id: fcd587c
Former-commit-id: 9b7b6909d349853824fcd1db9a6dfa80d06228b9
  • Loading branch information
yu45020 committed Jul 7, 2018
1 parent afb0611 commit 2ce7a39
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 190 deletions.
11 changes: 7 additions & 4 deletions dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from torch import nn
from torch.nn.functional import pad
from torch.utils.data import Dataset
from torchvision.transforms import ColorJitter, ToTensor, RandomResizedCrop, Compose, Normalize, transforms, Grayscale
from torchvision.transforms import ColorJitter, ToTensor, RandomResizedCrop, Compose, Normalize, transforms, Grayscale, \
RandomGrayscale
from torchvision.transforms.functional import resized_crop, to_tensor

use_cuda = torch.cuda.is_available()
Expand Down Expand Up @@ -86,7 +87,7 @@ def __init__(self, image_foler, name_tag_dict, mean, std,
print("Find {} images. ".format(len(self.images)))

self.name_tag_dict = name_tag_dict
self.img_transform = self.transformer(image_size, mean, std)
self.img_transform = self.transformer(mean, std)
# one hot encoding
self.onehot = torch.eye(num_class)

Expand All @@ -103,8 +104,10 @@ def __getitem__(self, item):
return image, LongTensor(target)

@staticmethod
def transformer(image_size, mean, std):
m = Compose([RandomResizedCrop(image_size, scale=(0.5, 2.0)),
def transformer(mean, std):
m = Compose([RandomGrayscale(p=0.2),
# RandomHorizontalFlip(p=0.2),
# RandomVerticalFlip(p=0.2),
ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
ToTensor(),
Normalize(mean, std)])
Expand Down
5 changes: 5 additions & 0 deletions loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,8 @@ def total_variation_loss(image):
loss = torch.mean(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) + \
torch.mean(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))
return loss


0.775
0.465
0.17
42 changes: 27 additions & 15 deletions models/BaseModels.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import math
import warnings
from contextlib import contextmanager

from torch import nn

warnings.simplefilter('ignore')
try:
from .inplace_abn import InPlaceABN

inplace_batch_norm = True
except ImportError:
inplace_batch_norm = False

# +++++++++++++++++++++++++++++++++++++
# Add more functions to PyTorch's base model
# -------------------------------------


class BaseModule(nn.Module):
def __init__(self):
self.act_fn = None
Expand All @@ -27,23 +31,18 @@ def selu_init_params(self):
m.bias.data.zero_()

elif isinstance(m, nn.Linear) and m.weight.requires_grad:
n = m.weight.size(1)
m.weight.data.normal_(0, n)
m.weight.data.normal_(0, 1.0 / math.sqrt(m.weight.numel()))
m.bias.data.zero_()

def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d) and m.weight.requires_grad:
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d) and m.weight.requires_grad:
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear) and m.weight.requires_grad:
n = m.weight.size(1)
m.weight.data.normal_(0, n)
m.bias.data.zero_()

def load_state_dict(self, state_dict, strict=True):
own_state = self.state_dict()
Expand Down Expand Up @@ -83,13 +82,26 @@ def forward(self, *x):
# -------------------------------------


def Conv_block(in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
BN=False, activation=None):
def Conv_block(in_channels, out_channels, kernel_size, stride=1, padding=0,
dilation=1, groups=1, bias=True, BN=False, activation=None):
m = [nn.Conv2d(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)]
if BN:
m.append(nn.BatchNorm2d(out_channels))
if activation:
m.append(activation)
m += activated_batch_norm(out_channels, activation, inplace_abn=inplace_batch_norm)
if BN is False and activation is not None:
m += [activation]
return m


def activated_batch_norm(in_channels, activation, inplace_abn=inplace_batch_norm):
m = []
if inplace_abn:
if activation:
m.append(InPlaceABN(in_channels, activation="leaky_relu", slope=0.3))
else:
m.append(InPlaceABN(in_channels, activation='none'))
else:
m.append(nn.BatchNorm2d(in_channels))
if activation:
m.append(activation)
return m
115 changes: 73 additions & 42 deletions models/MobileNetV2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from torch import nn
from torch.utils.checkpoint import checkpoint

from models.common import SpatialChannelSqueezeExcitation
from .BaseModels import BaseModule, Conv_block
from .partial_convolution import partial_gated_conv_block


class MobileNetV2(BaseModule):
def __init__(self, width_mult=1, add_partial=False, activation=nn.ReLU6(), bias=False):
def __init__(self, width_mult=1, activation=nn.ReLU6(), bias=False, add_sece=False, add_partial=False, ):

super(MobileNetV2, self).__init__()
self.add_partial = add_partial
Expand All @@ -32,26 +33,27 @@ def __init__(self, width_mult=1, add_partial=False, activation=nn.ReLU6(), bias=
[6, 320, 1, 1, 1],
]
self.last_channel = 0 # last one is avg pool
self.features = self.make_inverted_resblocks(self.inverted_residual_setting)
self.features = self.make_inverted_resblocks(self.inverted_residual_setting, add_sece)

def make_inverted_resblocks(self, settings):
in_channel = int(32 * self.width_mult)
def make_inverted_resblocks(self, settings, add_sece):
in_channel = self._make_divisible(32 * self.width_mult, divisor=8)

# first_layer
features = [nn.Sequential(*self.conv_block(3, in_channel, kernel_size=3, stride=2,
padding=(3 - 1) // 2, bias=self.bias,
BN=True, activation=self.act_fn))]

for t, c, n, s, d in settings:
out_channel = int(c * self.width_mult)
out_channel = self._make_divisible(c * self.width_mult, divisor=8)
# out_channel = int(c * self.width_mult)
block = []
for i in range(n):
if i == 0:
block.append(self.res_block(in_channel, out_channel, s, t, d,
activation=self.act_fn, bias=self.bias))
activation=self.act_fn, bias=self.bias, add_sece=add_sece))
else:
block.append(self.res_block(in_channel, out_channel, 1, t, 1,
activation=self.act_fn, bias=self.bias))
block.append(self.res_block(in_channel, out_channel, 1, t, d,
activation=self.act_fn, bias=self.bias, add_sece=add_sece))
in_channel = out_channel
features.append(nn.Sequential(*block))
# last layer
Expand Down Expand Up @@ -81,6 +83,17 @@ def freeze_params(self, free_last_blocks=2):
print("{}/{} layers in the encoder are freezed.".format(len(self.features) - free_last_blocks,
len(self.features)))

def _make_divisible(self, v, divisor=8, min_value=None):
# https://github.com/tensorflow/models/blob/7367d494135368a7790df6172206a58a2a2f3d40/research/slim/nets/mobilenet/mobilenet.py#L62
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor

return new_v

# for partial conv ---- will not use
# def load_state_dict(self, state_dict, strict=True):
# own_state = self.state_dict()
Expand Down Expand Up @@ -108,7 +121,7 @@ def forward_checkpoint(self, x):

class InvertedResidual(BaseModule):
def __init__(self, in_channel, out_channel, stride, expand_ratio, dilation, conv_block_fn=Conv_block,
activation=nn.ReLU6(), bias=False):
activation=nn.ReLU6(), bias=False, add_sece=False):
super(InvertedResidual, self).__init__()
self.conv_bloc = conv_block_fn
self.stride = stride
Expand All @@ -119,9 +132,9 @@ def __init__(self, in_channel, out_channel, stride, expand_ratio, dilation, conv
# assert stride in [1, 2]

self.res_connect = self.stride == 1 and in_channel == out_channel
self.conv = self.make_body(in_channel, out_channel, stride, expand_ratio, dilation)
self.conv = self.make_body(in_channel, out_channel, stride, expand_ratio, dilation, add_sece)

def make_body(self, in_channel, out_channel, stride, expand_ratio, dilation):
def make_body(self, in_channel, out_channel, stride, expand_ratio, dilation, add_sece):
# standard convolution
mid_channel = in_channel * expand_ratio
m = self.conv_bloc(in_channel, mid_channel,
Expand All @@ -133,6 +146,8 @@ def make_body(self, in_channel, out_channel, stride, expand_ratio, dilation):
BN=True, activation=self.act_fn)
# linear to preserve info : see the section: linear bottleneck. Removing the activation improves the result
m += self.conv_bloc(mid_channel, out_channel, 1, 1, 0, bias=self.bias, BN=True, activation=None)
if add_sece:
m += [SpatialChannelSqueezeExcitation(out_channel, reduction=16, activation=self.act_fn)]
return nn.Sequential(*m)

def forward(self, x):
Expand Down Expand Up @@ -170,57 +185,73 @@ def forward(self, args):
else:
return self.conv(args)

def forward_checkpoint(self, args):
with torch.no_grad():
return self.forward(args)


class DilatedMobileNetV2(MobileNetV2):
def __init__(self, width_mult=1, activation=nn.ReLU6(), bias=False, add_partial=False, ):
def __init__(self, width_mult=1, activation=nn.ReLU6(), bias=False, add_sece=False, add_partial=False, ):
super(DilatedMobileNetV2, self).__init__(width_mult=width_mult, activation=activation,
bias=bias, add_partial=add_partial, )
bias=bias, add_sece=add_sece, add_partial=add_partial, )
self.add_partial = add_partial
self.bias = bias
self.width_mult = width_mult
self.act_fn = activation
self.out_stride = 8
# # Rethinking Atrous Convolution for Semantic Image Segmentation
self.inverted_residual_setting = [
# t, c, n, s, dila # input size
[1, 16, 1, 1, 1], # 1/2
[6, 24, 2, 2, 1], # 1/4
[6, 32, 3, 2, 1], # 1/8
# t, c, n, s, dila # input output
[1, 16, 1, 1, 1], # 1/2 ---> 1/2
[6, 24, 2, 2, 1], # 1/2 ---> 1/4
[6, 32, 3, 2, 1], # 1/4 ---> 1/8
[6, 64, 4, 1, 2], # <-- add astrous conv and keep 1/8
[6, 96, 3, 1, 4],
[6, 160, 3, 1, 8],
[6, 320, 1, 1, 16],
]
self.features = self.make_inverted_resblocks(self.inverted_residual_setting)
self.features = self.make_inverted_resblocks(self.inverted_residual_setting, add_sece=add_sece)


class MobileNetV2Classifier(BaseModule):
def __init__(self, num_class, input_size=512, width_mult=1.4):
def __init__(self, num_class, width_mult=1.4, add_sece=False):
super(MobileNetV2Classifier, self).__init__()
self.act_fn = nn.SELU(inplace=True)
self.feature = DilatedMobileNetV2(width_mult=width_mult, activation=self.act_fn,
bias=False, add_partial=False)
self.pre_classifier = nn.Sequential(
*Conv_block(self.feature.last_channel, 1024, kernel_size=1,
bias=False, BN=True, activation=self.act_fn),
nn.AvgPool2d(input_size // self.feature.out_stride)) # b,c, 1,1

self.classifier = nn.Sequential(
nn.Dropout(0.2, inplace=True),
nn.Linear(1024, num_class)
)
if isinstance(self.act_fn, nn.SELU):
self.selu_init_params()
else:
self.initialize_weights()
self.act_fn = nn.LeakyReLU(0.3, inplace=True) # nn.SELU(inplace=True)
self.encoder = DilatedMobileNetV2(width_mult=width_mult, activation=self.act_fn,
bias=False, add_sece=add_sece, add_partial=False)

# if width multiple is 1.4, then there are 944 channels
cat_feat_num = sum([i[0].out_channels for i in self.encoder.features[3:]])
self.conv_classifier = self.make_conv_classifier(cat_feat_num, num_class)
# self.linear = nn.Sequential(nn.AlphaDropout(0.05), # recommend by selu's authors
# nn.Linear(num_class, num_class // 16),
# nn.SELU(),
# nn.Linear(num_class // 16, num_class))
# if isinstance(self.act_fn, nn.SELU):
# self.selu_init_params()
# else:
# self.initialize_weights()

def make_conv_classifier(self, in_channel, out_channel):
m = nn.Sequential(
InvertedResidual(in_channel, out_channel, stride=3, expand_ratio=1, dilation=1, conv_block_fn=Conv_block,
activation=self.act_fn, bias=False, add_sece=False),
InvertedResidual(out_channel, out_channel, stride=3, expand_ratio=2, dilation=1, conv_block_fn=Conv_block,
activation=self.act_fn, bias=False, add_sece=False),
*Conv_block(out_channel, out_channel, kernel_size=3, padding=1,
groups=out_channel, BN=False, activation=self.act_fn),
nn.Conv2d(out_channel, out_channel, kernel_size=1),
nn.AdaptiveAvgPool2d(1))
return m

def forward(self, x):
x = self.feature(x)
x = self.pre_classifier(x)
for layer in self.encoder.features[:3]:
x = layer(x)

feature_maps = []
for layer in self.encoder.features[3:]:
x = layer(x)
feature_maps.append(x)

# all feature maps are 1/8 of input size
x = torch.cat(feature_maps, dim=1)
del feature_maps
x = self.conv_classifier(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
Loading

0 comments on commit 2ce7a39

Please sign in to comment.