Skip to content

Commit

Permalink
Implementation of activations as pytorch modules (#15616)
Browse files Browse the repository at this point in the history
* Implement activations as pytorch modules

* Apply fixup

* Add missing tests for activations

* Update docstring

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
eldarkurtic and patrickvonplaten authored Feb 16, 2022
1 parent 66828a1 commit f65fe36
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 41 deletions.
139 changes: 98 additions & 41 deletions src/transformers/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,94 +16,140 @@

import torch
from packaging import version
from torch import nn
from torch import Tensor, nn

from .utils import logging


logger = logging.get_logger(__name__)


def gelu_python(x):
class NewGELUActivation(nn.Module):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""

def __init__(self):
super().__init__()

def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))


class GELUActivation(nn.Module):
"""
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

def __init__(self, use_gelu_python: bool = False):
super().__init__()
if version.parse(torch.__version__) < version.parse("1.4") or use_gelu_python:
self.act = self._gelu_python
else:
self.act = nn.functional.gelu

def _gelu_python(self, input: Tensor) -> Tensor:
return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))

def gelu_new(x):
def forward(self, input: Tensor) -> Tensor:
return self.act(input)


class FastGELUActivation(nn.Module):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
"""
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

def __init__(self):
super().__init__()

if version.parse(torch.__version__) < version.parse("1.4"):
gelu = gelu_python
else:
gelu = nn.functional.gelu
def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))


def gelu_fast(x):
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
class QuickGELUActivation(nn.Module):
"""
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
"""

def __init__(self):
super().__init__()

def quick_gelu(x):
return x * torch.sigmoid(1.702 * x)
def forward(self, input: Tensor) -> Tensor:
return input * torch.sigmoid(1.702 * input)


def _silu_python(x):
class SiLUActivation(nn.Module):
"""
See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
later.
"""
return x * torch.sigmoid(x)

def __init__(self):
if version.parse(torch.__version__) < version.parse("1.7"):
self.act = self._silu_python
else:
self.act = nn.functional.silu

if version.parse(torch.__version__) < version.parse("1.7"):
silu = _silu_python
else:
silu = nn.functional.silu
def _silu_python(self, input: Tensor) -> Tensor:
return input * torch.sigmoid(input)

def forward(self, input: Tensor) -> Tensor:
return self.act(input)

def _mish_python(x):

class MishActivation(nn.Module):
"""
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
visit the official repository for the paper: https://github.com/digantamisra98/Mish
"""
return x * torch.tanh(nn.functional.softplus(x))

def __init__(self):
super().__init__()
if version.parse(torch.__version__) < version.parse("1.9"):
self.act = self._mish_python
else:
self.act = nn.functional.mish

if version.parse(torch.__version__) < version.parse("1.9"):
mish = _mish_python
else:
mish = nn.functional.mish
def _mish_python(self, input: Tensor) -> Tensor:
return input * torch.tanh(nn.functional.softplus(input))

def forward(self, input: Tensor) -> Tensor:
return self.act(input)

def linear_act(x):
return x

class LinearActivation(nn.Module):
"""
Applies the linear activation function, i.e. forwarding input directly to output.
"""

def __init__(self):
super().__init__()

def forward(self, input: Tensor) -> Tensor:
return input


ACT2FN = {
"relu": nn.functional.relu,
"silu": silu,
"swish": silu,
"gelu": gelu,
"tanh": torch.tanh,
"gelu_python": gelu_python,
"gelu_new": gelu_new,
"gelu_fast": gelu_fast,
"quick_gelu": quick_gelu,
"mish": mish,
"linear": linear_act,
"sigmoid": torch.sigmoid,
"relu": nn.ReLU(),
"silu": SiLUActivation(),
"swish": SiLUActivation(),
"gelu": GELUActivation(),
"tanh": nn.Tanh(),
"gelu_python": GELUActivation(use_gelu_python=True),
"gelu_new": NewGELUActivation(),
"gelu_fast": FastGELUActivation(),
"quick_gelu": QuickGELUActivation(),
"mish": MishActivation(),
"linear": LinearActivation(),
"sigmoid": nn.Sigmoid(),
}


Expand All @@ -112,3 +158,14 @@ def get_activation(activation_string):
return ACT2FN[activation_string]
else:
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")


# For backwards compatibility with: from activations import gelu_python
gelu_python = get_activation("gelu_python")
gelu_new = get_activation("gelu_new")
gelu = get_activation("gelu")
gelu_fast = get_activation("gelu_fast")
quick_gelu = get_activation("quick_gelu")
silu = get_activation("silu")
mish = get_activation("mish")
linear_act = get_activation("linear")
4 changes: 4 additions & 0 deletions tests/test_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def test_get_activation(self):
get_activation("gelu_new")
get_activation("gelu_fast")
get_activation("gelu_python")
get_activation("quick_gelu")
get_activation("mish")
get_activation("linear")
get_activation("sigmoid")
with self.assertRaises(KeyError):
get_activation("bogus")
with self.assertRaises(KeyError):
Expand Down

0 comments on commit f65fe36

Please sign in to comment.