Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 32b7fc6

Browse files
committed
Add decomposition for aten.native_layer_norm_backward op.
Signed-Off-By: Prateek Gupta <prateek@nod-labs.com>
1 parent 34006db commit 32b7fc6

File tree

2 files changed

+70
-0
lines changed

2 files changed

+70
-0
lines changed

functorch/_src/decompositions.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,72 @@ def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int =
412412
return out
413413
return beta * self + out
414414

415+
@register_decomposition(aten.native_layer_norm_backward)
416+
def native_layer_norm_backward(grad_out: Tensor, input: Tensor, normalized_shape: List[int], mean: Tensor, rstd: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], output_mask: List[bool]) -> Tuple[Tensor, Tensor, Tensor]:
417+
input_shape = input.shape
418+
input_ndim = input.dim()
419+
420+
axis = input_ndim - len(normalized_shape)
421+
inner_dims = input_shape[axis:]
422+
outer_dims = input_shape[:axis]
423+
inner_dim_indices = []
424+
outer_dim_indices = []
425+
for i in range(input_ndim):
426+
if(i >= axis):
427+
inner_dim_indices.append(i)
428+
else:
429+
outer_dim_indices.append(i)
430+
431+
N = float(prod(inner_dims))
432+
M = prod(outer_dims)
433+
if M <= 0 or N <= 0.0:
434+
return (aten.new_empty(input, input_shape), aten.new_empty(input[axis:], input_shape[axis:]), aten.new_empty(input[axis:], input_shape[axis:]))
435+
436+
x_hat = aten.mul(aten.sub(input, mean), rstd)
437+
if weight is not None:
438+
grad_x_hat = aten.mul(grad_out, weight)
439+
else:
440+
grad_x_hat = grad_out
441+
a = aten.mul(grad_x_hat, N)
442+
b = aten.sum(grad_x_hat, inner_dim_indices, True)
443+
c1 = aten.mul(grad_x_hat, x_hat)
444+
c2 = aten.sum(c1, inner_dim_indices, True)
445+
c3 = aten.mul(x_hat, c2)
446+
447+
inner = aten.sub(aten.sub(a, b), c3)
448+
449+
if output_mask[0]:
450+
d_input = aten.mul(aten.div(rstd, N), inner)
451+
else:
452+
d_input = None
453+
454+
if output_mask[1] and weight is not None:
455+
if len(outer_dim_indices) > 0:
456+
d_weight = aten.sum(aten.mul(grad_out, x_hat), outer_dim_indices, False)
457+
else:
458+
d_weight = aten.mul(grad_out, x_hat)
459+
else:
460+
d_weight = None
461+
462+
if output_mask[2] and bias is not None:
463+
if len(outer_dim_indices) > 0:
464+
d_bias = aten.sum(grad_out, outer_dim_indices, False)
465+
else:
466+
d_bias = grad_out
467+
else:
468+
d_bias = None
469+
return (d_input, d_weight, d_bias)
470+
471+
# @register_decomposition(aten.addmm)
472+
# def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta=1, alpha=1):
473+
# if not self.is_floating_point():
474+
# beta = int(beta)
475+
# alpha = int(alpha)
476+
# out = alpha * aten.mm(mat1, mat2)
477+
# if beta == 0:
478+
# return out
479+
# return beta * self + out
480+
415481

416482
@register_decomposition(aten.clamp_min)
417483
def clamp_min(self: Tensor, min: float):

test/test_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,6 +1044,7 @@ def op_assert_equal(op, a, b, arg_string):
10441044
tol_table = {
10451045
# Due to strange epsilon behaviors, see https://github.com/pytorch/pytorch/issues/73161
10461046
(torch.float32, aten.native_layer_norm): (1e-3, 1e-3),
1047+
(torch.float32, aten.native_layer_norm_backward): (1e-3, 1e-3),
10471048
}
10481049
if (b.dtype, op) in tol_table:
10491050
rtol, atol = tol_table[(b.dtype, op)]
@@ -1165,6 +1166,9 @@ def call_op(func, map_fn, *args, **kwargs):
11651166
real_out = call_op(func, unwrap_tensor, *args, **kwargs)
11661167
assert(len(real_out) == len(decomp_out))
11671168
for orig, decomp, ref in zip(real_out, decomp_out, real_out_double):
1169+
if orig is None:
1170+
assert(decomp is None)
1171+
continue
11681172
orig = orig.to(dtype=TEST_DTYPE)
11691173
decomp = decomp.to(dtype=TEST_DTYPE)
11701174
if DO_RELATIVE_CHECK and ref.dtype.is_floating_point:

0 commit comments

Comments
 (0)