@@ -50,7 +50,7 @@ class MBConvBlock(nn.Module):
50
50
def __init__ (self , block_args , global_params , image_size = None ):
51
51
super ().__init__ ()
52
52
self ._block_args = block_args
53
- self ._bn_mom = 1 - global_params .batch_norm_momentum # pytorch's difference from tensorflow
53
+ self ._bn_mom = 1 - global_params .batch_norm_momentum # pytorch's difference from tensorflow
54
54
self ._bn_eps = global_params .batch_norm_epsilon
55
55
self .has_se = (self ._block_args .se_ratio is not None ) and (0 < self ._block_args .se_ratio <= 1 )
56
56
self .id_skip = block_args .id_skip # whether to use skip connection and drop connect
@@ -152,9 +152,7 @@ class EfficientNet(nn.Module):
152
152
[1] https://arxiv.org/abs/1905.11946 (EfficientNet)
153
153
154
154
Example:
155
-
156
-
157
- import torch
155
+ >>> import torch
158
156
>>> from efficientnet.model import EfficientNet
159
157
>>> inputs = torch.rand(1, 3, 224, 224)
160
158
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
@@ -198,7 +196,7 @@ def __init__(self, blocks_args=None, global_params=None):
198
196
# The first block needs to take care of stride and filter size increase.
199
197
self ._blocks .append (MBConvBlock (block_args , self ._global_params , image_size = image_size ))
200
198
image_size = calculate_output_image_size (image_size , block_args .stride )
201
- if block_args .num_repeat > 1 : # modify block_args to keep same output size
199
+ if block_args .num_repeat > 1 : # modify block_args to keep same output size
202
200
block_args = block_args ._replace (input_filters = block_args .output_filters , stride = 1 )
203
201
for _ in range (block_args .num_repeat - 1 ):
204
202
self ._blocks .append (MBConvBlock (block_args , self ._global_params , image_size = image_size ))
@@ -213,16 +211,18 @@ def __init__(self, blocks_args=None, global_params=None):
213
211
214
212
# Final linear layer
215
213
self ._avg_pooling = nn .AdaptiveAvgPool2d (1 )
216
- self ._dropout = nn .Dropout (self ._global_params .dropout_rate )
217
- self ._fc = nn .Linear (out_channels , self ._global_params .num_classes )
214
+ if self ._global_params .include_top :
215
+ self ._dropout = nn .Dropout (self ._global_params .dropout_rate )
216
+ self ._fc = nn .Linear (out_channels , self ._global_params .num_classes )
217
+
218
+ # set activation to memory efficient swish by default
218
219
self ._swish = MemoryEfficientSwish ()
219
220
220
221
def set_swish (self , memory_efficient = True ):
221
222
"""Sets swish function as memory efficient (for training) or standard (for export).
222
223
223
224
Args:
224
225
memory_efficient (bool): Whether to use memory-efficient version of swish.
225
-
226
226
"""
227
227
self ._swish = MemoryEfficientSwish () if memory_efficient else Swish ()
228
228
for block in self ._blocks :
@@ -261,17 +261,17 @@ def extract_endpoints(self, inputs):
261
261
for idx , block in enumerate (self ._blocks ):
262
262
drop_connect_rate = self ._global_params .drop_connect_rate
263
263
if drop_connect_rate :
264
- drop_connect_rate *= float (idx ) / len (self ._blocks ) # scale drop connect_rate
264
+ drop_connect_rate *= float (idx ) / len (self ._blocks ) # scale drop connect_rate
265
265
x = block (x , drop_connect_rate = drop_connect_rate )
266
266
if prev_x .size (2 ) > x .size (2 ):
267
- endpoints ['reduction_{}' .format (len (endpoints )+ 1 )] = prev_x
267
+ endpoints ['reduction_{}' .format (len (endpoints ) + 1 )] = prev_x
268
268
elif idx == len (self ._blocks ) - 1 :
269
- endpoints ['reduction_{}' .format (len (endpoints )+ 1 )] = x
269
+ endpoints ['reduction_{}' .format (len (endpoints ) + 1 )] = x
270
270
prev_x = x
271
271
272
272
# Head
273
273
x = self ._swish (self ._bn1 (self ._conv_head (x )))
274
- endpoints ['reduction_{}' .format (len (endpoints )+ 1 )] = x
274
+ endpoints ['reduction_{}' .format (len (endpoints ) + 1 )] = x
275
275
276
276
return endpoints
277
277
@@ -292,7 +292,7 @@ def extract_features(self, inputs):
292
292
for idx , block in enumerate (self ._blocks ):
293
293
drop_connect_rate = self ._global_params .drop_connect_rate
294
294
if drop_connect_rate :
295
- drop_connect_rate *= float (idx ) / len (self ._blocks ) # scale drop connect_rate
295
+ drop_connect_rate *= float (idx ) / len (self ._blocks ) # scale drop connect_rate
296
296
x = block (x , drop_connect_rate = drop_connect_rate )
297
297
298
298
# Head
@@ -322,7 +322,7 @@ def forward(self, inputs):
322
322
323
323
@classmethod
324
324
def from_name (cls , model_name , in_channels = 3 , ** override_params ):
325
- """create an efficientnet model according to name.
325
+ """Create an efficientnet model according to name.
326
326
327
327
Args:
328
328
model_name (str): Name for efficientnet.
@@ -348,7 +348,7 @@ def from_name(cls, model_name, in_channels=3, **override_params):
348
348
@classmethod
349
349
def from_pretrained (cls , model_name , weights_path = None , advprop = False ,
350
350
in_channels = 3 , num_classes = 1000 , ** override_params ):
351
- """create an efficientnet model according to name.
351
+ """Create an efficientnet model according to name.
352
352
353
353
Args:
354
354
model_name (str): Name for efficientnet.
@@ -375,7 +375,8 @@ def from_pretrained(cls, model_name, weights_path=None, advprop=False,
375
375
A pretrained efficientnet model.
376
376
"""
377
377
model = cls .from_name (model_name , num_classes = num_classes , ** override_params )
378
- load_pretrained_weights (model , model_name , weights_path = weights_path , load_fc = (num_classes == 1000 ), advprop = advprop )
378
+ load_pretrained_weights (model , model_name , weights_path = weights_path ,
379
+ load_fc = (num_classes == 1000 ), advprop = advprop )
379
380
model ._change_in_channels (in_channels )
380
381
return model
381
382
0 commit comments