Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 67 additions & 1 deletion functorch/_src/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def native_layer_norm(input: Tensor, normalized_shape: List[int], weight: Option
if M > 0:
input_reshaped = input.view(1, M, -1)
else:
return (input, aten.new_empty(input, (0,)), aten.new_empty(input, (0,)))
return (input, aten.new_zeros(input, (0,)), aten.new_zeros(input, (0,)))

# Unlike Batch Normalization, which applies scalar scale and bias for each
# entire channel/plane with the affine option, Layer Normalization applies
Expand Down Expand Up @@ -417,6 +417,72 @@ def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int =
return out
return beta * self + out

@register_decomposition(aten.native_layer_norm_backward)
def native_layer_norm_backward(grad_out: Tensor, input: Tensor, normalized_shape: List[int], mean: Tensor, rstd: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], output_mask: List[bool]) -> Tuple[Tensor, Tensor, Tensor]:
input_shape = input.shape
input_ndim = input.dim()

axis = input_ndim - len(normalized_shape)
inner_dims = input_shape[axis:]
outer_dims = input_shape[:axis]
inner_dim_indices = []
outer_dim_indices = []
for i in range(input_ndim):
if(i >= axis):
inner_dim_indices.append(i)
else:
outer_dim_indices.append(i)

N = float(prod(inner_dims))
M = prod(outer_dims)
if M <= 0 or N <= 0.0:
return (aten.new_empty(input, input_shape), aten.new_zeros(input[axis:], input_shape[axis:]), aten.new_zeros(input[axis:], input_shape[axis:]))

x_hat = aten.mul(aten.sub(input, mean), rstd)
if weight is not None:
grad_x_hat = aten.mul(grad_out, weight)
else:
grad_x_hat = grad_out
a = aten.mul(grad_x_hat, N)
b = aten.sum(grad_x_hat, inner_dim_indices, True)
c1 = aten.mul(grad_x_hat, x_hat)
c2 = aten.sum(c1, inner_dim_indices, True)
c3 = aten.mul(x_hat, c2)

inner = aten.sub(aten.sub(a, b), c3)

if output_mask[0]:
d_input = aten.mul(aten.div(rstd, N), inner)
else:
d_input = None

if output_mask[1] and weight is not None:
if len(outer_dim_indices) > 0:
d_weight = aten.sum(aten.mul(grad_out, x_hat), outer_dim_indices, False)
else:
d_weight = aten.mul(grad_out, x_hat)
else:
d_weight = None

if output_mask[2] and bias is not None:
if len(outer_dim_indices) > 0:
d_bias = aten.sum(grad_out, outer_dim_indices, False)
else:
d_bias = grad_out
else:
d_bias = None
return (d_input, d_weight, d_bias)

# @register_decomposition(aten.addmm)
# def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta=1, alpha=1):
# if not self.is_floating_point():
# beta = int(beta)
# alpha = int(alpha)
# out = alpha * aten.mm(mat1, mat2)
# if beta == 0:
# return out
# return beta * self + out


@register_decomposition(aten.clamp_min)
def clamp_min(self: Tensor, min: float):
Expand Down
8 changes: 6 additions & 2 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,8 @@ def op_assert_equal(op, a, b, arg_string):
# Before adding an entry to this table, make sure your decomposition is right :)
tol_table = {
# Due to strange epsilon behaviors, see https://github.com/pytorch/pytorch/issues/73161
(torch.float32, aten.native_layer_norm.default): (1e-3, 1e-3),
(torch.float32, aten.native_layer_norm): (1e-3, 1e-3),
(torch.float32, aten.native_layer_norm_backward): (1e-3, 1e-3),
}
if (b.dtype, op) in tol_table:
rtol, atol = tol_table[(b.dtype, op)]
Expand Down Expand Up @@ -1230,6 +1231,9 @@ def call_op(func, map_fn, *args, **kwargs):
real_out = call_op(func, unwrap_tensor, *args, **kwargs)
assert(len(real_out) == len(decomp_out))
for orig, decomp, ref in zip(real_out, decomp_out, real_out_double):
if orig is None:
assert(decomp is None)
continue
orig = orig.to(dtype=TEST_DTYPE)
decomp = decomp.to(dtype=TEST_DTYPE)
if DO_RELATIVE_CHECK and ref.dtype.is_floating_point:
Expand Down Expand Up @@ -1308,7 +1312,7 @@ def get_names(inpt):
f.write(f'{op}\n')

def test_decompositions_torchscriptable(self, device):
skip_list = []
skip_list = [torch.ops.aten.native_layer_norm_backward]
for op, decomposition in decomposition_table.items():
if op in skip_list:
continue
Expand Down