forked from Sanghyun-Hong/DeepSloth
-
Notifications
You must be signed in to change notification settings - Fork 0
/
profiler.py
212 lines (164 loc) · 5.94 KB
/
profiler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
"""
To compute GFLOPs (inference cost) and num params of a CNN or SDN
"""
# torch
import torch
import torch.nn as nn
# custom libs
import utils
import networks.SDNs.VGG_SDN as vgg_sdn
def count_conv2d(m, x, y):
x = x[0]
cin = m.in_channels // m.groups
cout = m.out_channels // m.groups
kh, kw = m.kernel_size
batch_size = x.size()[0]
# ops per output element
kernel_mul = kh * kw * cin
kernel_add = kh * kw * cin - 1
bias_ops = 1 if m.bias is not None else 0
ops = kernel_mul + kernel_add + bias_ops
# total ops
num_out_elements = y.numel()
total_ops = num_out_elements * ops * m.groups
# incase same conv is used multiple times
m.total_ops += torch.Tensor([int(total_ops)])
def count_bn2d(m, x, y):
x = x[0]
nelements = x.numel()
total_sub = nelements
total_div = nelements
total_ops = total_sub + total_div
m.total_ops += torch.Tensor([int(total_ops)])
def count_relu(m, x, y):
x = x[0]
nelements = x.numel()
total_ops = nelements
m.total_ops += torch.Tensor([int(total_ops)])
def count_softmax(m, x, y):
x = x[0]
batch_size, nfeatures = x.size()
total_exp = nfeatures
total_add = nfeatures - 1
total_div = nfeatures
total_ops = batch_size * (total_exp + total_add + total_div)
m.total_ops += torch.Tensor([int(total_ops)])
def count_maxpool(m, x, y):
kernel_ops = torch.prod(torch.Tensor([m.kernel_size])) - 1
num_elements = y.numel()
total_ops = kernel_ops * num_elements
m.total_ops += torch.Tensor([int(total_ops)])
def count_avgpool(m, x, y):
total_add = torch.prod(torch.Tensor([m.kernel_size])) - 1
total_div = 1
kernel_ops = total_add + total_div
num_elements = y.numel()
total_ops = kernel_ops * num_elements
m.total_ops += torch.Tensor([int(total_ops)])
def count_linear(m, x, y):
# per output element
total_mul = m.in_features
total_add = m.in_features - 1
num_elements = y.numel()
total_ops = (total_mul + total_add) * num_elements
m.total_ops += torch.Tensor([int(total_ops)])
def profile_sdn(model, input_size, device):
inp = (1, 3, input_size, input_size)
model.eval()
def add_hooks(m):
if len(list(m.children())) > 0: return
m.register_buffer('total_ops', torch.zeros(1))
m.register_buffer('total_params', torch.zeros(1))
for p in m.parameters():
m.total_params += torch.Tensor([p.numel()])
if isinstance(m, nn.Conv2d):
m.register_forward_hook(count_conv2d)
elif isinstance(m, nn.BatchNorm2d):
m.register_forward_hook(count_bn2d)
elif isinstance(m, nn.ReLU):
m.register_forward_hook(count_relu)
elif isinstance(m, (nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d)):
m.register_forward_hook(count_maxpool)
elif isinstance(m, (nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d)):
m.register_forward_hook(count_avgpool)
elif isinstance(m, nn.Linear):
m.register_forward_hook(count_linear)
elif isinstance(m, (nn.Dropout, nn.Dropout2d, nn.Dropout3d)):
pass
else:
#print("Not implemented for ", m)
pass
model.apply(add_hooks)
x = torch.zeros(inp)
x = x.to(device)
model(x)
output_total_ops = {}
output_total_params = {}
total_ops = 0
total_params = 0
cur_output_id = 0
cur_output_layer_id = -10
wait_for = -10
vgg = False
for layer_id, m in enumerate(model.modules()):
if isinstance(m, utils.InternalClassifier):
cur_output_layer_id = layer_id
elif isinstance(m, vgg_sdn.FcBlockWOutput) and m.no_output == False:
vgg = True
cur_output_layer_id = layer_id
if layer_id == cur_output_layer_id + 1:
if vgg:
wait_for = 4
elif isinstance(m, nn.Linear):
wait_for = 1
else:
wait_for = 3
if len(list(m.children())) > 0: continue
total_ops += m.total_ops
total_params += m.total_params
if layer_id == cur_output_layer_id + wait_for:
output_total_ops[cur_output_id] = total_ops.numpy()[0]/1e9
output_total_params[cur_output_id] = total_params.numpy()[0]/1e6
cur_output_id += 1
output_total_ops[cur_output_id] = total_ops.numpy()[0]/1e9
output_total_params[cur_output_id] = total_params.numpy()[0]/1e6
return output_total_ops, output_total_params
def profile(model, input_size, device):
inp = (1, 3, input_size, input_size)
model.eval()
def add_hooks(m):
if len(list(m.children())) > 0: return
m.register_buffer('total_ops', torch.zeros(1))
m.register_buffer('total_params', torch.zeros(1))
for p in m.parameters():
m.total_params += torch.Tensor([p.numel()])
if isinstance(m, nn.Conv2d):
m.register_forward_hook(count_conv2d)
elif isinstance(m, nn.BatchNorm2d):
m.register_forward_hook(count_bn2d)
elif isinstance(m, nn.ReLU):
m.register_forward_hook(count_relu)
elif isinstance(m, (nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d)):
m.register_forward_hook(count_maxpool)
elif isinstance(m, (nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d)):
m.register_forward_hook(count_avgpool)
elif isinstance(m, nn.Linear):
m.register_forward_hook(count_linear)
elif isinstance(m, (nn.Dropout, nn.Dropout2d, nn.Dropout3d)):
pass
else:
#print("Not implemented for ", m)
pass
model.apply(add_hooks)
x = torch.zeros(inp)
x = x.to(device)
model(x)
total_ops = 0
total_params = 0
for m in model.modules():
if len(list(m.children())) > 0: continue
total_ops += m.total_ops
total_params += m.total_params
total_ops = total_ops
total_params = total_params
return total_ops, total_params