@@ -152,9 +152,7 @@ class EfficientNet(nn.Module):
152
152
[1] https://arxiv.org/abs/1905.11946 (EfficientNet)
153
153
154
154
Example:
155
-
156
-
157
- import torch
155
+ >>> import torch
158
156
>>> from efficientnet.model import EfficientNet
159
157
>>> inputs = torch.rand(1, 3, 224, 224)
160
158
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
@@ -213,16 +211,18 @@ def __init__(self, blocks_args=None, global_params=None):
213
211
214
212
# Final linear layer
215
213
self ._avg_pooling = nn .AdaptiveAvgPool2d (1 )
216
- self ._dropout = nn .Dropout (self ._global_params .dropout_rate )
217
- self ._fc = nn .Linear (out_channels , self ._global_params .num_classes )
214
+ if self ._global_params .include_top :
215
+ self ._dropout = nn .Dropout (self ._global_params .dropout_rate )
216
+ self ._fc = nn .Linear (out_channels , self ._global_params .num_classes )
217
+
218
+ # set activation to memory efficient swish by default
218
219
self ._swish = MemoryEfficientSwish ()
219
220
220
221
def set_swish (self , memory_efficient = True ):
221
222
"""Sets swish function as memory efficient (for training) or standard (for export).
222
223
223
224
Args:
224
225
memory_efficient (bool): Whether to use memory-efficient version of swish.
225
-
226
226
"""
227
227
self ._swish = MemoryEfficientSwish () if memory_efficient else Swish ()
228
228
for block in self ._blocks :
@@ -238,17 +238,18 @@ def extract_endpoints(self, inputs):
238
238
Returns:
239
239
Dictionary of last intermediate features
240
240
with reduction levels i in [1, 2, 3, 4, 5].
241
- Example:
242
- >>> import torch
243
- >>> from efficientnet.model import EfficientNet
244
- >>> inputs = torch.rand(1, 3, 224, 224)
245
- >>> model = EfficientNet.from_pretrained('efficientnet-b0')
246
- >>> endpoints = model.extract_endpoints(inputs)
247
- >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
248
- >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
249
- >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
250
- >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
251
- >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 1280, 7, 7])
241
+
242
+ Example:
243
+ >>> import torch
244
+ >>> from efficientnet.model import EfficientNet
245
+ >>> inputs = torch.rand(1, 3, 224, 224)
246
+ >>> model = EfficientNet.from_pretrained('efficientnet-b0')
247
+ >>> endpoints = model.extract_endpoints(inputs)
248
+ >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
249
+ >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
250
+ >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
251
+ >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
252
+ >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 1280, 7, 7])
252
253
"""
253
254
endpoints = dict ()
254
255
@@ -319,7 +320,7 @@ def forward(self, inputs):
319
320
320
321
@classmethod
321
322
def from_name (cls , model_name , in_channels = 3 , ** override_params ):
322
- """create an efficientnet model according to name.
323
+ """Create an efficientnet model according to name.
323
324
324
325
Args:
325
326
model_name (str): Name for efficientnet.
@@ -345,7 +346,7 @@ def from_name(cls, model_name, in_channels=3, **override_params):
345
346
@classmethod
346
347
def from_pretrained (cls , model_name , weights_path = None , advprop = False ,
347
348
in_channels = 3 , num_classes = 1000 , ** override_params ):
348
- """create an efficientnet model according to name.
349
+ """Create an efficientnet model according to name.
349
350
350
351
Args:
351
352
model_name (str): Name for efficientnet.
0 commit comments