Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit ff0ae21

Browse files
committed
Add decomposition for aten.native_layer_norm_backward op.
Signed-Off-By: Prateek Gupta <prateek@nod-labs.com>
1 parent 6f3e137 commit ff0ae21

File tree

2 files changed

+73
-3
lines changed

2 files changed

+73
-3
lines changed

functorch/_src/decompositions.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def native_layer_norm(input: Tensor, normalized_shape: List[int], weight: Option
359359
if M > 0:
360360
input_reshaped = input.view(1, M, -1)
361361
else:
362-
return (input, aten.new_empty(input, (0,)), aten.new_empty(input, (0,)))
362+
return (input, aten.new_zeros(input, (0,)), aten.new_zeros(input, (0,)))
363363

364364
# Unlike Batch Normalization, which applies scalar scale and bias for each
365365
# entire channel/plane with the affine option, Layer Normalization applies
@@ -417,6 +417,72 @@ def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int =
417417
return out
418418
return beta * self + out
419419

420+
@register_decomposition(aten.native_layer_norm_backward)
421+
def native_layer_norm_backward(grad_out: Tensor, input: Tensor, normalized_shape: List[int], mean: Tensor, rstd: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], output_mask: List[bool]) -> Tuple[Tensor, Tensor, Tensor]:
422+
input_shape = input.shape
423+
input_ndim = input.dim()
424+
425+
axis = input_ndim - len(normalized_shape)
426+
inner_dims = input_shape[axis:]
427+
outer_dims = input_shape[:axis]
428+
inner_dim_indices = []
429+
outer_dim_indices = []
430+
for i in range(input_ndim):
431+
if(i >= axis):
432+
inner_dim_indices.append(i)
433+
else:
434+
outer_dim_indices.append(i)
435+
436+
N = float(prod(inner_dims))
437+
M = prod(outer_dims)
438+
if M <= 0 or N <= 0.0:
439+
return (aten.new_empty(input, input_shape), aten.new_zeros(input[axis:], input_shape[axis:]), aten.new_zeros(input[axis:], input_shape[axis:]))
440+
441+
x_hat = aten.mul(aten.sub(input, mean), rstd)
442+
if weight is not None:
443+
grad_x_hat = aten.mul(grad_out, weight)
444+
else:
445+
grad_x_hat = grad_out
446+
a = aten.mul(grad_x_hat, N)
447+
b = aten.sum(grad_x_hat, inner_dim_indices, True)
448+
c1 = aten.mul(grad_x_hat, x_hat)
449+
c2 = aten.sum(c1, inner_dim_indices, True)
450+
c3 = aten.mul(x_hat, c2)
451+
452+
inner = aten.sub(aten.sub(a, b), c3)
453+
454+
if output_mask[0]:
455+
d_input = aten.mul(aten.div(rstd, N), inner)
456+
else:
457+
d_input = None
458+
459+
if output_mask[1] and weight is not None:
460+
if len(outer_dim_indices) > 0:
461+
d_weight = aten.sum(aten.mul(grad_out, x_hat), outer_dim_indices, False)
462+
else:
463+
d_weight = aten.mul(grad_out, x_hat)
464+
else:
465+
d_weight = None
466+
467+
if output_mask[2] and bias is not None:
468+
if len(outer_dim_indices) > 0:
469+
d_bias = aten.sum(grad_out, outer_dim_indices, False)
470+
else:
471+
d_bias = grad_out
472+
else:
473+
d_bias = None
474+
return (d_input, d_weight, d_bias)
475+
476+
# @register_decomposition(aten.addmm)
477+
# def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta=1, alpha=1):
478+
# if not self.is_floating_point():
479+
# beta = int(beta)
480+
# alpha = int(alpha)
481+
# out = alpha * aten.mm(mat1, mat2)
482+
# if beta == 0:
483+
# return out
484+
# return beta * self + out
485+
420486

421487
@register_decomposition(aten.clamp_min)
422488
def clamp_min(self: Tensor, min: float):

test/test_ops.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,7 +1108,8 @@ def op_assert_equal(op, a, b, arg_string):
11081108
# Before adding an entry to this table, make sure your decomposition is right :)
11091109
tol_table = {
11101110
# Due to strange epsilon behaviors, see https://github.com/pytorch/pytorch/issues/73161
1111-
(torch.float32, aten.native_layer_norm.default): (1e-3, 1e-3),
1111+
(torch.float32, aten.native_layer_norm): (1e-3, 1e-3),
1112+
(torch.float32, aten.native_layer_norm_backward): (1e-3, 1e-3),
11121113
}
11131114
if (b.dtype, op) in tol_table:
11141115
rtol, atol = tol_table[(b.dtype, op)]
@@ -1230,6 +1231,9 @@ def call_op(func, map_fn, *args, **kwargs):
12301231
real_out = call_op(func, unwrap_tensor, *args, **kwargs)
12311232
assert(len(real_out) == len(decomp_out))
12321233
for orig, decomp, ref in zip(real_out, decomp_out, real_out_double):
1234+
if orig is None:
1235+
assert(decomp is None)
1236+
continue
12331237
orig = orig.to(dtype=TEST_DTYPE)
12341238
decomp = decomp.to(dtype=TEST_DTYPE)
12351239
if DO_RELATIVE_CHECK and ref.dtype.is_floating_point:
@@ -1308,7 +1312,7 @@ def get_names(inpt):
13081312
f.write(f'{op}\n')
13091313

13101314
def test_decompositions_torchscriptable(self, device):
1311-
skip_list = []
1315+
skip_list = [torch.ops.aten.native_layer_norm_backward]
13121316
for op, decomposition in decomposition_table.items():
13131317
if op in skip_list:
13141318
continue

0 commit comments

Comments
 (0)