Skip to content

Commit 3bdeeee

Browse files
committed
Minor improvements
1 parent 5fbffa4 commit 3bdeeee

File tree

2 files changed

+21
-19
lines changed

2 files changed

+21
-19
lines changed

efficientnet_pytorch/model.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,7 @@ class EfficientNet(nn.Module):
152152
[1] https://arxiv.org/abs/1905.11946 (EfficientNet)
153153
154154
Example:
155-
156-
157-
import torch
155+
>>> import torch
158156
>>> from efficientnet.model import EfficientNet
159157
>>> inputs = torch.rand(1, 3, 224, 224)
160158
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
@@ -213,16 +211,18 @@ def __init__(self, blocks_args=None, global_params=None):
213211

214212
# Final linear layer
215213
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
216-
self._dropout = nn.Dropout(self._global_params.dropout_rate)
217-
self._fc = nn.Linear(out_channels, self._global_params.num_classes)
214+
if self._global_params.include_top:
215+
self._dropout = nn.Dropout(self._global_params.dropout_rate)
216+
self._fc = nn.Linear(out_channels, self._global_params.num_classes)
217+
218+
# set activation to memory efficient swish by default
218219
self._swish = MemoryEfficientSwish()
219220

220221
def set_swish(self, memory_efficient=True):
221222
"""Sets swish function as memory efficient (for training) or standard (for export).
222223
223224
Args:
224225
memory_efficient (bool): Whether to use memory-efficient version of swish.
225-
226226
"""
227227
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
228228
for block in self._blocks:
@@ -238,17 +238,18 @@ def extract_endpoints(self, inputs):
238238
Returns:
239239
Dictionary of last intermediate features
240240
with reduction levels i in [1, 2, 3, 4, 5].
241-
Example:
242-
>>> import torch
243-
>>> from efficientnet.model import EfficientNet
244-
>>> inputs = torch.rand(1, 3, 224, 224)
245-
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
246-
>>> endpoints = model.extract_endpoints(inputs)
247-
>>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
248-
>>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
249-
>>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
250-
>>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
251-
>>> print(endpoints['reduction_5'].shape) # torch.Size([1, 1280, 7, 7])
241+
242+
Example:
243+
>>> import torch
244+
>>> from efficientnet.model import EfficientNet
245+
>>> inputs = torch.rand(1, 3, 224, 224)
246+
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
247+
>>> endpoints = model.extract_endpoints(inputs)
248+
>>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
249+
>>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
250+
>>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
251+
>>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
252+
>>> print(endpoints['reduction_5'].shape) # torch.Size([1, 1280, 7, 7])
252253
"""
253254
endpoints = dict()
254255

@@ -319,7 +320,7 @@ def forward(self, inputs):
319320

320321
@classmethod
321322
def from_name(cls, model_name, in_channels=3, **override_params):
322-
"""create an efficientnet model according to name.
323+
"""Create an efficientnet model according to name.
323324
324325
Args:
325326
model_name (str): Name for efficientnet.
@@ -345,7 +346,7 @@ def from_name(cls, model_name, in_channels=3, **override_params):
345346
@classmethod
346347
def from_pretrained(cls, model_name, weights_path=None, advprop=False,
347348
in_channels=3, num_classes=1000, **override_params):
348-
"""create an efficientnet model according to name.
349+
"""Create an efficientnet model according to name.
349350
350351
Args:
351352
model_name (str): Name for efficientnet.

efficientnet_pytorch/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def backward(ctx, grad_output):
7171
sigmoid_i = torch.sigmoid(i)
7272
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
7373

74+
7475
class MemoryEfficientSwish(nn.Module):
7576
def forward(self, x):
7677
return SwishImplementation.apply(x)

0 commit comments

Comments
 (0)