@@ -412,6 +412,72 @@ def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int =
412412 return out
413413 return beta * self + out
414414
415+ @register_decomposition (aten .native_layer_norm_backward )
416+ def native_layer_norm_backward (grad_out : Tensor , input : Tensor , normalized_shape : List [int ], mean : Tensor , rstd : Tensor , weight : Optional [Tensor ], bias : Optional [Tensor ], output_mask : List [bool ]) -> Tuple [Tensor , Tensor , Tensor ]:
417+ input_shape = input .shape
418+ input_ndim = input .dim ()
419+
420+ axis = input_ndim - len (normalized_shape )
421+ inner_dims = input_shape [axis :]
422+ outer_dims = input_shape [:axis ]
423+ inner_dim_indices = []
424+ outer_dim_indices = []
425+ for i in range (input_ndim ):
426+ if (i >= axis ):
427+ inner_dim_indices .append (i )
428+ else :
429+ outer_dim_indices .append (i )
430+
431+ N = float (prod (inner_dims ))
432+ M = prod (outer_dims )
433+ if M <= 0 or N <= 0.0 :
434+ return (aten .new_empty (input , input_shape ), aten .new_empty (input [axis :], input_shape [axis :]), aten .new_empty (input [axis :], input_shape [axis :]))
435+
436+ x_hat = aten .mul (aten .sub (input , mean ), rstd )
437+ if weight is not None :
438+ grad_x_hat = aten .mul (grad_out , weight )
439+ else :
440+ grad_x_hat = grad_out
441+ a = aten .mul (grad_x_hat , N )
442+ b = aten .sum (grad_x_hat , inner_dim_indices , True )
443+ c1 = aten .mul (grad_x_hat , x_hat )
444+ c2 = aten .sum (c1 , inner_dim_indices , True )
445+ c3 = aten .mul (x_hat , c2 )
446+
447+ inner = aten .sub (aten .sub (a , b ), c3 )
448+
449+ if output_mask [0 ]:
450+ d_input = aten .mul (aten .div (rstd , N ), inner )
451+ else :
452+ d_input = None
453+
454+ if output_mask [1 ] and weight is not None :
455+ if len (outer_dim_indices ) > 0 :
456+ d_weight = aten .sum (aten .mul (grad_out , x_hat ), outer_dim_indices , False )
457+ else :
458+ d_weight = aten .mul (grad_out , x_hat )
459+ else :
460+ d_weight = None
461+
462+ if output_mask [2 ] and bias is not None :
463+ if len (outer_dim_indices ) > 0 :
464+ d_bias = aten .sum (grad_out , outer_dim_indices , False )
465+ else :
466+ d_bias = grad_out
467+ else :
468+ d_bias = None
469+ return (d_input , d_weight , d_bias )
470+
471+ # @register_decomposition(aten.addmm)
472+ # def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta=1, alpha=1):
473+ # if not self.is_floating_point():
474+ # beta = int(beta)
475+ # alpha = int(alpha)
476+ # out = alpha * aten.mm(mat1, mat2)
477+ # if beta == 0:
478+ # return out
479+ # return beta * self + out
480+
415481
416482@register_decomposition (aten .clamp_min )
417483def clamp_min (self : Tensor , min : float ):
0 commit comments