Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

swish and mish activation functions #1235

Merged
merged 6 commits into from
Nov 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ Blocks
.. autoclass:: ResidualUnit
:members:

`Swish`
~~~~~~~
.. autoclass:: Swish
:members:

`Mish`
~~~~~~
.. autoclass:: Mish
:members:

`GCN Module`
~~~~~~~~~~~~
.. autoclass:: GCN
Expand Down
1 change: 1 addition & 0 deletions monai/networks/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

from .acti_norm import ADN
from .activation import Mish, Swish
from .aspp import SimpleASPP
from .convolutions import Convolution, ResidualUnit
from .downsample import MaxAvgPool
Expand Down
69 changes: 69 additions & 0 deletions monai/networks/blocks/activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch import nn


class Swish(nn.Module):
r"""Applies the element-wise function:

.. math::
\text{Swish}(x) = x * \text{Sigmoid}(\alpha * x) for constant value alpha.

Citation: Searching for Activation Functions, Ramachandran et al., 2017, https://arxiv.org/abs/1710.05941.


Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input


Examples::

>>> m = Act['swish']()
>>> input = torch.randn(2)
>>> output = m(input)
"""

def __init__(self, alpha=1.0):
super().__init__()
self.alpha = alpha

def forward(self, input: torch.Tensor) -> torch.Tensor:
return input * torch.sigmoid(self.alpha * input)


class Mish(nn.Module):
r"""Applies the element-wise function:

.. math::
\text{Mwish}(x) = x * tanh(\text{softplus}(x)).

Citation: Mish: A Self Regularized Non-Monotonic Activation Function, Diganta Misra, 2019, https://arxiv.org/abs/1908.08681.


Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input


Examples::

>>> m = Act['mish']()
>>> input = torch.randn(2)
>>> output = m(input)
"""

def forward(self, input: torch.Tensor) -> torch.Tensor:
return input * torch.tanh(torch.nn.functional.softplus(input))
14 changes: 14 additions & 0 deletions monai/networks/layers/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,20 @@ def sync_batch_factory(_dim: Optional[int] = None) -> Type[nn.SyncBatchNorm]:
Act.add_factory_callable("logsoftmax", lambda: nn.modules.LogSoftmax)


@Act.factory_function("swish")
def swish_factory():
from monai.networks.blocks.activation import Swish

return Swish


@Act.factory_function("mish")
def mish_factory():
from monai.networks.blocks.activation import Mish

return Mish


@Conv.factory_function("conv")
def conv_factory(dim: int) -> Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]]:
types = (nn.Conv1d, nn.Conv2d, nn.Conv3d)
Expand Down
26 changes: 26 additions & 0 deletions tests/test_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
from parameterized import parameterized

from monai.networks.layers.factories import Act
from monai.transforms import Activations

TEST_CASE_1 = [
Expand All @@ -37,6 +38,24 @@
(1, 1, 2, 2),
]

TEST_CASE_4 = [
"swish",
torch.tensor([[[[-10, -8, -6, -4, -2], [0, 2, 4, 6, 8]]]], dtype=torch.float32),
torch.tensor(
[[[[-4.54e-04, -2.68e-03, -1.48e-02, -7.19e-02, -2.38e-01], [0.00e00, 1.76e00, 3.93e00, 5.99e00, 8.00e00]]]]
),
(1, 1, 2, 5),
]

TEST_CASE_5 = [
"mish",
torch.tensor([[[[-10, -8, -6, -4, -2], [0, 2, 4, 6, 8]]]], dtype=torch.float32),
torch.tensor(
[[[[-4.54e-04, -2.68e-03, -1.49e-02, -7.26e-02, -2.53e-01], [0.00e00, 1.94e00, 4.00e00, 6.00e00, 8.00e00]]]]
),
(1, 1, 2, 5),
]


class TestActivations(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
Expand All @@ -45,6 +64,13 @@ def test_value_shape(self, input_param, img, out, expected_shape):
torch.testing.assert_allclose(result, out)
self.assertTupleEqual(result.shape, expected_shape)

@parameterized.expand([TEST_CASE_4, TEST_CASE_5])
def test_monai_activations_value_shape(self, input_param, img, out, expected_shape):
act = Act[input_param]()
result = act(img)
torch.testing.assert_allclose(result, out, rtol=1e-2, atol=1e-5)
self.assertTupleEqual(result.shape, expected_shape)


if __name__ == "__main__":
unittest.main()