File tree Expand file tree Collapse file tree 1 file changed +1
-14
lines changed Expand file tree Collapse file tree 1 file changed +1
-14
lines changed Original file line number Diff line number Diff line change @@ -337,20 +337,7 @@ def num_hidden_layers(self):
337337 from apex .normalization .fused_layer_norm import FusedLayerNorm as XLNetLayerNorm
338338except (ImportError , AttributeError ) as e :
339339 logger .info ("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex ." )
340- class XLNetLayerNorm (nn .Module ):
341- def __init__ (self , d_model , eps = 1e-12 ):
342- """Construct a layernorm module in the TF style (epsilon inside the square root).
343- """
344- super (XLNetLayerNorm , self ).__init__ ()
345- self .weight = nn .Parameter (torch .ones (d_model ))
346- self .bias = nn .Parameter (torch .zeros (d_model ))
347- self .variance_epsilon = eps
348-
349- def forward (self , x ):
350- u = x .mean (- 1 , keepdim = True )
351- s = (x - u ).pow (2 ).mean (- 1 , keepdim = True )
352- x = (x - u ) / torch .sqrt (s + self .variance_epsilon )
353- return self .weight * x + self .bias
340+ from torch .nn import LayerNorm as XLNetLayerNorm
354341
355342class XLNetRelativeAttention (nn .Module ):
356343 def __init__ (self , config ):
You can’t perform that action at this time.
0 commit comments