Skip to content

Commit

Permalink
Adding FitNets notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
desinurch committed Aug 13, 2020
1 parent 9f0ac58 commit 5e2c544
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 67 deletions.
1 change: 1 addition & 0 deletions src/Distillation/Fitnets.ipynb

Large diffs are not rendered by default.

133 changes: 66 additions & 67 deletions src/Distillation/network_colab.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ def __init__(self, in_channels, out_channels):
self.downsample = (in_channels != out_channels)
if self.downsample:
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False)
self.ds = nn.Sequential(*[
self.ds = nn.Sequential(*[
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2, bias=False),
nn.BatchNorm2d(out_channels)
])
else:
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.ds = None
self.ds = None
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
Expand Down Expand Up @@ -69,15 +69,15 @@ class resnet1(nn.Module):
def __init__(self, num_class):
super(resnet1, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)

self.res1 = self.make_layer(resblock, 1, 16, 16)
self.res2 = self.make_layer(resblock, 1, 16, 32)
self.res3 = self.make_layer(resblock, 1, 32, 64)

self.avgpool = nn.AvgPool2d(8)
self.fc = nn.Linear(64, num_class)
self.fc = nn.Linear(64, num_class)

for m in self.modules():
if isinstance(m, nn.Conv2d):
Expand Down Expand Up @@ -111,15 +111,15 @@ class resnet10(nn.Module):
def __init__(self, num_class):
super(resnet10, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)

self.res1 = self.make_layer(resblock, 2, 16, 16)
self.res2 = self.make_layer(resblock, 2, 16, 32)
self.res3 = self.make_layer(resblock, 2, 32, 64)

self.avgpool = nn.AvgPool2d(8)
self.fc = nn.Linear(64, num_class)
self.fc = nn.Linear(64, num_class)

for m in self.modules():
if isinstance(m, nn.Conv2d):
Expand Down Expand Up @@ -153,15 +153,15 @@ class resnet20(nn.Module):
def __init__(self, num_class):
super(resnet20, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)

self.res1 = self.make_layer(resblock, 3, 16, 16)
self.res2 = self.make_layer(resblock, 3, 16, 32)
self.res3 = self.make_layer(resblock, 3, 32, 64)

self.avgpool = nn.AvgPool2d(8)
self.fc = nn.Linear(64, num_class)
self.fc = nn.Linear(64, num_class)

for m in self.modules():
if isinstance(m, nn.Conv2d):
Expand Down Expand Up @@ -195,15 +195,15 @@ class resnet56(nn.Module):
def __init__(self, num_class):
super(resnet56, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)

self.res1 = self.make_layer(resblock, 9, 16, 16)
self.res2 = self.make_layer(resblock, 9, 16, 32)
self.res3 = self.make_layer(resblock, 9, 32, 64)

self.avgpool = nn.AvgPool2d(8)
self.fc = nn.Linear(64, num_class)
self.fc = nn.Linear(64, num_class)

for m in self.modules():
if isinstance(m, nn.Conv2d):
Expand Down Expand Up @@ -237,15 +237,15 @@ class resnet110(nn.Module):
def __init__(self, num_class):
super(resnet110, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)

self.res1 = self.make_layer(resblock, 18, 16, 16)
self.res2 = self.make_layer(resblock, 18, 16, 32)
self.res3 = self.make_layer(resblock, 18, 32, 64)

self.avgpool = nn.AvgPool2d(8)
self.fc = nn.Linear(64, num_class)
self.fc = nn.Linear(64, num_class)

for m in self.modules():
if isinstance(m, nn.Conv2d):
Expand Down Expand Up @@ -279,81 +279,80 @@ def forward(self, x):
# from https://github.com/bearpaw/pytorch-classification/blob/master/models/cifar/preresnet.py
########################################################################
class PreActBlock(nn.Module):
'''Pre-activation version of the BasicBlock.'''
expansion = 1
'''Pre-activation version of the BasicBlock.'''
expansion = 1

def __init__(self, in_channels, out_channels):
super(PreActBlock, self).__init__()
def __init__(self, in_channels, out_channels):
super(PreActBlock, self).__init__()
self.downsample = (in_channels != out_channels)
if self.downsample:
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False)
self.ds = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2, bias=False),
)
self.ds = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2, bias=False),
)
else:
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.ds = None
self.bn1 = nn.BatchNorm2d(in_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.relu = nn.ReLU(inplace=True)
self.ds = None
self.bn1 = nn.BatchNorm2d(in_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.relu = nn.ReLU(inplace=True)

def forward(self, x):
residual = x
def forward(self, x):
residual = x

out = self.relu(self.bn1(x))
out = self.conv1(out)
out = self.conv2(self.relu(self.bn2(out)))
out = self.relu(self.bn1(x))
out = self.conv1(out)
out = self.conv2(self.relu(self.bn2(out)))

if self.downsample:
residual = self.ds(x)

out += residual

return out
out += residual

return out

class PreResNet110(nn.Module):
def __init__(self, num_class=10):
super(PreResNet110, self).__init__()
# Model type specifies number of layers for CIFAR-10 model
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64 * PreActBlock.expansion)
self.relu = nn.ReLU(inplace=True)

self.res1 = self.make_layer(PreActBlock, 18, 16, 16)
self.res2 = self.make_layer(PreActBlock, 18, 16, 32)
self.res3 = self.make_layer(PreActBlock, 18, 32, 64)
self.avgpool = nn.AvgPool2d(8)
self.fc = nn.Linear(64 * PreActBlock.expansion, num_class)

for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def __init__(self, num_class=10):
super(PreResNet110, self).__init__()
# Model type specifies number of layers for CIFAR-10 model
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64 * PreActBlock.expansion)
self.relu = nn.ReLU(inplace=True)

self.res1 = self.make_layer(PreActBlock, 18, 16, 16)
self.res2 = self.make_layer(PreActBlock, 18, 16, 32)
self.res3 = self.make_layer(PreActBlock, 18, 32, 64)
self.avgpool = nn.AvgPool2d(8)
self.fc = nn.Linear(64 * PreActBlock.expansion, num_class)

for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

def make_layer(self, block, num, in_channels, out_channels):
layers = [block(in_channels, out_channels)]
for i in range(num-1):
layers.append(block(out_channels, out_channels))
return nn.Sequential(*layers)

def forward(self, x):
pre = self.conv1(x)
def forward(self, x):
pre = self.conv1(x)

rb1 = self.res1(pre) # 32x32
rb2 = self.res2(rb1) # 16x16
rb3 = self.res3(rb2) # 8x8
x = self.bn1(rb3)
x = self.relu(x)
rb1 = self.res1(pre) # 32x32
rb2 = self.res2(rb1) # 16x16
rb3 = self.res3(rb2) # 8x8
x = self.bn1(rb3)
x = self.relu(x)

x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)

return pre, rb1, rb2, rb3, x
return pre, rb1, rb2, rb3, x

###########################################################################

Expand Down

0 comments on commit 5e2c544

Please sign in to comment.