Skip to content

Commit db9c48d

Browse files
authored
bugfix: fix the misaligned address bug of norm kernels for certain shapes (#636)
This PR fixes the issue #634, which is brought by #592 . If we want to use 16-bytes vectorized read/write, we need to confirm the address is aligned to 16 bytes. When `num_warps` is not a multiple of 4 (4*sizeof(float) = 16), the address of `smem + num_warps` might not align to 16 bytes. We can fix this by shifting the start offset of vectorized read/write to `smem + ceil_div(num_warps, 4) * 4` to force the alignment. cc @ovowei @Abatom
1 parent ae501ed commit db9c48d

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

include/flashinfer/norm.cuh

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
127127
const uint32_t num_threads = num_warps * warp_size;
128128
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads);
129129
extern __shared__ float smem[];
130+
float* smem_x = smem + ceil_div(num_warps, 4) * 4;
130131

131132
float sum_sq = 0.f;
132133

@@ -151,7 +152,7 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
151152
}
152153
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
153154
residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
154-
x_vec.store(smem + num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
155+
x_vec.store(smem_x + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
155156
}
156157
}
157158

@@ -185,7 +186,7 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
185186
x_vec.fill(0.f);
186187
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
187188
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
188-
x_vec.load(smem + num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
189+
x_vec.load(smem_x + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
189190
}
190191
#pragma unroll
191192
for (uint32_t j = 0; j < VEC_SIZE; j++) {
@@ -247,7 +248,8 @@ cudaError_t GemmaFusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batc
247248
const uint32_t num_warps = ceil_div(block_size, 32);
248249
dim3 nblks(batch_size);
249250
dim3 nthrs(32, num_warps);
250-
const uint32_t smem_size = (num_warps + d) * sizeof(float);
251+
// NOTE(Zihao): use ceil_div(num_warps, 4) * 4 for address alignment to 16 bytes
252+
const uint32_t smem_size = (ceil_div(num_warps, 4) * 4 + d) * sizeof(float);
251253
float weight_bias = 1.f;
252254
void* args[] = {&input, &residual, &weight, &d, &weight_bias, &eps};
253255

tests/test_norm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def fused_add_rms_norm(x, residual, weight, eps):
6565

6666

6767
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
68-
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 4096, 8192])
68+
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192])
6969
@pytest.mark.parametrize("dtype", [torch.float16])
7070
@pytest.mark.parametrize("specify_out", [True, False])
7171
def test_norm(batch_size, hidden_size, dtype, specify_out):
@@ -83,7 +83,7 @@ def test_norm(batch_size, hidden_size, dtype, specify_out):
8383

8484

8585
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
86-
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 4096, 8192])
86+
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192])
8787
@pytest.mark.parametrize("dtype", [torch.float16])
8888
def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):
8989
eps = 1e-6
@@ -105,7 +105,7 @@ def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):
105105

106106

107107
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
108-
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 4096, 8192])
108+
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192])
109109
@pytest.mark.parametrize("dtype", [torch.float16])
110110
@pytest.mark.parametrize("specify_out", [True, False])
111111
def test_gemma_norm(batch_size, hidden_size, dtype, specify_out):
@@ -123,7 +123,7 @@ def test_gemma_norm(batch_size, hidden_size, dtype, specify_out):
123123

124124

125125
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
126-
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 4096, 8192])
126+
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192])
127127
@pytest.mark.parametrize("dtype", [torch.float16])
128128
def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype):
129129
eps = 1e-6

0 commit comments

Comments
 (0)