|
1 | 1 | import torch |
2 | | -from torch import Tensor |
| 2 | +from torch import Tensor, reciprocal_ |
3 | 3 | from typing import Optional, List, Tuple |
4 | 4 | from enum import Enum |
5 | 5 |
|
@@ -381,6 +381,43 @@ def native_layer_norm(input: Tensor, normalized_shape: List[int], weight: Option |
381 | 381 | return (out, mean, rstd) |
382 | 382 |
|
383 | 383 |
|
| 384 | +@register_decomposition(aten.native_layer_norm) |
| 385 | +def native_layer_norm_backward(grad_out: Tensor, input: Tensor, normalized_shape: List[int], mean: Tensor, rstd: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], output_mask: List[bool]) -> Tuple[Tensor, Tensor, Tensor]: |
| 386 | + input_shape = input.shape |
| 387 | + input_ndim = input.dim() |
| 388 | + |
| 389 | + axis = input_ndim - len(normalized_shape) |
| 390 | + inner_dims = input_shape[axis:] |
| 391 | + outer_dims = input_shape[:axis] |
| 392 | + N = prod(inner_dims) |
| 393 | + |
| 394 | + x_hat = (input - mean) * rstd |
| 395 | + grad_x_hat = [] |
| 396 | + if weight is not None: |
| 397 | + grad_x_hat = grad_out * weight |
| 398 | + else: |
| 399 | + grad_x_hat = grad_out |
| 400 | + a = grad_x_hat * N |
| 401 | + b = aten.sum(grad_x_hat, inner_dims, True) |
| 402 | + c1 = grad_x_hat * x_hat |
| 403 | + c2 = aten.sum(c1, inner_dims, True) |
| 404 | + c3 = x_hat * c2 |
| 405 | + |
| 406 | + inner = (a - b) - c3 |
| 407 | + |
| 408 | + d_weight = [] |
| 409 | + d_bias = [] |
| 410 | + d_input = [] |
| 411 | + |
| 412 | + if output_mask[0]: |
| 413 | + d_input = (rstd / N ) * inner |
| 414 | + if output_mask[1]: |
| 415 | + d_weight = aten.sum((grad_out * x_hat), outer_dims, False) |
| 416 | + if output_mask[2]: |
| 417 | + d_bias = aten.sum(grad_out, outer_dims, False) |
| 418 | + |
| 419 | + return (d_input, d_weight, d_bias) |
| 420 | + |
384 | 421 | # @register_decomposition(aten.addmm) |
385 | 422 | # def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta=1, alpha=1): |
386 | 423 | # if not self.is_floating_point(): |
|
0 commit comments