Skip to content

Commit

Permalink
save mean and rstd in float32 when the inputs are in half-precision f…
Browse files Browse the repository at this point in the history
…p types (#221)

* fix layer_norm_backward: save mean and rstd in float32 when the inputs are in half-precision floating point dtypes to avoid numerical instability or errors
* add a util get_accumulator_type
  • Loading branch information
iclementine authored Sep 24, 2024
1 parent 674a978 commit fc32e19
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 12 deletions.
24 changes: 16 additions & 8 deletions src/flag_gems/ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import triton.language as tl

from ..utils import libentry
from ..utils.type_utils import get_accumulator_dtype


@triton.jit
Expand Down Expand Up @@ -321,8 +322,12 @@ def forward(ctx, x, normalized_shape, weight, bias, eps=1e-5, cudnn_enable=True)
weight = weight.contiguous()
bias = bias.contiguous()
y = torch.empty_like(x)
mean = torch.empty(M, dtype=x.dtype, device=x.device)
rstd = torch.empty(M, dtype=x.dtype, device=x.device)

# NOTE: when the input is half-precision(either float16 or bfloat16)
# these statistical data saved for backward is in single precision
acc_type = get_accumulator_dtype(x.dtype)
mean = torch.empty(M, dtype=acc_type, device=x.device)
rstd = torch.empty(M, dtype=acc_type, device=x.device)

with torch.cuda.device(x.device):
if N <= 128:
Expand Down Expand Up @@ -382,14 +387,17 @@ def backward(ctx, out_grad, mean_grad, rstd_grad):
(x, weight, mean, rstd) = ctx.saved_tensors
M = ctx.M
N = ctx.N
in_grad = torch.empty_like(x)
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_ROW_SIZE"]), 1, 1)
layer_norm_backward_kernel[grid](out_grad, x, weight, mean, rstd, in_grad, M, N)
grid = lambda meta: (triton.cdiv(N, meta["BLOCK_COL_SIZE"]), 1, 1)
weight_grad = torch.empty_like(weight)
bias_grad = torch.empty_like(weight)

with torch.cuda.device(x.device):
in_grad = torch.empty_like(x)
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_ROW_SIZE"]), 1, 1)
layer_norm_backward_kernel[grid](
out_grad, x, weight, mean, rstd, in_grad, M, N
)

grid = lambda meta: (triton.cdiv(N, meta["BLOCK_COL_SIZE"]), 1, 1)
weight_grad = torch.empty_like(weight)
bias_grad = torch.empty_like(weight)
weight_bias_backward_kernel[grid](
out_grad, x, mean, rstd, weight_grad, bias_grad, M, N
)
Expand Down
12 changes: 12 additions & 0 deletions src/flag_gems/utils/type_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND, elementwise_dtypes


Expand All @@ -7,3 +8,14 @@ def type_promotion(*args, type_promotion: ELEMENTWISE_TYPE_PROMOTION_KIND):
type_promotion_kind=type_promotion,
)
return computation_dtype, result_dtype


_accumulator_dtype_map = {
torch.bfloat16: torch.float32,
torch.float16: torch.float32,
torch.complex32: torch.complex64,
}


def get_accumulator_dtype(dtype: torch.dtype) -> torch.dtype:
return _accumulator_dtype_map.get(dtype, dtype)
16 changes: 12 additions & 4 deletions tests/test_norm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,19 @@ def test_accuracy_groupnorm(N, C, H, W, num_groups, dtype):
gems_assert_close(res_bias_grad, ref_bias_grad, dtype, reduce_dim=N * HW)


# TODO: failed at (1, 2) (2~32, 40499) (200, 2~64) (200~4096, 40999)
@pytest.mark.layer_norm
@pytest.mark.native_layer_norm
@pytest.mark.parametrize(
"shape", [(1, 40999)] if QUICK_MODE else [(1, 40999), (4096, 256), (4096, 100)]
"shape",
[(1, 40999)]
if QUICK_MODE
else [
(200, 36),
(4096, 100),
(1, 40999),
(100, 40499),
(4096, 256),
],
)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_layernorm(shape, dtype):
Expand Down Expand Up @@ -110,8 +118,8 @@ def test_accuracy_layernorm(shape, dtype):
ref_mean = torch.mean(ref_inp, dim=1)
ref_var = torch.var(ref_inp, dim=1, correction=0)
ref_rstd = torch.rsqrt(ref_var + eps)
gems_assert_close(res_mean, ref_mean, dtype)
gems_assert_close(res_rstd, ref_rstd, dtype)
gems_assert_close(res_mean, ref_mean, res_mean.dtype)
gems_assert_close(res_rstd, ref_rstd, res_rstd.dtype)
gems_assert_close(res_out, ref_out, dtype)

out_grad = torch.randn_like(inp)
Expand Down

0 comments on commit fc32e19

Please sign in to comment.