Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add different swish implementations #88

Merged
merged 1 commit into from
Oct 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions efficientnet_pytorch/model.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from torch.nn import functional as F

from .utils import (
relu_fn,
round_filters,
round_repeats,
drop_connect,
get_same_padding_conv2d,
get_model_params,
efficientnet_params,
load_pretrained_weights,
Swish,
MemoryEfficientSwish,
)

class MBConvBlock(nn.Module):
Expand Down Expand Up @@ -61,6 +62,7 @@ def __init__(self, block_args, global_params):
final_oup = self._block_args.output_filters
self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
self._swish = MemoryEfficientSwish()

def forward(self, inputs, drop_connect_rate=None):
"""
Expand All @@ -72,13 +74,13 @@ def forward(self, inputs, drop_connect_rate=None):
# Expansion and Depthwise Convolution
x = inputs
if self._block_args.expand_ratio != 1:
x = relu_fn(self._bn0(self._expand_conv(inputs)))
x = relu_fn(self._bn1(self._depthwise_conv(x)))
x = self._swish(self._bn0(self._expand_conv(inputs)))
x = self._swish(self._bn1(self._depthwise_conv(x)))

# Squeeze and Excitation
if self.has_se:
x_squeezed = F.adaptive_avg_pool2d(x, 1)
x_squeezed = self._se_expand(relu_fn(self._se_reduce(x_squeezed)))
x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed)))
x = torch.sigmoid(x_squeezed) * x

x = self._bn2(self._project_conv(x))
Expand All @@ -91,6 +93,12 @@ def forward(self, inputs, drop_connect_rate=None):
x = x + inputs # skip connection
return x

def set_swish(self, memory_efficient=True):
if memory_efficient:
self._swish = MemoryEfficientSwish()
else:
self._swish = Swish()


class EfficientNet(nn.Module):
"""
Expand Down Expand Up @@ -153,12 +161,23 @@ def __init__(self, blocks_args=None, global_params=None):
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
self._dropout = nn.Dropout(self._global_params.dropout_rate)
self._fc = nn.Linear(out_channels, self._global_params.num_classes)
self._swish = MemoryEfficientSwish()

def set_swish(self, memory_efficient=True):
if memory_efficient:
self._swish = MemoryEfficientSwish()
else:
self._swish = Swish()

for block in self._blocks:
block.set_swish(memory_efficient)


def extract_features(self, inputs):
""" Returns output of the final convolution layer """

# Stem
x = relu_fn(self._bn0(self._conv_stem(inputs)))
x = self._swish(self._bn0(self._conv_stem(inputs)))

# Blocks
for idx, block in enumerate(self._blocks):
Expand All @@ -168,7 +187,7 @@ def extract_features(self, inputs):
x = block(x, drop_connect_rate=drop_connect_rate)

# Head
x = relu_fn(self._bn1(self._conv_head(x)))
x = self._swish(self._bn1(self._conv_head(x)))

return x

Expand Down
10 changes: 5 additions & 5 deletions efficientnet_pytorch/utils.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ def backward(ctx, grad_output):
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))


class Swish(nn.Module):
@staticmethod
def forward(x):
class MemoryEfficientSwish(nn.Module):
def forward(self, x):
return SwishImplementation.apply(x)


relu_fn = Swish()
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)


def round_filters(filters, global_params):
Expand Down