diff --git a/src/transformers/activations.py b/src/transformers/activations.py index 3d81e8bb1dd6f9..e845e7712e7c45 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -16,7 +16,7 @@ import torch from packaging import version -from torch import nn +from torch import Tensor, nn from .utils import logging @@ -24,39 +24,66 @@ logger = logging.get_logger(__name__) -def gelu_python(x): +class NewGELUActivation(nn.Module): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see + the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + + def __init__(self): + super().__init__() + + def forward(self, input: Tensor) -> Tensor: + return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) + + +class GELUActivation(nn.Module): """ Original Implementation of the GELU activation function in Google BERT repo when initially created. For information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 """ - return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + def __init__(self, use_gelu_python: bool = False): + super().__init__() + if version.parse(torch.__version__) < version.parse("1.4") or use_gelu_python: + self.act = self._gelu_python + else: + self.act = nn.functional.gelu + + def _gelu_python(self, input: Tensor) -> Tensor: + return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0))) -def gelu_new(x): + def forward(self, input: Tensor) -> Tensor: + return self.act(input) + + +class FastGELUActivation(nn.Module): """ - Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see - the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs """ - return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + def __init__(self): + super().__init__() -if version.parse(torch.__version__) < version.parse("1.4"): - gelu = gelu_python -else: - gelu = nn.functional.gelu + def forward(self, input: Tensor) -> Tensor: + return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))) -def gelu_fast(x): - return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) +class QuickGELUActivation(nn.Module): + """ + Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs + """ + def __init__(self): + super().__init__() -def quick_gelu(x): - return x * torch.sigmoid(1.702 * x) + def forward(self, input: Tensor) -> Tensor: + return input * torch.sigmoid(1.702 * input) -def _silu_python(x): +class SiLUActivation(nn.Module): """ See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function @@ -64,46 +91,65 @@ def _silu_python(x): Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with later. """ - return x * torch.sigmoid(x) + def __init__(self): + if version.parse(torch.__version__) < version.parse("1.7"): + self.act = self._silu_python + else: + self.act = nn.functional.silu -if version.parse(torch.__version__) < version.parse("1.7"): - silu = _silu_python -else: - silu = nn.functional.silu + def _silu_python(self, input: Tensor) -> Tensor: + return input * torch.sigmoid(input) + def forward(self, input: Tensor) -> Tensor: + return self.act(input) -def _mish_python(x): + +class MishActivation(nn.Module): """ See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also visit the official repository for the paper: https://github.com/digantamisra98/Mish """ - return x * torch.tanh(nn.functional.softplus(x)) + def __init__(self): + super().__init__() + if version.parse(torch.__version__) < version.parse("1.9"): + self.act = self._mish_python + else: + self.act = nn.functional.mish -if version.parse(torch.__version__) < version.parse("1.9"): - mish = _mish_python -else: - mish = nn.functional.mish + def _mish_python(self, input: Tensor) -> Tensor: + return input * torch.tanh(nn.functional.softplus(input)) + def forward(self, input: Tensor) -> Tensor: + return self.act(input) -def linear_act(x): - return x + +class LinearActivation(nn.Module): + """ + Applies the linear activation function, i.e. forwarding input directly to output. + """ + + def __init__(self): + super().__init__() + + def forward(self, input: Tensor) -> Tensor: + return input ACT2FN = { - "relu": nn.functional.relu, - "silu": silu, - "swish": silu, - "gelu": gelu, - "tanh": torch.tanh, - "gelu_python": gelu_python, - "gelu_new": gelu_new, - "gelu_fast": gelu_fast, - "quick_gelu": quick_gelu, - "mish": mish, - "linear": linear_act, - "sigmoid": torch.sigmoid, + "relu": nn.ReLU(), + "silu": SiLUActivation(), + "swish": SiLUActivation(), + "gelu": GELUActivation(), + "tanh": nn.Tanh(), + "gelu_python": GELUActivation(use_gelu_python=True), + "gelu_new": NewGELUActivation(), + "gelu_fast": FastGELUActivation(), + "quick_gelu": QuickGELUActivation(), + "mish": MishActivation(), + "linear": LinearActivation(), + "sigmoid": nn.Sigmoid(), } @@ -112,3 +158,14 @@ def get_activation(activation_string): return ACT2FN[activation_string] else: raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}") + + +# For backwards compatibility with: from activations import gelu_python +gelu_python = get_activation("gelu_python") +gelu_new = get_activation("gelu_new") +gelu = get_activation("gelu") +gelu_fast = get_activation("gelu_fast") +quick_gelu = get_activation("quick_gelu") +silu = get_activation("silu") +mish = get_activation("mish") +linear_act = get_activation("linear") diff --git a/tests/test_activations.py b/tests/test_activations.py index 2591352f39ff75..71b29913103479 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -40,6 +40,10 @@ def test_get_activation(self): get_activation("gelu_new") get_activation("gelu_fast") get_activation("gelu_python") + get_activation("quick_gelu") + get_activation("mish") + get_activation("linear") + get_activation("sigmoid") with self.assertRaises(KeyError): get_activation("bogus") with self.assertRaises(KeyError):