diff --git a/src/flag_gems/ops/layernorm.py b/src/flag_gems/ops/layernorm.py index 2d4a7a77..83fe1aa3 100644 --- a/src/flag_gems/ops/layernorm.py +++ b/src/flag_gems/ops/layernorm.py @@ -6,6 +6,7 @@ import triton.language as tl from ..utils import libentry +from ..utils.type_utils import get_accumulator_dtype @triton.jit @@ -321,8 +322,12 @@ def forward(ctx, x, normalized_shape, weight, bias, eps=1e-5, cudnn_enable=True) weight = weight.contiguous() bias = bias.contiguous() y = torch.empty_like(x) - mean = torch.empty(M, dtype=x.dtype, device=x.device) - rstd = torch.empty(M, dtype=x.dtype, device=x.device) + + # NOTE: when the input is half-precision(either float16 or bfloat16) + # these statistical data saved for backward is in single precision + acc_type = get_accumulator_dtype(x.dtype) + mean = torch.empty(M, dtype=acc_type, device=x.device) + rstd = torch.empty(M, dtype=acc_type, device=x.device) with torch.cuda.device(x.device): if N <= 128: @@ -382,14 +387,17 @@ def backward(ctx, out_grad, mean_grad, rstd_grad): (x, weight, mean, rstd) = ctx.saved_tensors M = ctx.M N = ctx.N - in_grad = torch.empty_like(x) - grid = lambda meta: (triton.cdiv(M, meta["BLOCK_ROW_SIZE"]), 1, 1) - layer_norm_backward_kernel[grid](out_grad, x, weight, mean, rstd, in_grad, M, N) - grid = lambda meta: (triton.cdiv(N, meta["BLOCK_COL_SIZE"]), 1, 1) - weight_grad = torch.empty_like(weight) - bias_grad = torch.empty_like(weight) with torch.cuda.device(x.device): + in_grad = torch.empty_like(x) + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_ROW_SIZE"]), 1, 1) + layer_norm_backward_kernel[grid]( + out_grad, x, weight, mean, rstd, in_grad, M, N + ) + + grid = lambda meta: (triton.cdiv(N, meta["BLOCK_COL_SIZE"]), 1, 1) + weight_grad = torch.empty_like(weight) + bias_grad = torch.empty_like(weight) weight_bias_backward_kernel[grid]( out_grad, x, mean, rstd, weight_grad, bias_grad, M, N ) diff --git a/src/flag_gems/utils/type_utils.py b/src/flag_gems/utils/type_utils.py index 91a86ddd..baff2bd0 100644 --- a/src/flag_gems/utils/type_utils.py +++ b/src/flag_gems/utils/type_utils.py @@ -1,3 +1,4 @@ +import torch from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND, elementwise_dtypes @@ -7,3 +8,14 @@ def type_promotion(*args, type_promotion: ELEMENTWISE_TYPE_PROMOTION_KIND): type_promotion_kind=type_promotion, ) return computation_dtype, result_dtype + + +_accumulator_dtype_map = { + torch.bfloat16: torch.float32, + torch.float16: torch.float32, + torch.complex32: torch.complex64, +} + + +def get_accumulator_dtype(dtype: torch.dtype) -> torch.dtype: + return _accumulator_dtype_map.get(dtype, dtype) diff --git a/tests/test_norm_ops.py b/tests/test_norm_ops.py index 2a3fba8e..89004858 100644 --- a/tests/test_norm_ops.py +++ b/tests/test_norm_ops.py @@ -74,11 +74,19 @@ def test_accuracy_groupnorm(N, C, H, W, num_groups, dtype): gems_assert_close(res_bias_grad, ref_bias_grad, dtype, reduce_dim=N * HW) -# TODO: failed at (1, 2) (2~32, 40499) (200, 2~64) (200~4096, 40999) @pytest.mark.layer_norm @pytest.mark.native_layer_norm @pytest.mark.parametrize( - "shape", [(1, 40999)] if QUICK_MODE else [(1, 40999), (4096, 256), (4096, 100)] + "shape", + [(1, 40999)] + if QUICK_MODE + else [ + (200, 36), + (4096, 100), + (1, 40999), + (100, 40499), + (4096, 256), + ], ) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_layernorm(shape, dtype): @@ -110,8 +118,8 @@ def test_accuracy_layernorm(shape, dtype): ref_mean = torch.mean(ref_inp, dim=1) ref_var = torch.var(ref_inp, dim=1, correction=0) ref_rstd = torch.rsqrt(ref_var + eps) - gems_assert_close(res_mean, ref_mean, dtype) - gems_assert_close(res_rstd, ref_rstd, dtype) + gems_assert_close(res_mean, ref_mean, res_mean.dtype) + gems_assert_close(res_rstd, ref_rstd, res_rstd.dtype) gems_assert_close(res_out, ref_out, dtype) out_grad = torch.randn_like(inp)