diff --git a/dataloader.py b/dataloader.py index 2240707..95285ff 100644 --- a/dataloader.py +++ b/dataloader.py @@ -9,7 +9,8 @@ from torch import nn from torch.nn.functional import pad from torch.utils.data import Dataset -from torchvision.transforms import ColorJitter, ToTensor, RandomResizedCrop, Compose, Normalize, transforms, Grayscale +from torchvision.transforms import ColorJitter, ToTensor, RandomResizedCrop, Compose, Normalize, transforms, Grayscale, \ + RandomGrayscale from torchvision.transforms.functional import resized_crop, to_tensor use_cuda = torch.cuda.is_available() @@ -86,7 +87,7 @@ def __init__(self, image_foler, name_tag_dict, mean, std, print("Find {} images. ".format(len(self.images))) self.name_tag_dict = name_tag_dict - self.img_transform = self.transformer(image_size, mean, std) + self.img_transform = self.transformer(mean, std) # one hot encoding self.onehot = torch.eye(num_class) @@ -103,8 +104,10 @@ def __getitem__(self, item): return image, LongTensor(target) @staticmethod - def transformer(image_size, mean, std): - m = Compose([RandomResizedCrop(image_size, scale=(0.5, 2.0)), + def transformer(mean, std): + m = Compose([RandomGrayscale(p=0.2), + # RandomHorizontalFlip(p=0.2), + # RandomVerticalFlip(p=0.2), ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), ToTensor(), Normalize(mean, std)]) diff --git a/loss.py b/loss.py index 572fd00..276f477 100644 --- a/loss.py +++ b/loss.py @@ -135,3 +135,8 @@ def total_variation_loss(image): loss = torch.mean(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) + \ torch.mean(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :])) return loss + + +0.775 +0.465 +0.17 diff --git a/models/BaseModels.py b/models/BaseModels.py index 96e859e..4958317 100644 --- a/models/BaseModels.py +++ b/models/BaseModels.py @@ -1,16 +1,20 @@ import math -import warnings from contextlib import contextmanager from torch import nn -warnings.simplefilter('ignore') +try: + from .inplace_abn import InPlaceABN + inplace_batch_norm = True +except ImportError: + inplace_batch_norm = False # +++++++++++++++++++++++++++++++++++++ # Add more functions to PyTorch's base model # ------------------------------------- + class BaseModule(nn.Module): def __init__(self): self.act_fn = None @@ -27,23 +31,18 @@ def selu_init_params(self): m.bias.data.zero_() elif isinstance(m, nn.Linear) and m.weight.requires_grad: - n = m.weight.size(1) - m.weight.data.normal_(0, n) + m.weight.data.normal_(0, 1.0 / math.sqrt(m.weight.numel())) m.bias.data.zero_() def initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d) and m.weight.requires_grad: - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d) and m.weight.requires_grad: m.weight.data.fill_(1) m.bias.data.zero_() - elif isinstance(m, nn.Linear) and m.weight.requires_grad: - n = m.weight.size(1) - m.weight.data.normal_(0, n) - m.bias.data.zero_() def load_state_dict(self, state_dict, strict=True): own_state = self.state_dict() @@ -83,13 +82,26 @@ def forward(self, *x): # ------------------------------------- -def Conv_block(in_channels, out_channels, kernel_size, stride=1, - padding=0, dilation=1, groups=1, bias=True, - BN=False, activation=None): +def Conv_block(in_channels, out_channels, kernel_size, stride=1, padding=0, + dilation=1, groups=1, bias=True, BN=False, activation=None): m = [nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)] if BN: - m.append(nn.BatchNorm2d(out_channels)) - if activation: - m.append(activation) + m += activated_batch_norm(out_channels, activation, inplace_abn=inplace_batch_norm) + if BN is False and activation is not None: + m += [activation] + return m + + +def activated_batch_norm(in_channels, activation, inplace_abn=inplace_batch_norm): + m = [] + if inplace_abn: + if activation: + m.append(InPlaceABN(in_channels, activation="leaky_relu", slope=0.3)) + else: + m.append(InPlaceABN(in_channels, activation='none')) + else: + m.append(nn.BatchNorm2d(in_channels)) + if activation: + m.append(activation) return m diff --git a/models/MobileNetV2.py b/models/MobileNetV2.py index 0692b32..a2f74b5 100644 --- a/models/MobileNetV2.py +++ b/models/MobileNetV2.py @@ -6,12 +6,13 @@ from torch import nn from torch.utils.checkpoint import checkpoint +from models.common import SpatialChannelSqueezeExcitation from .BaseModels import BaseModule, Conv_block from .partial_convolution import partial_gated_conv_block class MobileNetV2(BaseModule): - def __init__(self, width_mult=1, add_partial=False, activation=nn.ReLU6(), bias=False): + def __init__(self, width_mult=1, activation=nn.ReLU6(), bias=False, add_sece=False, add_partial=False, ): super(MobileNetV2, self).__init__() self.add_partial = add_partial @@ -32,10 +33,10 @@ def __init__(self, width_mult=1, add_partial=False, activation=nn.ReLU6(), bias= [6, 320, 1, 1, 1], ] self.last_channel = 0 # last one is avg pool - self.features = self.make_inverted_resblocks(self.inverted_residual_setting) + self.features = self.make_inverted_resblocks(self.inverted_residual_setting, add_sece) - def make_inverted_resblocks(self, settings): - in_channel = int(32 * self.width_mult) + def make_inverted_resblocks(self, settings, add_sece): + in_channel = self._make_divisible(32 * self.width_mult, divisor=8) # first_layer features = [nn.Sequential(*self.conv_block(3, in_channel, kernel_size=3, stride=2, @@ -43,15 +44,16 @@ def make_inverted_resblocks(self, settings): BN=True, activation=self.act_fn))] for t, c, n, s, d in settings: - out_channel = int(c * self.width_mult) + out_channel = self._make_divisible(c * self.width_mult, divisor=8) + # out_channel = int(c * self.width_mult) block = [] for i in range(n): if i == 0: block.append(self.res_block(in_channel, out_channel, s, t, d, - activation=self.act_fn, bias=self.bias)) + activation=self.act_fn, bias=self.bias, add_sece=add_sece)) else: - block.append(self.res_block(in_channel, out_channel, 1, t, 1, - activation=self.act_fn, bias=self.bias)) + block.append(self.res_block(in_channel, out_channel, 1, t, d, + activation=self.act_fn, bias=self.bias, add_sece=add_sece)) in_channel = out_channel features.append(nn.Sequential(*block)) # last layer @@ -81,6 +83,17 @@ def freeze_params(self, free_last_blocks=2): print("{}/{} layers in the encoder are freezed.".format(len(self.features) - free_last_blocks, len(self.features))) + def _make_divisible(self, v, divisor=8, min_value=None): + # https://github.com/tensorflow/models/blob/7367d494135368a7790df6172206a58a2a2f3d40/research/slim/nets/mobilenet/mobilenet.py#L62 + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + + return new_v + # for partial conv ---- will not use # def load_state_dict(self, state_dict, strict=True): # own_state = self.state_dict() @@ -108,7 +121,7 @@ def forward_checkpoint(self, x): class InvertedResidual(BaseModule): def __init__(self, in_channel, out_channel, stride, expand_ratio, dilation, conv_block_fn=Conv_block, - activation=nn.ReLU6(), bias=False): + activation=nn.ReLU6(), bias=False, add_sece=False): super(InvertedResidual, self).__init__() self.conv_bloc = conv_block_fn self.stride = stride @@ -119,9 +132,9 @@ def __init__(self, in_channel, out_channel, stride, expand_ratio, dilation, conv # assert stride in [1, 2] self.res_connect = self.stride == 1 and in_channel == out_channel - self.conv = self.make_body(in_channel, out_channel, stride, expand_ratio, dilation) + self.conv = self.make_body(in_channel, out_channel, stride, expand_ratio, dilation, add_sece) - def make_body(self, in_channel, out_channel, stride, expand_ratio, dilation): + def make_body(self, in_channel, out_channel, stride, expand_ratio, dilation, add_sece): # standard convolution mid_channel = in_channel * expand_ratio m = self.conv_bloc(in_channel, mid_channel, @@ -133,6 +146,8 @@ def make_body(self, in_channel, out_channel, stride, expand_ratio, dilation): BN=True, activation=self.act_fn) # linear to preserve info : see the section: linear bottleneck. Removing the activation improves the result m += self.conv_bloc(mid_channel, out_channel, 1, 1, 0, bias=self.bias, BN=True, activation=None) + if add_sece: + m += [SpatialChannelSqueezeExcitation(out_channel, reduction=16, activation=self.act_fn)] return nn.Sequential(*m) def forward(self, x): @@ -170,15 +185,11 @@ def forward(self, args): else: return self.conv(args) - def forward_checkpoint(self, args): - with torch.no_grad(): - return self.forward(args) - class DilatedMobileNetV2(MobileNetV2): - def __init__(self, width_mult=1, activation=nn.ReLU6(), bias=False, add_partial=False, ): + def __init__(self, width_mult=1, activation=nn.ReLU6(), bias=False, add_sece=False, add_partial=False, ): super(DilatedMobileNetV2, self).__init__(width_mult=width_mult, activation=activation, - bias=bias, add_partial=add_partial, ) + bias=bias, add_sece=add_sece, add_partial=add_partial, ) self.add_partial = add_partial self.bias = bias self.width_mult = width_mult @@ -186,41 +197,61 @@ def __init__(self, width_mult=1, activation=nn.ReLU6(), bias=False, add_partial= self.out_stride = 8 # # Rethinking Atrous Convolution for Semantic Image Segmentation self.inverted_residual_setting = [ - # t, c, n, s, dila # input size - [1, 16, 1, 1, 1], # 1/2 - [6, 24, 2, 2, 1], # 1/4 - [6, 32, 3, 2, 1], # 1/8 + # t, c, n, s, dila # input output + [1, 16, 1, 1, 1], # 1/2 ---> 1/2 + [6, 24, 2, 2, 1], # 1/2 ---> 1/4 + [6, 32, 3, 2, 1], # 1/4 ---> 1/8 [6, 64, 4, 1, 2], # <-- add astrous conv and keep 1/8 [6, 96, 3, 1, 4], [6, 160, 3, 1, 8], [6, 320, 1, 1, 16], ] - self.features = self.make_inverted_resblocks(self.inverted_residual_setting) + self.features = self.make_inverted_resblocks(self.inverted_residual_setting, add_sece=add_sece) class MobileNetV2Classifier(BaseModule): - def __init__(self, num_class, input_size=512, width_mult=1.4): + def __init__(self, num_class, width_mult=1.4, add_sece=False): super(MobileNetV2Classifier, self).__init__() - self.act_fn = nn.SELU(inplace=True) - self.feature = DilatedMobileNetV2(width_mult=width_mult, activation=self.act_fn, - bias=False, add_partial=False) - self.pre_classifier = nn.Sequential( - *Conv_block(self.feature.last_channel, 1024, kernel_size=1, - bias=False, BN=True, activation=self.act_fn), - nn.AvgPool2d(input_size // self.feature.out_stride)) # b,c, 1,1 - - self.classifier = nn.Sequential( - nn.Dropout(0.2, inplace=True), - nn.Linear(1024, num_class) - ) - if isinstance(self.act_fn, nn.SELU): - self.selu_init_params() - else: - self.initialize_weights() + self.act_fn = nn.LeakyReLU(0.3, inplace=True) # nn.SELU(inplace=True) + self.encoder = DilatedMobileNetV2(width_mult=width_mult, activation=self.act_fn, + bias=False, add_sece=add_sece, add_partial=False) + + # if width multiple is 1.4, then there are 944 channels + cat_feat_num = sum([i[0].out_channels for i in self.encoder.features[3:]]) + self.conv_classifier = self.make_conv_classifier(cat_feat_num, num_class) + # self.linear = nn.Sequential(nn.AlphaDropout(0.05), # recommend by selu's authors + # nn.Linear(num_class, num_class // 16), + # nn.SELU(), + # nn.Linear(num_class // 16, num_class)) + # if isinstance(self.act_fn, nn.SELU): + # self.selu_init_params() + # else: + # self.initialize_weights() + + def make_conv_classifier(self, in_channel, out_channel): + m = nn.Sequential( + InvertedResidual(in_channel, out_channel, stride=3, expand_ratio=1, dilation=1, conv_block_fn=Conv_block, + activation=self.act_fn, bias=False, add_sece=False), + InvertedResidual(out_channel, out_channel, stride=3, expand_ratio=2, dilation=1, conv_block_fn=Conv_block, + activation=self.act_fn, bias=False, add_sece=False), + *Conv_block(out_channel, out_channel, kernel_size=3, padding=1, + groups=out_channel, BN=False, activation=self.act_fn), + nn.Conv2d(out_channel, out_channel, kernel_size=1), + nn.AdaptiveAvgPool2d(1)) + return m def forward(self, x): - x = self.feature(x) - x = self.pre_classifier(x) + for layer in self.encoder.features[:3]: + x = layer(x) + + feature_maps = [] + for layer in self.encoder.features[3:]: + x = layer(x) + feature_maps.append(x) + + # all feature maps are 1/8 of input size + x = torch.cat(feature_maps, dim=1) + del feature_maps + x = self.conv_classifier(x) x = x.view(x.size(0), -1) - x = self.classifier(x) return x diff --git a/models/text_segmentation.py b/models/text_segmentation.py index a5e7020..b78baa6 100644 --- a/models/text_segmentation.py +++ b/models/text_segmentation.py @@ -7,6 +7,7 @@ from torch.nn import functional as F from torch.utils.checkpoint import checkpoint_sequential +from models.common import SpatialChannelSqueezeExcitation, RFB from .BaseModels import BaseModule, Conv_block from .MobileNetV2 import DilatedMobileNetV2 @@ -59,135 +60,6 @@ # len(self.features))) -class SpatialChannelSqueezeExcitation(BaseModule): - # https://arxiv.org/pdf/1803.02579v1.pdf - def __init__(self, in_channel): - super(SpatialChannelSqueezeExcitation, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.channel_excite = nn.Sequential( - nn.Linear(in_channel, in_channel // 2), - nn.ReLU(inplace=True), - nn.Linear(in_channel // 2, in_channel), - nn.Sigmoid() - ) - self.spatial_excite = nn.Sequential( - nn.Conv2d(in_channel, 1, kernel_size=1, stride=1, padding=0, bias=False), - nn.Sigmoid() - ) - - def forward(self, x): - b, c, h, w = x.size() - channel = self.avg_pool(x).view(b, c) - cSE = self.channel_excite(channel).view(b, c, 1, 1) - x_cSE = torch.mul(x, cSE) - - sSE = self.spatial_excite(x) - x_sSE = torch.mul(x, sSE) - return torch.add(x_cSE, x_sSE) - - -def add_SCSE_block(model_block, in_channel=None): - if in_channel is None: - # the first layer is assumed to be conv - in_channel = model_block[0].out_channels - model_block.add_module("SCSE", SpatialChannelSqueezeExcitation(in_channel)) - - -class ASP(BaseModule): - # Atrous Spatial Pyramid Pooling with Image Pooling - # add Vortex pooling https://arxiv.org/pdf/1804.06242v1.pdf - def __init__(self, in_channel=256, out_channel=256): - super(ASP, self).__init__() - asp_rate = [5, 17, 29] - self.asp = nn.Sequential( - nn.Sequential(*Conv_block(in_channel, out_channel, kernel_size=1, stride=1, padding=0, - bias=False, BN=True, activation=None)), - nn.Sequential(nn.AvgPool2d(kernel_size=asp_rate[0], stride=1, padding=(asp_rate[0] - 1) // 2), - *Conv_block(in_channel, out_channel, kernel_size=3, stride=1, padding=asp_rate[0], - dilation=asp_rate[0], bias=False, BN=True, activation=None)), - nn.Sequential(nn.AvgPool2d(kernel_size=asp_rate[1], stride=1, padding=(asp_rate[1] - 1) // 2), - *Conv_block(in_channel, out_channel, kernel_size=3, stride=1, padding=asp_rate[1], - dilation=asp_rate[1], bias=False, BN=True, activation=None)), - nn.Sequential(nn.AvgPool2d(kernel_size=asp_rate[2], stride=1, padding=(asp_rate[2] - 1) // 2), - *Conv_block(in_channel, out_channel, kernel_size=3, stride=1, padding=asp_rate[2], - dilation=asp_rate[2], bias=False, BN=True, activation=None)) - ) - - """ To see why adding gobal average, please refer to 3.1 Global Context in https://www.cs.unc.edu/~wliu/papers/parsenet.pdf """ - self.img_pooling_1 = nn.AdaptiveAvgPool2d(1) - self.img_pooling_2 = nn.Sequential( - *Conv_block(in_channel, out_channel, kernel_size=1, bias=False, BN=True, activation=None)) - - # self.initialize_weights() - # self.selu_init_params() - - def forward(self, x): - avg_pool = self.img_pooling_1(x) - avg_pool = F.upsample(avg_pool, size=x.shape[2:], mode='bilinear') - avg_pool = [x, self.img_pooling_2(avg_pool)] - asp_pool = [layer(x) for layer in self.asp.children()] - return torch.cat(avg_pool + asp_pool, dim=1) - - -class RFB(BaseModule): - # receptive fiedl block https://arxiv.org/abs/1711.07767 with some changes - # reference: https://github.com/ansleliu/LightNet/blob/master/modules/rfblock.py - # https://github.com/ruinmessi/RFBNet/blob/master/models/RFB_Net_mobile.py - def __init__(self, in_channel, out_channel, activation): - super(RFB, self).__init__() - asp_rate = [5, 17, 29] - # self.act_fn = activation - self.input_down_channel = nn.Sequential( - *Conv_block(in_channel, out_channel, kernel_size=1, bias=False, BN=True, activation=False)) - - self.rfb_linear_conv = nn.Conv2d(out_channel * 4, out_channel, kernel_size=1, bias=False) - self.rfb = nn.Sequential( - self.make_pooling_branch(in_channel, out_channel // 2, out_channel, conv_kernel=1, - astro_rate=1, half_conv=False), - self.make_pooling_branch(in_channel, out_channel // 2, out_channel, conv_kernel=3, - astro_rate=asp_rate[0], half_conv=True), - self.make_pooling_branch(in_channel, out_channel // 2, out_channel, conv_kernel=5, - astro_rate=asp_rate[1], half_conv=True), - self.make_pooling_branch(in_channel, out_channel // 2, out_channel, conv_kernel=7, - astro_rate=asp_rate[2], half_conv=True) - ) - - @staticmethod - def make_pooling_branch(in_channel, mid_channel, out_channel, conv_kernel, astro_rate, half_conv=False): - # from the paper: we use a 1 x n plus an nx1 conv-layer to take place of the original nxn convlayer - # similar to EffNet style - if half_conv: - m = nn.Sequential( - *Conv_block(in_channel, mid_channel, kernel_size=1, padding=0, - bias=False, BN=True, activation=None), - *Conv_block(mid_channel, 3 * mid_channel // 2, kernel_size=(1, conv_kernel), - padding=(0, (conv_kernel - 1) // 2), - bias=False, BN=True, activation=None), - *Conv_block(3 * mid_channel // 2, out_channel, kernel_size=(conv_kernel, 1), - padding=((conv_kernel - 1) // 2, 0), - bias=False, BN=True, activation=None), - *Conv_block(out_channel, out_channel, kernel_size=3, dilation=astro_rate, padding=astro_rate, - bias=False, BN=True, activation=None, groups=out_channel)) - else: - m = nn.Sequential( - *Conv_block(in_channel, mid_channel, kernel_size=conv_kernel, padding=(conv_kernel - 1) // 2, - bias=False, BN=True, activation=None), - *Conv_block(mid_channel, out_channel, kernel_size=3, dilation=astro_rate, padding=astro_rate, - bias=False, BN=True, activation=None, groups=mid_channel)) - - return m - - def forward(self, x): - # feature poolings - rfb_pool = [layer(x) for layer in self.rfb.children()] - rfb_pool = torch.cat(rfb_pool, dim=1) - rfb_pool = self.rfb_linear_conv(rfb_pool) - - # skip connection - resi = self.input_down_channel(x) - return rfb_pool + resi - - class TextSegament(BaseModule): def __init__(self, encoder_checkpoint=None, free_last_blocks=-1, width_mult=1): super(TextSegament, self).__init__()