Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ Blocks
.. autoclass:: Mish
:members:

`GEGLU`
~~~~~~~
.. autoclass:: GEGLU
:members:

`GCN Module`
~~~~~~~~~~~~
.. autoclass:: GCN
Expand Down
2 changes: 1 addition & 1 deletion monai/networks/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from __future__ import annotations

from .acti_norm import ADN
from .activation import MemoryEfficientSwish, Mish, Swish
from .activation import GEGLU, MemoryEfficientSwish, Mish, Swish
from .aspp import SimpleASPP
from .backbone_fpn_utils import BackboneWithFPN
from .convolutions import Convolution, ResidualUnit
Expand Down
20 changes: 20 additions & 0 deletions monai/networks/blocks/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,23 @@ def __init__(self, inplace: bool = False):

def forward(self, input: torch.Tensor):
return monai_mish(input, self.inplace)


class GEGLU(nn.Module):
r"""Applies the element-wise function:

.. math::
\text{GEGLU}(x) = x_1 * \text{Sigmoid}(x_2)

where :math:`x_1` and :math:`x_2` are split from the input tensor along the last dimension.

Citation: GLU Variants Improve Transformer, Noam Shazeer, 2020, https://arxiv.org/abs/2002.05202.

Shape:
- Input: :math:`(N, *, 2 * D)`
- Output: :math:`(N, *, D)`, where `*` means, any number of additional dimensions
"""

def forward(self, input: torch.Tensor):
x, gate = input.chunk(2, dim=-1)
return x * nn.functional.gelu(gate)
4 changes: 2 additions & 2 deletions monai/networks/blocks/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
hidden_size: dimension of hidden layer.
mlp_dim: dimension of feedforward layer. If 0, `hidden_size` will be used.
dropout_rate: faction of the input units to drop.
act: activation type and arguments. Defaults to GELU.
act: activation type and arguments. Defaults to GELU. Also supports "GEGLU" and others.
dropout_mode: dropout mode, can be "vit" or "swin".
"vit" mode uses two dropout instances as implemented in
https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py#L87
Expand All @@ -48,7 +48,7 @@ def __init__(
if not (0 <= dropout_rate <= 1):
raise ValueError("dropout_rate should be between 0 and 1.")
mlp_dim = mlp_dim or hidden_size
self.linear1 = nn.Linear(hidden_size, mlp_dim)
self.linear1 = nn.Linear(hidden_size, mlp_dim) if act != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2)
self.linear2 = nn.Linear(mlp_dim, hidden_size)
self.fn = get_act_layer(act)
self.drop1 = nn.Dropout(dropout_rate)
Expand Down
7 changes: 7 additions & 0 deletions monai/networks/layers/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,13 @@ def mish_factory():
return Mish


@Act.factory_function("geglu")
def geglu_factory():
from monai.networks.blocks.activation import GEGLU

return GEGLU


@Conv.factory_function("conv")
def conv_factory(dim: int) -> type[nn.Conv1d | nn.Conv2d | nn.Conv3d]:
types = (nn.Conv1d, nn.Conv2d, nn.Conv3d)
Expand Down
9 changes: 8 additions & 1 deletion tests/test_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@
(1, 2, 5),
]

TEST_CASE_7 = [
"geglu",
torch.tensor([[[-10, -8, -6, -4, -2, 0], [0, 2, 4, 6, 8, 10]]], dtype=torch.float32),
torch.tensor([[[1.27e-03, 3.64e-01, 0.00e00], [0.00e00, 1.60e01, 4.00e01]]]),
(1, 2, 3),
]


class TestActivations(unittest.TestCase):
@parameterized.expand(TEST_CASES)
Expand All @@ -101,7 +108,7 @@ def _compare(ret, out, shape):
else:
_compare(result, out, expected_shape)

@parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6])
@parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7])
def test_monai_activations_value_shape(self, input_param, img, out, expected_shape):
act = Act[input_param]()
result = act(img)
Expand Down