@@ -359,7 +359,7 @@ def native_layer_norm(input: Tensor, normalized_shape: List[int], weight: Option
359359 if M > 0 :
360360 input_reshaped = input .view (1 , M , - 1 )
361361 else :
362- return (input , aten .new_empty (input , (0 ,)), aten .new_empty (input , (0 ,)))
362+ return (input , aten .new_zeros (input , (0 ,)), aten .new_zeros (input , (0 ,)))
363363
364364 # Unlike Batch Normalization, which applies scalar scale and bias for each
365365 # entire channel/plane with the affine option, Layer Normalization applies
@@ -417,6 +417,72 @@ def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int =
417417 return out
418418 return beta * self + out
419419
420+ @register_decomposition (aten .native_layer_norm_backward )
421+ def native_layer_norm_backward (grad_out : Tensor , input : Tensor , normalized_shape : List [int ], mean : Tensor , rstd : Tensor , weight : Optional [Tensor ], bias : Optional [Tensor ], output_mask : List [bool ]) -> Tuple [Tensor , Tensor , Tensor ]:
422+ input_shape = input .shape
423+ input_ndim = input .dim ()
424+
425+ axis = input_ndim - len (normalized_shape )
426+ inner_dims = input_shape [axis :]
427+ outer_dims = input_shape [:axis ]
428+ inner_dim_indices = []
429+ outer_dim_indices = []
430+ for i in range (input_ndim ):
431+ if (i >= axis ):
432+ inner_dim_indices .append (i )
433+ else :
434+ outer_dim_indices .append (i )
435+
436+ N = float (prod (inner_dims ))
437+ M = prod (outer_dims )
438+ if M <= 0 or N <= 0.0 :
439+ return (aten .new_empty (input , input_shape ), aten .new_zeros (input [axis :], input_shape [axis :]), aten .new_zeros (input [axis :], input_shape [axis :]))
440+
441+ x_hat = aten .mul (aten .sub (input , mean ), rstd )
442+ if weight is not None :
443+ grad_x_hat = aten .mul (grad_out , weight )
444+ else :
445+ grad_x_hat = grad_out
446+ a = aten .mul (grad_x_hat , N )
447+ b = aten .sum (grad_x_hat , inner_dim_indices , True )
448+ c1 = aten .mul (grad_x_hat , x_hat )
449+ c2 = aten .sum (c1 , inner_dim_indices , True )
450+ c3 = aten .mul (x_hat , c2 )
451+
452+ inner = aten .sub (aten .sub (a , b ), c3 )
453+
454+ if output_mask [0 ]:
455+ d_input = aten .mul (aten .div (rstd , N ), inner )
456+ else :
457+ d_input = None
458+
459+ if output_mask [1 ] and weight is not None :
460+ if len (outer_dim_indices ) > 0 :
461+ d_weight = aten .sum (aten .mul (grad_out , x_hat ), outer_dim_indices , False )
462+ else :
463+ d_weight = aten .mul (grad_out , x_hat )
464+ else :
465+ d_weight = None
466+
467+ if output_mask [2 ] and bias is not None :
468+ if len (outer_dim_indices ) > 0 :
469+ d_bias = aten .sum (grad_out , outer_dim_indices , False )
470+ else :
471+ d_bias = grad_out
472+ else :
473+ d_bias = None
474+ return (d_input , d_weight , d_bias )
475+
476+ # @register_decomposition(aten.addmm)
477+ # def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta=1, alpha=1):
478+ # if not self.is_floating_point():
479+ # beta = int(beta)
480+ # alpha = int(alpha)
481+ # out = alpha * aten.mm(mat1, mat2)
482+ # if beta == 0:
483+ # return out
484+ # return beta * self + out
485+
420486
421487@register_decomposition (aten .clamp_min )
422488def clamp_min (self : Tensor , min : float ):
0 commit comments