Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 7d52c15

Browse files
committed
Add decomposition for aten.native_layer_norm_backward op.
Signed-Off-By: Prateek Gupta <prateek@nod-labs.com>
1 parent 0c0f325 commit 7d52c15

File tree

1 file changed

+38
-1
lines changed

1 file changed

+38
-1
lines changed

functorch/_src/decompositions.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from torch import Tensor
2+
from torch import Tensor, reciprocal_
33
from typing import Optional, List, Tuple
44
from enum import Enum
55

@@ -381,6 +381,43 @@ def native_layer_norm(input: Tensor, normalized_shape: List[int], weight: Option
381381
return (out, mean, rstd)
382382

383383

384+
@register_decomposition(aten.native_layer_norm)
385+
def native_layer_norm_backward(grad_out: Tensor, input: Tensor, normalized_shape: List[int], mean: Tensor, rstd: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], output_mask: List[bool]) -> Tuple[Tensor, Tensor, Tensor]:
386+
input_shape = input.shape
387+
input_ndim = input.dim()
388+
389+
axis = input_ndim - len(normalized_shape)
390+
inner_dims = input_shape[axis:]
391+
outer_dims = input_shape[:axis]
392+
N = prod(inner_dims)
393+
394+
x_hat = (input - mean) * rstd
395+
grad_x_hat = []
396+
if weight is not None:
397+
grad_x_hat = grad_out * weight
398+
else:
399+
grad_x_hat = grad_out
400+
a = grad_x_hat * N
401+
b = aten.sum(grad_x_hat, inner_dims, True)
402+
c1 = grad_x_hat * x_hat
403+
c2 = aten.sum(c1, inner_dims, True)
404+
c3 = x_hat * c2
405+
406+
inner = (a - b) - c3
407+
408+
d_weight = []
409+
d_bias = []
410+
d_input = []
411+
412+
if output_mask[0]:
413+
d_input = (rstd / N ) * inner
414+
if output_mask[1]:
415+
d_weight = aten.sum((grad_out * x_hat), outer_dims, False)
416+
if output_mask[2]:
417+
d_bias = aten.sum(grad_out, outer_dims, False)
418+
419+
return (d_input, d_weight, d_bias)
420+
384421
# @register_decomposition(aten.addmm)
385422
# def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta=1, alpha=1):
386423
# if not self.is_floating_point():

0 commit comments

Comments
 (0)