1+ import torch
2+
3+ try :
4+ import fused_mix_prec_layer_norm_cuda
5+ except :
6+ fused_mix_prec_layer_norm_cuda = None
7+
8+
9+ class FusedLayerNormAffineFunction1D (torch .autograd .Function ):
10+
11+ @staticmethod
12+ def forward (ctx , input , weight , bias , normalized_shape , eps ):
13+ ctx .normalized_shape = normalized_shape
14+ ctx .eps = eps
15+ input_ = input .contiguous ()
16+ weight_ = weight .contiguous ()
17+ bias_ = bias .contiguous ()
18+ output , mean , invvar = fused_mix_prec_layer_norm_cuda .forward_affine (
19+ input_ , ctx .normalized_shape , weight_ , bias_ , ctx .eps )
20+ ctx .save_for_backward (input_ , weight_ , bias_ , mean , invvar )
21+ return output
22+
23+
24+ @staticmethod
25+ def backward (ctx , grad_output ):
26+ input_ , weight_ , bias_ , mean , invvar = ctx .saved_tensors
27+ grad_input = grad_weight = grad_bias = None
28+ grad_input , grad_weight , grad_bias \
29+ = fused_mix_prec_layer_norm_cuda .backward_affine (
30+ grad_output .contiguous (), mean , invvar ,
31+ input_ , ctx .normalized_shape ,
32+ weight_ , bias_ , ctx .eps )
33+
34+ return grad_input , grad_weight , grad_bias , None , None
0 commit comments