From c6710c2ffb03412e4c13ebd851f548e06721f173 Mon Sep 17 00:00:00 2001 From: ctrl-q <34975747+ctrl-q@users.noreply.github.com> Date: Wed, 28 Aug 2019 19:20:39 -0400 Subject: [PATCH] Add models.vgg --- models/vgg.py | 110 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 models/vgg.py diff --git a/models/vgg.py b/models/vgg.py new file mode 100644 index 0000000..21c2b41 --- /dev/null +++ b/models/vgg.py @@ -0,0 +1,110 @@ +''' +Modified from https://github.com/pytorch/vision.git +''' +import math + +import torch.nn as nn +from .base import PyTorchModel + +__all__ = [ + 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', + 'vgg19_bn', 'vgg19', +] + + +class VGG(nn.Module, PyTorchModel): + ''' + VGG model + ''' + + def __init__(self, features): + super(VGG, self).__init__() + self.features = features + self.classifier = nn.Sequential( + nn.Dropout(), + nn.Linear(512, 512), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(512, 512), + nn.ReLU(True), + nn.Linear(512, 10), + ) + # Initialize weights + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + m.bias.data.zero_() + + self.criterion = nn.CrossEntropyLoss() + + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + return x + + +def make_layers(cfg, batch_norm=False): + layers = [] + in_channels = 3 + for v in cfg: + if v == 'M': + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] + else: + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = v + return nn.Sequential(*layers) + + +cfg = { + 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', + 512, 512, 512, 512, 'M'], +} + + +def vgg11(): + """VGG 11-layer model (configuration "A")""" + return VGG(make_layers(cfg['A'])) + + +def vgg11_bn(): + """VGG 11-layer model (configuration "A") with batch normalization""" + return VGG(make_layers(cfg['A'], batch_norm=True)) + + +def vgg13(): + """VGG 13-layer model (configuration "B")""" + return VGG(make_layers(cfg['B'])) + + +def vgg13_bn(): + """VGG 13-layer model (configuration "B") with batch normalization""" + return VGG(make_layers(cfg['B'], batch_norm=True)) + + +def vgg16(): + """VGG 16-layer model (configuration "D")""" + return VGG(make_layers(cfg['D'])) + + +def vgg16_bn(): + """VGG 16-layer model (configuration "D") with batch normalization""" + return VGG(make_layers(cfg['D'], batch_norm=True)) + + +def vgg19(): + """VGG 19-layer model (configuration "E")""" + return VGG(make_layers(cfg['E'])) + + +def vgg19_bn(): + """VGG 19-layer model (configuration 'E') with batch normalization""" + return VGG(make_layers(cfg['E'], batch_norm=True))