Skip to content

Commit c830fc1

Browse files
authored
Adding activation kernels (#40890)
* first commit * add mode * revert modeling * add compile * rm print
1 parent f6999b0 commit c830fc1

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

src/transformers/activations.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919
from torch import Tensor, nn
2020

21+
from .integrations.hub_kernels import use_kernel_forward_from_hub
2122
from .utils import logging
2223
from .utils.import_utils import is_torchdynamo_compiling
2324

@@ -38,6 +39,7 @@ def forward(self, input: Tensor) -> Tensor:
3839
return nn.functional.gelu(input, approximate="tanh")
3940

4041

42+
@use_kernel_forward_from_hub("NewGELU")
4143
class NewGELUActivation(nn.Module):
4244
"""
4345
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
@@ -70,6 +72,7 @@ def forward(self, input: Tensor) -> Tensor:
7072
return self.act(input)
7173

7274

75+
@use_kernel_forward_from_hub("FastGELU")
7376
class FastGELUActivation(nn.Module):
7477
"""
7578
Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
@@ -79,6 +82,7 @@ def forward(self, input: Tensor) -> Tensor:
7982
return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
8083

8184

85+
@use_kernel_forward_from_hub("QuickGELU")
8286
class QuickGELUActivation(nn.Module):
8387
"""
8488
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs

src/transformers/integrations/hub_kernels.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,33 @@
8484
)
8585
},
8686
},
87+
"FastGELU": {
88+
"cuda": {
89+
Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
90+
repo_id="kernels-community/activation",
91+
layer_name="FastGELU",
92+
version=">=0.0.4,<0.1.0",
93+
)
94+
}
95+
},
96+
"QuickGELU": {
97+
"cuda": {
98+
Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
99+
repo_id="kernels-community/activation",
100+
layer_name="QuickGELU",
101+
version=">=0.0.4,<0.1.0",
102+
)
103+
}
104+
},
105+
"NewGELU": {
106+
"cuda": {
107+
Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
108+
repo_id="kernels-community/activation",
109+
layer_name="NewGELU",
110+
version=">=0.0.4,<0.1.0",
111+
)
112+
}
113+
},
87114
}
88115

89116
register_kernel_mapping(_KERNEL_MAPPING)

0 commit comments

Comments
 (0)