-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathvggface.py
110 lines (93 loc) · 3.27 KB
/
vggface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
from torchvision import datasets, transforms, models
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import Dataset, DataLoader
from skimage import io, transform
from PIL import Image
import torchvision.transforms.functional as TF
import itertools
import torch.utils.data as data_utils
from backbones.countFLOPS import _calc_width, count_model_flops
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
class VGG_16(nn.Module):
"""
Main Class
"""
def __init__(self):
"""
Constructor
"""
super().__init__()
self.block_size = [2, 2, 3, 3, 3]
self.conv_1_1 = nn.Conv2d(3, 64, 3, stride=1, padding=1)
self.conv_1_2 = nn.Conv2d(64, 64, 3, stride=1, padding=1)
self.conv_2_1 = nn.Conv2d(64, 128, 3, stride=1, padding=1)
self.conv_2_2 = nn.Conv2d(128, 128, 3, stride=1, padding=1)
self.conv_3_1 = nn.Conv2d(128, 256, 3, stride=1, padding=1)
self.conv_3_2 = nn.Conv2d(256, 256, 3, stride=1, padding=1)
self.conv_3_3 = nn.Conv2d(256, 256, 3, stride=1, padding=1)
self.conv_4_1 = nn.Conv2d(256, 512, 3, stride=1, padding=1)
self.conv_4_2 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
self.conv_4_3 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
self.conv_5_1 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
self.conv_5_2 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
self.conv_5_3 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
self.fc6 = nn.Linear(512 * 7 * 7, 4096)
self.fc7 = nn.Linear(4096, 4096)
self.fc8 = nn.Linear(4096, 2622)
def forward(self, x):
""" Pytorch forward
Args:
x: input image (224x224)
Returns: class logits
"""
x = F.relu(self.conv_1_1(x))
x = F.relu(self.conv_1_2(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv_2_1(x))
x = F.relu(self.conv_2_2(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv_3_1(x))
x = F.relu(self.conv_3_2(x))
x = F.relu(self.conv_3_3(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv_4_1(x))
x = F.relu(self.conv_4_2(x))
x = F.relu(self.conv_4_3(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv_5_1(x))
x = F.relu(self.conv_5_2(x))
x = F.relu(self.conv_5_3(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(x.size(0), -1)
x = F.relu(self.fc6(x))
x = F.dropout(x, 0.5, self.training)
x = F.relu(self.fc7(x))
x = F.dropout(x, 0.5, self.training)
return self.fc8(x)
def _test():
import torch
pretrained = False
models = [
VGG_16
]
for model in models:
net = model()
print(net)
# net.train()
weight_count = _calc_width(net)
print("m={}, {}".format(model.__name__, weight_count))
flops=count_model_flops(net, input_res=[224,224])
print("m={}, {}".format(model.__name__, flops))
net.eval()
x = torch.randn(1, 3, 224, 224)
y = net(x)
y.sum().backward()
assert (tuple(y.size()) == (1, 512))
if __name__ == "__main__":
_test()