Skip to content

Commit

Permalink
swish and mish activation functions (#1235)
Browse files Browse the repository at this point in the history
* swish and mish activation functions

Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com>

* update examples

Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com>

* add new classes to __init__

Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com>

* networks.rst

Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com>
  • Loading branch information
rijobro authored Nov 16, 2020
1 parent 24611b0 commit 8bbdc41
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 0 deletions.
10 changes: 10 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ Blocks
.. autoclass:: ResidualUnit
:members:

`Swish`
~~~~~~~
.. autoclass:: Swish
:members:

`Mish`
~~~~~~
.. autoclass:: Mish
:members:

`GCN Module`
~~~~~~~~~~~~
.. autoclass:: GCN
Expand Down
1 change: 1 addition & 0 deletions monai/networks/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

from .acti_norm import ADN
from .activation import Mish, Swish
from .aspp import SimpleASPP
from .convolutions import Convolution, ResidualUnit
from .downsample import MaxAvgPool
Expand Down
69 changes: 69 additions & 0 deletions monai/networks/blocks/activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch import nn


class Swish(nn.Module):
r"""Applies the element-wise function:
.. math::
\text{Swish}(x) = x * \text{Sigmoid}(\alpha * x) for constant value alpha.
Citation: Searching for Activation Functions, Ramachandran et al., 2017, https://arxiv.org/abs/1710.05941.
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
Examples::
>>> m = Act['swish']()
>>> input = torch.randn(2)
>>> output = m(input)
"""

def __init__(self, alpha=1.0):
super().__init__()
self.alpha = alpha

def forward(self, input: torch.Tensor) -> torch.Tensor:
return input * torch.sigmoid(self.alpha * input)


class Mish(nn.Module):
r"""Applies the element-wise function:
.. math::
\text{Mwish}(x) = x * tanh(\text{softplus}(x)).
Citation: Mish: A Self Regularized Non-Monotonic Activation Function, Diganta Misra, 2019, https://arxiv.org/abs/1908.08681.
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
Examples::
>>> m = Act['mish']()
>>> input = torch.randn(2)
>>> output = m(input)
"""

def forward(self, input: torch.Tensor) -> torch.Tensor:
return input * torch.tanh(torch.nn.functional.softplus(input))
14 changes: 14 additions & 0 deletions monai/networks/layers/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,20 @@ def sync_batch_factory(_dim: Optional[int] = None) -> Type[nn.SyncBatchNorm]:
Act.add_factory_callable("logsoftmax", lambda: nn.modules.LogSoftmax)


@Act.factory_function("swish")
def swish_factory():
from monai.networks.blocks.activation import Swish

return Swish


@Act.factory_function("mish")
def mish_factory():
from monai.networks.blocks.activation import Mish

return Mish


@Conv.factory_function("conv")
def conv_factory(dim: int) -> Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]]:
types = (nn.Conv1d, nn.Conv2d, nn.Conv3d)
Expand Down
26 changes: 26 additions & 0 deletions tests/test_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
from parameterized import parameterized

from monai.networks.layers.factories import Act
from monai.transforms import Activations

TEST_CASE_1 = [
Expand All @@ -37,6 +38,24 @@
(1, 1, 2, 2),
]

TEST_CASE_4 = [
"swish",
torch.tensor([[[[-10, -8, -6, -4, -2], [0, 2, 4, 6, 8]]]], dtype=torch.float32),
torch.tensor(
[[[[-4.54e-04, -2.68e-03, -1.48e-02, -7.19e-02, -2.38e-01], [0.00e00, 1.76e00, 3.93e00, 5.99e00, 8.00e00]]]]
),
(1, 1, 2, 5),
]

TEST_CASE_5 = [
"mish",
torch.tensor([[[[-10, -8, -6, -4, -2], [0, 2, 4, 6, 8]]]], dtype=torch.float32),
torch.tensor(
[[[[-4.54e-04, -2.68e-03, -1.49e-02, -7.26e-02, -2.53e-01], [0.00e00, 1.94e00, 4.00e00, 6.00e00, 8.00e00]]]]
),
(1, 1, 2, 5),
]


class TestActivations(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
Expand All @@ -45,6 +64,13 @@ def test_value_shape(self, input_param, img, out, expected_shape):
torch.testing.assert_allclose(result, out)
self.assertTupleEqual(result.shape, expected_shape)

@parameterized.expand([TEST_CASE_4, TEST_CASE_5])
def test_monai_activations_value_shape(self, input_param, img, out, expected_shape):
act = Act[input_param]()
result = act(img)
torch.testing.assert_allclose(result, out, rtol=1e-2, atol=1e-5)
self.assertTupleEqual(result.shape, expected_shape)


if __name__ == "__main__":
unittest.main()

0 comments on commit 8bbdc41

Please sign in to comment.