Skip to content

Commit 761ac94

Browse files
authored
Merge pull request #208 from nwschurink/#192_include_top
#192 include top
2 parents a746930 + a78e84e commit 761ac94

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

efficientnet_pytorch/model.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,9 @@ class EfficientNet(nn.Module):
152152
[1] https://arxiv.org/abs/1905.11946 (EfficientNet)
153153
154154
Example:
155-
>>> import torch
155+
156+
157+
import torch
156158
>>> from efficientnet.model import EfficientNet
157159
>>> inputs = torch.rand(1, 3, 224, 224)
158160
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
@@ -307,13 +309,12 @@ def forward(self, inputs):
307309
"""
308310
# Convolution layers
309311
x = self.extract_features(inputs)
310-
311312
# Pooling and final linear layer
312313
x = self._avg_pooling(x)
313-
x = x.flatten(start_dim=1)
314-
x = self._dropout(x)
315-
x = self._fc(x)
316-
314+
if self._global_params.include_top:
315+
x = x.flatten(start_dim=1)
316+
x = self._dropout(x)
317+
x = self._fc(x)
317318
return x
318319

319320
@classmethod

efficientnet_pytorch/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
GlobalParams = collections.namedtuple('GlobalParams', [
4040
'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate',
4141
'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon',
42-
'drop_connect_rate', 'depth_divisor', 'min_depth'])
42+
'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top'])
4343

4444
# Parameters for an individual model block
4545
BlockArgs = collections.namedtuple('BlockArgs', [
@@ -475,7 +475,7 @@ def efficientnet_params(model_name):
475475

476476

477477
def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None,
478-
dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000):
478+
dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000, include_top=True):
479479
"""Create BlockArgs and GlobalParams for efficientnet model.
480480
481481
Args:
@@ -517,6 +517,7 @@ def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None
517517
drop_connect_rate=drop_connect_rate,
518518
depth_divisor=8,
519519
min_depth=None,
520+
include_top=include_top,
520521
)
521522

522523
return blocks_args, global_params

0 commit comments

Comments
 (0)