File tree Expand file tree Collapse file tree 2 files changed +31
-0
lines changed
Expand file tree Collapse file tree 2 files changed +31
-0
lines changed Original file line number Diff line number Diff line change 1818import torch
1919from torch import Tensor , nn
2020
21+ from .integrations .hub_kernels import use_kernel_forward_from_hub
2122from .utils import logging
2223from .utils .import_utils import is_torchdynamo_compiling
2324
@@ -38,6 +39,7 @@ def forward(self, input: Tensor) -> Tensor:
3839 return nn .functional .gelu (input , approximate = "tanh" )
3940
4041
42+ @use_kernel_forward_from_hub ("NewGELU" )
4143class NewGELUActivation (nn .Module ):
4244 """
4345 Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
@@ -70,6 +72,7 @@ def forward(self, input: Tensor) -> Tensor:
7072 return self .act (input )
7173
7274
75+ @use_kernel_forward_from_hub ("FastGELU" )
7376class FastGELUActivation (nn .Module ):
7477 """
7578 Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
@@ -79,6 +82,7 @@ def forward(self, input: Tensor) -> Tensor:
7982 return 0.5 * input * (1.0 + torch .tanh (input * 0.7978845608 * (1.0 + 0.044715 * input * input )))
8083
8184
85+ @use_kernel_forward_from_hub ("QuickGELU" )
8286class QuickGELUActivation (nn .Module ):
8387 """
8488 Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
Original file line number Diff line number Diff line change 8484 )
8585 },
8686 },
87+ "FastGELU" : {
88+ "cuda" : {
89+ Mode .INFERENCE | Mode .TORCH_COMPILE : LayerRepository (
90+ repo_id = "kernels-community/activation" ,
91+ layer_name = "FastGELU" ,
92+ version = ">=0.0.4,<0.1.0" ,
93+ )
94+ }
95+ },
96+ "QuickGELU" : {
97+ "cuda" : {
98+ Mode .INFERENCE | Mode .TORCH_COMPILE : LayerRepository (
99+ repo_id = "kernels-community/activation" ,
100+ layer_name = "QuickGELU" ,
101+ version = ">=0.0.4,<0.1.0" ,
102+ )
103+ }
104+ },
105+ "NewGELU" : {
106+ "cuda" : {
107+ Mode .INFERENCE | Mode .TORCH_COMPILE : LayerRepository (
108+ repo_id = "kernels-community/activation" ,
109+ layer_name = "NewGELU" ,
110+ version = ">=0.0.4,<0.1.0" ,
111+ )
112+ }
113+ },
87114 }
88115
89116 register_kernel_mapping (_KERNEL_MAPPING )
You can’t perform that action at this time.
0 commit comments