diff --git a/src/kernl/implementations/layer_norm.py b/src/kernl/implementations/layer_norm.py index 94ef1d8d..f8c8058b 100644 --- a/src/kernl/implementations/layer_norm.py +++ b/src/kernl/implementations/layer_norm.py @@ -65,7 +65,10 @@ def layer_norm_xformers( Bias, Mean, Rstd, + stride_row_out, + stride_col_out, stride_row_a, + stride_col_a, N, eps, HAS_BIAS: tl.constexpr, # not used, just to make the signature similar to single pass @@ -85,7 +88,7 @@ def layer_norm_xformers( cols = tl.arange(0, BLOCK_SIZE) mask = cols < N - x_ptrs = A + row * stride_row_a + cols + x_ptrs = A + row * stride_row_a + cols * stride_col_a x = tl.load(x_ptrs, mask=mask, other=0.0, eviction_policy="evict_first").to(tl.float32) w = tl.load(Weight + cols, mask=mask, other=1.0) @@ -104,7 +107,7 @@ def layer_norm_xformers( tl.store(Rstd + row, rstd) y = y * w + b - y_ptrs = Out + row * stride_row_a + cols + y_ptrs = Out + row * stride_row_out + cols * stride_col_out tl.store(y_ptrs, y, mask=mask) @@ -116,7 +119,10 @@ def _layer_norm_fwd_fused_single_pass( Bias, Mean, Rstd, + stride_row_out, + stride_col_out, stride_row_a, + stride_col_a, N, eps, HAS_BIAS: tl.constexpr, @@ -145,7 +151,7 @@ def _layer_norm_fwd_fused_single_pass( # position of elements processed by this program _idx = tl.program_id(0) a_ptr = A + _idx * stride_row_a - out_ptr = Out + _idx * stride_row_a + out_ptr = Out + _idx * stride_row_out # compute mean mean = 0.0 var = 0.0 @@ -156,7 +162,9 @@ def _layer_norm_fwd_fused_single_pass( mask = column_offset < N # eviction policy below have little impact now because of new implementation. Kept as is. # float32 is used to avoid overflow because of the square operation - a = tl.load(a_ptr + column_offset, mask=mask, other=0.0, eviction_policy="evict_last").to(tl.float32) + a = tl.load(a_ptr + column_offset * stride_col_a, mask=mask, other=0.0, eviction_policy="evict_last").to( + tl.float32 + ) if IS_RMSNORM: var += tl.sum(a * a, axis=0) else: @@ -184,14 +192,16 @@ def _layer_norm_fwd_fused_single_pass( weight = tl.load(Weight + column_offset, mask=mask) # eviction policy helps to keep weights in cache (reused by other threads) - a = tl.load(a_ptr + column_offset, mask=mask, other=0.0, eviction_policy="evict_first").to(tl.float32) + a = tl.load(a_ptr + column_offset * stride_col_a, mask=mask, other=0.0, eviction_policy="evict_first").to( + tl.float32 + ) a_hat = (a - mean) * rstd out = a_hat * weight if HAS_BIAS: bias = tl.load(Bias + column_offset, mask=mask) out = out + bias # write-back - tl.store(out_ptr + column_offset, out, mask=mask) + tl.store(out_ptr + column_offset * stride_col_out, out, mask=mask) @triton.jit @@ -202,7 +212,10 @@ def _layer_norm_fwd_fused_multi_pass( Bias, Mean, Rstd, + stride_row_out, + stride_col_out, stride_row_a, + stride_col_a, N, eps, IS_RMSNORM: tl.constexpr, # not used, just to have the same signature than the single pass @@ -217,21 +230,21 @@ def _layer_norm_fwd_fused_multi_pass( """ # position of elements processed by this program row = tl.program_id(0) - Out += row * stride_row_a + Out += row * stride_row_out A += row * stride_row_a # compute mean mean = 0 _mean = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) - a = tl.load(A + cols, mask=cols < N, other=0.0, eviction_policy="evict_last").to(tl.float32) + a = tl.load(A + cols * stride_col_a, mask=cols < N, other=0.0, eviction_policy="evict_last").to(tl.float32) _mean += a mean = tl.sum(_mean, axis=0) / N # compute variance _var = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) - a = tl.load(A + cols, mask=cols < N, other=0.0, eviction_policy="evict_last").to(tl.float32) + a = tl.load(A + cols * stride_col_a, mask=cols < N, other=0.0, eviction_policy="evict_last").to(tl.float32) a = tl.where(cols < N, a - mean, 0.0) _var += a * a var = tl.sum(_var, axis=0) / N @@ -245,11 +258,11 @@ def _layer_norm_fwd_fused_multi_pass( mask = cols < N weight = tl.load(Weight + cols, mask=mask) bias = tl.load(Bias + cols, mask=mask) - a = tl.load(A + cols, mask=mask, other=0.0, eviction_policy="evict_first").to(tl.float32) + a = tl.load(A + cols * stride_col_a, mask=mask, other=0.0, eviction_policy="evict_first").to(tl.float32) a_hat = (a - mean) * rstd out = a_hat * weight + bias # write-back - tl.store(Out + cols, out, mask=mask) + tl.store(Out + cols * stride_col_out, out, mask=mask) class LayerNorm(torch.autograd.Function): @@ -294,7 +307,10 @@ def forward( Bias=bias if bias is not None else a_arg, Mean=mean, Rstd=std, + stride_row_out=out.stride(0), + stride_col_out=out.stride(1), stride_row_a=a_arg.stride(0), + stride_col_a=a_arg.stride(1), N=N, eps=eps, HAS_BIAS=bias is not None, diff --git a/test/test_layer_norm.py b/test/test_layer_norm.py index 653078cf..f6d2ad93 100644 --- a/test/test_layer_norm.py +++ b/test/test_layer_norm.py @@ -105,3 +105,19 @@ def test_benchmark_rms_norm(benchmark, shape: int, dtype, cuda_graphs: bool, imp value = benchmark(fn, x) assert_all_close(value.float(), expected, atol=1e-1) + + +@pytest.mark.parametrize("implementation", implementations_layer_norm.keys()) +def test_stride(implementation): + M = N = 250 + eps = 1e-5 + factory_kwargs = {"device": "cuda", "dtype": torch.float32, "requires_grad": False} + layer_weight = torch.rand((N,), **factory_kwargs) + layer_bias = torch.randn_like(layer_weight) + x = -20 + 0.5 * torch.randn((M, N), **factory_kwargs) + x = x.transpose(-1, -2) + + expected = torch.nn.functional.layer_norm(x, layer_weight.shape, layer_weight, layer_bias, eps) + fn = implementations_layer_norm[implementation](layer_weight, layer_bias, eps) + value = fn(x) + assert_all_close(value.float(), expected, atol=1e-1)